VaneDB 0.1.0
Embeddable vector database for edge AI
Loading...
Searching...
No Matches
hnsw_index.h
Go to the documentation of this file.
1// VaneDB - Copyright (c) 2025 Anton Tsvetkov - MIT License
2#pragma once
3#include "distance_strategy.h"
4#include "detail/file_utils.h"
5#include <algorithm>
6#include <atomic>
7#include <cmath>
8#include <cstddef>
9#include <cstdint>
10#include <filesystem>
11#include <fstream>
12#include <limits>
13#include <mutex>
14#include <queue>
15#include <random>
16#include <shared_mutex>
17#include <sstream>
18#include <stdexcept>
19#include <string>
20#include <unordered_map>
21#include <vector>
22namespace vanedb {
23
24namespace detail {
25template <typename T> void write_bin(std::ofstream& f, const T& v) {
26 f.write(reinterpret_cast<const char*>(&v), sizeof(T));
27}
28template <typename T> void read_bin(std::ifstream& f, T& v) {
29 if (!f.read(reinterpret_cast<char*>(&v), sizeof(T)))
30 throw std::runtime_error("Unexpected end of file or read error");
31}
32template <typename T> void write_vec(std::ofstream& f, const std::vector<T>& v) {
33 write_bin(f, v.size());
34 if (!v.empty()) f.write(reinterpret_cast<const char*>(v.data()), v.size() * sizeof(T));
35}
36constexpr size_t MAX_VEC_SIZE = 100000000ULL;
37constexpr size_t MAX_RNG_STATE_SIZE = 10000; // Reasonable upper bound for serialized RNG state
38template <typename T> void read_vec(std::ifstream& f, std::vector<T>& v) {
39 size_t sz; read_bin(f, sz);
40 if (sz > MAX_VEC_SIZE || sz > SIZE_MAX / sizeof(T))
41 throw std::runtime_error("Corrupted file: vector too large");
42 v.resize(sz);
43 if (!v.empty() && !f.read(reinterpret_cast<char*>(v.data()), sz * sizeof(T)))
44 throw std::runtime_error("Unexpected end of file or read error");
45}
46} // namespace detail
47
50 float distance;
51 bool operator<(const HNSWSearchResult& o) const { return distance < o.distance; }
52 bool operator>(const HNSWSearchResult& o) const { return distance > o.distance; }
53};
54
55class HNSWIndex {
56public:
57 static constexpr uint32_t MAGIC = 0x51565244; // "QVRD" (legacy QuiverDB magic, retained for on-disk compat)
58 static constexpr uint32_t VERSION = 2; // v2: added RNG state serialization
59 static constexpr int MAX_LEVEL = 32; // Reasonable upper bound for HNSW levels
60 static constexpr size_t INVALID_ID = static_cast<size_t>(-1); // Sentinel for empty entry point
61
63 size_t max_elements = 100000, size_t M = 16, size_t ef_construction = 200, uint32_t seed = 42)
64 : dim_(dimension), metric_(metric), dist_(metric, dimension),
65 max_elements_(max_elements), M_(M), M_max_(M),
66 M_max0_(M * 2), ef_construction_(std::max(ef_construction, M)), ef_search_(50),
67 mult_(M > 1 ? 1.0 / std::log(static_cast<double>(M)) : 1.0), level_gen_(seed) {
68 if (dimension == 0) throw std::invalid_argument("Dimension must be > 0");
69 if (max_elements == 0) throw std::invalid_argument("max_elements must be > 0");
70 if (M < 2) throw std::invalid_argument("M must be >= 2");
71 if (max_elements > SIZE_MAX / dim_) throw std::invalid_argument("max_elements * dimension overflow");
72 vectors_.resize(max_elements * dim_);
73 ext_ids_.resize(max_elements);
74 levels_.resize(max_elements, 0);
75 neighbors_.resize(max_elements);
76 }
77
78 // Thread-safety: global_mtx_ is the single sync point. add() holds it
79 // exclusive; readers (search/size/contains/get_vector/save) hold it shared.
80 // No per-node locks are needed because add() can never run concurrently
81 // with any reader.
82 void add(uint64_t id, const float* vec) {
83 if (!vec) throw std::invalid_argument("Vector must not be null");
84 std::unique_lock glock(global_mtx_); // Exclusive: only one add() at a time
85 if (id_map_.count(id)) throw std::invalid_argument("ID " + std::to_string(id) + " exists");
86 if (count_ >= max_elements_) throw std::runtime_error("Index full");
87
88 size_t iid = count_++;
89 id_map_[id] = iid;
90 ext_ids_[iid] = id;
91 std::copy_n(vec, dim_, vectors_.begin() + iid * dim_);
92
93 int level = get_level();
94 levels_[iid] = level;
95 neighbors_[iid].resize(level + 1);
96 for (int l = 0; l <= level; ++l)
97 neighbors_[iid][l].reserve(l == 0 ? M_max0_ : M_max_);
98
99 if (ep_.load() == INVALID_ID) { ep_.store(iid); max_level_.store(level); return; }
100
101 size_t curr = ep_.load();
102 int cur_max_level = max_level_.load();
103 if (level < cur_max_level) {
104 float d = dist_(vec, get_vec(curr));
105 for (int l = cur_max_level; l > level; --l) {
106 bool changed = true;
107 while (changed) {
108 changed = false;
109 for (size_t n : neighbors_[curr][l]) {
110 float nd = dist_(vec, get_vec(n));
111 if (nd < d) { d = nd; curr = n; changed = true; }
112 }
113 }
114 }
115 }
116
117 for (int l = std::min(level, cur_max_level); l >= 0; --l) {
118 auto top = search_layer(vec, curr, ef_construction_, l);
119 auto sel = select_neighbors(top, M_, l);
120 neighbors_[iid][l] = std::move(sel);
121
122 size_t max_conn = l == 0 ? M_max0_ : M_max_;
123 for (size_t nid : neighbors_[iid][l]) {
124 auto& nc = neighbors_[nid][l];
125 if (nc.size() < max_conn) { nc.push_back(iid); }
126 else {
127 float d2new = dist_(get_vec(nid), vec);
128 std::vector<std::pair<float, size_t>> cands;
129 cands.reserve(nc.size() + 1);
130 for (size_t c : nc) cands.emplace_back(dist_(get_vec(nid), get_vec(c)), c);
131 cands.emplace_back(d2new, iid);
132 std::sort(cands.begin(), cands.end());
133 nc.clear();
134 for (size_t i = 0; i < max_conn && i < cands.size(); ++i) nc.push_back(cands[i].second);
135 }
136 }
137 // Use closest candidate (min distance) for next layer entry point
138 if (!top.empty()) {
139 std::pair<float, size_t> best = top.top();
140 while (!top.empty()) {
141 if (top.top().first < best.first) best = top.top();
142 top.pop();
143 }
144 curr = best.second;
145 }
146 }
147 if (level > cur_max_level) { ep_.store(iid); max_level_.store(level); }
148 }
149
150 std::vector<HNSWSearchResult> search(const float* query, size_t k) const {
151 if (!query) throw std::invalid_argument("Query must not be null");
152 if (k == 0) throw std::invalid_argument("k must be > 0");
153 std::shared_lock glock(global_mtx_);
154 if (count_ == 0) return {};
155
156 size_t curr = ep_.load();
157 float d = dist_(query, get_vec(curr));
158 for (int l = max_level_.load(); l > 0; --l) {
159 bool changed = true;
160 while (changed) {
161 changed = false;
162 if (static_cast<int>(neighbors_[curr].size()) <= l) continue;
163 for (size_t n : neighbors_[curr][l]) {
164 float nd = dist_(query, get_vec(n));
165 if (nd < d) { d = nd; curr = n; changed = true; }
166 }
167 }
168 }
169
170 auto top = search_layer(query, curr, std::max(ef_search_.load(std::memory_order_relaxed), k), 0);
171 std::vector<std::pair<float, size_t>> temp;
172 while (!top.empty()) { temp.push_back(top.top()); top.pop(); }
173 std::sort(temp.begin(), temp.end());
174
175 std::vector<HNSWSearchResult> res;
176 res.reserve(std::min(k, temp.size()));
177 for (size_t i = 0; i < k && i < temp.size(); ++i)
178 res.push_back({ext_ids_[temp[i].second], temp[i].first});
179 return res;
180 }
181
182 void set_ef_search(size_t ef) {
183 if (ef == 0) throw std::invalid_argument("ef_search must be > 0");
184 ef_search_.store(ef, std::memory_order_relaxed);
185 }
186 size_t get_ef_search() const { return ef_search_.load(std::memory_order_relaxed); }
187 size_t size() const { std::shared_lock lk(global_mtx_); return count_; }
188 size_t dimension() const { return dim_; }
189 size_t capacity() const { return max_elements_; }
190 bool contains(uint64_t id) const { std::shared_lock lk(global_mtx_); return id_map_.count(id); }
191
192 std::vector<float> get_vector(uint64_t id) const {
193 std::shared_lock lk(global_mtx_);
194 auto it = id_map_.find(id);
195 if (it == id_map_.end()) throw std::runtime_error("ID not found: " + std::to_string(id));
196 const float* p = vectors_.data() + it->second * dim_;
197 return std::vector<float>(p, p + dim_);
198 }
199
200 void save(const std::string& filename) const {
201 std::shared_lock glock(global_mtx_);
202 std::string tmp = filename + ".tmp";
203 std::ofstream f(tmp, std::ios::binary);
204 if (!f) throw std::runtime_error("Cannot open: " + tmp);
205 try {
208 detail::write_bin(f, dim_);
209 detail::write_bin(f, static_cast<uint32_t>(metric_));
210 detail::write_bin(f, max_elements_);
211 detail::write_bin(f, M_);
212 detail::write_bin(f, ef_construction_);
213 detail::write_bin(f, ef_search_.load());
214 detail::write_bin(f, mult_);
215 detail::write_bin(f, count_.load());
216 detail::write_bin(f, ep_.load());
217 detail::write_bin(f, max_level_.load());
218 detail::write_vec(f, vectors_);
219 detail::write_vec(f, ext_ids_);
220 detail::write_vec(f, levels_);
221 detail::write_bin(f, id_map_.size());
222 for (const auto& [k, v] : id_map_) { detail::write_bin(f, k); detail::write_bin(f, v); }
223 detail::write_bin(f, neighbors_.size());
224 for (size_t i = 0; i < neighbors_.size(); ++i) {
225 detail::write_bin(f, neighbors_[i].size());
226 for (size_t l = 0; l < neighbors_[i].size(); ++l) detail::write_vec(f, neighbors_[i][l]);
227 }
228 // Save RNG state for deterministic behavior after load
229 std::stringstream rng_ss;
230 rng_ss << level_gen_;
231 std::string rng_state = rng_ss.str();
232 detail::write_bin(f, rng_state.size());
233 f.write(rng_state.data(), rng_state.size());
234 f.flush();
235 if (!f) { std::filesystem::remove(tmp); throw std::runtime_error("Write failed: " + tmp); }
236 f.close(); // close before fsync_file (see file_utils.h: Windows lock contract)
238 std::filesystem::rename(tmp, filename);
239 } catch (...) { f.close(); std::filesystem::remove(tmp); throw; }
240 }
241
242 static std::unique_ptr<HNSWIndex> load(const std::string& filename) {
243 std::ifstream f(filename, std::ios::binary);
244 if (!f) throw std::runtime_error("Cannot open: " + filename);
245 uint32_t magic, ver;
246 detail::read_bin(f, magic);
247 if (magic != MAGIC) throw std::runtime_error("Invalid magic");
248 detail::read_bin(f, ver);
249 if (ver != VERSION && ver != 1) throw std::runtime_error("Unsupported version");
250
251 size_t dim, max_el, M, ef_con, ef_s; uint32_t met; double mult;
252 detail::read_bin(f, dim);
253 detail::read_bin(f, met);
254 if (met > 2) throw std::runtime_error("Corrupted file: invalid metric");
255 detail::read_bin(f, max_el);
256 detail::read_bin(f, M);
257 detail::read_bin(f, ef_con);
258 detail::read_bin(f, ef_s);
259 detail::read_bin(f, mult);
260
261 auto idx = std::make_unique<HNSWIndex>(dim, static_cast<DistanceMetric>(met), max_el, M, ef_con);
262 idx->ef_search_.store(ef_s);
263 idx->mult_ = mult;
264
265 size_t cnt, ep_val;
266 int max_level_val;
267 detail::read_bin(f, cnt);
268 if (cnt > max_el) throw std::runtime_error("Corrupted file: count exceeds max_elements");
269 idx->count_.store(cnt);
270 detail::read_bin(f, ep_val);
271 detail::read_bin(f, max_level_val);
272 // Validate ep_ and max_level_
273 if (cnt > 0) {
274 if (ep_val >= cnt) throw std::runtime_error("Corrupted file: invalid entry point");
275 if (max_level_val < 0 || max_level_val > MAX_LEVEL)
276 throw std::runtime_error("Corrupted file: invalid max_level");
277 } else {
278 // Empty index must have invalid entry point
279 if (ep_val != INVALID_ID)
280 throw std::runtime_error("Corrupted file: non-empty entry point for empty index");
281 }
282 idx->ep_.store(ep_val);
283 idx->max_level_.store(max_level_val);
284 detail::read_vec(f, idx->vectors_);
285 detail::read_vec(f, idx->ext_ids_);
286 detail::read_vec(f, idx->levels_);
287
288 size_t msz;
289 detail::read_bin(f, msz);
290 if (msz > cnt) throw std::runtime_error("Corrupted file: id_map size exceeds count");
291 idx->id_map_.reserve(msz);
292 for (size_t i = 0; i < msz; ++i) {
293 uint64_t k; size_t v;
294 detail::read_bin(f, k);
295 detail::read_bin(f, v);
296 if (v >= cnt) throw std::runtime_error("Corrupted file: invalid internal index in id_map");
297 idx->id_map_[k] = v;
298 }
299
300 size_t nsz;
301 detail::read_bin(f, nsz);
302 if (nsz > max_el) throw std::runtime_error("Corrupted file: neighbors size exceeds max_elements");
303 idx->neighbors_.resize(nsz);
304 for (size_t i = 0; i < nsz; ++i) {
305 size_t lsz;
306 detail::read_bin(f, lsz);
307 if (lsz > static_cast<size_t>(MAX_LEVEL) + 1) throw std::runtime_error("Corrupted file: too many levels");
308 idx->neighbors_[i].resize(lsz);
309 for (size_t l = 0; l < lsz; ++l) {
310 detail::read_vec(f, idx->neighbors_[i][l]);
311 // Validate neighbor indices are within bounds
312 for (size_t nid : idx->neighbors_[i][l]) {
313 if (nid >= cnt) throw std::runtime_error("Corrupted file: invalid neighbor index");
314 }
315 }
316 }
317
318 // Restore RNG state for deterministic behavior (v2+)
319 if (ver >= 2) {
320 size_t rng_state_size;
321 detail::read_bin(f, rng_state_size);
322 if (rng_state_size > detail::MAX_RNG_STATE_SIZE)
323 throw std::runtime_error("Corrupted file: RNG state too large");
324 std::string rng_state(rng_state_size, '\0');
325 if (!f.read(rng_state.data(), rng_state_size))
326 throw std::runtime_error("Unexpected end of file or read error");
327 std::stringstream rng_ss(rng_state);
328 rng_ss >> idx->level_gen_;
329 if (rng_ss.fail()) throw std::runtime_error("Corrupted file: invalid RNG state");
330 }
331 // Note: v1 files don't have RNG state, level_gen_ keeps default initialization
332 return idx;
333 }
334
335private:
336 static constexpr double MIN_LEVEL_RANDOM = 1e-9; // Clamp floor for level generation RNG
337 using MaxHeap = std::priority_queue<std::pair<float, size_t>>;
338
339 int get_level() {
340 std::uniform_real_distribution<double> d(0.0, 1.0);
341 double r = std::max(d(level_gen_), MIN_LEVEL_RANDOM);
342 int level = static_cast<int>(-std::log(r) * mult_);
343 return std::min(level, MAX_LEVEL);
344 }
345
346 const float* get_vec(size_t iid) const { return vectors_.data() + iid * dim_; }
347
348 MaxHeap search_layer(const float* q, size_t ep, size_t ef, int level) const {
349 // Versioned thread-local visited bitmap. vis[i] == vis_epoch means
350 // visited; bumping the epoch each call replaces the per-search O(N)
351 // zero-init a fresh bitmap would need with one O(count_) fill every
352 // 65k searches when the uint16_t epoch wraps. Buffer is shared across
353 // HNSWIndex instances on a thread (monotonic epoch keeps cross-index
354 // marks distinct) and is never shrunk.
355 //
356 // Relaxed load on count_ is safe: every caller holds global_mtx_
357 // (exclusive in add(), shared in search()).
358 static thread_local std::vector<uint16_t> vis;
359 static thread_local uint16_t vis_epoch = 0;
360 const size_t total = count_.load(std::memory_order_relaxed);
361 // Entry-point ID must be a live node. load() validates this for persisted
362 // indexes; this hot-path guard also catches in-memory corruption or future
363 // call-site bugs. Defensive — unreachable by construction in tests.
364 if (ep >= total) [[unlikely]] // LCOV_EXCL_LINE
365 throw std::logic_error("HNSWIndex::search_layer: entry point out of range"); // LCOV_EXCL_LINE
366 if (vis.size() < total) vis.resize(total, 0);
367 if (++vis_epoch == 0) {
368 std::fill(vis.begin(), vis.end(), 0);
369 vis_epoch = 1;
370 }
371 vis[ep] = vis_epoch;
372 std::priority_queue<std::pair<float, size_t>, std::vector<std::pair<float, size_t>>,
373 std::greater<std::pair<float, size_t>>> cands;
374 MaxHeap res;
375 float d = dist_(q, get_vec(ep));
376 cands.emplace(d, ep);
377 res.emplace(d, ep);
378 float lb = d;
379
380 while (!cands.empty()) {
381 auto [cd, cid] = cands.top();
382 if (cd > lb && res.size() >= ef) break;
383 cands.pop();
384 if (static_cast<int>(neighbors_[cid].size()) <= level) continue;
385 for (size_t n : neighbors_[cid][level]) {
386 if (vis[n] == vis_epoch) continue;
387 vis[n] = vis_epoch;
388 float nd = dist_(q, get_vec(n));
389 if (res.size() < ef || nd < lb) {
390 cands.emplace(nd, n);
391 res.emplace(nd, n);
392 if (res.size() > ef) res.pop();
393 if (!res.empty()) lb = res.top().first;
394 }
395 }
396 }
397 return res;
398 }
399
400 std::vector<size_t> select_neighbors(MaxHeap& cands, size_t M, int /*level*/) const {
401 if (cands.size() <= M) {
402 std::vector<size_t> r;
403 r.reserve(cands.size());
404 while (!cands.empty()) { r.push_back(cands.top().second); cands.pop(); }
405 return r;
406 }
407 std::vector<std::pair<float, size_t>> sorted;
408 sorted.reserve(cands.size());
409 while (!cands.empty()) { sorted.push_back(cands.top()); cands.pop(); }
410 std::sort(sorted.begin(), sorted.end());
411
412 std::vector<size_t> r;
413 r.reserve(M);
414 for (auto& [dq, cid] : sorted) {
415 if (r.size() >= M) break;
416 bool ok = true;
417 for (size_t s : r)
418 if (dist_(get_vec(cid), get_vec(s)) < dq) { ok = false; break; }
419 if (ok) r.push_back(cid);
420 }
421 if (r.size() < M) {
422 for (auto& p : sorted) {
423 if (r.size() >= M) break;
424 if (std::find(r.begin(), r.end(), p.second) == r.end()) r.push_back(p.second);
425 }
426 }
427 return r;
428 }
429
430 size_t dim_;
431 DistanceMetric metric_;
432 DistanceComputer dist_;
433 size_t max_elements_, M_, M_max_, M_max0_, ef_construction_;
434 std::atomic<size_t> ef_search_; // Atomic for thread-safe reads during search
435 double mult_;
436 std::mt19937 level_gen_;
437 std::vector<float> vectors_;
438 std::vector<uint64_t> ext_ids_;
439 std::unordered_map<uint64_t, size_t> id_map_;
440 std::vector<int> levels_;
441 std::vector<std::vector<std::vector<size_t>>> neighbors_;
442 std::atomic<size_t> ep_{INVALID_ID};
443 std::atomic<int> max_level_{-1};
444 std::atomic<size_t> count_{0};
445 mutable std::shared_mutex global_mtx_;
446};
447
448} // namespace vanedb
HNSWIndex(size_t dimension, DistanceMetric metric=DistanceMetric::L2, size_t max_elements=100000, size_t M=16, size_t ef_construction=200, uint32_t seed=42)
Definition hnsw_index.h:62
void add(uint64_t id, const float *vec)
Definition hnsw_index.h:82
static constexpr uint32_t VERSION
Definition hnsw_index.h:58
static constexpr int MAX_LEVEL
Definition hnsw_index.h:59
size_t dimension() const
Definition hnsw_index.h:188
static constexpr uint32_t MAGIC
Definition hnsw_index.h:57
static constexpr size_t INVALID_ID
Definition hnsw_index.h:60
bool contains(uint64_t id) const
Definition hnsw_index.h:190
static std::unique_ptr< HNSWIndex > load(const std::string &filename)
Definition hnsw_index.h:242
size_t capacity() const
Definition hnsw_index.h:189
std::vector< float > get_vector(uint64_t id) const
Definition hnsw_index.h:192
void save(const std::string &filename) const
Definition hnsw_index.h:200
std::vector< HNSWSearchResult > search(const float *query, size_t k) const
Definition hnsw_index.h:150
size_t get_ef_search() const
Definition hnsw_index.h:186
void set_ef_search(size_t ef)
Definition hnsw_index.h:182
size_t size() const
Definition hnsw_index.h:187
void write_bin(std::ofstream &f, const T &v)
Definition hnsw_index.h:25
constexpr size_t MAX_RNG_STATE_SIZE
Definition hnsw_index.h:37
void read_bin(std::ifstream &f, T &v)
Definition hnsw_index.h:28
void read_vec(std::ifstream &f, std::vector< T > &v)
Definition hnsw_index.h:38
constexpr size_t MAX_VEC_SIZE
Definition hnsw_index.h:36
void fsync_file(const std::string &path) noexcept
Definition file_utils.h:24
void write_vec(std::ofstream &f, const std::vector< T > &v)
Definition hnsw_index.h:32
bool operator>(const HNSWSearchResult &o) const
Definition hnsw_index.h:52
bool operator<(const HNSWSearchResult &o) const
Definition hnsw_index.h:51