Handle soar assignments when vector and centroid are very close (#130206)
This commit is contained in:
parent
745d71fdbb
commit
23cd462e07
|
@ -9,25 +9,23 @@
|
|||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.internal.hppc.IntArrayList;
|
||||
|
||||
final class CentroidAssignments {
|
||||
|
||||
private final int numCentroids;
|
||||
private final float[][] cachedCentroids;
|
||||
private final IntArrayList[] assignmentsByCluster;
|
||||
private final int[][] assignmentsByCluster;
|
||||
|
||||
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, IntArrayList[] assignmentsByCluster) {
|
||||
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, int[][] assignmentsByCluster) {
|
||||
this.numCentroids = numCentroids;
|
||||
this.cachedCentroids = cachedCentroids;
|
||||
this.assignmentsByCluster = assignmentsByCluster;
|
||||
}
|
||||
|
||||
CentroidAssignments(float[][] centroids, IntArrayList[] assignmentsByCluster) {
|
||||
CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) {
|
||||
this(centroids.length, centroids, assignmentsByCluster);
|
||||
}
|
||||
|
||||
CentroidAssignments(int numCentroids, IntArrayList[] assignmentsByCluster) {
|
||||
CentroidAssignments(int numCentroids, int[][] assignmentsByCluster) {
|
||||
this(numCentroids, null, assignmentsByCluster);
|
||||
}
|
||||
|
||||
|
@ -40,7 +38,7 @@ final class CentroidAssignments {
|
|||
return cachedCentroids;
|
||||
}
|
||||
|
||||
public IntArrayList[] assignmentsByCluster() {
|
||||
public int[][] assignmentsByCluster() {
|
||||
return assignmentsByCluster;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ import org.apache.lucene.index.FieldInfo;
|
|||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.internal.hppc.IntArrayList;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
@ -27,6 +26,7 @@ import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
|||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
|
||||
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
|
||||
|
@ -53,7 +53,7 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
CentroidSupplier centroidSupplier,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput,
|
||||
IntArrayList[] assignmentsByCluster
|
||||
int[][] assignmentsByCluster
|
||||
) throws IOException {
|
||||
// write the posting lists
|
||||
final long[] offsets = new long[centroidSupplier.size()];
|
||||
|
@ -65,16 +65,16 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
float[] centroid = centroidSupplier.centroid(c);
|
||||
binarizedByteVectorValues.centroid = centroid;
|
||||
// TODO: add back in sorting vectors by distance to centroid
|
||||
IntArrayList cluster = assignmentsByCluster[c];
|
||||
int[] cluster = assignmentsByCluster[c];
|
||||
// TODO align???
|
||||
offsets[c] = postingsOutput.getFilePointer();
|
||||
int size = cluster.size();
|
||||
int size = cluster.length;
|
||||
postingsOutput.writeVInt(size);
|
||||
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
|
||||
// TODO we might want to consider putting the docIds in a separate file
|
||||
// to aid with only having to fetch vectors from slower storage when they are required
|
||||
// keeping them in the same file indicates we pull the entire file into cache
|
||||
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
|
||||
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
|
||||
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
|
||||
}
|
||||
|
||||
|
@ -85,23 +85,23 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
return offsets;
|
||||
}
|
||||
|
||||
private static void printClusterQualityStatistics(IntArrayList[] clusters) {
|
||||
private static void printClusterQualityStatistics(int[][] clusters) {
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = Float.MIN_VALUE;
|
||||
float mean = 0;
|
||||
float m2 = 0;
|
||||
// iteratively compute the variance & mean
|
||||
int count = 0;
|
||||
for (IntArrayList cluster : clusters) {
|
||||
for (int[] cluster : clusters) {
|
||||
count += 1;
|
||||
if (cluster == null) {
|
||||
continue;
|
||||
}
|
||||
float delta = cluster.size() - mean;
|
||||
float delta = cluster.length - mean;
|
||||
mean += delta / count;
|
||||
m2 += delta * (cluster.size() - mean);
|
||||
min = Math.min(min, cluster.size());
|
||||
max = Math.max(max, cluster.size());
|
||||
m2 += delta * (cluster.length - mean);
|
||||
min = Math.min(min, cluster.length);
|
||||
max = Math.max(max, cluster.length);
|
||||
}
|
||||
float variance = m2 / (clusters.length - 1);
|
||||
logger.debug(
|
||||
|
@ -115,16 +115,16 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
);
|
||||
}
|
||||
|
||||
private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
|
||||
private void writePostingList(int[] cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
|
||||
throws IOException {
|
||||
int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1;
|
||||
int limit = cluster.length - ES91OSQVectorsScorer.BULK_SIZE + 1;
|
||||
int cidx = 0;
|
||||
OptimizedScalarQuantizer.QuantizationResult[] corrections =
|
||||
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
|
||||
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
|
||||
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
int ord = cluster.get(cidx + j);
|
||||
int ord = cluster[cidx + j];
|
||||
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
||||
// write vector
|
||||
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
|
||||
|
@ -147,8 +147,8 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
}
|
||||
}
|
||||
// write tail
|
||||
for (; cidx < cluster.size(); cidx++) {
|
||||
int ord = cluster.get(cidx);
|
||||
for (; cidx < cluster.length; cidx++) {
|
||||
int ord = cluster[cidx];
|
||||
// write vector
|
||||
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
||||
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||
|
@ -261,23 +261,31 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
logger.debug("final centroid count: {}", centroids.length);
|
||||
}
|
||||
|
||||
IntArrayList[] assignmentsByCluster = new IntArrayList[centroids.length];
|
||||
int[] centroidVectorCount = new int[centroids.length];
|
||||
for (int i = 0; i < assignments.length; i++) {
|
||||
centroidVectorCount[assignments[i]]++;
|
||||
// if soar assignments are present, count them as well
|
||||
if (soarAssignments.length > i && soarAssignments[i] != -1) {
|
||||
centroidVectorCount[soarAssignments[i]]++;
|
||||
}
|
||||
}
|
||||
|
||||
int[][] assignmentsByCluster = new int[centroids.length][];
|
||||
for (int c = 0; c < centroids.length; c++) {
|
||||
IntArrayList cluster = new IntArrayList(vectorPerCluster);
|
||||
for (int j = 0; j < assignments.length; j++) {
|
||||
if (assignments[j] == c) {
|
||||
cluster.add(j);
|
||||
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
|
||||
}
|
||||
Arrays.fill(centroidVectorCount, 0);
|
||||
|
||||
for (int i = 0; i < assignments.length; i++) {
|
||||
int c = assignments[i];
|
||||
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
|
||||
// if soar assignments are present, add them to the cluster as well
|
||||
if (soarAssignments.length > i) {
|
||||
int s = soarAssignments[i];
|
||||
if (s != -1) {
|
||||
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < soarAssignments.length; j++) {
|
||||
if (soarAssignments[j] == c) {
|
||||
cluster.add(j);
|
||||
}
|
||||
}
|
||||
|
||||
cluster.trimToSize();
|
||||
assignmentsByCluster[c] = cluster;
|
||||
}
|
||||
|
||||
if (cacheCentroids) {
|
||||
|
|
|
@ -23,7 +23,6 @@ import org.apache.lucene.index.SegmentWriteState;
|
|||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntArrayList;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
@ -140,7 +139,7 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
|||
CentroidSupplier centroidSupplier,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput,
|
||||
IntArrayList[] assignmentsByCluster
|
||||
int[][] assignmentsByCluster
|
||||
) throws IOException;
|
||||
|
||||
abstract CentroidSupplier createCentroidSupplier(
|
||||
|
|
|
@ -26,6 +26,11 @@ import java.util.Random;
|
|||
*/
|
||||
class KMeansLocal {
|
||||
|
||||
// the minimum distance that is considered to be "far enough" to a centroid in order to compute the soar distance.
|
||||
// For vectors that are closer than this distance to the centroid, we use the squared distance to find the
|
||||
// second closest centroid.
|
||||
private static final float SOAR_MIN_DISTANCE = 1e-16f;
|
||||
|
||||
final int sampleSize;
|
||||
final int maxIterations;
|
||||
final int clustersPerNeighborhood;
|
||||
|
@ -190,15 +195,18 @@ class KMeansLocal {
|
|||
|
||||
int currAssignment = assignments[i];
|
||||
float[] currentCentroid = centroids[currAssignment];
|
||||
for (int j = 0; j < vectors.dimension(); j++) {
|
||||
float diff = vector[j] - currentCentroid[j];
|
||||
diffs[j] = diff;
|
||||
}
|
||||
|
||||
// TODO: cache these?
|
||||
// float vectorCentroidDist = assignmentDistances[i];
|
||||
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);
|
||||
|
||||
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
|
||||
for (int j = 0; j < vectors.dimension(); j++) {
|
||||
float diff = vector[j] - currentCentroid[j];
|
||||
diffs[j] = diff;
|
||||
}
|
||||
}
|
||||
|
||||
int bestAssignment = -1;
|
||||
float minSoar = Float.MAX_VALUE;
|
||||
assert neighborhoods.get(currAssignment) != null;
|
||||
|
@ -207,13 +215,19 @@ class KMeansLocal {
|
|||
continue;
|
||||
}
|
||||
float[] neighborCentroid = centroids[neighbor];
|
||||
float soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
|
||||
final float soar;
|
||||
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
|
||||
soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
|
||||
} else {
|
||||
// if the vector is very close to the centroid, we look for the second-nearest centroid
|
||||
soar = VectorUtil.squareDistance(vector, neighborCentroid);
|
||||
}
|
||||
if (soar < minSoar) {
|
||||
bestAssignment = neighbor;
|
||||
minSoar = soar;
|
||||
}
|
||||
}
|
||||
|
||||
assert bestAssignment != -1 : "Failed to assign soar vector to centroid";
|
||||
spilledAssignments[i] = bestAssignment;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue