5#include <TargetConditionals.h>
6#if TARGET_OS_MAC || TARGET_OS_IPHONE
10#import <Foundation/Foundation.h>
22 static MetalCompute& get() {
static MetalCompute c;
return c; }
23 bool ok()
const {
return dev_ != nil; }
25 std::vector<float> l2(
const float* q,
const float* v,
size_t d,
size_t n) {
26 return run(q, v, d, n, l2_); }
27 std::vector<float> dot(
const float* q,
const float* v,
size_t d,
size_t n) {
28 return run(q, v, d, n, dot_); }
29 std::vector<float> cos(
const float* q,
const float* v,
size_t d,
size_t n) {
30 return run(q, v, d, n, cos_); }
33 id<MTLBuffer> upload(
const float* v,
size_t n,
size_t d) {
34 return [dev_ newBufferWithBytes:v length:n*d*
sizeof(float) options:MTLResourceStorageModeShared];
37 std::vector<float> search(
const float* q, id<MTLBuffer> vbuf,
size_t d,
size_t n, MetalMetric m) {
38 if (!ok() || d % 4 != 0)
throw std::runtime_error(
"Metal unavailable or dim%4!=0");
39 std::vector<float> out(n);
41 id<MTLBuffer> qb = [dev_ newBufferWithBytes:q length:d*
sizeof(float) options:MTLResourceStorageModeShared];
42 id<MTLBuffer> rb = [dev_ newBufferWithLength:n*
sizeof(float) options:MTLResourceStorageModeShared];
43 uint32_t d4 = (uint32_t)(d/4);
44 id<MTLBuffer> db = [dev_ newBufferWithBytes:&d4 length:4 options:MTLResourceStorageModeShared];
45 id<MTLComputePipelineState> p = (m==MetalMetric::L2)?l2_:(m==MetalMetric::
DOT)?dot_:cos_;
46 id<MTLCommandBuffer> cmd = [q_ commandBuffer];
47 id<MTLComputeCommandEncoder> e = [cmd computeCommandEncoder];
48 [e setComputePipelineState:p];
49 [e setBuffer:qb offset:0 atIndex:0];
50 [e setBuffer:vbuf offset:0 atIndex:1];
51 [e setBuffer:rb offset:0 atIndex:2];
52 [e setBuffer:db offset:0 atIndex:3];
53 MTLSize grid = MTLSizeMake(n,1,1);
54 MTLSize group = MTLSizeMake(MIN(p.maxTotalThreadsPerThreadgroup,n),1,1);
55 [e dispatchThreads:grid threadsPerThreadgroup:group];
58 [cmd waitUntilCompleted];
59 memcpy(out.data(), [rb contents], n*
sizeof(
float));
65 MetalCompute() { init(); }
68 dev_ = MTLCreateSystemDefaultDevice();
70 q_ = [dev_ newCommandQueue];
71 if (!q_) { dev_=nil;
return; }
74#include <metal_stdlib>
76kernel void l2(device const float4* q[[buffer(0)]], device const float4* v[[buffer(1)]],
77 device float* r[[buffer(2)]], constant uint& d4[[buffer(3)]], uint i[[thread_position_in_grid]]) {
78 float4 s=0; uint o=i*d4; for(uint j=0;j<d4;++j){float4 x=q[j]-v[o+j];s+=x*x;} r[i]=s.x+s.y+s.z+s.w;
80kernel void dp(device const float4* q[[buffer(0)]], device const float4* v[[buffer(1)]],
81 device float* r[[buffer(2)]], constant uint& d4[[buffer(3)]], uint i[[thread_position_in_grid]]) {
82 float4 s=0; uint o=i*d4; for(uint j=0;j<d4;++j)s+=q[j]*v[o+j]; r[i]=-(s.x+s.y+s.z+s.w);
84kernel void cs(device const float4* q[[buffer(0)]], device const float4* v[[buffer(1)]],
85 device float* r[[buffer(2)]], constant uint& d4[[buffer(3)]], uint i[[thread_position_in_grid]]) {
86 float4 d=0,nq=0,nv=0; uint o=i*d4;
87 for(uint j=0;j<d4;++j){float4 a=q[j],b=v[o+j];d+=a*b;nq+=a*a;nv+=b*b;}
88 float dot=d.x+d.y+d.z+d.w,na=nq.x+nq.y+nq.z+nq.w,nb=nv.x+nv.y+nv.z+nv.w;
89 float dn=na*nb; r[i]=1.0f-clamp((dn<1e-12f)?0.0f:dot*rsqrt(dn),-1.0f,1.0f);
92 id<MTLLibrary> lib = [dev_ newLibraryWithSource:src options:nil error:&err];
93 if (!lib) { dev_=nil;
return; }
94 l2_ = [dev_ newComputePipelineStateWithFunction:[lib newFunctionWithName:
@"l2"] error:&err];
95 dot_ = [dev_ newComputePipelineStateWithFunction:[lib newFunctionWithName:
@"dp"] error:&err];
96 cos_ = [dev_ newComputePipelineStateWithFunction:[lib newFunctionWithName:
@"cs"] error:&err];
97 if (!l2_||!dot_||!cos_) dev_=nil;
101 std::vector<float> run(
const float* q,
const float* v,
size_t d,
size_t n,
102 id<MTLComputePipelineState> p) {
103 if (!ok() || d % 4 != 0)
throw std::runtime_error(
"Metal unavailable or dim%4!=0");
104 std::vector<float> out(n);
106 id<MTLBuffer> qb = [dev_ newBufferWithBytes:q length:d*
sizeof(float) options:MTLResourceStorageModeShared];
107 id<MTLBuffer> vb = [dev_ newBufferWithBytes:v length:n*d*
sizeof(float) options:MTLResourceStorageModeShared];
108 id<MTLBuffer> rb = [dev_ newBufferWithLength:n*
sizeof(float) options:MTLResourceStorageModeShared];
109 uint32_t d4 = (uint32_t)(d/4);
110 id<MTLBuffer> db = [dev_ newBufferWithBytes:&d4 length:4 options:MTLResourceStorageModeShared];
111 id<MTLCommandBuffer> cmd = [q_ commandBuffer];
112 id<MTLComputeCommandEncoder> e = [cmd computeCommandEncoder];
113 [e setComputePipelineState:p];
114 [e setBuffer:qb offset:0 atIndex:0];
115 [e setBuffer:vb offset:0 atIndex:1];
116 [e setBuffer:rb offset:0 atIndex:2];
117 [e setBuffer:db offset:0 atIndex:3];
118 MTLSize grid = MTLSizeMake(n,1,1);
119 MTLSize group = MTLSizeMake(MIN(p.maxTotalThreadsPerThreadgroup,n),1,1);
120 [e dispatchThreads:grid threadsPerThreadgroup:group];
123 [cmd waitUntilCompleted];
124 memcpy(out.data(), [rb contents], n*
sizeof(
float));
129 id<MTLDevice> dev_ = nil;
130 id<MTLCommandQueue> q_ = nil;
131 id<MTLComputePipelineState> l2_ = nil, dot_ = nil, cos_ = nil;
134inline bool metal_available() {
return MetalCompute::get().ok(); }