Skip to content

Commit a291449

Browse files
committed
Query: add find_ids_by_score()
1 parent 23fbe65 commit a291449

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

objectbox/c.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,9 @@ def c_array_pointer(py_list: Union[List[Any], np.ndarray], c_type):
881881
# OBX_C_API OBX_id_score_array* obx_query_find_ids_with_scores(OBX_query* query);
882882
obx_query_find_ids_with_scores = c_fn('obx_query_find_ids_with_scores', OBX_id_score_array_p, [OBX_query_p])
883883

884+
# OBX_C_API OBX_id_array* obx_query_find_ids_by_score(OBX_query* query);
885+
obx_query_find_ids_by_score = c_fn('obx_query_find_ids_by_score', OBX_id_array_p, [OBX_query_p])
886+
884887
# OBX_C_API obx_err obx_query_count(OBX_query* query, uint64_t* out_count);
885888
obx_query_count = c_fn_rc('obx_query_count', [OBX_query_p, ctypes.POINTER(ctypes.c_uint64)])
886889

objectbox/query.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ def find_ids_with_scores(self) -> List[Tuple[int, float]]:
8888
finally:
8989
obx_id_score_array_free(c_id_score_array_p)
9090

91+
def find_ids_by_score(self) -> List[int]:
92+
""" Finds object IDs matching the query ordered by their query score (e.g. distance in NN search).
93+
The resulting list of IDs is sorted by score in ascending order. """
94+
# TODO extract utility function for ID array conversion
95+
c_id_array_p = obx_query_find_ids_by_score(self._c_query)
96+
try:
97+
c_id_array: OBX_id_array = c_id_array_p.contents
98+
if c_id_array.count == 0:
99+
return []
100+
ids = ctypes.cast(c_id_array.ids, ctypes.POINTER(obx_id * c_id_array.count))
101+
return list(ids.contents)
102+
finally:
103+
obx_id_array_free(c_id_array_p)
104+
91105
def count(self) -> int:
92106
count = ctypes.c_uint64()
93107
obx_query_count(self._c_query, ctypes.byref(count))

tests/test_hnsw.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ def _test_combined_nn_search(distance_type: VectorDistanceType = VectorDistanceT
163163
assert search_results[2][0] == 4
164164
assert search_results[3][0] == 2
165165

166+
search_results = query.find_ids_by_score()
167+
assert len(search_results) == 4
168+
assert search_results[0] == 9
169+
assert search_results[1] == 5
170+
assert search_results[2] == 4
171+
assert search_results[3] == 2
172+
166173
search_results = query.find_ids()
167174
assert len(search_results) == 4
168175
assert search_results[0] == 2

0 commit comments

Comments
 (0)