IVF Hierarchical KMeans Flush & Merge (#128675)

added hierarchical kmeans as a clustering algorithm to better partitionin the space when running ivf on flush and merge
This commit is contained in:
John Wagster 2025-06-10 15:19:27 -05:00 committed by GitHub
parent 1e13409049
commit 47d4b983af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1099 additions and 1237 deletions

View File

@ -276,12 +276,11 @@ class KnnSearcher {
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException {
Query knnQuery;
int topK = this.topK;
int efSearch = this.efSearch;
if (overSamplingFactor > 1f) {
// oversample the topK results to get more candidates for the final result
topK = (int) Math.ceil(topK * overSamplingFactor);
efSearch = Math.max(topK, efSearch);
}
int efSearch = Math.max(topK, this.efSearch);
if (indexType == KnnIndexTester.IndexType.IVF) {
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, null, nProbe);
} else {

View File

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.internal.hppc.IntArrayList;
final class CentroidAssignments {
private final int numCentroids;
private final float[][] cachedCentroids;
private final IntArrayList[] assignmentsByCluster;
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, IntArrayList[] assignmentsByCluster) {
this.numCentroids = numCentroids;
this.cachedCentroids = cachedCentroids;
this.assignmentsByCluster = assignmentsByCluster;
}
CentroidAssignments(float[][] centroids, IntArrayList[] assignmentsByCluster) {
this(centroids.length, centroids, assignmentsByCluster);
}
CentroidAssignments(int numCentroids, IntArrayList[] assignmentsByCluster) {
this(numCentroids, null, assignmentsByCluster);
}
// Getters and setters
public int numCentroids() {
return numCentroids;
}
public float[][] cachedCentroids() {
return cachedCentroids;
}
public IntArrayList[] assignmentsByCluster() {
return assignmentsByCluster;
}
}

View File

@ -112,15 +112,6 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader {
};
}
@Override
protected FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) {
FieldEntry entry = fields.get(info.number);
if (entry == null) {
return null;
}
return new OffHeapCentroidFloatVectorValues(numCentroids, indexInput, info.getVectorDimension());
}
@Override
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
throws IOException {

View File

@ -14,21 +14,19 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntArrayList;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ESVectorUtil;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
@ -36,16 +34,12 @@ import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packA
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT;
/**
* Default implementation of {@link IVFVectorsWriter}. It uses {@link KMeans} algorithm to
* partition the vector space, and then stores the centroids an posting list in a sequential
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
* partition the vector space, and then stores the centroids and posting list in a sequential
* fashion.
*/
public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
static final float SOAR_LAMBDA = 1.0f;
// What percentage of the centroids do we do a second check on for SOAR assignment
static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f;
private final int vectorPerCluster;
public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException {
@ -53,77 +47,81 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
this.vectorPerCluster = vectorPerCluster;
}
@Override
CentroidAssignmentScorer calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
float[] globalCentroid
) throws IOException {
if (floatVectorValues.size() == 0) {
return CentroidAssignmentScorer.EMPTY;
}
// calculate the centroids
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
final KMeans.Results kMeans = KMeans.cluster(
floatVectorValues,
desiredClusters,
false,
42L,
KMeans.KmeansInitializationMethod.PLUS_PLUS,
null,
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
1,
15,
desiredClusters * 256
);
float[][] centroids = kMeans.centroids();
// write them
writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);
return new OnHeapCentroidAssignmentScorer(centroids);
}
@Override
long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
InfoStream infoStream,
CentroidAssignmentScorer randomCentroidScorer,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput
IndexOutput postingsOutput,
InfoStream infoStream,
IntArrayList[] assignmentsByCluster
) throws IOException {
IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()];
for (int i = 0; i < randomCentroidScorer.size(); i++) {
clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4);
}
assignCentroids(randomCentroidScorer, floatVectorValues, clusters);
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
printClusterQualityStatistics(clusters, infoStream);
}
// write the posting lists
final long[] offsets = new long[randomCentroidScorer.size()];
final long[] offsets = new long[centroidSupplier.size()];
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
DocIdsWriter docIdsWriter = new DocIdsWriter();
for (int i = 0; i < randomCentroidScorer.size(); i++) {
float[] centroid = randomCentroidScorer.centroid(i);
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
binarizedByteVectorValues.centroid = centroid;
// TODO sort by distance to the centroid
IntArrayList cluster = clusters[i];
// TODO: add back in sorting vectors by distance to centroid
IntArrayList cluster = assignmentsByCluster[c];
// TODO align???
offsets[i] = postingsOutput.getFilePointer();
offsets[c] = postingsOutput.getFilePointer();
int size = cluster.size();
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput);
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
}
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
printClusterQualityStatistics(assignmentsByCluster, infoStream);
}
return offsets;
}
private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) {
float min = Float.MAX_VALUE;
float max = Float.MIN_VALUE;
float mean = 0;
float m2 = 0;
// iteratively compute the variance & mean
int count = 0;
for (IntArrayList cluster : clusters) {
count += 1;
if (cluster == null) {
continue;
}
float delta = cluster.size() - mean;
mean += delta / count;
m2 += delta * (cluster.size() - mean);
min = Math.min(min, cluster.size());
max = Math.max(max, cluster.size());
}
float variance = m2 / (clusters.length - 1);
infoStream.message(
IVF_VECTOR_COMPONENT,
"Centroid count: "
+ clusters.length
+ " min: "
+ min
+ " max: "
+ max
+ " mean: "
+ mean
+ " stdDev: "
+ Math.sqrt(variance)
+ " variance: "
+ variance
);
}
private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
throws IOException {
int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1;
@ -173,13 +171,8 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
}
@Override
CentroidAssignmentScorer createCentroidScorer(
IndexInput centroidsInput,
int numCentroids,
FieldInfo fieldInfo,
float[] globalCentroid
) {
return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo);
CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
}
static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
@ -188,24 +181,8 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
// TODO do we want to store these distances as well for future use?
float[] distances = new float[centroids.length];
for (int i = 0; i < centroids.length; i++) {
distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid);
}
// sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest
// (largest)
for (int i = 0; i < centroids.length; i++) {
for (int j = i + 1; j < centroids.length; j++) {
if (distances[i] > distances[j]) {
float[] tmp = centroids[i];
centroids[i] = centroids[j];
centroids[j] = tmp;
float tmpDistance = distances[i];
distances[i] = distances[j];
distances[j] = tmpDistance;
}
}
}
// TODO: sort centroids by global centroid (was doing so previously here)
// TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
for (float[] centroid : centroids) {
System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(
@ -223,190 +200,60 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
}
}
static float[][] gatherInitCentroids(
List<FloatVectorValues> centroidList,
List<SegmentCentroid> segmentCentroids,
int desiredClusters,
FieldInfo fieldInfo,
MergeState mergeState
) throws IOException {
if (centroidList.size() == 0) {
return null;
}
long startTime = System.nanoTime();
// sort centroid list by floatvector size
FloatVectorValues baseSegment = centroidList.get(0);
for (var l : centroidList) {
if (l.size() > baseSegment.size()) {
baseSegment = l;
}
}
float[] scratch = new float[fieldInfo.getVectorDimension()];
float minimumDistance = Float.MAX_VALUE;
for (int j = 0; j < baseSegment.size(); j++) {
System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension());
for (int k = j + 1; k < baseSegment.size(); k++) {
float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k));
if (d < minimumDistance) {
minimumDistance = d;
}
}
}
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
mergeState.infoStream.message(
IVF_VECTOR_COMPONENT,
"Agglomerative cluster min distance: " + minimumDistance + " From biggest segment: " + baseSegment.size()
);
}
int[] labels = new int[segmentCentroids.size()];
// loop over segments
int clusterIdx = 0;
// keep track of all inter-centroid distances,
// using less than centroid * centroid space (e.g. not keeping track of duplicates)
for (int i = 0; i < segmentCentroids.size(); i++) {
if (labels[i] == 0) {
clusterIdx += 1;
labels[i] = clusterIdx;
}
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
System.arraycopy(
centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid),
0,
scratch,
0,
baseSegment.dimension()
);
for (int j = i + 1; j < segmentCentroids.size(); j++) {
float d = VectorUtil.squareDistance(
scratch,
centroidList.get(segmentCentroids.get(j).segment()).vectorValue(segmentCentroids.get(j).centroid())
);
if (d < minimumDistance / 2) {
if (labels[j] == 0) {
labels[j] = labels[i];
} else {
for (int k = 0; k < labels.length; k++) {
if (labels[k] == labels[j]) {
labels[k] = labels[i];
}
}
}
}
}
}
float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()];
int[] sum = new int[clusterIdx];
for (int i = 0; i < segmentCentroids.size(); i++) {
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
int label = labels[i];
FloatVectorValues segment = centroidList.get(segmentCentroid.segment());
float[] vector = segment.vectorValue(segmentCentroid.centroid);
for (int j = 0; j < vector.length; j++) {
initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize);
}
sum[label - 1] += segmentCentroid.centroidSize;
}
for (int i = 0; i < initCentroids.length; i++) {
if (sum[i] == 0 || sum[i] == 1) {
continue;
}
for (int j = 0; j < initCentroids[i].length; j++) {
initCentroids[i][j] /= sum[i];
}
}
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
mergeState.infoStream.message(
IVF_VECTOR_COMPONENT,
"Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0)
);
mergeState.infoStream.message(
IVF_VECTOR_COMPONENT,
"Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters
);
}
return initCentroids;
}
record SegmentCentroid(int segment, int centroid, int centroidSize) {}
/**
* Calculate the centroids for the given field and write them to the given
* temporary centroid output.
* When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments.
* To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than
* the largest segments intra-cluster distance are merged into a single centroid.
* The resulting centroids are then used to initialize the KMeans algorithm.
*
* @param fieldInfo merging field info
* @param floatVectorValues the float vector values to merge
* @param temporaryCentroidOutput the temporary centroid output
* @param mergeState the merge state
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
* @return the number of centroids written
* @throws IOException if an I/O error occurs
*/
@Override
protected int calculateAndWriteCentroids(
CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput temporaryCentroidOutput,
IndexOutput centroidOutput,
MergeState mergeState,
float[] globalCentroid
) throws IOException {
if (floatVectorValues.size() == 0) {
return 0;
}
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
// init centroids from merge state
List<FloatVectorValues> centroidList = new ArrayList<>();
List<SegmentCentroid> segmentCentroids = new ArrayList<>(desiredClusters);
int segmentIdx = 0;
for (var reader : mergeState.knnVectorsReaders) {
IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name);
if (ivfVectorsReader == null) {
continue;
// TODO: take advantage of prior generated clusters from mergeState in the future
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, mergeState.infoStream, globalCentroid, false);
}
FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo);
if (centroid == null) {
continue;
}
centroidList.add(centroid);
for (int i = 0; i < centroid.size(); i++) {
int size = ivfVectorsReader.centroidSize(fieldInfo.name, i);
if (size == 0) {
continue;
}
segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size));
}
segmentIdx++;
CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
InfoStream infoStream,
float[] globalCentroid
) throws IOException {
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, infoStream, globalCentroid, true);
}
float[][] initCentroids = gatherInitCentroids(centroidList, segmentCentroids, desiredClusters, fieldInfo, mergeState);
/**
* Calculate the centroids for the given field and write them to the given centroid output.
* We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
*
* @param fieldInfo merging field info
* @param floatVectorValues the float vector values to merge
* @param centroidOutput the centroid output
* @param infoStream the merge state
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
* @param cacheCentroids whether the centroids are kept or discarded once computed
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
* @throws IOException if an I/O error occurs
*/
CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
InfoStream infoStream,
float[] globalCentroid,
boolean cacheCentroids
) throws IOException {
// FIXME: run a custom version of KMeans that is just better...
long nanoTime = System.nanoTime();
final KMeans.Results kMeans = KMeans.cluster(
floatVectorValues,
desiredClusters,
false,
42L,
KMeans.KmeansInitializationMethod.PLUS_PLUS,
initCentroids,
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
1,
5,
desiredClusters * 64
);
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
}
float[][] centroids = kMeans.centroids();
// write them
// calculate the global centroid from all the centroids:
// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
int[] assignments = kMeansResult.assignments();
int[] soarAssignments = kMeansResult.soarAssignments();
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
// TODO: push this logic into vector util?
for (float[] centroid : centroids) {
for (int j = 0; j < centroid.length; j++) {
globalCentroid[j] += centroid[j];
@ -415,197 +262,41 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
for (int j = 0; j < globalCentroid.length; j++) {
globalCentroid[j] /= centroids.length;
}
writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput);
return centroids.length;
}
@Override
long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidAssignmentScorer centroidAssignmentScorer,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
MergeState mergeState
) throws IOException {
IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()];
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
clusters[i] = new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4);
}
long nanoTime = System.nanoTime();
// Can we do a pre-filter by finding the nearest centroids to the original vector centroids?
// We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing
assignCentroids(centroidAssignmentScorer, floatVectorValues, clusters);
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
}
// write centroids
writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
printClusterQualityStatistics(clusters, mergeState.infoStream);
}
// write the posting lists
final long[] offsets = new long[centroidAssignmentScorer.size()];
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
DocIdsWriter docIdsWriter = new DocIdsWriter();
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
float[] centroid = centroidAssignmentScorer.centroid(i);
binarizedByteVectorValues.centroid = centroid;
// TODO: sort by distance to the centroid
IntArrayList cluster = clusters[i];
// TODO align???
offsets[i] = postingsOutput.getFilePointer();
int size = cluster.size();
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
}
return offsets;
}
private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) {
float min = Float.MAX_VALUE;
float max = Float.MIN_VALUE;
float mean = 0;
float m2 = 0;
// iteratively compute the variance & mean
int count = 0;
for (IntArrayList cluster : clusters) {
count += 1;
if (cluster == null) {
continue;
}
float delta = cluster.size() - mean;
mean += delta / count;
m2 += delta * (cluster.size() - mean);
min = Math.min(min, cluster.size());
max = Math.max(max, cluster.size());
}
float variance = m2 / (clusters.length - 1);
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
infoStream.message(
IVF_VECTOR_COMPONENT,
"Centroid count: "
+ clusters.length
+ " min: "
+ min
+ " max: "
+ max
+ " mean: "
+ mean
+ " stdDev: "
+ Math.sqrt(variance)
+ " variance: "
+ variance
"calculate centroids and assign vectors time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)
);
infoStream.message(IVF_VECTOR_COMPONENT, "final centroid count: " + centroids.length);
}
static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) throws IOException {
int numCentroids = scorer.size();
// we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible
int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO);
int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck);
NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true);
OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1);
float[] scratch = new float[vectors.dimension()];
for (int docID = 0; docID < vectors.size(); docID++) {
float[] vector = vectors.vectorValue(docID);
scorer.setScoringVector(vector);
int bestCentroid = 0;
float bestScore = Float.MAX_VALUE;
if (numCentroids > 1) {
for (short c = 0; c < numCentroids; c++) {
float squareDist = scorer.score(c);
neighborsToCheck.insertWithOverflow(c, squareDist);
}
// pop the best
int sz = neighborsToCheck.size();
int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores);
// Set the size to the number of neighbors we actually found
ordScoreIterator.setSize(sz);
bestScore = ordScoreIterator.getScore(best);
bestCentroid = ordScoreIterator.getOrd(best);
}
clusters[bestCentroid].add(docID);
if (soarClusterCheckCount > 0) {
assignCentroidSOAR(
ordScoreIterator,
docID,
bestCentroid,
scorer.centroid(bestCentroid),
bestScore,
scratch,
scorer,
vector,
clusters
);
}
neighborsToCheck.clear();
IntArrayList[] assignmentsByCluster = new IntArrayList[centroids.length];
for (int c = 0; c < centroids.length; c++) {
IntArrayList cluster = new IntArrayList(vectorPerCluster);
for (int j = 0; j < assignments.length; j++) {
if (assignments[j] == c) {
cluster.add(j);
}
}
static void assignCentroidSOAR(
OrdScoreIterator centroidsToCheck,
int vecOrd,
int bestCentroidId,
float[] bestCentroid,
float bestScore,
float[] scratch,
CentroidAssignmentScorer scorer,
float[] vector,
IntArrayList[] clusters
) throws IOException {
ESVectorUtil.subtract(vector, bestCentroid, scratch);
int bestSecondaryCentroid = -1;
float minDist = Float.MAX_VALUE;
for (int i = 0; i < centroidsToCheck.size(); i++) {
float score = centroidsToCheck.getScore(i);
int centroidOrdinal = centroidsToCheck.getOrd(i);
if (centroidOrdinal == bestCentroidId) {
continue;
}
float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch);
score += SOAR_LAMBDA * proj * proj / bestScore;
if (score < minDist) {
bestSecondaryCentroid = centroidOrdinal;
minDist = score;
}
}
if (bestSecondaryCentroid != -1) {
clusters[bestSecondaryCentroid].add(vecOrd);
for (int j = 0; j < soarAssignments.length; j++) {
if (soarAssignments[j] == c) {
cluster.add(j);
}
}
static class OrdScoreIterator {
private final int[] ords;
private final float[] scores;
private int idx = 0;
OrdScoreIterator(int size) {
this.ords = new int[size];
this.scores = new float[size];
cluster.trimToSize();
assignmentsByCluster[c] = cluster;
}
int setSize(int size) {
if (size > ords.length) {
throw new IllegalArgumentException("size must be <= " + ords.length);
}
this.idx = size;
return size;
}
int getOrd(int idx) {
return ords[idx];
}
float getScore(int idx) {
return scores[idx];
}
int size() {
return idx;
if (cacheCentroids) {
return new CentroidAssignments(centroids, assignmentsByCluster);
} else {
return new CentroidAssignments(centroids.length, assignmentsByCluster);
}
}
@ -650,16 +341,25 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
}
}
static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
throws IOException {
indexOutput.writeBytes(binaryValue, binaryValue.length);
indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval()));
indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
indexOutput.writeShort((short) corrections.quantizedComponentSum());
}
static class OffHeapCentroidSupplier implements CentroidSupplier {
private final IndexInput centroidsInput;
private final int numCentroids;
private final int dimension;
private final float[] scratch;
private float[] q;
private final long rawCentroidOffset;
private int currOrd = -1;
OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
OffHeapCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
this.centroidsInput = centroidsInput;
this.numCentroids = numCentroids;
this.dimension = info.getVectorDimension();
@ -682,55 +382,5 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
this.currOrd = centroidOrdinal;
return scratch;
}
@Override
public void setScoringVector(float[] vector) {
q = vector;
}
@Override
public float score(int centroidOrdinal) throws IOException {
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
}
}
// TODO throw away rawCentroids
static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
private final float[][] centroids;
private float[] q;
OnHeapCentroidAssignmentScorer(float[][] centroids) {
this.centroids = centroids;
}
@Override
public int size() {
return centroids.length;
}
@Override
public void setScoringVector(float[] vector) {
q = vector;
}
@Override
public float[] centroid(int centroidOrdinal) throws IOException {
return centroids[centroidOrdinal];
}
@Override
public float score(int centroidOrdinal) throws IOException {
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
}
}
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
throws IOException {
indexOutput.writeBytes(binaryValue, binaryValue.length);
indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval()));
indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
indexOutput.writeShort((short) corrections.quantizedComponentSum());
}
}

