mirror of https://github.com/redis/redis.git
Expr filtering: implement HNSW filter in search_layer().
This commit is contained in:
parent
438adc917b
commit
025790fc50
76
hnsw.c
76
hnsw.c
|
|
@ -618,9 +618,19 @@ void hnsw_add_node(HNSW *index, hnswNode *node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Search the specified layer starting from the specified entry point
|
/* Search the specified layer starting from the specified entry point
|
||||||
* to collect 'ef' nodes that are near to 'query'. */
|
* to collect 'ef' nodes that are near to 'query'.
|
||||||
pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
|
*
|
||||||
uint32_t ef, uint32_t layer, uint32_t slot)
|
* This function implements optional hybrid search, so that each node
|
||||||
|
* can be accepted or not based on its associated value. In this case
|
||||||
|
* a callback 'filter_callback' should be passed, together with a maximum
|
||||||
|
* effort for the search (number of candidates to evaluate), since even
|
||||||
|
* with a a low "EF" value we risk that there are too few nodes that satisfy
|
||||||
|
* the provided filter, and we could trigger a full scan. */
|
||||||
|
pqueue *search_layer_with_filter(
|
||||||
|
HNSW *index, hnswNode *query, hnswNode *entry_point,
|
||||||
|
uint32_t ef, uint32_t layer, uint32_t slot,
|
||||||
|
int (*filter_callback)(void *value, void *privdata),
|
||||||
|
void *filter_privdata, uint32_t max_candidates)
|
||||||
{
|
{
|
||||||
// Mark visited nodes with a never seen epoch.
|
// Mark visited nodes with a never seen epoch.
|
||||||
index->current_epoch[slot]++;
|
index->current_epoch[slot]++;
|
||||||
|
|
@ -633,27 +643,33 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Take track of the total effort: only used when filtering via
|
||||||
|
// a callback to have a bound effort.
|
||||||
|
uint32_t evaluated_candidates = 1;
|
||||||
|
|
||||||
// Add entry point.
|
// Add entry point.
|
||||||
float dist = hnsw_distance(index, query, entry_point);
|
float dist = hnsw_distance(index, query, entry_point);
|
||||||
pq_push(candidates, entry_point, dist);
|
pq_push(candidates, entry_point, dist);
|
||||||
pq_push(results, entry_point, dist);
|
if (filter_callback == NULL ||
|
||||||
|
filter_callback(entry_point->value, filter_privdata))
|
||||||
|
{
|
||||||
|
pq_push(results, entry_point, dist);
|
||||||
|
}
|
||||||
entry_point->visited_epoch[slot] = index->current_epoch[slot];
|
entry_point->visited_epoch[slot] = index->current_epoch[slot];
|
||||||
|
|
||||||
// Process candidates.
|
// Process candidates.
|
||||||
while (candidates->count > 0) {
|
while (candidates->count > 0) {
|
||||||
// Pop closest element and use its saved distance.
|
// Max effort. If zero, we keep scanning.
|
||||||
|
if (filter_callback &&
|
||||||
|
max_candidates &&
|
||||||
|
evaluated_candidates >= max_candidates) break;
|
||||||
|
|
||||||
float cur_dist;
|
float cur_dist;
|
||||||
hnswNode *current = pq_pop(candidates, &cur_dist);
|
hnswNode *current = pq_pop(candidates, &cur_dist);
|
||||||
|
evaluated_candidates++;
|
||||||
|
|
||||||
/* Stop if we can't get better results. Note that this can
|
|
||||||
* be true only if we already collected 'ef' elements in
|
|
||||||
* the priority queue. This is why: if we have less than EF
|
|
||||||
* elements, later in the for loop that checks the neighbors we
|
|
||||||
* add new elements BOTH in the results and candidates pqueue: this
|
|
||||||
* means that before accumulating EF elements, the worst candidate
|
|
||||||
* can be as bad as the worst result, but not worse. */
|
|
||||||
float furthest = pq_max_distance(results);
|
float furthest = pq_max_distance(results);
|
||||||
if (cur_dist > furthest) break;
|
if (results->count >= ef && cur_dist > furthest) break;
|
||||||
|
|
||||||
/* Check neighbors. */
|
/* Check neighbors. */
|
||||||
for (uint32_t i = 0; i < current->layers[layer].num_links; i++) {
|
for (uint32_t i = 0; i < current->layers[layer].num_links; i++) {
|
||||||
|
|
@ -664,11 +680,29 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
|
||||||
|
|
||||||
neighbor->visited_epoch[slot] = index->current_epoch[slot];
|
neighbor->visited_epoch[slot] = index->current_epoch[slot];
|
||||||
float neighbor_dist = hnsw_distance(index, query, neighbor);
|
float neighbor_dist = hnsw_distance(index, query, neighbor);
|
||||||
// Add to results if better than current max or results not full.
|
|
||||||
furthest = pq_max_distance(results);
|
furthest = pq_max_distance(results);
|
||||||
if (neighbor_dist < furthest || results->count < ef) {
|
if (filter_callback == NULL) {
|
||||||
pq_push(candidates, neighbor, neighbor_dist);
|
/* Original HNSW logic when no filtering:
|
||||||
pq_push(results, neighbor, neighbor_dist);
|
* Add to results if better than current max or
|
||||||
|
* results not full. */
|
||||||
|
if (neighbor_dist < furthest || results->count < ef) {
|
||||||
|
pq_push(candidates, neighbor, neighbor_dist);
|
||||||
|
pq_push(results, neighbor, neighbor_dist);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/* With filtering: we add candidates even if doesn't match
|
||||||
|
* the filter, in order to continue to explore the graph. */
|
||||||
|
if (neighbor_dist < furthest || candidates->count < ef) {
|
||||||
|
pq_push(candidates, neighbor, neighbor_dist);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Add results only if passes filter. */
|
||||||
|
if (filter_callback(neighbor->value, filter_privdata)) {
|
||||||
|
if (neighbor_dist < furthest || results->count < ef) {
|
||||||
|
pq_push(results, neighbor, neighbor_dist);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -677,6 +711,14 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Just a wrapper without hybrid search callback. */
|
||||||
|
pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
|
||||||
|
uint32_t ef, uint32_t layer, uint32_t slot)
|
||||||
|
{
|
||||||
|
return search_layer_with_filter(index, query, entry_point, ef, layer, slot,
|
||||||
|
NULL, NULL, 0);
|
||||||
|
}
|
||||||
|
|
||||||
/* This function is used in order to initialize a node allocated in the
|
/* This function is used in order to initialize a node allocated in the
|
||||||
* function stack with the specified vector. The idea is that we can
|
* function stack with the specified vector. The idea is that we can
|
||||||
* easily use hnsw_distance() from a vector and the HNSW nodes this way:
|
* easily use hnsw_distance() from a vector and the HNSW nodes this way:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue