Skip to content

Commit c091e75

Browse files
committed
Query: add find_ids_by_score_numpy()
1 parent 00a3b91 commit c091e75

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

objectbox/query.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,23 @@ def find_ids_by_score(self) -> List[int]:
102102
finally:
103103
obx_id_array_free(c_id_array_p)
104104

105+
def find_ids_by_score_numpy(self) -> np.array:
106+
""" Finds object IDs matching the query ordered by their query score (e.g. distance in NN search).
107+
The resulting list of IDs is sorted by score in ascending order. """
108+
# TODO extract utility function for ID array conversion
109+
c_id_array_p = obx_query_find_ids_by_score(self._c_query)
110+
try:
111+
c_id_array: OBX_id_array = c_id_array_p.contents
112+
c_count = c_id_array.count
113+
if c_count == 0:
114+
return []
115+
c_ids = ctypes.cast(c_id_array.ids, ctypes.POINTER(obx_id))
116+
numpy_array = np.empty(c_count, dtype=np.uint64)
117+
ctypes.memmove(numpy_array.ctypes.data, c_ids, numpy_array.nbytes)
118+
return numpy_array
119+
finally:
120+
obx_id_array_free(c_id_array_p)
121+
105122
def count(self) -> int:
106123
count = ctypes.c_uint64()
107124
obx_query_count(self._c_query, ctypes.byref(count))

tests/test_hnsw.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ def _test_combined_nn_search(distance_type: VectorDistanceType = VectorDistanceT
170170
assert search_results[2] == 4
171171
assert search_results[3] == 2
172172

173+
search_results = query.find_ids_by_score_numpy()
174+
assert search_results.size == 4
175+
assert search_results[0] == 9
176+
assert search_results[1] == 5
177+
assert search_results[2] == 4
178+
assert search_results[3] == 2
179+
173180
search_results = query.find_ids()
174181
assert len(search_results) == 4
175182
assert search_results[0] == 2

0 commit comments

Comments
 (0)