diff --git a/modules/vector-sets/hnsw.c b/modules/vector-sets/hnsw.c index cd38e4454..701c1f0d2 100644 --- a/modules/vector-sets/hnsw.c +++ b/modules/vector-sets/hnsw.c @@ -2058,6 +2058,18 @@ hnswNode *hnsw_random_node(HNSW *index, int slot) { * hash table, then scan all the nodes again and fix all the links converting * the ID to the pointer. */ +/* History of serialization versions: + * version 0: the first implementation, lacking worst node id/info. + * version 1: includes worst link id/info. */ +#define HNSW_SERIALIZATION_VERSION 1 + +/* This is a special worst link index that is set when loading a serialized + * node with version 0 (this version of the serialization lacked explicit + * information about the worst link index/distance). This way, later, the + * function that fixes a deserialized index will know to compute the worst + * index info at runtime. */ +#define HNSW_SER_WORSTLINK_MISSING UINT32_MAX + /* Return the serialized node information as specified in the top comment * above. Note that the returned information is true as long as the node * provided is not deleted or modified, so this function should be called @@ -2073,6 +2085,7 @@ hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node) { for (uint32_t i = 0; i <= node->level; i++) { num_params += 2; // max_links and num_links info for this layer. num_params += node->layers[i].num_links; // The IDs of linked nodes. + num_params += 1; // worst link id/distance parameter. } /* We use another 64bit value to store two floats that are about @@ -2096,18 +2109,43 @@ hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node) { uint32_t param_idx = 0; sn->params[param_idx++] = node->id; - sn->params[param_idx++] = node->level; + /* The second parameter contains information about the serialization + * version of this node, the node level and some unused field: + * + * +--------+--------+--------+--------+ + * |VVVVVVVV|........|........|LLLLLLLL| + * +--------+--------+--------+--------+ + * + * V is the version, 8 bits. + * L is the node level, 8 bits (but actually 16 is the max so far). + * The middle two bytes are reserved for future uses. */ + sn->params[param_idx] = node->level & 0xff; + sn->params[param_idx] |= HNSW_SERIALIZATION_VERSION << 24; + param_idx++; for (uint32_t i = 0; i <= node->level; i++) { sn->params[param_idx++] = node->layers[i].num_links; sn->params[param_idx++] = node->layers[i].max_links; for (uint32_t j = 0; j < node->layers[i].num_links; j++) { sn->params[param_idx++] = node->layers[i].links[j]->id; } + /* Since version 1: pack and store worst_idx and worst_distance. */ + uint32_t worst_distance_bits; + memcpy(&worst_distance_bits, &node->layers[i].worst_distance, + sizeof(float)); + uint64_t wi = + (((uint64_t)worst_distance_bits) << 32) | node->layers[i].worst_idx; + sn->params[param_idx++] = wi; } - uint64_t l2_and_range = 0; - unsigned char *aux = (unsigned char*)&l2_and_range; - memcpy(aux,&node->l2,sizeof(float)); - memcpy(aux+4,&node->quants_range,sizeof(float)); + + /* Store l2 and range as uint32_t, in a way that is endian-safe. + * Note that in big endian archs both are reversed: integers and + * also the bytes of floats, so they will match. */ + uint64_t l2_and_range; + uint32_t l2_bits, range_bits; + memcpy(&l2_bits,&node->l2,sizeof(float)); + memcpy(&range_bits,&node->quants_range,sizeof(float)); + l2_and_range = ((uint64_t)range_bits<<32) | l2_bits; + sn->params[param_idx++] = l2_and_range; /* Better safe than sorry: */ @@ -2128,13 +2166,18 @@ void hnsw_free_serialized_node(hnswSerNode *sn) { * The function returns NULL both on out of memory and if the remaining * parameters length does not match the number of links or other items * to load. */ -#define HNSW_SER_WORSTLINK_MISSING UINT32_MAX hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value) { if (params_len < 2) return NULL; uint64_t id = params[0]; - uint32_t level = params[1]; + /* Check the node serialization function for the specific layout + * of param[1] fields. */ + uint32_t level = params[1] & 0xff; // Node level. + uint32_t version = (params[1] & 0xff000000) >> 24; // Format version. + + if (version > HNSW_SERIALIZATION_VERSION) return NULL; + int has_worst_link_info = version > 0; /* Keep track of maximum ID seen while loading. */ if (id >= index->last_id) index->last_id = id; @@ -2152,7 +2195,7 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui uint32_t param_idx = 2; for (uint32_t i = 0; i <= level; i++) { /* Sanity check. */ - if (param_idx + 2 > params_len) { + if (param_idx + 2 + has_worst_link_info > params_len) { hnsw_node_free(node); return NULL; } @@ -2183,7 +2226,7 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui node->layers[i].num_links = num_links; /* Sanity check. */ - if (param_idx + num_links > params_len) { + if (param_idx + num_links + has_worst_link_info > params_len) { hnsw_node_free(node); return NULL; } @@ -2195,9 +2238,26 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui for (uint32_t j = 0; j < num_links; j++) node->layers[i].links[j] = (hnswNode*)params[param_idx++]; - /* XXX: fix me, we need to store the worst link info in a - * backward compatible way. */ - node->layers[i].worst_idx = HNSW_SER_WORSTLINK_MISSING; + if (has_worst_link_info) { + uint64_t wi = params[param_idx++]; + uint32_t worst_idx = wi & 0xffffffff; + uint32_t worst_distance_bits = wi >> 32; + float worst_distance; + memcpy(&worst_distance,&worst_distance_bits,sizeof(float)); + node->layers[i].worst_idx = worst_idx; + node->layers[i].worst_distance = worst_distance; + + // Sanity check the worst ID range. + if (node->layers[i].num_links > 0 && + node->layers[i].worst_idx >= node->layers[i].num_links) + { + hnsw_node_free(node); + return NULL; + } + } else { + node->layers[i].worst_idx = HNSW_SER_WORSTLINK_MISSING; + node->layers[i].worst_distance = 0; + } } /* Get l2 and quantization range. */ @@ -2205,10 +2265,14 @@ hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, ui hnsw_node_free(node); return NULL; } + + /* Load l2 and range packed into an uint64_t in an endian safe way. */ uint64_t l2_and_range = params[param_idx]; - unsigned char *aux = (unsigned char*)&l2_and_range; - memcpy(&node->l2, aux, sizeof(float)); - memcpy(&node->quants_range, aux+4, sizeof(float)); + uint32_t l2_bits, range_bits; + l2_bits = l2_and_range & 0xffffffff; + range_bits = l2_and_range >> 32; + memcpy(&node->l2, &l2_bits, sizeof(float)); + memcpy(&node->quants_range, &range_bits, sizeof(float)); node->value = value; hnsw_add_node(index, node);