View File

@ -97,23 +97,6 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
IndexInput clusters
) throws IOException;
protected abstract FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException;
public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException {
FieldEntry entry = fields.get(fieldInfo.number);
if (entry == null) {
return null;
}
return getCentroids(entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo);
}
int centroidSize(String fieldName, int centroidOrdinal) throws IOException {
FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName);
FieldEntry entry = fields.get(fieldInfo.number);
ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]);
return ivfClusters.readVInt();
}
private static IndexInput openDataInput(
SegmentReadState state,
int versionMeta,

View File

@ -11,11 +11,9 @@ package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
@ -25,6 +23,7 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntArrayList;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
@ -123,38 +122,32 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
return rawVectorDelegate;
}
protected abstract int calculateAndWriteCentroids(
abstract CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput temporaryCentroidOutput,
IndexOutput centroidOutput,
MergeState mergeState,
float[] globalCentroid
) throws IOException;
abstract long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidAssignmentScorer scorer,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
MergeState mergeState
) throws IOException;
abstract CentroidAssignmentScorer calculateAndWriteCentroids(
abstract CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
InfoStream infoStream,
float[] globalCentroid
) throws IOException;
abstract long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
InfoStream infoStream,
CentroidAssignmentScorer scorer,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput
IndexOutput postingsOutput,
InfoStream infoStream,
IntArrayList[] assignmentsByCluster
) throws IOException;
abstract CentroidAssignmentScorer createCentroidScorer(
abstract CentroidSupplier createCentroidSupplier(
IndexInput centroidsInput,
int numCentroids,
FieldInfo fieldInfo,
@ -166,33 +159,31 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
rawVectorDelegate.flush(maxDoc, sortMap);
for (FieldWriter fieldWriter : fieldWriters) {
float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
// calculate global centroid
for (var vector : fieldWriter.delegate.getVectors()) {
for (int i = 0; i < globalCentroid.length; i++) {
globalCentroid[i] += vector[i];
}
}
for (int i = 0; i < globalCentroid.length; i++) {
globalCentroid[i] /= fieldWriter.delegate.getVectors().size();
}
// build a float vector values with random access
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
// build centroids
long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
final CentroidAssignmentScorer centroidAssignmentScorer = calculateAndWriteCentroids(
final CentroidAssignments centroidAssignments = calculateAndWriteCentroids(
fieldWriter.fieldInfo,
floatVectorValues,
ivfCentroids,
segmentWriteState.infoStream,
globalCentroid
);
CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.cachedCentroids());
long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
final long[] offsets = buildAndWritePostingsLists(
fieldWriter.fieldInfo,
segmentWriteState.infoStream,
centroidAssignmentScorer,
centroidSupplier,
floatVectorValues,
ivfClusters
ivfClusters,
segmentWriteState.infoStream,
centroidAssignments.assignmentsByCluster()
);
// write posting lists
writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
}
}
@ -240,16 +231,6 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
};
}
static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
vectorsReader = candidateReader.getFieldReader(fieldName);
}
if (vectorsReader instanceof IVFVectorsReader reader) {
return reader;
}
return null;
}
@Override
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
@ -277,22 +258,25 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
success = false;
CentroidAssignmentScorer centroidAssignmentScorer;
long centroidOffset;
long centroidLength;
String centroidTempName = null;
int numCentroids;
IndexOutput centroidTemp = null;
CentroidAssignments centroidAssignments;
try {
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
centroidTempName = centroidTemp.getName();
numCentroids = calculateAndWriteCentroids(
centroidAssignments = calculateAndWriteCentroids(
fieldInfo,
floatVectorValues,
centroidTemp,
mergeState,
calculatedGlobalCentroid
);
numCentroids = centroidAssignments.numCentroids();
success = true;
} finally {
if (success == false && centroidTempName != null) {
@ -311,21 +295,28 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
CodecUtil.writeFooter(centroidTemp);
IOUtils.close(centroidTemp);
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength());
try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
ivfCentroids.copyBytes(centroidsInput, centroidsInput.length() - CodecUtil.footerLength());
centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid);
assert centroidAssignmentScorer.size() == numCentroids;
CentroidSupplier centroidSupplier = createCentroidSupplier(
centroidsInput,
numCentroids,
fieldInfo,
calculatedGlobalCentroid
);
// build a float vector values with random access
// build centroids
final long[] offsets = buildAndWritePostingsLists(
fieldInfo,
centroidAssignmentScorer,
centroidSupplier,
floatVectorValues,
ivfClusters,
mergeState
mergeState.infoStream,
centroidAssignments.assignmentsByCluster()
);
assert offsets.length == centroidAssignmentScorer.size();
assert offsets.length == centroidSupplier.size();
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);
}
} finally {
@ -453,8 +444,8 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter<float[]> delegate) {}
interface CentroidAssignmentScorer {
CentroidAssignmentScorer EMPTY = new CentroidAssignmentScorer() {
interface CentroidSupplier {
CentroidSupplier EMPTY = new CentroidSupplier() {
@Override
public int size() {
return 0;
@ -464,24 +455,29 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
public float[] centroid(int centroidOrdinal) {
throw new IllegalStateException("No centroids");
}
@Override
public float score(int centroidOrdinal) {
throw new IllegalStateException("No centroids");
}
@Override
public void setScoringVector(float[] vector) {
throw new IllegalStateException("No centroids");
}
};
int size();
float[] centroid(int centroidOrdinal) throws IOException;
}
void setScoringVector(float[] vector);
// TODO throw away rawCentroids
static class OnHeapCentroidSupplier implements CentroidSupplier {
private final float[][] centroids;
float score(int centroidOrdinal) throws IOException;
OnHeapCentroidSupplier(float[][] centroids) {
this.centroids = centroids;
}
@Override
public int size() {
return centroids.length;
}
@Override
public float[] centroid(int centroidOrdinal) throws IOException {
return centroids[centroidOrdinal];
}
}
}

View File

@ -1,494 +0,0 @@
/*
* @notice
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Modifications copyright (C) 2025 Elasticsearch B.V.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntArrayList;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import static org.elasticsearch.index.codec.vectors.SampleReader.createSampleReader;
/** KMeans clustering algorithm for vectors */
class KMeans {
public static final int DEFAULT_RESTARTS = 1;
public static final int DEFAULT_ITRS = 10;
public static final int DEFAULT_SAMPLE_VECTORS_PER_CENTROID = 128;
private static final float EPS = 1f / 1024f;
private final FloatVectorValues vectors;
private final int numVectors;
private final int numCentroids;
private final Random random;
private final KmeansInitializationMethod initializationMethod;
private final float[][] initCentroids;
private final int restarts;
private final int iters;
/**
* Cluster vectors into a given number of clusters
*
* @param vectors float vectors
* @param similarityFunction vector similarity function. For COSINE similarity, vectors must be
* normalized.
* @param numClusters number of cluster to cluster vector into
* @return results of clustering: produced centroids and for each vector its centroid
* @throws IOException when if there is an error accessing vectors
*/
static Results cluster(FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int numClusters) throws IOException {
return cluster(
vectors,
numClusters,
true,
42L,
KmeansInitializationMethod.PLUS_PLUS,
null,
similarityFunction == VectorSimilarityFunction.COSINE,
DEFAULT_RESTARTS,
DEFAULT_ITRS,
DEFAULT_SAMPLE_VECTORS_PER_CENTROID * numClusters
);
}
/**
* Expert: Cluster vectors into a given number of clusters
*
* @param vectors float vectors
* @param numClusters number of cluster to cluster vector into
* @param assignCentroidsToVectors if {@code true} assign centroids for all vectors. Centroids are
* computed on a sample of vectors. If this parameter is {@code true}, in results also return
* for all vectors what centroids they belong to.
* @param seed random seed
* @param initializationMethod Kmeans initialization method
* @param initCentroids initial centroids, if not {@code null} utilize as initial centroids for
* the given initialization method
* @param normalizeCenters for cosine distance, set to true, to use spherical k-means where
* centers are normalized
* @param restarts how many times to run Kmeans algorithm
* @param iters how many iterations to do within a single run
* @param sampleSize sample size to select from all vectors on which to run Kmeans algorithm
* @return results of clustering: produced centroids and if {@code assignCentroidsToVectors ==
* true} also for each vector its centroid
* @throws IOException if there is error accessing vectors
*/
static Results cluster(
FloatVectorValues vectors,
int numClusters,
boolean assignCentroidsToVectors,
long seed,
KmeansInitializationMethod initializationMethod,
float[][] initCentroids,
boolean normalizeCenters,
int restarts,
int iters,
int sampleSize
) throws IOException {
if (vectors.size() == 0) {
return null;
}
// adjust sampleSize and numClusters
sampleSize = Math.max(sampleSize, 100 * numClusters);
if (sampleSize > vectors.size()) {
sampleSize = vectors.size();
// Decrease the number of clusters if needed
int maxNumClusters = Math.max(1, sampleSize / 100);
numClusters = Math.min(numClusters, maxNumClusters);
}
Random random = new Random(seed);
float[][] centroids;
if (numClusters == 1) {
centroids = new float[1][vectors.dimension()];
for (int i = 0; i < vectors.size(); i++) {
float[] vector = vectors.vectorValue(i);
for (int dim = 0; dim < vector.length; dim++) {
centroids[0][dim] += vector[dim];
}
}
for (int dim = 0; dim < centroids[0].length; dim++) {
centroids[0][dim] /= vectors.size();
}
} else {
FloatVectorValues sampleVectors = vectors.size() <= sampleSize ? vectors : createSampleReader(vectors, sampleSize, seed);
KMeans kmeans = new KMeans(sampleVectors, numClusters, random, initializationMethod, initCentroids, restarts, iters);
centroids = kmeans.computeCentroids(normalizeCenters);
}
int[] vectorCentroids = null;
int[] centroidSize = null;
// Assign each vector to the nearest centroid and update the centres
if (assignCentroidsToVectors) {
vectorCentroids = new int[vectors.size()];
centroidSize = new int[centroids.length];
assignCentroids(random, vectorCentroids, centroidSize, vectors, centroids);
}
if (normalizeCenters) {
for (float[] centroid : centroids) {
VectorUtil.l2normalize(centroid, false);
}
}
return new Results(centroids, centroidSize, vectorCentroids);
}
private static void assignCentroids(
Random random,
int[] docCentroids,
int[] centroidSize,
FloatVectorValues vectors,
float[][] centroids
) throws IOException {
short numCentroids = (short) centroids.length;
assert Arrays.stream(centroidSize).allMatch(size -> size == 0);
for (int docID = 0; docID < vectors.size(); docID++) {
float[] vector = vectors.vectorValue(docID);
short bestCentroid = 0;
if (numCentroids > 1) {
float minSquaredDist = Float.MAX_VALUE;
for (short c = 0; c < numCentroids; c++) {
// TODO: replace with RandomVectorScorer::score possible on quantized vectors
float squareDist = VectorUtil.squareDistance(centroids[c], vector);
if (squareDist < minSquaredDist) {
bestCentroid = c;
minSquaredDist = squareDist;
}
}
}
centroidSize[bestCentroid] += 1;
docCentroids[docID] = bestCentroid;
}
IntArrayList unassignedCentroids = new IntArrayList();
for (int c = 0; c < numCentroids; c++) {
if (centroidSize[c] == 0) {
unassignedCentroids.add(c);
}
}
if (unassignedCentroids.size() > 0) {
throwAwayAndSplitCentroids(random, vectors, centroids, docCentroids, centroidSize, unassignedCentroids);
}
assert Arrays.stream(centroidSize).sum() == vectors.size();
}
private final float[] kmeansPlusPlusScratch;
KMeans(
FloatVectorValues vectors,
int numCentroids,
Random random,
KmeansInitializationMethod initializationMethod,
float[][] initCentroids,
int restarts,
int iters
) {
this.vectors = vectors;
this.numVectors = vectors.size();
this.numCentroids = numCentroids;
this.random = random;
this.initializationMethod = initializationMethod;
this.restarts = restarts;
this.iters = iters;
this.initCentroids = initCentroids;
this.kmeansPlusPlusScratch = initializationMethod == KmeansInitializationMethod.PLUS_PLUS ? new float[numVectors] : null;
}
float[][] computeCentroids(boolean normalizeCenters) throws IOException {
// TODO can we make this off-heap, or reusable? This could be a big array
int[] vectorCentroids = new int[numVectors];
double minSquaredDist = Double.MAX_VALUE;
double squaredDist = 0;
float[][] bestCentroids = null;
float[][] centroids = new float[numCentroids][vectors.dimension()];
int restarts = this.restarts;
int numInitializedCentroids = 0;
// The user has given us a solid number of centroids to start of with, so skip restarts, fill in
// where we can, and refine
if (initCentroids != null && initCentroids.length > numCentroids / 2) {
int i = 0;
for (; i < Math.min(numCentroids, initCentroids.length); i++) {
System.arraycopy(initCentroids[i], 0, centroids[i], 0, initCentroids[i].length);
}
numInitializedCentroids = i;
restarts = 1;
}
for (int restart = 0; restart < restarts; restart++) {
switch (initializationMethod) {
case FORGY -> initializeForgy(centroids, numInitializedCentroids);
case RESERVOIR_SAMPLING -> initializeReservoirSampling(centroids, numInitializedCentroids);
case PLUS_PLUS -> initializePlusPlus(centroids, numInitializedCentroids);
}
double prevSquaredDist = Double.MAX_VALUE;
int[] centroidSize = new int[centroids.length];
for (int iter = 0; iter < iters; iter++) {
squaredDist = runKMeansStep(centroids, centroidSize, vectorCentroids, normalizeCenters);
// Check for convergence
if (prevSquaredDist <= (squaredDist + 1e-6)) {
break;
}
Arrays.fill(centroidSize, 0);
prevSquaredDist = squaredDist;
}
if (squaredDist < minSquaredDist) {
minSquaredDist = squaredDist;
// Copy out the best centroid as it might be overwritten by the next restart
bestCentroids = new float[centroids.length][];
for (int i = 0; i < centroids.length; i++) {
bestCentroids[i] = ArrayUtil.copyArray(centroids[i]);
}
}
}
return bestCentroids;
}
/**
* Initialize centroids using Forgy method: randomly select numCentroids vectors for initial
* centroids
*/
private void initializeForgy(float[][] initialCentroids, int fromCentroid) throws IOException {
if (fromCentroid >= numCentroids) {
return;
}
int numCentroids = this.numCentroids - fromCentroid;
Set<Integer> selection = new HashSet<>();
while (selection.size() < numCentroids) {
selection.add(random.nextInt(numVectors));
}
int i = 0;
for (Integer selectedIdx : selection) {
float[] vector = vectors.vectorValue(selectedIdx);
System.arraycopy(vector, 0, initialCentroids[fromCentroid + i++], 0, vector.length);
}
}
/** Initialize centroids using a reservoir sampling method */
private void initializeReservoirSampling(float[][] initialCentroids, int fromCentroid) throws IOException {
if (fromCentroid >= numCentroids) {
return;
}
int numCentroids = this.numCentroids - fromCentroid;
for (int index = 0; index < numVectors; index++) {
float[] vector = vectors.vectorValue(index);
if (index < numCentroids) {
System.arraycopy(vector, 0, initialCentroids[index + fromCentroid], 0, vector.length);
} else if (random.nextDouble() < numCentroids * (1.0 / index)) {
int c = random.nextInt(numCentroids);
System.arraycopy(vector, 0, initialCentroids[c + fromCentroid], 0, vector.length);
}
}
}
/** Initialize centroids using Kmeans++ method */
private void initializePlusPlus(float[][] initialCentroids, int fromCentroid) throws IOException {
if (fromCentroid >= numCentroids) {
return;
}
// Choose the first centroid uniformly at random
int firstIndex = random.nextInt(numVectors);
float[] value = vectors.vectorValue(firstIndex);
System.arraycopy(value, 0, initialCentroids[fromCentroid], 0, value.length);
// Store distances of each point to the nearest centroid
Arrays.fill(kmeansPlusPlusScratch, Float.MAX_VALUE);
// Step 2 and 3: Select remaining centroids
for (int i = fromCentroid + 1; i < numCentroids; i++) {
// Update distances with the new centroid
double totalSum = 0;
for (int j = 0; j < numVectors; j++) {
// TODO: replace with RandomVectorScorer::score possible on quantized vectors
float dist = VectorUtil.squareDistance(vectors.vectorValue(j), initialCentroids[i - 1]);
if (dist < kmeansPlusPlusScratch[j]) {
kmeansPlusPlusScratch[j] = dist;
}
totalSum += kmeansPlusPlusScratch[j];
}
// Randomly select next centroid
double r = totalSum * random.nextDouble();
double cumulativeSum = 0;
int nextCentroidIndex = 0;
for (int j = 0; j < numVectors; j++) {
cumulativeSum += kmeansPlusPlusScratch[j];
if (cumulativeSum >= r && kmeansPlusPlusScratch[j] > 0) {
nextCentroidIndex = j;
break;
}
}
// Update centroid
value = vectors.vectorValue(nextCentroidIndex);
System.arraycopy(value, 0, initialCentroids[i], 0, value.length);
}
}
/**
* Run kmeans step
*
* @param centroids centroids, new calculated centroids are written here
* @param docCentroids for each document which centroid it belongs to, results will be written
* here
* @param normalizeCentroids if centroids should be normalized; used for cosine similarity only
* @throws IOException if there is an error accessing vector values
*/
private double runKMeansStep(float[][] centroids, int[] centroidSize, int[] docCentroids, boolean normalizeCentroids)
throws IOException {
short numCentroids = (short) centroids.length;
assert Arrays.stream(centroidSize).allMatch(size -> size == 0);
float[][] newCentroids = new float[numCentroids][centroids[0].length];
double sumSquaredDist = 0;
for (int docID = 0; docID < vectors.size(); docID++) {
float[] vector = vectors.vectorValue(docID);
short bestCentroid = 0;
if (numCentroids > 1) {
float minSquaredDist = Float.MAX_VALUE;
for (short c = 0; c < numCentroids; c++) {
// TODO: replace with RandomVectorScorer::score possible on quantized vectors
float squareDist = VectorUtil.squareDistance(centroids[c], vector);
if (squareDist < minSquaredDist) {
bestCentroid = c;
minSquaredDist = squareDist;
}
}
sumSquaredDist += minSquaredDist;
}
centroidSize[bestCentroid] += 1;
for (int dim = 0; dim < vector.length; dim++) {
newCentroids[bestCentroid][dim] += vector[dim];
}
docCentroids[docID] = bestCentroid;
}
IntArrayList unassignedCentroids = new IntArrayList();
for (int c = 0; c < numCentroids; c++) {
if (centroidSize[c] > 0) {
for (int dim = 0; dim < newCentroids[c].length; dim++) {
centroids[c][dim] = newCentroids[c][dim] / centroidSize[c];
}
} else {
unassignedCentroids.add(c);
}
}
if (unassignedCentroids.size() > 0) {
throwAwayAndSplitCentroids(random, vectors, centroids, docCentroids, centroidSize, unassignedCentroids);
}
if (normalizeCentroids) {
for (float[] centroid : centroids) {
VectorUtil.l2normalize(centroid, false);
}
}
assert Arrays.stream(centroidSize).sum() == vectors.size();
return sumSquaredDist;
}
static void throwAwayAndSplitCentroids(
Random random,
FloatVectorValues vectors,
float[][] centroids,
int[] docCentroids,
int[] centroidSize,
IntArrayList unassignedCentroidsIdxs
) throws IOException {
IntObjectHashMap<IntArrayList> splitCentroids = new IntObjectHashMap<>(unassignedCentroidsIdxs.size());
// used for splitting logic
int[] splitSizes = Arrays.copyOf(centroidSize, centroidSize.length);
// FAISS style algorithm for splitting
for (int i = 0; i < unassignedCentroidsIdxs.size(); i++) {
int toSplit;
for (toSplit = 0; true; toSplit = (toSplit + 1) % centroids.length) {
/* probability to pick this cluster for split */
double p = (splitSizes[toSplit] - 1.0) / (float) (docCentroids.length - centroids.length);
float r = random.nextFloat();
if (r < p) {
break; /* found our cluster to be split */
}
}
int unassignedCentroidIdx = unassignedCentroidsIdxs.get(i);
// keep track of those that are split, this way we reassign docCentroids and fix up true size
// & centroids
splitCentroids.getOrDefault(toSplit, new IntArrayList()).add(unassignedCentroidIdx);
System.arraycopy(centroids[toSplit], 0, centroids[unassignedCentroidIdx], 0, centroids[unassignedCentroidIdx].length);
for (int dim = 0; dim < centroids[unassignedCentroidIdx].length; dim++) {
if (dim % 2 == 0) {
centroids[unassignedCentroidIdx][dim] *= (1 + EPS);
centroids[toSplit][dim] *= (1 - EPS);
} else {
centroids[unassignedCentroidIdx][dim] *= (1 - EPS);
centroids[toSplit][dim] *= (1 + EPS);
}
}
splitSizes[unassignedCentroidIdx] = splitSizes[toSplit] / 2;
splitSizes[toSplit] -= splitSizes[unassignedCentroidIdx];
}
// now we need to reassign docCentroids and fix up true size & centroids
for (int i = 0; i < docCentroids.length; i++) {
int docCentroid = docCentroids[i];
IntArrayList split = splitCentroids.get(docCentroid);
if (split != null) {
// we need to reassign this doc
int bestCentroid = docCentroid;
float bestDist = VectorUtil.squareDistance(centroids[docCentroid], vectors.vectorValue(i));
for (int j = 0; j < split.size(); j++) {
int newCentroid = split.get(j);
float dist = VectorUtil.squareDistance(centroids[newCentroid], vectors.vectorValue(i));
if (dist < bestDist) {
bestCentroid = newCentroid;
bestDist = dist;
}
}
if (bestCentroid != docCentroid) {
// we need to update the centroid size
centroidSize[docCentroid]--;
centroidSize[bestCentroid]++;
docCentroids[i] = (short) bestCentroid;
// we need to update the old and new centroid accounting for size as well
for (int dim = 0; dim < centroids[docCentroid].length; dim++) {
centroids[docCentroid][dim] -= vectors.vectorValue(i)[dim] / centroidSize[docCentroid];
centroids[bestCentroid][dim] += vectors.vectorValue(i)[dim] / centroidSize[bestCentroid];
}
}
}
}
}
/** Kmeans initialization methods */
public enum KmeansInitializationMethod {
FORGY,
RESERVOIR_SAMPLING,
PLUS_PLUS
}
/**
* Results of KMeans clustering
*
* @param centroids the produced centroids
* @param centroidsSize for each centroid how many vectors belong to it
* @param vectorCentroids for each vector which centroid it belongs to
*/
public record Results(float[][] centroids, int[] centroidsSize, int[] vectorCentroids) {}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors.cluster;
import org.apache.lucene.index.FloatVectorValues;
import java.io.IOException;
class FloatVectorValuesSlice extends FloatVectorValues {
private final FloatVectorValues allValues;
private final int[] slice;
FloatVectorValuesSlice(FloatVectorValues allValues, int[] slice) {
assert slice != null;
assert slice.length <= allValues.size();
this.allValues = allValues;
this.slice = slice;
}
@Override
public float[] vectorValue(int ord) throws IOException {
return this.allValues.vectorValue(this.slice[ord]);
}
@Override
public int dimension() {
return this.allValues.dimension();
}
@Override
public int size() {
return slice.length;
}
@Override
public int ordToDoc(int ord) {
return this.slice[ord];
}
@Override
public FloatVectorValues copy() throws IOException {
return new FloatVectorValuesSlice(this.allValues.copy(), this.slice);
}
}

View File

@ -0,0 +1,197 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors.cluster;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.VectorUtil;
import java.io.IOException;
/**
* An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means
*/
public class HierarchicalKMeans {
static final int MAXK = 128;
static final int MAX_ITERATIONS_DEFAULT = 6;
static final int SAMPLES_PER_CLUSTER_DEFAULT = 256;
static final float DEFAULT_SOAR_LAMBDA = 1.0f;
final int dimension;
final int maxIterations;
final int samplesPerCluster;
final int clustersPerNeighborhood;
final float soarLambda;
public HierarchicalKMeans(int dimension) {
this(dimension, MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, MAXK, DEFAULT_SOAR_LAMBDA);
}
HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) {
this.dimension = dimension;
this.maxIterations = maxIterations;
this.samplesPerCluster = samplesPerCluster;
this.clustersPerNeighborhood = clustersPerNeighborhood;
this.soarLambda = soarLambda;
}
/**
* clusters or moreso partitions the set of vectors by starting with a rough number of partitions and then recursively refining those
* lastly a pass is made to adjust nearby neighborhoods and add an extra assignment per vector to nearby neighborhoods
*
* @param vectors the vectors to cluster
* @param targetSize the rough number of vectors that should be attached to a cluster
* @return the centroids and the vectors assignments and SOAR (spilled from nearby neighborhoods) assignments
* @throws IOException is thrown if vectors is inaccessible
*/
public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException {
if (vectors.size() == 0) {
return new KMeansIntermediate();
}
// if we have a small number of vectors pick one and output that as the centroid
if (vectors.size() <= targetSize) {
float[] centroid = new float[dimension];
System.arraycopy(vectors.vectorValue(0), 0, centroid, 0, dimension);
return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()]);
}
// partition the space
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
float f = Math.min((float) samplesPerCluster / targetSize, 1.0f);
int localSampleSize = (int) (f * vectors.size());
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
}
return kMeansIntermediate;
}
KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException {
if (vectors.size() <= targetSize) {
return new KMeansIntermediate();
}
int k = Math.clamp((int) ((vectors.size() + targetSize / 2.0f) / (float) targetSize), 2, MAXK);
int m = Math.min(k * samplesPerCluster, vectors.size());
// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
int[] assignments = new int[vectors.size()];
KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
kmeans.cluster(vectors, kMeansIntermediate);
// TODO: consider adding cluster size counts to the kmeans algo
// handle assignment here so we can track distance and cluster size
int[] centroidVectorCount = new int[centroids.length];
float[][] nextCentroids = new float[centroids.length][dimension];
for (int i = 0; i < vectors.size(); i++) {
float smallest = Float.MAX_VALUE;
int centroidIdx = -1;
float[] vector = vectors.vectorValue(i);
for (int j = 0; j < centroids.length; j++) {
float[] centroid = centroids[j];
float d = VectorUtil.squareDistance(vector, centroid);
if (d < smallest) {
smallest = d;
centroidIdx = j;
}
}
centroidVectorCount[centroidIdx]++;
for (int j = 0; j < dimension; j++) {
nextCentroids[centroidIdx][j] += vector[j];
}
assignments[i] = centroidIdx;
}
// update centroids based on assignments of all vectors
for (int i = 0; i < centroids.length; i++) {
if (centroidVectorCount[i] > 0) {
for (int j = 0; j < dimension; j++) {
centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i];
}
}
}
int effectiveK = 0;
for (int i = 0; i < centroidVectorCount.length; i++) {
if (centroidVectorCount[i] > 0) {
effectiveK++;
}
}
kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
if (effectiveK == 1) {
return kMeansIntermediate;
}
for (int c = 0; c < centroidVectorCount.length; c++) {
// Recurse for each cluster which is larger than targetSize
// Give ourselves 30% margin for the target size
if (100 * centroidVectorCount[c] > 134 * targetSize) {
FloatVectorValues sample = createClusterSlice(centroidVectorCount[c], c, vectors, assignments);
// TODO: consider iterative here instead of recursive
// recursive call to build out the sub partitions around this centroid c
// subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
updateAssignmentsWithRecursiveSplit(kMeansIntermediate, c, clusterAndSplit(sample, targetSize));
}
}
return kMeansIntermediate;
}
static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) {
int[] slice = new int[clusterSize];
int idx = 0;
for (int i = 0; i < assignments.length; i++) {
if (assignments[i] == cluster) {
slice[idx] = i;
idx++;
}
}
return new FloatVectorValuesSlice(vectors, slice);
}
void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
int orgCentroidsSize = current.centroids().length;
int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;
// update based on the outcomes from the split clusters recursion
if (subPartitions.centroids().length > 1) {
float[][] newCentroids = new float[newCentroidsSize][dimension];
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
// replace the original cluster
int origCentroidOrd = 0;
newCentroids[cluster] = subPartitions.centroids()[0];
// append the remainder
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
current.setCentroids(newCentroids);
for (int i = 0; i < subPartitions.assignments().length; i++) {
// this is a new centroid that was added, and so we'll need to remap it
if (subPartitions.assignments()[i] != origCentroidOrd) {
int parentOrd = subPartitions.ordToDoc(i);
assert current.assignments()[parentOrd] == cluster;
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
}
}
}
}
}

