57 static constexpr uint32_t
MAGIC = 0x51565244;
60 static constexpr size_t INVALID_ID =
static_cast<size_t>(-1);
63 size_t max_elements = 100000,
size_t M = 16,
size_t ef_construction = 200, uint32_t seed = 42)
65 max_elements_(max_elements), M_(M), M_max_(M),
66 M_max0_(M * 2), ef_construction_(std::max(ef_construction, M)), ef_search_(50),
67 mult_(M > 1 ? 1.0 / std::log(static_cast<double>(M)) : 1.0), level_gen_(seed) {
68 if (
dimension == 0)
throw std::invalid_argument(
"Dimension must be > 0");
69 if (max_elements == 0)
throw std::invalid_argument(
"max_elements must be > 0");
70 if (M < 2)
throw std::invalid_argument(
"M must be >= 2");
71 if (max_elements > SIZE_MAX / dim_)
throw std::invalid_argument(
"max_elements * dimension overflow");
72 vectors_.resize(max_elements * dim_);
73 ext_ids_.resize(max_elements);
74 levels_.resize(max_elements, 0);
75 neighbors_.resize(max_elements);
82 void add(uint64_t
id,
const float* vec) {
83 if (!vec)
throw std::invalid_argument(
"Vector must not be null");
84 std::unique_lock glock(global_mtx_);
85 if (id_map_.count(
id))
throw std::invalid_argument(
"ID " + std::to_string(
id) +
" exists");
86 if (count_ >= max_elements_)
throw std::runtime_error(
"Index full");
88 size_t iid = count_++;
91 std::copy_n(vec, dim_, vectors_.begin() + iid * dim_);
93 int level = get_level();
95 neighbors_[iid].resize(level + 1);
96 for (
int l = 0; l <= level; ++l)
97 neighbors_[iid][l].reserve(l == 0 ? M_max0_ : M_max_);
99 if (ep_.load() ==
INVALID_ID) { ep_.store(iid); max_level_.store(level);
return; }
101 size_t curr = ep_.load();
102 int cur_max_level = max_level_.load();
103 if (level < cur_max_level) {
104 float d = dist_(vec, get_vec(curr));
105 for (
int l = cur_max_level; l > level; --l) {
109 for (
size_t n : neighbors_[curr][l]) {
110 float nd = dist_(vec, get_vec(n));
111 if (nd < d) { d = nd; curr = n; changed =
true; }
117 for (
int l = std::min(level, cur_max_level); l >= 0; --l) {
118 auto top = search_layer(vec, curr, ef_construction_, l);
119 auto sel = select_neighbors(top, M_, l);
120 neighbors_[iid][l] = std::move(sel);
122 size_t max_conn = l == 0 ? M_max0_ : M_max_;
123 for (
size_t nid : neighbors_[iid][l]) {
124 auto& nc = neighbors_[nid][l];
125 if (nc.size() < max_conn) { nc.push_back(iid); }
127 float d2new = dist_(get_vec(nid), vec);
128 std::vector<std::pair<float, size_t>> cands;
129 cands.reserve(nc.size() + 1);
130 for (
size_t c : nc) cands.emplace_back(dist_(get_vec(nid), get_vec(c)), c);
131 cands.emplace_back(d2new, iid);
132 std::sort(cands.begin(), cands.end());
134 for (
size_t i = 0; i < max_conn && i < cands.size(); ++i) nc.push_back(cands[i].second);
139 std::pair<float, size_t> best = top.top();
140 while (!top.empty()) {
141 if (top.top().first < best.first) best = top.top();
147 if (level > cur_max_level) { ep_.store(iid); max_level_.store(level); }
150 std::vector<HNSWSearchResult>
search(
const float* query,
size_t k)
const {
151 if (!query)
throw std::invalid_argument(
"Query must not be null");
152 if (k == 0)
throw std::invalid_argument(
"k must be > 0");
153 std::shared_lock glock(global_mtx_);
154 if (count_ == 0)
return {};
156 size_t curr = ep_.load();
157 float d = dist_(query, get_vec(curr));
158 for (
int l = max_level_.load(); l > 0; --l) {
162 if (
static_cast<int>(neighbors_[curr].
size()) <= l)
continue;
163 for (
size_t n : neighbors_[curr][l]) {
164 float nd = dist_(query, get_vec(n));
165 if (nd < d) { d = nd; curr = n; changed =
true; }
170 auto top = search_layer(query, curr, std::max(ef_search_.load(std::memory_order_relaxed), k), 0);
171 std::vector<std::pair<float, size_t>> temp;
172 while (!top.empty()) { temp.push_back(top.top()); top.pop(); }
173 std::sort(temp.begin(), temp.end());
175 std::vector<HNSWSearchResult> res;
176 res.reserve(std::min(k, temp.size()));
177 for (
size_t i = 0; i < k && i < temp.size(); ++i)
178 res.push_back({ext_ids_[temp[i].second], temp[i].first});
183 if (ef == 0)
throw std::invalid_argument(
"ef_search must be > 0");
184 ef_search_.store(ef, std::memory_order_relaxed);
186 size_t get_ef_search()
const {
return ef_search_.load(std::memory_order_relaxed); }
187 size_t size()
const { std::shared_lock lk(global_mtx_);
return count_; }
190 bool contains(uint64_t
id)
const { std::shared_lock lk(global_mtx_);
return id_map_.count(
id); }
193 std::shared_lock lk(global_mtx_);
194 auto it = id_map_.find(
id);
195 if (it == id_map_.end())
throw std::runtime_error(
"ID not found: " + std::to_string(
id));
196 const float* p = vectors_.data() + it->second * dim_;
197 return std::vector<float>(p, p + dim_);
200 void save(
const std::string& filename)
const {
201 std::shared_lock glock(global_mtx_);
202 std::string tmp = filename +
".tmp";
203 std::ofstream f(tmp, std::ios::binary);
204 if (!f)
throw std::runtime_error(
"Cannot open: " + tmp);
224 for (
size_t i = 0; i < neighbors_.size(); ++i) {
226 for (
size_t l = 0; l < neighbors_[i].size(); ++l)
detail::write_vec(f, neighbors_[i][l]);
229 std::stringstream rng_ss;
230 rng_ss << level_gen_;
231 std::string rng_state = rng_ss.str();
233 f.write(rng_state.data(), rng_state.size());
235 if (!f) { std::filesystem::remove(tmp);
throw std::runtime_error(
"Write failed: " + tmp); }
238 std::filesystem::rename(tmp, filename);
239 }
catch (...) { f.close(); std::filesystem::remove(tmp);
throw; }
242 static std::unique_ptr<HNSWIndex>
load(
const std::string& filename) {
243 std::ifstream f(filename, std::ios::binary);
244 if (!f)
throw std::runtime_error(
"Cannot open: " + filename);
247 if (magic !=
MAGIC)
throw std::runtime_error(
"Invalid magic");
249 if (ver !=
VERSION && ver != 1)
throw std::runtime_error(
"Unsupported version");
251 size_t dim, max_el, M, ef_con, ef_s; uint32_t met;
double mult;
254 if (met > 2)
throw std::runtime_error(
"Corrupted file: invalid metric");
261 auto idx = std::make_unique<HNSWIndex>(dim,
static_cast<DistanceMetric>(met), max_el, M, ef_con);
262 idx->ef_search_.store(ef_s);
268 if (cnt > max_el)
throw std::runtime_error(
"Corrupted file: count exceeds max_elements");
269 idx->count_.store(cnt);
274 if (ep_val >= cnt)
throw std::runtime_error(
"Corrupted file: invalid entry point");
275 if (max_level_val < 0 || max_level_val >
MAX_LEVEL)
276 throw std::runtime_error(
"Corrupted file: invalid max_level");
280 throw std::runtime_error(
"Corrupted file: non-empty entry point for empty index");
282 idx->ep_.store(ep_val);
283 idx->max_level_.store(max_level_val);
290 if (msz > cnt)
throw std::runtime_error(
"Corrupted file: id_map size exceeds count");
291 idx->id_map_.reserve(msz);
292 for (
size_t i = 0; i < msz; ++i) {
293 uint64_t k;
size_t v;
296 if (v >= cnt)
throw std::runtime_error(
"Corrupted file: invalid internal index in id_map");
302 if (nsz > max_el)
throw std::runtime_error(
"Corrupted file: neighbors size exceeds max_elements");
303 idx->neighbors_.resize(nsz);
304 for (
size_t i = 0; i < nsz; ++i) {
307 if (lsz >
static_cast<size_t>(
MAX_LEVEL) + 1)
throw std::runtime_error(
"Corrupted file: too many levels");
308 idx->neighbors_[i].resize(lsz);
309 for (
size_t l = 0; l < lsz; ++l) {
312 for (
size_t nid : idx->neighbors_[i][l]) {
313 if (nid >= cnt)
throw std::runtime_error(
"Corrupted file: invalid neighbor index");
320 size_t rng_state_size;
323 throw std::runtime_error(
"Corrupted file: RNG state too large");
324 std::string rng_state(rng_state_size,
'\0');
325 if (!f.read(rng_state.data(), rng_state_size))
326 throw std::runtime_error(
"Unexpected end of file or read error");
327 std::stringstream rng_ss(rng_state);
328 rng_ss >> idx->level_gen_;
329 if (rng_ss.fail())
throw std::runtime_error(
"Corrupted file: invalid RNG state");
336 static constexpr double MIN_LEVEL_RANDOM = 1e-9;
337 using MaxHeap = std::priority_queue<std::pair<float, size_t>>;
340 std::uniform_real_distribution<double> d(0.0, 1.0);
341 double r = std::max(d(level_gen_), MIN_LEVEL_RANDOM);
342 int level =
static_cast<int>(-std::log(r) * mult_);
346 const float* get_vec(
size_t iid)
const {
return vectors_.data() + iid * dim_; }
348 MaxHeap search_layer(
const float* q,
size_t ep,
size_t ef,
int level)
const {
358 static thread_local std::vector<uint16_t> vis;
359 static thread_local uint16_t vis_epoch = 0;
360 const size_t total = count_.load(std::memory_order_relaxed);
364 if (ep >= total) [[unlikely]]
365 throw std::logic_error(
"HNSWIndex::search_layer: entry point out of range");
366 if (vis.size() < total) vis.resize(total, 0);
367 if (++vis_epoch == 0) {
368 std::fill(vis.begin(), vis.end(), 0);
372 std::priority_queue<std::pair<float, size_t>, std::vector<std::pair<float, size_t>>,
373 std::greater<std::pair<float, size_t>>> cands;
375 float d = dist_(q, get_vec(ep));
376 cands.emplace(d, ep);
380 while (!cands.empty()) {
381 auto [cd, cid] = cands.top();
382 if (cd > lb && res.size() >= ef)
break;
384 if (
static_cast<int>(neighbors_[cid].
size()) <= level)
continue;
385 for (
size_t n : neighbors_[cid][level]) {
386 if (vis[n] == vis_epoch)
continue;
388 float nd = dist_(q, get_vec(n));
389 if (res.size() < ef || nd < lb) {
390 cands.emplace(nd, n);
392 if (res.size() > ef) res.pop();
393 if (!res.empty()) lb = res.top().first;
400 std::vector<size_t> select_neighbors(MaxHeap& cands,
size_t M,
int )
const {
401 if (cands.size() <= M) {
402 std::vector<size_t> r;
403 r.reserve(cands.size());
404 while (!cands.empty()) { r.push_back(cands.top().second); cands.pop(); }
407 std::vector<std::pair<float, size_t>> sorted;
408 sorted.reserve(cands.size());
409 while (!cands.empty()) { sorted.push_back(cands.top()); cands.pop(); }
410 std::sort(sorted.begin(), sorted.end());
412 std::vector<size_t> r;
414 for (
auto& [dq, cid] : sorted) {
415 if (r.size() >= M)
break;
418 if (dist_(get_vec(cid), get_vec(s)) < dq) { ok =
false;
break; }
419 if (ok) r.push_back(cid);
422 for (
auto& p : sorted) {
423 if (r.size() >= M)
break;
424 if (std::find(r.begin(), r.end(), p.second) == r.end()) r.push_back(p.second);
432 DistanceComputer dist_;
433 size_t max_elements_, M_, M_max_, M_max0_, ef_construction_;
434 std::atomic<size_t> ef_search_;
436 std::mt19937 level_gen_;
437 std::vector<float> vectors_;
438 std::vector<uint64_t> ext_ids_;
439 std::unordered_map<uint64_t, size_t> id_map_;
440 std::vector<int> levels_;
441 std::vector<std::vector<std::vector<size_t>>> neighbors_;
443 std::atomic<int> max_level_{-1};
444 std::atomic<size_t> count_{0};
445 mutable std::shared_mutex global_mtx_;