1use parking_lot::RwLock;
2use rand::Rng;
3use std::collections::{BinaryHeap, HashSet};
4use std::cmp::Reverse;
5
6#[derive(Clone)]
7struct HnswNode {
8 vector: Vec<f32>,
9 neighbors: Vec<Vec<u32>>,
10}
11
12pub struct HnswIndex {
13 nodes: RwLock<Vec<HnswNode>>,
14 entry_point: RwLock<Option<u32>>,
15 max_level: RwLock<usize>,
16 m: usize,
17 m_max0: usize,
18 ef_construction: usize,
19 ml: f64,
20 dist_fn: fn(&[f32], &[f32]) -> f32,
21}
22
23#[derive(Clone, PartialEq)]
24struct Candidate {
25 id: u32,
26 distance: f32,
27}
28
29impl Eq for Candidate {}
30
31impl PartialOrd for Candidate {
32 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
33 self.distance.partial_cmp(&other.distance)
34 }
35}
36
37impl Ord for Candidate {
38 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
39 self.distance.partial_cmp(&other.distance).unwrap()
40 }
41}
42
43impl HnswIndex {
44 pub fn new(m: usize, ef_construction: usize) -> Self {
45 Self {
46 nodes: RwLock::new(Vec::new()),
47 entry_point: RwLock::new(None),
48 max_level: RwLock::new(0),
49 m,
50 m_max0: m * 2,
51 ef_construction,
52 ml: 1.0 / (m as f64).ln(),
53 dist_fn: |a, b| 1.0 - cosine_similarity(a, b),
54 }
55 }
56
57 fn random_level(&self) -> usize {
58 let mut rng = rand::thread_rng();
59 let r: f64 = rng.gen();
60 (-r.ln() * self.ml) as usize
61 }
62
63 pub fn insert(&self, vector: Vec<f32>) -> u32 {
64 let level = self.random_level();
65 let mut node = HnswNode {
66 vector,
67 neighbors: vec![Vec::new(); level + 1],
68 };
69
70 let mut nodes = self.nodes.write();
71 let id = nodes.len() as u32;
72 nodes.push(node);
73
74 let entry_point = *self.entry_point.read();
75 if entry_point.is_none() {
76 *self.entry_point.write() = Some(id);
77 *self.max_level.write() = level;
78 return id;
79 }
80
81 let ep = entry_point.unwrap();
82 let current_max_level = *self.max_level.read();
83
84
85 let mut current = ep;
86 let query = &nodes[id as usize].vector;
87
88 for l in (level + 1..=current_max_level).rev() {
89 loop {
90 let mut changed = false;
91 let neighbors = &nodes[current as usize].neighbors;
92 if l < neighbors.len() {
93 for &neighbor in &neighbors[l] {
94 let dist = (self.dist_fn)(
95 query,
96 &nodes[neighbor as usize].vector,
97 );
98 let current_dist = (self.dist_fn)(
99 query,
100 &nodes[current as usize].vector,
101 );
102 if dist < current_dist {
103 current = neighbor;
104 changed = true;
105 }
106 }
107 }
108 if !changed { break; }
109 }
110 }
111
112
113 let mut ep_list = vec![current];
114 let insert_level = level.min(current_max_level);
115
116 for l in (0..=insert_level).rev() {
117 let candidates = self.search_layer(
118 &nodes, query, &ep_list, self.ef_construction, l,
119 );
120
121 let max_conn = if l == 0 { self.m_max0 } else { self.m };
122 let selected: Vec<u32> = candidates
123 .iter()
124 .take(max_conn)
125 .map(|c| c.id)
126 .collect();
127
128
129 nodes[id as usize].neighbors[l] = selected.clone();
130
131
132 for &neighbor_id in &selected {
133 let neighbor = &mut nodes[neighbor_id as usize];
134 if l < neighbor.neighbors.len() {
135 neighbor.neighbors[l].push(id);
136 if neighbor.neighbors[l].len() > max_conn {
137
138 let nv = neighbor.vector.clone();
139 let mut scored: Vec<_> = neighbor.neighbors[l]
140 .iter()
141 .map(|&nid| Candidate {
142 id: nid,
143 distance: (self.dist_fn)(
144 &nv,
145 &nodes[nid as usize].vector,
146 ),
147 })
148 .collect();
149 scored.sort();
150 neighbor.neighbors[l] = scored
151 .into_iter()
152 .take(max_conn)
153 .map(|c| c.id)
154 .collect();
155 }
156 }
157 }
158
159 ep_list = selected;
160 }
161
162 if level > current_max_level {
163 *self.max_level.write() = level;
164 *self.entry_point.write() = Some(id);
165 }
166
167 id
168 }
169
170 fn search_layer(
171 &self,
172 nodes: &[HnswNode],
173 query: &[f32],
174 entry_points: &[u32],
175 ef: usize,
176 level: usize,
177 ) -> Vec<Candidate> {
178 let mut visited = HashSet::new();
179 let mut candidates = BinaryHeap::new();
180 let mut results = BinaryHeap::new();
181
182 for &ep in entry_points {
183 let dist = (self.dist_fn)(query, &nodes[ep as usize].vector);
184 visited.insert(ep);
185 candidates.push(Reverse(Candidate { id: ep, distance: dist }));
186 results.push(Candidate { id: ep, distance: dist });
187 }
188
189 while let Some(Reverse(nearest)) = candidates.pop() {
190 let farthest_dist = results.peek().map(|c| c.distance).unwrap_or(f32::MAX);
191 if nearest.distance > farthest_dist {
192 break;
193 }
194
195 let neighbors = &nodes[nearest.id as usize].neighbors;
196 if level >= neighbors.len() { continue; }
197
198 for &neighbor_id in &neighbors[level] {
199 if visited.contains(&neighbor_id) { continue; }
200 visited.insert(neighbor_id);
201
202 let dist = (self.dist_fn)(
203 query,
204 &nodes[neighbor_id as usize].vector,
205 );
206
207 let farthest_dist = results.peek().map(|c| c.distance).unwrap_or(f32::MAX);
208
209 if results.len() < ef || dist < farthest_dist {
210 candidates.push(Reverse(Candidate {
211 id: neighbor_id,
212 distance: dist,
213 }));
214 results.push(Candidate {
215 id: neighbor_id,
216 distance: dist,
217 });
218 if results.len() > ef {
219 results.pop();
220 }
221 }
222 }
223 }
224
225 let mut result_vec: Vec<_> = results.into_vec();
226 result_vec.sort();
227 result_vec
228 }
229
230 pub fn search(&self, query: &[f32], top_k: usize, ef_search: usize) -> Vec<(u32, f32)> {
231 let nodes = self.nodes.read();
232 let entry_point = *self.entry_point.read();
233
234 let ep = match entry_point {
235 Some(ep) => ep,
236 None => return Vec::new(),
237 };
238
239 let max_level = *self.max_level.read();
240 let mut current = ep;
241
242
243 for l in (1..=max_level).rev() {
244 loop {
245 let mut changed = false;
246 let neighbors = &nodes[current as usize].neighbors;
247 if l < neighbors.len() {
248 for &neighbor in &neighbors[l] {
249 let dist = (self.dist_fn)(
250 query,
251 &nodes[neighbor as usize].vector,
252 );
253 let cur_dist = (self.dist_fn)(
254 query,
255 &nodes[current as usize].vector,
256 );
257 if dist < cur_dist {
258 current = neighbor;
259 changed = true;
260 }
261 }
262 }
263 if !changed { break; }
264 }
265 }
266
267
268 let results = self.search_layer(
269 &nodes, query, &[current], ef_search, 0,
270 );
271
272 results
273 .into_iter()
274 .take(top_k)
275 .map(|c| (c.id, 1.0 - c.distance))
276 .collect()
277 }
278}
279