[Vector sets] Endianess fix and speedup of data loading (#14144)
CI / test-ubuntu-latest (push) Waiting to run Details
CI / test-sanitizer-address (push) Waiting to run Details
CI / build-debian-old (push) Waiting to run Details
CI / build-macos-latest (push) Waiting to run Details
CI / build-32bit (push) Waiting to run Details
CI / build-libc-malloc (push) Waiting to run Details
CI / build-centos-jemalloc (push) Waiting to run Details
CI / build-old-chain-jemalloc (push) Waiting to run Details
Codecov / code-coverage (push) Waiting to run Details
External Server Tests / test-external-standalone (push) Waiting to run Details
External Server Tests / test-external-cluster (push) Waiting to run Details
External Server Tests / test-external-nodebug (push) Waiting to run Details
Spellcheck / Spellcheck (push) Waiting to run Details

Hello, this is a patch that improves vector sets in two ways:

1. It makes the RDB format compatible with big endian machines: yeah,
they are non existent nowadays, but still it is better to be correct.
The behavior remains unchanged in little endian systems, it only changes
what happens in big endian systems in order for it to load and emit the
exact same format produced by little endian. The implementation was
*already largely safe* but for one detail.

2. More importantly, this PR saves nodes worst link score / index in a
backward compatible way, introducing also versioning information for the
serialized node encoding, that could be useful in the future. With this
information, that in the past was not saved for a programming error
(mine), there is no longer need to compute the worst link info at
runtime when loading data. This results in a speed improvement of about
30% when loading data from disk / RESTORE. The saving performance is
unaffected.

The patch was tested with care to be sure that data produced with old
vector sets implementations are loaded without issues (that is, the
backward compatibility was hand-tested). The new code is tested by the
persistence test already in the test suite, so no new test was added.
This commit is contained in:
Salvatore Sanfilippo 2025-07-10 04:08:59 +02:00 committed by GitHub
parent 92e39cac96
commit b5d54866ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 79 additions and 15 deletions

View File

@ -2058,6 +2058,18 @@ hnswNode *hnsw_random_node(HNSW *index, int slot) {
* hash table, then scan all the nodes again and fix all the links converting * hash table, then scan all the nodes again and fix all the links converting
* the ID to the pointer. */ * the ID to the pointer. */
/* History of serialization versions:
* version 0: the first implementation, lacking worst node id/info.
* version 1: includes worst link id/info. */
#define HNSW_SERIALIZATION_VERSION 1
/* This is a special worst link index that is set when loading a serialized
* node with version 0 (this version of the serialization lacked explicit
* information about the worst link index/distance). This way, later, the
* function that fixes a deserialized index will know to compute the worst
* index info at runtime. */
#define HNSW_SER_WORSTLINK_MISSING UINT32_MAX
/* Return the serialized node information as specified in the top comment /* Return the serialized node information as specified in the top comment
* above. Note that the returned information is true as long as the node * above. Note that the returned information is true as long as the node
* provided is not deleted or modified, so this function should be called * provided is not deleted or modified, so this function should be called
@ -2073,6 +2085,7 @@ hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node) {
for (uint32_t i = 0; i <= node->level; i++) { for (uint32_t i = 0; i <= node->level; i++) {
num_params += 2; // max_links and num_links info for this layer. num_params += 2; // max_links and num_links info for this layer.
num_params += node->layers[i].num_links; // The IDs of linked nodes. num_params += node->layers[i].num_links; // The IDs of linked nodes.
num_params += 1; // worst link id/distance parameter.
} }
/* We use another 64bit value to store two floats that are about /* We use another 64bit value to store two floats that are about
@ -2096,18 +2109,43 @@ hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node) {
uint32_t param_idx = 0; uint32_t param_idx = 0;
sn->params[param_idx++] = node->id; sn->params[param_idx++] = node->id;
sn->params[param_idx++] = node->level; /* The second parameter contains information about the serialization
* version of this node, the node level and some unused field:
*
* +--------+--------+--------+--------+
* |VVVVVVVV|........|........|LLLLLLLL|
* +--------+--------+--------+--------+
*
* V is the version, 8 bits.
* L is the node level, 8 bits (but actually 16 is the max so far).
* The middle two bytes are reserved for future uses. */
sn->params[param_idx] = node->level & 0xff;
sn->params[param_idx] |= HNSW_SERIALIZATION_VERSION << 24;
param_idx++;
for (uint32_t i = 0; i <= node->level; i++) { for (uint32_t i = 0; i <= node->level; i++) {
sn->params[param_idx++] = node->layers[i].num_links; sn->params[param_idx++] = node->layers[i].num_links;
sn->params[param_idx++] = node->layers[i].max_links; sn->params[param_idx++] = node->layers[i].max_links;
for (uint32_t j = 0; j < node->layers[i].num_links; j++) { for (uint32_t j = 0; j < node->layers[i].num_links; j++) {
sn->params[param_idx++] = node->layers[i].links[j]->id; sn->params[param_idx++] = node->layers[i].links[j]->id;
} }
/* Since version 1: pack and store worst_idx and worst_distance. */
uint32_t worst_distance_bits;
memcpy(&worst_distance_bits, &node->layers[i].worst_distance,
sizeof(float));
uint64_t wi =
(((uint64_t)worst_distance_bits) << 32) | node->layers[i].worst_idx;
sn->params[param_idx++] = wi;
} }
uint64_t l2_and_range = 0;
unsigned char *aux = (unsigned char*)&l2_and_range; /* Store l2 and range as uint32_t, in a way that is endian-safe.
memcpy(aux,&node->l2,sizeof(float)); * Note that in big endian archs both are reversed: integers and
memcpy(aux+4,&node->quants_range,sizeof(float)); * also the bytes of floats, so they will match. */
uint64_t l2_and_range;
uint32_t l2_bits, range_bits;
memcpy(&l2_bits,&node->l2,sizeof(float));
memcpy(&range_bits,&node->quants_range,sizeof(float));
l2_and_range = ((uint64_t)range_bits<<32) | l2_bits;
sn->params[param_idx++] = l2_and_range; sn->params[param_idx++] = l2_and_range;
/* Better safe than sorry: */ /* Better safe than sorry: */
@ -2128,13 +2166,18 @@ void hnsw_free_serialized_node(hnswSerNode *sn) {
* The function returns NULL both on out of memory and if the remaining * The function returns NULL both on out of memory and if the remaining
* parameters length does not match the number of links or other items * parameters length does not match the number of links or other items
* to load. */ * to load. */
#define HNSW_SER_WORSTLINK_MISSING UINT32_MAX
hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value) hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value)
{ {
if (params_len < 2) return NULL; if (params_len < 2) return NULL;
uint64_t id = params[0]; uint64_t id = params[0];
uint32_t level = params[1]; /* Check the node serialization function for the specific layout
* of param[1] fields. */
uint32_t level = params[1] & 0xff; // Node level.
uint32_t version = (params[1] & 0xff000000) >> 24; // Format version.
if (version > HNSW_SERIALIZATION_VERSION) return NULL;
int has_worst_link_info = version > 0;
/* Keep track of maximum ID seen while loading. */ /* Keep track of maximum ID seen while loading. */
if (id >= index->last_id) index->last_id = id; if (id >= index->last_id) index->last_id = id;
@ -2152,7 +2195,7 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui
uint32_t param_idx = 2; uint32_t param_idx = 2;
for (uint32_t i = 0; i <= level; i++) { for (uint32_t i = 0; i <= level; i++) {
/* Sanity check. */ /* Sanity check. */
if (param_idx + 2 > params_len) { if (param_idx + 2 + has_worst_link_info > params_len) {
hnsw_node_free(node); hnsw_node_free(node);
return NULL; return NULL;
} }
@ -2183,7 +2226,7 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui
node->layers[i].num_links = num_links; node->layers[i].num_links = num_links;
/* Sanity check. */ /* Sanity check. */
if (param_idx + num_links > params_len) { if (param_idx + num_links + has_worst_link_info > params_len) {
hnsw_node_free(node); hnsw_node_free(node);
return NULL; return NULL;
} }
@ -2195,9 +2238,26 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui
for (uint32_t j = 0; j < num_links; j++) for (uint32_t j = 0; j < num_links; j++)
node->layers[i].links[j] = (hnswNode*)params[param_idx++]; node->layers[i].links[j] = (hnswNode*)params[param_idx++];
/* XXX: fix me, we need to store the worst link info in a if (has_worst_link_info) {
* backward compatible way. */ uint64_t wi = params[param_idx++];
node->layers[i].worst_idx = HNSW_SER_WORSTLINK_MISSING; uint32_t worst_idx = wi & 0xffffffff;
uint32_t worst_distance_bits = wi >> 32;
float worst_distance;
memcpy(&worst_distance,&worst_distance_bits,sizeof(float));
node->layers[i].worst_idx = worst_idx;
node->layers[i].worst_distance = worst_distance;
// Sanity check the worst ID range.
if (node->layers[i].num_links > 0 &&
node->layers[i].worst_idx >= node->layers[i].num_links)
{
hnsw_node_free(node);
return NULL;
}
} else {
node->layers[i].worst_idx = HNSW_SER_WORSTLINK_MISSING;
node->layers[i].worst_distance = 0;
}
} }
/* Get l2 and quantization range. */ /* Get l2 and quantization range. */
@ -2205,10 +2265,14 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui
hnsw_node_free(node); hnsw_node_free(node);
return NULL; return NULL;
} }
/* Load l2 and range packed into an uint64_t in an endian safe way. */
uint64_t l2_and_range = params[param_idx]; uint64_t l2_and_range = params[param_idx];
unsigned char *aux = (unsigned char*)&l2_and_range; uint32_t l2_bits, range_bits;
memcpy(&node->l2, aux, sizeof(float)); l2_bits = l2_and_range & 0xffffffff;
memcpy(&node->quants_range, aux+4, sizeof(float)); range_bits = l2_and_range >> 32;
memcpy(&node->l2, &l2_bits, sizeof(float));
memcpy(&node->quants_range, &range_bits, sizeof(float));
node->value = value; node->value = value;
hnsw_add_node(index, node); hnsw_add_node(index, node);