VaneDB 0.1.0
Embeddable vector database for edge AI
Loading...
Searching...
No Matches
vector_store.h
Go to the documentation of this file.
1// VaneDB - Copyright (c) 2025 Anton Tsvetkov - MIT License
2#pragma once
3#include "distance_strategy.h"
4#include <algorithm>
5#include <cstddef>
6#include <cstdint>
7#include <limits>
8#include <mutex>
9#include <shared_mutex>
10#include <stdexcept>
11#include <string>
12#include <unordered_map>
13#include <vector>
14
15namespace vanedb {
16
19 float distance;
20 bool operator<(const SearchResult& o) const { return distance < o.distance; }
21};
22
24public:
26 : dim_(dimension), metric_(metric), dist_(metric, dimension) {
27 if (dimension == 0) throw std::invalid_argument("Dimension must be > 0");
28 }
29
30 void add(uint64_t id, const float* vector) {
31 if (!vector) throw std::invalid_argument("Vector must not be null");
32 std::unique_lock lock(mutex_);
33 if (id_to_index_.count(id)) throw std::invalid_argument("ID " + std::to_string(id) + " exists");
34 vectors_data_.insert(vectors_data_.end(), vector, vector + dim_);
35 ids_.push_back(id);
36 id_to_index_[id] = ids_.size() - 1;
37 }
38
39 bool remove(uint64_t id) {
40 std::unique_lock lock(mutex_);
41 auto it = id_to_index_.find(id);
42 if (it == id_to_index_.end()) return false;
43 size_t idx = it->second, last = ids_.size() - 1;
44 if (idx != last) {
45 std::copy_n(vectors_data_.data() + last * dim_, dim_, vectors_data_.data() + idx * dim_);
46 ids_[idx] = ids_[last];
47 id_to_index_[ids_[idx]] = idx;
48 }
49 vectors_data_.resize(vectors_data_.size() - dim_);
50 ids_.pop_back();
51 id_to_index_.erase(it);
52 return true;
53 }
54
55 // WARNING: Returned pointer invalidated by any write operation
56 const float* get(uint64_t id) const {
57 std::shared_lock lock(mutex_);
58 auto it = id_to_index_.find(id);
59 return it == id_to_index_.end() ? nullptr : vectors_data_.data() + it->second * dim_;
60 }
61
62 // Thread-safe: returns a copy of the vector (safe for concurrent access)
63 std::vector<float> get_copy(uint64_t id) const {
64 std::shared_lock lock(mutex_);
65 auto it = id_to_index_.find(id);
66 if (it == id_to_index_.end()) return {};
67 const float* ptr = vectors_data_.data() + it->second * dim_;
68 return std::vector<float>(ptr, ptr + dim_);
69 }
70
71 std::vector<SearchResult> search(const float* query, size_t k) const {
72 if (!query) throw std::invalid_argument("Query must not be null");
73 if (k == 0) throw std::invalid_argument("k must be > 0");
74 std::shared_lock lock(mutex_);
75 std::vector<SearchResult> results;
76 results.reserve(ids_.size());
77 for (size_t i = 0; i < ids_.size(); ++i)
78 results.push_back({ids_[i], dist_(query, vectors_data_.data() + i * dim_)});
79 size_t n = std::min(k, results.size());
80 std::partial_sort(results.begin(), results.begin() + n, results.end());
81 results.resize(n);
82 return results;
83 }
84
85 size_t size() const { std::shared_lock lock(mutex_); return ids_.size(); }
86 size_t dimension() const { return dim_; }
87 DistanceMetric metric() const { return metric_; }
88
89 void clear() {
90 std::unique_lock lock(mutex_);
91 vectors_data_.clear();
92 ids_.clear();
93 id_to_index_.clear();
94 }
95
96 bool contains(uint64_t id) const {
97 std::shared_lock lock(mutex_);
98 return id_to_index_.count(id);
99 }
100
101 void reserve(size_t capacity) {
102 std::unique_lock lock(mutex_);
103 vectors_data_.reserve(capacity * dim_);
104 ids_.reserve(capacity);
105 id_to_index_.reserve(capacity);
106 }
107
108 bool update(uint64_t id, const float* vector) {
109 if (!vector) throw std::invalid_argument("Vector must not be null");
110 std::unique_lock lock(mutex_);
111 auto it = id_to_index_.find(id);
112 if (it == id_to_index_.end()) return false;
113 std::copy_n(vector, dim_, vectors_data_.data() + it->second * dim_);
114 return true;
115 }
116
117private:
118 size_t dim_;
119 DistanceMetric metric_;
120 DistanceComputer dist_;
121 std::vector<float> vectors_data_;
122 std::vector<uint64_t> ids_;
123 std::unordered_map<uint64_t, size_t> id_to_index_;
124 mutable std::shared_mutex mutex_;
125};
126
127} // namespace vanedb
bool remove(uint64_t id)
size_t dimension() const
DistanceMetric metric() const
size_t size() const
const float * get(uint64_t id) const
std::vector< SearchResult > search(const float *query, size_t k) const
VectorStore(size_t dimension, DistanceMetric metric=DistanceMetric::L2)
bool contains(uint64_t id) const
std::vector< float > get_copy(uint64_t id) const
bool update(uint64_t id, const float *vector)
void reserve(size_t capacity)
void add(uint64_t id, const float *vector)
bool operator<(const SearchResult &o) const