VaneDB 0.1.0
Embeddable vector database for edge AI
Loading...
Searching...
No Matches
distance.h
Go to the documentation of this file.
1// VaneDB - Copyright (c) 2025 Anton Tsvetkov - MIT License
2#pragma once
3#include <algorithm>
4#include <cassert>
5#include <cmath>
6#include <cstddef>
7
8#if defined(__ARM_NEON) || defined(__aarch64__)
9#include <arm_neon.h>
10#define VANE_ARM_NEON
11#elif defined(__AVX2__)
12#include <immintrin.h>
13#define VANE_AVX2
14#endif
15
16#if defined(_MSC_VER)
17#define RESTRICT __restrict
18#else
19#define RESTRICT __restrict__
20#endif
21
22namespace vanedb {
23
24// Below this denominator (||a|| ยท ||b||) the vectors are treated as orthogonal.
25inline constexpr float COSINE_EPSILON = 1e-12f;
26
27#ifdef VANE_ARM_NEON
28[[nodiscard]] inline float hsum(float32x4_t v) noexcept {
29#if defined(__aarch64__)
30 return vaddvq_f32(v);
31#else
32 float32x2_t r = vadd_f32(vget_low_f32(v), vget_high_f32(v));
33 return vget_lane_f32(vpadd_f32(r, r), 0);
34#endif
35}
36#endif
37
38#ifdef VANE_AVX2
39[[nodiscard]] inline float hsum(__m256 v) noexcept {
40 __m128 lo = _mm256_castps256_ps128(v);
41 __m128 hi = _mm256_extractf128_ps(v, 1);
42 lo = _mm_add_ps(lo, hi);
43 __m128 shuf = _mm_movehdup_ps(lo);
44 lo = _mm_add_ps(lo, shuf);
45 return _mm_cvtss_f32(_mm_add_ss(lo, _mm_movehl_ps(shuf, lo)));
46}
47#endif
48
49[[nodiscard]] inline float l2_sq(const float* RESTRICT a, const float* RESTRICT b, size_t n) noexcept {
50 assert(a && b);
51 float sum = 0.0f;
52 size_t i = 0;
53
54#ifdef VANE_ARM_NEON
55 float32x4_t acc = vdupq_n_f32(0.0f);
56 for (; i + 4 <= n; i += 4) {
57 float32x4_t d = vsubq_f32(vld1q_f32(a + i), vld1q_f32(b + i));
58 acc = vmlaq_f32(acc, d, d);
59 }
60 sum = hsum(acc);
61#elif defined(VANE_AVX2)
62 __m256 acc = _mm256_setzero_ps();
63 for (; i + 8 <= n; i += 8) {
64 __m256 d = _mm256_sub_ps(_mm256_loadu_ps(a + i), _mm256_loadu_ps(b + i));
65 acc = _mm256_fmadd_ps(d, d, acc);
66 }
67 sum = hsum(acc);
68#endif
69
70 for (; i < n; ++i) {
71 float d = a[i] - b[i];
72 sum += d * d;
73 }
74 return sum;
75}
76
77[[nodiscard]] inline float dot_product(const float* RESTRICT a, const float* RESTRICT b, size_t n) noexcept {
78 assert(a && b);
79 float sum = 0.0f;
80 size_t i = 0;
81
82#ifdef VANE_ARM_NEON
83 float32x4_t acc = vdupq_n_f32(0.0f);
84 for (; i + 4 <= n; i += 4)
85 acc = vmlaq_f32(acc, vld1q_f32(a + i), vld1q_f32(b + i));
86 sum = hsum(acc);
87#elif defined(VANE_AVX2)
88 __m256 acc = _mm256_setzero_ps();
89 for (; i + 8 <= n; i += 8)
90 acc = _mm256_fmadd_ps(_mm256_loadu_ps(a + i), _mm256_loadu_ps(b + i), acc);
91 sum = hsum(acc);
92#endif
93
94 for (; i < n; ++i)
95 sum += a[i] * b[i];
96 return sum;
97}
98
99[[nodiscard]] inline float cosine_distance(const float* RESTRICT a, const float* RESTRICT b, size_t n) noexcept {
100 assert(a && b);
101 float dot = 0.0f, na = 0.0f, nb = 0.0f;
102 size_t i = 0;
103
104#ifdef VANE_ARM_NEON
105 float32x4_t vdot = vdupq_n_f32(0.0f), vna = vdupq_n_f32(0.0f), vnb = vdupq_n_f32(0.0f);
106 for (; i + 4 <= n; i += 4) {
107 float32x4_t va = vld1q_f32(a + i), vb = vld1q_f32(b + i);
108 vdot = vmlaq_f32(vdot, va, vb);
109 vna = vmlaq_f32(vna, va, va);
110 vnb = vmlaq_f32(vnb, vb, vb);
111 }
112 dot = hsum(vdot); na = hsum(vna); nb = hsum(vnb);
113#elif defined(VANE_AVX2)
114 __m256 vdot = _mm256_setzero_ps(), vna = _mm256_setzero_ps(), vnb = _mm256_setzero_ps();
115 for (; i + 8 <= n; i += 8) {
116 __m256 va = _mm256_loadu_ps(a + i), vb = _mm256_loadu_ps(b + i);
117 vdot = _mm256_fmadd_ps(va, vb, vdot);
118 vna = _mm256_fmadd_ps(va, va, vna);
119 vnb = _mm256_fmadd_ps(vb, vb, vnb);
120 }
121 dot = hsum(vdot); na = hsum(vna); nb = hsum(vnb);
122#endif
123
124 for (; i < n; ++i) {
125 dot += a[i] * b[i];
126 na += a[i] * a[i];
127 nb += b[i] * b[i];
128 }
129
130 float denom = na * nb;
131 if (denom < COSINE_EPSILON) return 1.0f;
132 float sim = dot / sqrtf(denom);
133 return 1.0f - std::clamp(sim, -1.0f, 1.0f);
134}
135
136} // namespace vanedb
#define RESTRICT
Definition distance.h:19
float cosine_distance(const float *__restrict__ a, const float *__restrict__ b, size_t n) noexcept
Definition distance.h:99
float hsum(float32x4_t v) noexcept
Definition distance.h:28
float dot_product(const float *__restrict__ a, const float *__restrict__ b, size_t n) noexcept
Definition distance.h:77
constexpr float COSINE_EPSILON
Definition distance.h:25
float l2_sq(const float *__restrict__ a, const float *__restrict__ b, size_t n) noexcept
Definition distance.h:49