Vector Sets fixes against corrupted data in absence of checksum verification (#14102)

Vector Sets deserialization was not designed to resist corrupted data,
assuming that a good checksum would mean everything is fine. However
Redis allows the user to specify extra protection via a specific
configuration option.

This commit makes the implementation more resistant, at the cost of some
slowdown. This also fixes a serialization bug that is unrelated (and has
no memory corruption effects) about the lack of the worst index /
distance serialization, that could lower the quality of a graph after
links are replaced. I'll address the serialization issues in a new PR
that will focus on that aspect alone (already work in progress).

The net result is that loading vector sets is, when the serialization of
worst index/distance is missing (always, for now) 100% slower, that is 2
times the loading time we had before. Instead when the info will be
added it will be just 10/15% slower, that is, just making the new sanity
checks.

It may be worth to export to modules if advanced sanity check if needed
or not. Anyway most of the slowdown in this patch comes from having to
recompute the worst neighbor, since duplicated and non reciprocal links
detection was heavy optimized with probabilistic algorithms.

---------

Co-authored-by: debing.sun <debing.sun@redis.com>
This commit is contained in:
Salvatore Sanfilippo 2025-06-10 15:55:09 +02:00 committed by YaacovHazan
parent 6b951ba34c
commit a39dda5462
4 changed files with 210 additions and 7 deletions

View File

@ -45,6 +45,7 @@
#include <float.h> /* for INFINITY if not in math.h */
#include <assert.h>
#include "hnsw.h"
#include "mixer.h"
#if 0
#define debugmsg printf
@ -2127,6 +2128,7 @@ void hnsw_free_serialized_node(hnswSerNode *sn) {
* The function returns NULL both on out of memory and if the remaining
* parameters length does not match the number of links or other items
* to load. */
#define HNSW_SER_WORSTLINK_MISSING UINT32_MAX
hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value)
{
if (params_len < 2) return NULL;
@ -2158,6 +2160,13 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui
uint32_t num_links = params[param_idx++];
uint32_t max_links = params[param_idx++];
/* Sanity check: links should be less than max links and
* in general a reasonable amount. */
if (num_links > max_links || max_links > HNSW_MAX_M*4) {
hnsw_node_free(node);
return NULL;
}
/* If max_links is larger than current allocation, reallocate.
* It could happen in select_neighbors() that we over-allocate the
* node under very unlikely to happen conditions. */
@ -2185,6 +2194,10 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui
* fit more than 2^32 nodes in a 32 bit system. */
for (uint32_t j = 0; j < num_links; j++)
node->layers[i].links[j] = (hnswNode*)params[param_idx++];
/* XXX: fix me, we need to store the worst link info in a
* backward compatible way. */
node->layers[i].worst_idx = HNSW_SER_WORSTLINK_MISSING;
}
/* Get l2 and quantization range. */
@ -2221,13 +2234,28 @@ uint64_t hnsw_hash_node_id(uint64_t id) {
return id;
}
/* Helper for duplicated link detection in hnsw_deserialize_index(). */
static int qsort_compare_pointers(const void *aptr, const void *bptr) {
uintptr_t a = *((uintptr_t*)aptr);
uintptr_t b = *((uintptr_t*)bptr);
if (a > b) return 1;
if (a < b) return -1;
return 0;
}
/* Fix pointers of neighbors nodes: after loading the serialized nodes, the
* neighbors links are just IDs (casted to pointers), instead of the actual
* pointers. We need to resolve IDs into pointers.
*
* The two integers salt0 and salt1 are used to make the internal state
* of the function unguessable to an external attacker, in order to protect
* from corruptions. Show be two random numbers from /dev/urandom if possible
* otherwise can be just 0,0 if the application is not security critical and
* never processes untrusted inputs.
*
* Return 0 on error (out of memory or some ID that can't be resolved), 1 on
* success. */
int hnsw_deserialize_index(HNSW *index) {
int hnsw_deserialize_index(HNSW *index, uint64_t salt0, uint64_t salt1) {
/* We will use simple linear probing, so over-allocating is a good
* idea: anyway this flat array of pointers will consume a fraction
* of the memory of the loaded index. */
@ -2253,12 +2281,60 @@ int hnsw_deserialize_index(HNSW *index) {
node = node->next;
}
/* Second pass: fix pointers of all the neighbors links. */
/* Second pass: fix pointers of all the neighbors links.
* As we scan and fix the links, we also compute the accumulator
* register "reciprocal", that is used in order to guarantee that all
* the links are reciprocal.
*
* This is how it works, we hash (using a strong hash function) the
* following key for each link that we see from A to B (or vice versa):
*
* hash(salt || A || B || link-level)
*
* We always sort A and B, so the same link from A to B and from B to A
* will hash the same. The we xor the result into the 128 bit accumulator.
* If each link has its own backlink, the accumulator is guaranteed to
* be zero at the end.
*
* Collisions are extremely unlikely to happen, and an external attacker
* can't easily control the hash function output, since the salt is
* unknown, and also there would be to control the pointers.
*
* This algorithm is O(1) for each node so it is basically free for
* us, as we scan the list of nodes, and runs on constant and very
* small memory. */
uint64_t accumulator[2] = {0,0};
node = index->head; // Rewind.
while(node) {
uint64_t this_node_id = node->id;
for (uint32_t i = 0; i <= node->level; i++) {
// Check if there are duplicated links: those are
// also corruptions of the on-disk serialization format.
if (node->layers[i].num_links > 0) {
qsort(node->layers[i].links, node->layers[i].num_links,
sizeof(void*), qsort_compare_pointers);
for (uint32_t j = 0; j < node->layers[i].num_links-1; j++) {
if (node->layers[i].links[j] == node->layers[i].links[j+1])
goto corrupted;
}
}
// Resolve pointers.
for (uint32_t j = 0; j < node->layers[i].num_links; j++) {
uint64_t linked_id = (uint64_t) node->layers[i].links[j];
// We can't link to our own node.
if (linked_id == this_node_id) goto corrupted;
// Compute accumulator for reciprocal links check.
uint64_t mixed_h1, mixed_h2;
secure_pair_mixer_128(salt0, salt1, this_node_id, linked_id, (uint64_t)i, &mixed_h1, &mixed_h2);
accumulator[0] ^= mixed_h1;
accumulator[1] ^= mixed_h2;
// Fix links.
uint64_t bucket = hnsw_hash_node_id(linked_id) & (table_size-1);
hnswNode *neighbor = NULL;
for (uint64_t k = 0; k < table_size; k++) {
@ -2268,19 +2344,37 @@ int hnsw_deserialize_index(HNSW *index) {
}
bucket = (bucket+1) & (table_size-1);
}
if (neighbor == NULL) {
/* The neighbor must exist and also exist at the right
* level. */
if (neighbor == NULL || neighbor->level < i) {
/* Unresolved link! Either a bug in this code
* or broken serialization data. */
hfree(table);
return 0;
goto corrupted;
}
node->layers[i].links[j] = neighbor;
}
/* The worst link information was missing from older
* serialization formats. Compute it on the fly if needed. */
if (node->layers[i].worst_idx == HNSW_SER_WORSTLINK_MISSING) {
hnsw_update_worst_neighbor(index,node,i);
}
}
node = node->next;
}
/* Check that links are reciprocal, otherwise fail. */
if (accumulator[0] || accumulator[1]) goto corrupted;
/* Everything fine. Return success. */
hfree(table);
return 1;
corrupted:
/* Some corruption error detected. */
hfree(table);
return 0;
}
/* ================================ Iterator ================================ */

View File

@ -158,7 +158,7 @@ void hnsw_free_insert_context(InsertContext *ctx);
hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node);
void hnsw_free_serialized_node(hnswSerNode *sn);
hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value);
int hnsw_deserialize_index(HNSW *index);
int hnsw_deserialize_index(HNSW *index, uint64_t salt0, uint64_t salt1);
// Helper function in case the user wants to directly copy
// the vector bytes.

106
modules/vector-sets/mixer.h Normal file
View File

@ -0,0 +1,106 @@
/* Redis implementation for vector sets. The data structure itself
* is implemented in hnsw.c.
*
* Copyright (c) 2009-Present, Redis Ltd.
* All rights reserved.
*
* Licensed under your choice of (a) the Redis Source Available License 2.0
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
* GNU Affero General Public License v3 (AGPLv3).
* Originally authored by: Salvatore Sanfilippo.
*
* =============================================================================
*
* Mixing function for HNSW link integrity verification
* Designed to resist collision attacks when salts are unknown.
*/
#include <stdint.h>
#include <string.h>
static inline uint64_t ROTL64(uint64_t x, int r) {
return (x << r) | (x >> (64 - r));
}
// Use more rounds and stronger constants
#define MIX_PRIME_1 0xFF51AFD7ED558CCDULL
#define MIX_PRIME_2 0xC4CEB9FE1A85EC53ULL
#define MIX_PRIME_3 0x9E3779B97F4A7C15ULL
#define MIX_PRIME_4 0xBF58476D1CE4E5B9ULL
#define MIX_PRIME_5 0x94D049BB133111EBULL
#define MIX_PRIME_6 0x2B7E151628AED2A7ULL
/* Mixer design goals:
* 1. Thorough mixing of the level parameter.
* 2. Enough rounds of mixing.
* 3. Cross-influence between h1 and h2.
* 4. Domain separation to prevent related-key attacks.
*/
void secure_pair_mixer_128(uint64_t salt0, uint64_t salt1,
uint64_t id1_in, uint64_t id2_in, uint64_t level,
uint64_t* out_h1, uint64_t* out_h2) {
// Order independence (A -> B links should hash as B -> A links).
uint64_t id_a = (id1_in < id2_in) ? id1_in : id2_in;
uint64_t id_b = (id1_in < id2_in) ? id2_in : id1_in;
// Domain separation: mix salts with a constant to prevent
// related-key attacks.
uint64_t h1 = salt0 ^ 0xDEADBEEFDEADBEEFULL;
uint64_t h2 = salt1 ^ 0xCAFEBABECAFEBABEULL;
// First, thoroughly mix the level into both accumulators
// This prevents predictable level values from being a weakness
uint64_t level_mix = level;
level_mix *= MIX_PRIME_5;
level_mix ^= level_mix >> 32;
level_mix *= MIX_PRIME_6;
h1 ^= level_mix;
h2 ^= ROTL64(level_mix, 31);
// Mix in id_a with strong diffusion.
h1 ^= id_a;
h1 *= MIX_PRIME_1;
h1 = ROTL64(h1, 23);
h1 *= MIX_PRIME_2;
// Mix in id_b.
h2 ^= id_b;
h2 *= MIX_PRIME_3;
h2 = ROTL64(h2, 29);
h2 *= MIX_PRIME_4;
// Three rounds of cross-mixing for better security.
for (int i = 0; i < 3; i++) {
// Cross-influence.
uint64_t tmp = h1;
h1 += h2;
h2 += tmp;
// Mix h1.
h1 ^= ROTL64(h1, 31);
h1 *= MIX_PRIME_1;
h1 ^= salt0;
// Mix h2.
h2 ^= ROTL64(h2, 37);
h2 *= MIX_PRIME_2;
h2 ^= salt1;
}
// Finalization with avalanche rounds.
h1 ^= h1 >> 33;
h1 *= MIX_PRIME_3;
h1 ^= h1 >> 29;
h1 *= MIX_PRIME_4;
h1 ^= h1 >> 32;
h2 ^= h2 >> 33;
h2 *= MIX_PRIME_5;
h2 ^= h2 >> 29;
h2 *= MIX_PRIME_6;
h2 ^= h2 >> 32;
*out_h1 = h1;
*out_h2 = h2;
}

View File

@ -1883,7 +1883,10 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) {
RedisModule_Free(vector);
RedisModule_Free(params);
}
if (!hnsw_deserialize_index(vset->hnsw)) goto ioerr;
uint64_t salt[2];
RedisModule_GetRandomBytes((unsigned char*)salt,sizeof(salt));
if (!hnsw_deserialize_index(vset->hnsw, salt[0], salt[1])) goto ioerr;
return vset;