View File

@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors.cluster;
import org.apache.lucene.util.hnsw.IntToIntFunction;
/**
* Intermediate object for clustering (partitioning) a set of vectors
*/
class KMeansIntermediate extends KMeansResult {
private final IntToIntFunction assignmentOrds;
private KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrds, int[] soarAssignments) {
super(centroids, assignments, soarAssignments);
assert assignmentOrds != null;
this.assignmentOrds = assignmentOrds;
}
KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrdinals) {
this(centroids, assignments, assignmentOrdinals, new int[0]);
}
KMeansIntermediate() {
this(new float[0][0], new int[0], i -> i, new int[0]);
}
KMeansIntermediate(float[][] centroids) {
this(centroids, new int[0], i -> i, new int[0]);
}
KMeansIntermediate(float[][] centroids, int[] assignments) {
this(centroids, assignments, i -> i, new int[0]);
}
public int ordToDoc(int ord) {
return assignmentOrds.apply(ord);
}
}

View File

@ -0,0 +1,306 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors.cluster;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.simdvec.ESVectorUtil;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
/**
* k-means implementation specific to the needs of the {@link HierarchicalKMeans} algorithm that deals specifically
* with finalizing nearby pre-established clusters and generate
* <a href="https://research.google/blog/soar-new-algorithms-for-even-faster-vector-search-with-scann/">SOAR</a> assignments
*/
class KMeansLocal {
final int sampleSize;
final int maxIterations;
final int clustersPerNeighborhood;
final float soarLambda;
KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood, float soarLambda) {
this.sampleSize = sampleSize;
this.maxIterations = maxIterations;
this.clustersPerNeighborhood = clustersPerNeighborhood;
this.soarLambda = soarLambda;
}
KMeansLocal(int sampleSize, int maxIterations) {
this(sampleSize, maxIterations, -1, -1f);
}
/**
* uses a Reservoir Sampling approach to picking the initial centroids which are subsequently expected
* to be used by a clustering algorithm
*
* @param vectors used to pick an initial set of random centroids
* @param centroidCount the total number of centroids to pick
* @return randomly selected centroids that are the min of centroidCount and sampleSize
* @throws IOException is thrown if vectors is inaccessible
*/
static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCount) throws IOException {
Random random = new Random(42L);
int centroidsSize = Math.min(vectors.size(), centroidCount);
float[][] centroids = new float[centroidsSize][vectors.dimension()];
for (int i = 0; i < vectors.size(); i++) {
float[] vector;
if (i < centroidCount) {
vector = vectors.vectorValue(i);
System.arraycopy(vector, 0, centroids[i], 0, vector.length);
} else if (random.nextDouble() < centroidCount * (1.0 / i)) {
int c = random.nextInt(centroidCount);
vector = vectors.vectorValue(i);
System.arraycopy(vector, 0, centroids[c], 0, vector.length);
}
}
return centroids;
}
private boolean stepLloyd(
FloatVectorValues vectors,
float[][] centroids,
float[][] nextCentroids,
int[] assignments,
int sampleSize,
List<int[]> neighborhoods
) throws IOException {
boolean changed = false;
int dim = vectors.dimension();
int[] centroidCounts = new int[centroids.length];
for (int i = 0; i < nextCentroids.length; i++) {
Arrays.fill(nextCentroids[i], 0.0f);
}
for (int i = 0; i < sampleSize; i++) {
float[] vector = vectors.vectorValue(i);
int[] neighborOffsets = null;
int centroidIdx = -1;
if (neighborhoods != null) {
neighborOffsets = neighborhoods.get(assignments[i]);
centroidIdx = assignments[i];
}
int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets);
if (assignments[i] != bestCentroidOffset) {
changed = true;
}
assignments[i] = bestCentroidOffset;
centroidCounts[bestCentroidOffset]++;
for (short d = 0; d < dim; d++) {
nextCentroids[bestCentroidOffset][d] += vector[d];
}
}
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
if (centroidCounts[clusterIdx] > 0) {
float countF = (float) centroidCounts[clusterIdx];
for (short d = 0; d < dim; d++) {
centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF;
}
}
}
return changed;
}
int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
int bestCentroidOffset = centroidIdx;
float minDsq;
if (centroidIdx > 0 && centroidIdx < centroids.length) {
minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
} else {
minDsq = Float.MAX_VALUE;
}
int k = 0;
for (int j = 0; j < centroids.length; j++) {
if (centroidOffsets == null || j == centroidOffsets[k]) {
float dsq = VectorUtil.squareDistance(vector, centroids[j]);
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = j;
}
}
}
return bestCentroidOffset;
}
private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods, int clustersPerNeighborhood) {
int k = neighborhoods.size();
if (k == 0 || clustersPerNeighborhood <= 0) {
return;
}
List<NeighborQueue> neighborQueues = new ArrayList<>(k);
for (int i = 0; i < k; i++) {
neighborQueues.add(new NeighborQueue(clustersPerNeighborhood, true));
}
for (int i = 0; i < k - 1; i++) {
for (int j = i + 1; j < k; j++) {
float dsq = VectorUtil.squareDistance(centers[i], centers[j]);
neighborQueues.get(j).insertWithOverflow(i, dsq);
neighborQueues.get(i).insertWithOverflow(j, dsq);
}
}
for (int i = 0; i < k; i++) {
NeighborQueue queue = neighborQueues.get(i);
int neighborCount = queue.size();
int[] neighbors = new int[neighborCount];
queue.consumeNodes(neighbors);
neighborhoods.set(i, neighbors);
}
}
private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, int[] assignments)
throws IOException {
// SOAR uses an adjusted distance for assigning spilled documents which is
// given by:
//
// soar(x, c) = ||x - c||^2 + lambda * ((x - c_1)^t (x - c))^2 / ||x - c_1||^2
//
// Here, x is the document, c is the nearest centroid, and c_1 is the first
// centroid the document was assigned to. The document is assigned to the
// cluster with the smallest soar(x, c).
int[] spilledAssignments = new int[assignments.length];
float[] diffs = new float[vectors.dimension()];
for (int i = 0; i < vectors.size(); i++) {
float[] vector = vectors.vectorValue(i);
int currAssignment = assignments[i];
float[] currentCentroid = centroids[currAssignment];
for (short j = 0; j < vectors.dimension(); j++) {
float diff = vector[j] - currentCentroid[j];
diffs[j] = diff;
}
// TODO: cache these?
// float vectorCentroidDist = assignmentDistances[i];
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);
int bestAssignment = -1;
float minSoar = Float.MAX_VALUE;
assert neighborhoods.get(currAssignment) != null;
for (int neighbor : neighborhoods.get(currAssignment)) {
if (neighbor == currAssignment) {
continue;
}
float[] neighborCentroid = centroids[neighbor];
float soar = distanceSoar(diffs, vector, neighborCentroid, vectorCentroidDist);
if (soar < minSoar) {
bestAssignment = neighbor;
minSoar = soar;
}
}
spilledAssignments[i] = bestAssignment;
}
return spilledAssignments;
}
private float distanceSoar(float[] residual, float[] vector, float[] centroid, float rnorm) {
// TODO: combine these to be more efficient
float dsq = VectorUtil.squareDistance(vector, centroid);
float rproj = ESVectorUtil.soarResidual(vector, centroid, residual);
return dsq + soarLambda * rproj * rproj / rnorm;
}
/**
* cluster using a lloyd k-means algorithm that is not neighbor aware
*
* @param vectors the vectors to cluster
* @param kMeansIntermediate the output object to populate which minimally includes centroids,
* but may include assignments and soar assignments as well; care should be taken in
* passing in a valid output object with a centroids array that is the size of centroids expected
* @throws IOException is thrown if vectors is inaccessible
*/
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
cluster(vectors, kMeansIntermediate, false);
}
/**
* cluster using a lloyd kmeans algorithm that also considers prior clustered neighborhoods when adjusting centroids
* this also is used to generate the neighborhood aware additional (SOAR) assignments
*
* @param vectors the vectors to cluster
* @param kMeansIntermediate the output object to populate which minimally includes centroids,
* the prior assignments of the given vectors; care should be taken in
* passing in a valid output object with a centroids array that is the size of centroids expected
* and assignments that are the same size as the vectors. The SOAR assignments are overwritten by this operation.
* @param neighborAware whether nearby neighboring centroids and their vectors should be used to update the centroid positions,
* implies SOAR assignments
* @throws IOException is thrown if vectors is inaccessible
*/
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException {
float[][] centroids = kMeansIntermediate.centroids();
List<int[]> neighborhoods = null;
if (neighborAware) {
int k = centroids.length;
neighborhoods = new ArrayList<>(k);
for (int i = 0; i < k; ++i) {
neighborhoods.add(null);
}
computeNeighborhoods(centroids, neighborhoods, clustersPerNeighborhood);
}
cluster(vectors, kMeansIntermediate, neighborhoods);
if (neighborAware && clustersPerNeighborhood > 0) {
int[] assignments = kMeansIntermediate.assignments();
assert assignments != null;
assert assignments.length == vectors.size();
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments));
}
}
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
float[][] centroids = kMeansIntermediate.centroids();
int k = centroids.length;
int n = vectors.size();
if (k == 1 || k >= n) {
return;
}
int[] assignments = new int[n];
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
for (int i = 0; i < maxIterations; i++) {
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
break;
}
}
stepLloyd(vectors, centroids, nextCentroids, assignments, vectors.size(), neighborhoods);
}
/**
* helper that calls {@link KMeansLocal#cluster(FloatVectorValues, KMeansIntermediate)} given a set of initialized centroids,
* this call is not neighbor aware
*
* @param vectors the vectors to cluster
* @param centroids the initialized centroids to be shifted using k-means
* @param sampleSize the subset of vectors to use when shifting centroids
* @param maxIterations the max iterations to shift centroids
*/
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
kMeans.cluster(vectors, kMeansIntermediate);
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors.cluster;
/**
* Output object for clustering (partitioning) a set of vectors
*/
public class KMeansResult {
private float[][] centroids;
private final int[] assignments;
private int[] soarAssignments;
KMeansResult(float[][] centroids, int[] assignments, int[] soarAssignments) {
assert centroids != null;
assert assignments != null;
assert soarAssignments != null;
this.centroids = centroids;
this.assignments = assignments;
this.soarAssignments = soarAssignments;
}
public float[][] centroids() {
return centroids;
}
void setCentroids(float[][] centroids) {
this.centroids = centroids;
}
public int[] assignments() {
return assignments;
}
void setSoarAssignments(int[] soarAssignments) {
this.soarAssignments = soarAssignments;
}
public int[] soarAssignments() {
return soarAssignments;
}
}

