Handle soar assignments when vector and centroid are very close (#130206)

This commit is contained in:
Ignacio Vera 2025-06-30 08:50:36 +02:00 committed by GitHub
parent 745d71fdbb
commit 23cd462e07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 45 deletions

View File

@ -9,25 +9,23 @@
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.internal.hppc.IntArrayList;
final class CentroidAssignments {
private final int numCentroids;
private final float[][] cachedCentroids;
private final IntArrayList[] assignmentsByCluster;
private final int[][] assignmentsByCluster;
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, IntArrayList[] assignmentsByCluster) {
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, int[][] assignmentsByCluster) {
this.numCentroids = numCentroids;
this.cachedCentroids = cachedCentroids;
this.assignmentsByCluster = assignmentsByCluster;
}
CentroidAssignments(float[][] centroids, IntArrayList[] assignmentsByCluster) {
CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) {
this(centroids.length, centroids, assignmentsByCluster);
}
CentroidAssignments(int numCentroids, IntArrayList[] assignmentsByCluster) {
CentroidAssignments(int numCentroids, int[][] assignmentsByCluster) {
this(numCentroids, null, assignmentsByCluster);
}
@ -40,7 +38,7 @@ final class CentroidAssignments {
return cachedCentroids;
}
public IntArrayList[] assignmentsByCluster() {
public int[][] assignmentsByCluster() {
return assignmentsByCluster;
}
}

View File

@ -14,7 +14,6 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.internal.hppc.IntArrayList;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.VectorUtil;
@ -27,6 +26,7 @@ import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
@ -53,7 +53,7 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
IntArrayList[] assignmentsByCluster
int[][] assignmentsByCluster
) throws IOException {
// write the posting lists
final long[] offsets = new long[centroidSupplier.size()];
@ -65,16 +65,16 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
float[] centroid = centroidSupplier.centroid(c);
binarizedByteVectorValues.centroid = centroid;
// TODO: add back in sorting vectors by distance to centroid
IntArrayList cluster = assignmentsByCluster[c];
int[] cluster = assignmentsByCluster[c];
// TODO align???
offsets[c] = postingsOutput.getFilePointer();
int size = cluster.size();
int size = cluster.length;
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
}
@ -85,23 +85,23 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
return offsets;
}
private static void printClusterQualityStatistics(IntArrayList[] clusters) {
private static void printClusterQualityStatistics(int[][] clusters) {
float min = Float.MAX_VALUE;
float max = Float.MIN_VALUE;
float mean = 0;
float m2 = 0;
// iteratively compute the variance & mean
int count = 0;
for (IntArrayList cluster : clusters) {
for (int[] cluster : clusters) {
count += 1;
if (cluster == null) {
continue;
}
float delta = cluster.size() - mean;
float delta = cluster.length - mean;
mean += delta / count;
m2 += delta * (cluster.size() - mean);
min = Math.min(min, cluster.size());
max = Math.max(max, cluster.size());
m2 += delta * (cluster.length - mean);
min = Math.min(min, cluster.length);
max = Math.max(max, cluster.length);
}
float variance = m2 / (clusters.length - 1);
logger.debug(
@ -115,16 +115,16 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
);
}
private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
private void writePostingList(int[] cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
throws IOException {
int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1;
int limit = cluster.length - ES91OSQVectorsScorer.BULK_SIZE + 1;
int cidx = 0;
OptimizedScalarQuantizer.QuantizationResult[] corrections =
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
int ord = cluster.get(cidx + j);
int ord = cluster[cidx + j];
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
// write vector
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
@ -147,8 +147,8 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
}
}
// write tail
for (; cidx < cluster.size(); cidx++) {
int ord = cluster.get(cidx);
for (; cidx < cluster.length; cidx++) {
int ord = cluster[cidx];
// write vector
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
@ -261,23 +261,31 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
logger.debug("final centroid count: {}", centroids.length);
}
IntArrayList[] assignmentsByCluster = new IntArrayList[centroids.length];
int[] centroidVectorCount = new int[centroids.length];
for (int i = 0; i < assignments.length; i++) {
centroidVectorCount[assignments[i]]++;
// if soar assignments are present, count them as well
if (soarAssignments.length > i && soarAssignments[i] != -1) {
centroidVectorCount[soarAssignments[i]]++;
}
}
int[][] assignmentsByCluster = new int[centroids.length][];
for (int c = 0; c < centroids.length; c++) {
IntArrayList cluster = new IntArrayList(vectorPerCluster);
for (int j = 0; j < assignments.length; j++) {
if (assignments[j] == c) {
cluster.add(j);
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
}
Arrays.fill(centroidVectorCount, 0);
for (int i = 0; i < assignments.length; i++) {
int c = assignments[i];
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
// if soar assignments are present, add them to the cluster as well
if (soarAssignments.length > i) {
int s = soarAssignments[i];
if (s != -1) {
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
}
}
for (int j = 0; j < soarAssignments.length; j++) {
if (soarAssignments[j] == c) {
cluster.add(j);
}
}
cluster.trimToSize();
assignmentsByCluster[c] = cluster;
}
if (cacheCentroids) {

View File

@ -23,7 +23,6 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntArrayList;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
@ -140,7 +139,7 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
IntArrayList[] assignmentsByCluster
int[][] assignmentsByCluster
) throws IOException;
abstract CentroidSupplier createCentroidSupplier(

View File

@ -26,6 +26,11 @@ import java.util.Random;
*/
class KMeansLocal {
// the minimum distance that is considered to be "far enough" to a centroid in order to compute the soar distance.
// For vectors that are closer than this distance to the centroid, we use the squared distance to find the
// second closest centroid.
private static final float SOAR_MIN_DISTANCE = 1e-16f;
final int sampleSize;
final int maxIterations;
final int clustersPerNeighborhood;
@ -190,15 +195,18 @@ class KMeansLocal {
int currAssignment = assignments[i];
float[] currentCentroid = centroids[currAssignment];
for (int j = 0; j < vectors.dimension(); j++) {
float diff = vector[j] - currentCentroid[j];
diffs[j] = diff;
}
// TODO: cache these?
// float vectorCentroidDist = assignmentDistances[i];
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
for (int j = 0; j < vectors.dimension(); j++) {
float diff = vector[j] - currentCentroid[j];
diffs[j] = diff;
}
}
int bestAssignment = -1;
float minSoar = Float.MAX_VALUE;
assert neighborhoods.get(currAssignment) != null;
@ -207,13 +215,19 @@ class KMeansLocal {
continue;
}
float[] neighborCentroid = centroids[neighbor];
float soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
final float soar;
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
} else {
// if the vector is very close to the centroid, we look for the second-nearest centroid
soar = VectorUtil.squareDistance(vector, neighborCentroid);
}
if (soar < minSoar) {
bestAssignment = neighbor;
minSoar = soar;
}
}
assert bestAssignment != -1 : "Failed to assign soar vector to centroid";
spilledAssignments[i] = bestAssignment;
}