44 file_handle_ = CreateFileA(filename.c_str(), GENERIC_READ, FILE_SHARE_READ,
45 nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL,
nullptr);
46 if (file_handle_ == INVALID_HANDLE_VALUE)
throw std::runtime_error(
"Cannot open: " + filename);
48 if (!GetFileSizeEx(file_handle_, &sz)) { CloseHandle(file_handle_); file_handle_ = INVALID_HANDLE_VALUE;
49 throw std::runtime_error(
"Cannot get file size"); }
50 file_size_ =
static_cast<size_t>(sz.QuadPart);
51 if (file_size_ <
HEADER_SIZE) { CloseHandle(file_handle_); file_handle_ = INVALID_HANDLE_VALUE;
52 throw std::runtime_error(
"File too small"); }
53 mapping_handle_ = CreateFileMappingA(file_handle_,
nullptr, PAGE_READONLY, 0, 0,
nullptr);
54 if (!mapping_handle_) { CloseHandle(file_handle_); file_handle_ = INVALID_HANDLE_VALUE;
55 throw std::runtime_error(
"Cannot create mapping"); }
56 mapped_ = MapViewOfFile(mapping_handle_, FILE_MAP_READ, 0, 0, 0);
57 if (!mapped_) { CloseHandle(mapping_handle_); CloseHandle(file_handle_);
58 mapping_handle_ =
nullptr; file_handle_ = INVALID_HANDLE_VALUE;
59 throw std::runtime_error(
"Cannot map file"); }
61 fd_ = open(filename.c_str(), O_RDONLY);
62 if (fd_ < 0)
throw std::runtime_error(
"Cannot open: " + filename);
64 if (fstat(fd_, &sb) < 0) { close(fd_); fd_ = -1;
throw std::runtime_error(
"Cannot stat file"); }
65 file_size_ =
static_cast<size_t>(sb.st_size);
66 if (file_size_ <
HEADER_SIZE) { close(fd_); fd_ = -1;
throw std::runtime_error(
"File too small"); }
67 mapped_ = mmap(
nullptr, file_size_, PROT_READ, MAP_PRIVATE, fd_, 0);
68 if (mapped_ == MAP_FAILED) { close(fd_); fd_ = -1; mapped_ =
nullptr;
69 throw std::runtime_error(
"Cannot mmap file"); }
71 const uint8_t* p =
static_cast<const uint8_t*
>(mapped_);
72 uint32_t magic; std::memcpy(&magic, p, 4);
73 if (magic !=
MAGIC) { cleanup();
throw std::runtime_error(
"Invalid magic"); }
75 uint32_t ver; std::memcpy(&ver, p, 4);
76 if (ver !=
VERSION) { cleanup();
throw std::runtime_error(
"Unsupported version"); }
78 std::memcpy(&dim_, p, 8); p += 8;
79 std::memcpy(&num_vectors_, p, 8); p += 8;
80 uint32_t met; std::memcpy(&met, p, 4);
81 if (met > 2) { cleanup();
throw std::runtime_error(
"Invalid metric"); }
85 if (num_vectors_ > SIZE_MAX /
sizeof(uint64_t)) {
86 cleanup();
throw std::runtime_error(
"File corrupted: size overflow");
88 if (dim_ == 0 && num_vectors_ > 0) {
89 cleanup();
throw std::runtime_error(
"File corrupted: zero dimension with vectors");
91 if (dim_ > SIZE_MAX /
sizeof(
float)) {
92 cleanup();
throw std::runtime_error(
"File corrupted: size overflow");
94 size_t vec_bytes_per = dim_ *
sizeof(float);
95 if (vec_bytes_per > 0 && num_vectors_ > SIZE_MAX / vec_bytes_per) {
96 cleanup();
throw std::runtime_error(
"File corrupted: size overflow");
98 size_t ids_size = num_vectors_ *
sizeof(uint64_t);
99 size_t vecs_size = num_vectors_ * vec_bytes_per;
101 cleanup();
throw std::runtime_error(
"File corrupted: size overflow");
103 size_t expected =
HEADER_SIZE + ids_size + vecs_size;
104 if (file_size_ < expected) { cleanup();
throw std::runtime_error(
"File truncated"); }
106 ids_ptr_ =
reinterpret_cast<const uint64_t*
>(
static_cast<const uint8_t*
>(mapped_) +
HEADER_SIZE);
107 vectors_ptr_ =
reinterpret_cast<const float*
>(
108 static_cast<const uint8_t*
>(mapped_) +
HEADER_SIZE + num_vectors_ *
sizeof(uint64_t));
111 id_map_.reserve(num_vectors_);
112 for (
size_t i = 0; i < num_vectors_; ++i) id_map_[ids_ptr_[i]] = i;
113 }
catch (...) { cleanup();
throw; }
122 file_handle_(o.file_handle_), mapping_handle_(o.mapping_handle_),
126 mapped_(o.mapped_), file_size_(o.file_size_), dim_(o.dim_), num_vectors_(o.num_vectors_),
127 metric_(o.metric_), dist_(o.dist_), ids_ptr_(o.ids_ptr_), vectors_ptr_(o.vectors_ptr_), id_map_(std::move(o.id_map_)) {
129 o.file_handle_ = INVALID_HANDLE_VALUE; o.mapping_handle_ =
nullptr;
140 file_handle_ = o.file_handle_; mapping_handle_ = o.mapping_handle_;
141 o.file_handle_ = INVALID_HANDLE_VALUE; o.mapping_handle_ =
nullptr;
143 fd_ = o.fd_; o.fd_ = -1;
145 mapped_ = o.mapped_; file_size_ = o.file_size_; dim_ = o.dim_; num_vectors_ = o.num_vectors_;
146 metric_ = o.metric_; dist_ = o.dist_; ids_ptr_ = o.ids_ptr_; vectors_ptr_ = o.vectors_ptr_;
147 id_map_ = std::move(o.id_map_); o.mapped_ =
nullptr;
159 std::vector<SearchResult>
search(
const float* query,
size_t k)
const {
160 if (!query)
throw std::invalid_argument(
"Query must not be null");
161 if (k == 0)
throw std::invalid_argument(
"k must be > 0");
162 std::vector<SearchResult> res;
163 res.reserve(num_vectors_);
164 for (
size_t i = 0; i < num_vectors_; ++i)
165 res.push_back({ids_ptr_[i], dist_(query, vectors_ptr_ + i * dim_)});
166 size_t n = std::min(k, res.size());
167 std::partial_sort(res.begin(), res.begin() + n, res.end());