View File

@ -14,10 +14,10 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Modifications copyright (C) 2025 Elasticsearch B.V.
*/
package org.elasticsearch.index.codec.vectors;
package org.elasticsearch.index.codec.vectors.cluster;
import org.apache.lucene.util.LongHeap;
import org.apache.lucene.util.NumericUtils;

View File

@ -1,154 +0,0 @@
/*
* @notice
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Modifications copyright (C) 2025 Elasticsearch B.V.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class KMeansTests extends ESTestCase {
public void testKMeansAPI() throws IOException {
int nClusters = random().nextInt(1, 10);
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
int dims = random().nextInt(2, 20);
int randIdx = random().nextInt(VectorSimilarityFunction.values().length);
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx];
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
// default case
{
KMeans.Results results = KMeans.cluster(vectors, similarityFunction, nClusters);
assertResults(results, nClusters, nVectors, true);
assertEquals(nClusters, results.centroids().length);
assertEquals(nClusters, results.centroidsSize().length);
assertEquals(nVectors, results.vectorCentroids().length);
}
// expert case
{
boolean assignCentroidsToVectors = random().nextBoolean();
int randIdx2 = random().nextInt(KMeans.KmeansInitializationMethod.values().length);
KMeans.KmeansInitializationMethod initializationMethod = KMeans.KmeansInitializationMethod.values()[randIdx2];
int restarts = random().nextInt(1, 6);
int iters = random().nextInt(1, 10);
int sampleSize = random().nextInt(10, nVectors * 2);
KMeans.Results results = KMeans.cluster(
vectors,
nClusters,
assignCentroidsToVectors,
random().nextLong(),
initializationMethod,
null,
similarityFunction == VectorSimilarityFunction.COSINE,
restarts,
iters,
sampleSize
);
assertResults(results, nClusters, nVectors, assignCentroidsToVectors);
}
}
private void assertResults(KMeans.Results results, int nClusters, int nVectors, boolean assignCentroidsToVectors) {
assertEquals(nClusters, results.centroids().length);
if (assignCentroidsToVectors) {
assertEquals(nClusters, results.centroidsSize().length);
assertEquals(nVectors, results.vectorCentroids().length);
int[] centroidsSize = new int[nClusters];
for (int i = 0; i < nVectors; i++) {
centroidsSize[results.vectorCentroids()[i]]++;
}
assertArrayEquals(centroidsSize, results.centroidsSize());
} else {
assertNull(results.vectorCentroids());
}
}
public void testKMeansSpecialCases() throws IOException {
{
// nClusters > nVectors
int nClusters = 20;
int nVectors = 10;
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
KMeans.Results results = KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters);
// assert that we get 1 centroid, as nClusters will be adjusted
assertEquals(1, results.centroids().length);
assertEquals(nVectors, results.vectorCentroids().length);
}
{
// small sample size
int sampleSize = 2;
int nClusters = 2;
int nVectors = 300;
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
KMeans.KmeansInitializationMethod initializationMethod = KMeans.KmeansInitializationMethod.PLUS_PLUS;
KMeans.Results results = KMeans.cluster(
vectors,
nClusters,
true,
random().nextLong(),
initializationMethod,
null,
false,
1,
2,
sampleSize
);
assertResults(results, nClusters, nVectors, true);
}
}
public void testKMeansSAllZero() throws IOException {
int nClusters = 10;
List<float[]> vectors = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
float[] vector = new float[5];
vectors.add(vector);
}
KMeans.Results results = KMeans.cluster(FloatVectorValues.fromFloats(vectors, 5), VectorSimilarityFunction.EUCLIDEAN, nClusters);
assertResults(results, nClusters, 1000, true);
}
private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) {
List<float[]> vectors = new ArrayList<>(nSamples);
float[][] centroids = new float[nClusters][nDims];
// Generate random centroids
for (int i = 0; i < nClusters; i++) {
for (int j = 0; j < nDims; j++) {
centroids[i][j] = random().nextFloat() * 100;
}
}
// Generate data points around centroids
for (int i = 0; i < nSamples; i++) {
int cluster = random().nextInt(nClusters);
float[] vector = new float[nDims];
for (int j = 0; j < nDims; j++) {
vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5;
}
vectors.add(vector);
}
return FloatVectorValues.fromFloats(vectors, nDims);
}
}

View File

@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors.cluster;
import org.apache.lucene.index.FloatVectorValues;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class HierarchicalKMeansTests extends ESTestCase {
public void testHKmeans() throws IOException {
int nClusters = random().nextInt(1, 10);
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
int dims = random().nextInt(2, 20);
int sampleSize = random().nextInt(100, nVectors + 1);
int maxIterations = random().nextInt(0, 100);
int clustersPerNeighborhood = random().nextInt(0, 512);
float soarLambda = random().nextFloat(0.5f, 1.5f);
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
int targetSize = (int) ((float) nVectors / (float) nClusters);
HierarchicalKMeans hkmeans = new HierarchicalKMeans(dims, maxIterations, sampleSize, clustersPerNeighborhood, soarLambda);
KMeansResult result = hkmeans.cluster(vectors, targetSize);
float[][] centroids = result.centroids();
int[] assignments = result.assignments();
int[] soarAssignments = result.soarAssignments();
assertEquals(nClusters, centroids.length, 6);
assertEquals(nVectors, assignments.length);
if (centroids.length > 1 && clustersPerNeighborhood > 0) {
assertEquals(nVectors, soarAssignments.length);
// verify no duplicates exist
for (int i = 0; i < assignments.length; i++) {
assert assignments[i] != soarAssignments[i];
}
}
}
private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) {
List<float[]> vectors = new ArrayList<>(nSamples);
float[][] centroids = new float[nClusters][nDims];
// Generate random centroids
for (int i = 0; i < nClusters; i++) {
for (int j = 0; j < nDims; j++) {
centroids[i][j] = random().nextFloat() * 100;
}
}
// Generate data points around centroids
for (int i = 0; i < nSamples; i++) {
int cluster = random().nextInt(nClusters);
float[] vector = new float[nDims];
for (int j = 0; j < nDims; j++) {
vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5;
}
vectors.add(vector);
}
return FloatVectorValues.fromFloats(vectors, nDims);
}
}

View File

@ -0,0 +1,127 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors.cluster;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class KMeansLocalTests extends ESTestCase {
public void testKMeansNeighbors() throws IOException {
int nClusters = random().nextInt(1, 10);
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
int dims = random().nextInt(2, 20);
int sampleSize = random().nextInt(100, nVectors + 1);
int maxIterations = random().nextInt(0, 100);
int clustersPerNeighborhood = random().nextInt(0, 512);
float soarLambda = random().nextFloat(0.5f, 1.5f);
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, nClusters);
KMeansLocal.cluster(vectors, centroids, sampleSize, maxIterations);
int[] assignments = new int[vectors.size()];
int[] assignmentOrdinals = new int[vectors.size()];
for (int i = 0; i < vectors.size(); i++) {
float minDist = Float.MAX_VALUE;
int ord = -1;
for (int j = 0; j < centroids.length; j++) {
float dist = VectorUtil.squareDistance(vectors.vectorValue(i), centroids[j]);
if (dist < minDist) {
minDist = dist;
ord = j;
}
}
assignments[i] = ord;
assignmentOrdinals[i] = i;
}
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
assertEquals(nClusters, centroids.length);
assertNotNull(kMeansIntermediate.soarAssignments());
}
public void testKMeansNeighborsAllZero() throws IOException {
int nClusters = 10;
int maxIterations = 10;
int clustersPerNeighborhood = 128;
float soarLambda = 1.0f;
int nVectors = 1000;
List<float[]> vectors = new ArrayList<>();
for (int i = 0; i < nVectors; i++) {
float[] vector = new float[5];
vectors.add(vector);
}
int sampleSize = vectors.size();
FloatVectorValues fvv = FloatVectorValues.fromFloats(vectors, 5);
float[][] centroids = KMeansLocal.pickInitialCentroids(fvv, nClusters);
KMeansLocal.cluster(fvv, centroids, sampleSize, maxIterations);
int[] assignments = new int[vectors.size()];
int[] assignmentOrdinals = new int[vectors.size()];
for (int i = 0; i < vectors.size(); i++) {
float minDist = Float.MAX_VALUE;
int ord = -1;
for (int j = 0; j < centroids.length; j++) {
float dist = VectorUtil.squareDistance(fvv.vectorValue(i), centroids[j]);
if (dist < minDist) {
minDist = dist;
ord = j;
}
}
assignments[i] = ord;
assignmentOrdinals[i] = i;
}
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
kMeansLocal.cluster(fvv, kMeansIntermediate, true);
assertEquals(nClusters, centroids.length);
assertNotNull(kMeansIntermediate.soarAssignments());
for (float[] centroid : centroids) {
for (float v : centroid) {
if (v > 0.0000001f) {
assertEquals(0.0f, v, 0.00000001f);
}
}
}
}
private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) {
List<float[]> vectors = new ArrayList<>(nSamples);
float[][] centroids = new float[nClusters][nDims];
// Generate random centroids
for (int i = 0; i < nClusters; i++) {
for (int j = 0; j < nDims; j++) {
centroids[i][j] = random().nextFloat() * 100;
}
}
// Generate data points around centroids
for (int i = 0; i < nSamples; i++) {
int cluster = random().nextInt(nClusters);
float[] vector = new float[nDims];
for (int j = 0; j < nDims; j++) {
vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5;
}
vectors.add(vector);
}
return FloatVectorValues.fromFloats(vectors, nDims);
}
}

View File

@ -14,10 +14,10 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* Modifications copyright (C) 2025 Elasticsearch B.V.
*
* Modifications copyright (C) 2024 Elasticsearch B.V.
*/
package org.elasticsearch.index.codec.vectors;
package org.elasticsearch.index.codec.vectors.cluster;
import org.elasticsearch.test.ESTestCase;