IVF Hierarchical KMeans Flush & Merge (#128675)
added hierarchical kmeans as a clustering algorithm to better partitionin the space when running ivf on flush and merge
This commit is contained in:
parent
1e13409049
commit
47d4b983af
|
@ -276,12 +276,11 @@ class KnnSearcher {
|
|||
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException {
|
||||
Query knnQuery;
|
||||
int topK = this.topK;
|
||||
int efSearch = this.efSearch;
|
||||
if (overSamplingFactor > 1f) {
|
||||
// oversample the topK results to get more candidates for the final result
|
||||
topK = (int) Math.ceil(topK * overSamplingFactor);
|
||||
efSearch = Math.max(topK, efSearch);
|
||||
}
|
||||
int efSearch = Math.max(topK, this.efSearch);
|
||||
if (indexType == KnnIndexTester.IndexType.IVF) {
|
||||
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, null, nProbe);
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.internal.hppc.IntArrayList;
|
||||
|
||||
final class CentroidAssignments {
|
||||
|
||||
private final int numCentroids;
|
||||
private final float[][] cachedCentroids;
|
||||
private final IntArrayList[] assignmentsByCluster;
|
||||
|
||||
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, IntArrayList[] assignmentsByCluster) {
|
||||
this.numCentroids = numCentroids;
|
||||
this.cachedCentroids = cachedCentroids;
|
||||
this.assignmentsByCluster = assignmentsByCluster;
|
||||
}
|
||||
|
||||
CentroidAssignments(float[][] centroids, IntArrayList[] assignmentsByCluster) {
|
||||
this(centroids.length, centroids, assignmentsByCluster);
|
||||
}
|
||||
|
||||
CentroidAssignments(int numCentroids, IntArrayList[] assignmentsByCluster) {
|
||||
this(numCentroids, null, assignmentsByCluster);
|
||||
}
|
||||
|
||||
// Getters and setters
|
||||
public int numCentroids() {
|
||||
return numCentroids;
|
||||
}
|
||||
|
||||
public float[][] cachedCentroids() {
|
||||
return cachedCentroids;
|
||||
}
|
||||
|
||||
public IntArrayList[] assignmentsByCluster() {
|
||||
return assignmentsByCluster;
|
||||
}
|
||||
}
|
|
@ -112,15 +112,6 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader {
|
|||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) {
|
||||
FieldEntry entry = fields.get(info.number);
|
||||
if (entry == null) {
|
||||
return null;
|
||||
}
|
||||
return new OffHeapCentroidFloatVectorValues(numCentroids, indexInput, info.getVectorDimension());
|
||||
}
|
||||
|
||||
@Override
|
||||
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
|
||||
throws IOException {
|
||||
|
|
|
@ -14,21 +14,19 @@ import org.apache.lucene.index.FieldInfo;
|
|||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntArrayList;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
|
||||
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
|
||||
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
|
||||
|
@ -36,16 +34,12 @@ import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packA
|
|||
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT;
|
||||
|
||||
/**
|
||||
* Default implementation of {@link IVFVectorsWriter}. It uses {@link KMeans} algorithm to
|
||||
* partition the vector space, and then stores the centroids an posting list in a sequential
|
||||
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
|
||||
* partition the vector space, and then stores the centroids and posting list in a sequential
|
||||
* fashion.
|
||||
*/
|
||||
public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
||||
|
||||
static final float SOAR_LAMBDA = 1.0f;
|
||||
// What percentage of the centroids do we do a second check on for SOAR assignment
|
||||
static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f;
|
||||
|
||||
private final int vectorPerCluster;
|
||||
|
||||
public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException {
|
||||
|
@ -53,77 +47,81 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
this.vectorPerCluster = vectorPerCluster;
|
||||
}
|
||||
|
||||
@Override
|
||||
CentroidAssignmentScorer calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput centroidOutput,
|
||||
float[] globalCentroid
|
||||
) throws IOException {
|
||||
if (floatVectorValues.size() == 0) {
|
||||
return CentroidAssignmentScorer.EMPTY;
|
||||
}
|
||||
// calculate the centroids
|
||||
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
|
||||
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
|
||||
final KMeans.Results kMeans = KMeans.cluster(
|
||||
floatVectorValues,
|
||||
desiredClusters,
|
||||
false,
|
||||
42L,
|
||||
KMeans.KmeansInitializationMethod.PLUS_PLUS,
|
||||
null,
|
||||
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
|
||||
1,
|
||||
15,
|
||||
desiredClusters * 256
|
||||
);
|
||||
float[][] centroids = kMeans.centroids();
|
||||
// write them
|
||||
writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);
|
||||
return new OnHeapCentroidAssignmentScorer(centroids);
|
||||
}
|
||||
|
||||
@Override
|
||||
long[] buildAndWritePostingsLists(
|
||||
FieldInfo fieldInfo,
|
||||
InfoStream infoStream,
|
||||
CentroidAssignmentScorer randomCentroidScorer,
|
||||
CentroidSupplier centroidSupplier,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput
|
||||
IndexOutput postingsOutput,
|
||||
InfoStream infoStream,
|
||||
IntArrayList[] assignmentsByCluster
|
||||
) throws IOException {
|
||||
IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()];
|
||||
for (int i = 0; i < randomCentroidScorer.size(); i++) {
|
||||
clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4);
|
||||
}
|
||||
assignCentroids(randomCentroidScorer, floatVectorValues, clusters);
|
||||
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
printClusterQualityStatistics(clusters, infoStream);
|
||||
}
|
||||
// write the posting lists
|
||||
final long[] offsets = new long[randomCentroidScorer.size()];
|
||||
final long[] offsets = new long[centroidSupplier.size()];
|
||||
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
|
||||
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||
for (int i = 0; i < randomCentroidScorer.size(); i++) {
|
||||
float[] centroid = randomCentroidScorer.centroid(i);
|
||||
|
||||
for (int c = 0; c < centroidSupplier.size(); c++) {
|
||||
float[] centroid = centroidSupplier.centroid(c);
|
||||
binarizedByteVectorValues.centroid = centroid;
|
||||
// TODO sort by distance to the centroid
|
||||
IntArrayList cluster = clusters[i];
|
||||
// TODO: add back in sorting vectors by distance to centroid
|
||||
IntArrayList cluster = assignmentsByCluster[c];
|
||||
// TODO align???
|
||||
offsets[i] = postingsOutput.getFilePointer();
|
||||
offsets[c] = postingsOutput.getFilePointer();
|
||||
int size = cluster.size();
|
||||
postingsOutput.writeVInt(size);
|
||||
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
|
||||
// TODO we might want to consider putting the docIds in a separate file
|
||||
// to aid with only having to fetch vectors from slower storage when they are required
|
||||
// keeping them in the same file indicates we pull the entire file into cache
|
||||
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput);
|
||||
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
|
||||
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
|
||||
}
|
||||
|
||||
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
printClusterQualityStatistics(assignmentsByCluster, infoStream);
|
||||
}
|
||||
|
||||
return offsets;
|
||||
}
|
||||
|
||||
private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) {
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = Float.MIN_VALUE;
|
||||
float mean = 0;
|
||||
float m2 = 0;
|
||||
// iteratively compute the variance & mean
|
||||
int count = 0;
|
||||
for (IntArrayList cluster : clusters) {
|
||||
count += 1;
|
||||
if (cluster == null) {
|
||||
continue;
|
||||
}
|
||||
float delta = cluster.size() - mean;
|
||||
mean += delta / count;
|
||||
m2 += delta * (cluster.size() - mean);
|
||||
min = Math.min(min, cluster.size());
|
||||
max = Math.max(max, cluster.size());
|
||||
}
|
||||
float variance = m2 / (clusters.length - 1);
|
||||
infoStream.message(
|
||||
IVF_VECTOR_COMPONENT,
|
||||
"Centroid count: "
|
||||
+ clusters.length
|
||||
+ " min: "
|
||||
+ min
|
||||
+ " max: "
|
||||
+ max
|
||||
+ " mean: "
|
||||
+ mean
|
||||
+ " stdDev: "
|
||||
+ Math.sqrt(variance)
|
||||
+ " variance: "
|
||||
+ variance
|
||||
);
|
||||
}
|
||||
|
||||
private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
|
||||
throws IOException {
|
||||
int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1;
|
||||
|
@ -173,13 +171,8 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
CentroidAssignmentScorer createCentroidScorer(
|
||||
IndexInput centroidsInput,
|
||||
int numCentroids,
|
||||
FieldInfo fieldInfo,
|
||||
float[] globalCentroid
|
||||
) {
|
||||
return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo);
|
||||
CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
|
||||
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
|
||||
}
|
||||
|
||||
static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
|
||||
|
@ -188,24 +181,8 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
|
||||
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
|
||||
// TODO do we want to store these distances as well for future use?
|
||||
float[] distances = new float[centroids.length];
|
||||
for (int i = 0; i < centroids.length; i++) {
|
||||
distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid);
|
||||
}
|
||||
// sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest
|
||||
// (largest)
|
||||
for (int i = 0; i < centroids.length; i++) {
|
||||
for (int j = i + 1; j < centroids.length; j++) {
|
||||
if (distances[i] > distances[j]) {
|
||||
float[] tmp = centroids[i];
|
||||
centroids[i] = centroids[j];
|
||||
centroids[j] = tmp;
|
||||
float tmpDistance = distances[i];
|
||||
distances[i] = distances[j];
|
||||
distances[j] = tmpDistance;
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: sort centroids by global centroid (was doing so previously here)
|
||||
// TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
|
||||
for (float[] centroid : centroids) {
|
||||
System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
|
||||
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(
|
||||
|
@ -223,190 +200,60 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
static float[][] gatherInitCentroids(
|
||||
List<FloatVectorValues> centroidList,
|
||||
List<SegmentCentroid> segmentCentroids,
|
||||
int desiredClusters,
|
||||
FieldInfo fieldInfo,
|
||||
MergeState mergeState
|
||||
) throws IOException {
|
||||
if (centroidList.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
long startTime = System.nanoTime();
|
||||
// sort centroid list by floatvector size
|
||||
FloatVectorValues baseSegment = centroidList.get(0);
|
||||
for (var l : centroidList) {
|
||||
if (l.size() > baseSegment.size()) {
|
||||
baseSegment = l;
|
||||
}
|
||||
}
|
||||
float[] scratch = new float[fieldInfo.getVectorDimension()];
|
||||
float minimumDistance = Float.MAX_VALUE;
|
||||
for (int j = 0; j < baseSegment.size(); j++) {
|
||||
System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension());
|
||||
for (int k = j + 1; k < baseSegment.size(); k++) {
|
||||
float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k));
|
||||
if (d < minimumDistance) {
|
||||
minimumDistance = d;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
mergeState.infoStream.message(
|
||||
IVF_VECTOR_COMPONENT,
|
||||
"Agglomerative cluster min distance: " + minimumDistance + " From biggest segment: " + baseSegment.size()
|
||||
);
|
||||
}
|
||||
int[] labels = new int[segmentCentroids.size()];
|
||||
// loop over segments
|
||||
int clusterIdx = 0;
|
||||
// keep track of all inter-centroid distances,
|
||||
// using less than centroid * centroid space (e.g. not keeping track of duplicates)
|
||||
for (int i = 0; i < segmentCentroids.size(); i++) {
|
||||
if (labels[i] == 0) {
|
||||
clusterIdx += 1;
|
||||
labels[i] = clusterIdx;
|
||||
}
|
||||
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
|
||||
System.arraycopy(
|
||||
centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid),
|
||||
0,
|
||||
scratch,
|
||||
0,
|
||||
baseSegment.dimension()
|
||||
);
|
||||
for (int j = i + 1; j < segmentCentroids.size(); j++) {
|
||||
float d = VectorUtil.squareDistance(
|
||||
scratch,
|
||||
centroidList.get(segmentCentroids.get(j).segment()).vectorValue(segmentCentroids.get(j).centroid())
|
||||
);
|
||||
if (d < minimumDistance / 2) {
|
||||
if (labels[j] == 0) {
|
||||
labels[j] = labels[i];
|
||||
} else {
|
||||
for (int k = 0; k < labels.length; k++) {
|
||||
if (labels[k] == labels[j]) {
|
||||
labels[k] = labels[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()];
|
||||
int[] sum = new int[clusterIdx];
|
||||
for (int i = 0; i < segmentCentroids.size(); i++) {
|
||||
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
|
||||
int label = labels[i];
|
||||
FloatVectorValues segment = centroidList.get(segmentCentroid.segment());
|
||||
float[] vector = segment.vectorValue(segmentCentroid.centroid);
|
||||
for (int j = 0; j < vector.length; j++) {
|
||||
initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize);
|
||||
}
|
||||
sum[label - 1] += segmentCentroid.centroidSize;
|
||||
}
|
||||
for (int i = 0; i < initCentroids.length; i++) {
|
||||
if (sum[i] == 0 || sum[i] == 1) {
|
||||
continue;
|
||||
}
|
||||
for (int j = 0; j < initCentroids[i].length; j++) {
|
||||
initCentroids[i][j] /= sum[i];
|
||||
}
|
||||
}
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
mergeState.infoStream.message(
|
||||
IVF_VECTOR_COMPONENT,
|
||||
"Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0)
|
||||
);
|
||||
mergeState.infoStream.message(
|
||||
IVF_VECTOR_COMPONENT,
|
||||
"Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters
|
||||
);
|
||||
}
|
||||
return initCentroids;
|
||||
}
|
||||
|
||||
record SegmentCentroid(int segment, int centroid, int centroidSize) {}
|
||||
|
||||
/**
|
||||
* Calculate the centroids for the given field and write them to the given
|
||||
* temporary centroid output.
|
||||
* When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments.
|
||||
* To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than
|
||||
* the largest segments intra-cluster distance are merged into a single centroid.
|
||||
* The resulting centroids are then used to initialize the KMeans algorithm.
|
||||
*
|
||||
* @param fieldInfo merging field info
|
||||
* @param floatVectorValues the float vector values to merge
|
||||
* @param temporaryCentroidOutput the temporary centroid output
|
||||
* @param mergeState the merge state
|
||||
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
|
||||
* @return the number of centroids written
|
||||
* @throws IOException if an I/O error occurs
|
||||
*/
|
||||
@Override
|
||||
protected int calculateAndWriteCentroids(
|
||||
CentroidAssignments calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput temporaryCentroidOutput,
|
||||
IndexOutput centroidOutput,
|
||||
MergeState mergeState,
|
||||
float[] globalCentroid
|
||||
) throws IOException {
|
||||
if (floatVectorValues.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
|
||||
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
|
||||
// init centroids from merge state
|
||||
List<FloatVectorValues> centroidList = new ArrayList<>();
|
||||
List<SegmentCentroid> segmentCentroids = new ArrayList<>(desiredClusters);
|
||||
|
||||
int segmentIdx = 0;
|
||||
for (var reader : mergeState.knnVectorsReaders) {
|
||||
IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name);
|
||||
if (ivfVectorsReader == null) {
|
||||
continue;
|
||||
// TODO: take advantage of prior generated clusters from mergeState in the future
|
||||
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, mergeState.infoStream, globalCentroid, false);
|
||||
}
|
||||
|
||||
FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo);
|
||||
if (centroid == null) {
|
||||
continue;
|
||||
}
|
||||
centroidList.add(centroid);
|
||||
for (int i = 0; i < centroid.size(); i++) {
|
||||
int size = ivfVectorsReader.centroidSize(fieldInfo.name, i);
|
||||
if (size == 0) {
|
||||
continue;
|
||||
}
|
||||
segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size));
|
||||
}
|
||||
segmentIdx++;
|
||||
CentroidAssignments calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput centroidOutput,
|
||||
InfoStream infoStream,
|
||||
float[] globalCentroid
|
||||
) throws IOException {
|
||||
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, infoStream, globalCentroid, true);
|
||||
}
|
||||
|
||||
float[][] initCentroids = gatherInitCentroids(centroidList, segmentCentroids, desiredClusters, fieldInfo, mergeState);
|
||||
/**
|
||||
* Calculate the centroids for the given field and write them to the given centroid output.
|
||||
* We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
|
||||
*
|
||||
* @param fieldInfo merging field info
|
||||
* @param floatVectorValues the float vector values to merge
|
||||
* @param centroidOutput the centroid output
|
||||
* @param infoStream the merge state
|
||||
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
|
||||
* @param cacheCentroids whether the centroids are kept or discarded once computed
|
||||
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
|
||||
* @throws IOException if an I/O error occurs
|
||||
*/
|
||||
CentroidAssignments calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput centroidOutput,
|
||||
InfoStream infoStream,
|
||||
float[] globalCentroid,
|
||||
boolean cacheCentroids
|
||||
) throws IOException {
|
||||
|
||||
// FIXME: run a custom version of KMeans that is just better...
|
||||
long nanoTime = System.nanoTime();
|
||||
final KMeans.Results kMeans = KMeans.cluster(
|
||||
floatVectorValues,
|
||||
desiredClusters,
|
||||
false,
|
||||
42L,
|
||||
KMeans.KmeansInitializationMethod.PLUS_PLUS,
|
||||
initCentroids,
|
||||
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
|
||||
1,
|
||||
5,
|
||||
desiredClusters * 64
|
||||
);
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
|
||||
}
|
||||
float[][] centroids = kMeans.centroids();
|
||||
|
||||
// write them
|
||||
// calculate the global centroid from all the centroids:
|
||||
// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
|
||||
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
|
||||
float[][] centroids = kMeansResult.centroids();
|
||||
int[] assignments = kMeansResult.assignments();
|
||||
int[] soarAssignments = kMeansResult.soarAssignments();
|
||||
|
||||
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
|
||||
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
|
||||
// TODO: push this logic into vector util?
|
||||
for (float[] centroid : centroids) {
|
||||
for (int j = 0; j < centroid.length; j++) {
|
||||
globalCentroid[j] += centroid[j];
|
||||
|
@ -415,197 +262,41 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
for (int j = 0; j < globalCentroid.length; j++) {
|
||||
globalCentroid[j] /= centroids.length;
|
||||
}
|
||||
writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput);
|
||||
return centroids.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
long[] buildAndWritePostingsLists(
|
||||
FieldInfo fieldInfo,
|
||||
CentroidAssignmentScorer centroidAssignmentScorer,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput,
|
||||
MergeState mergeState
|
||||
) throws IOException {
|
||||
IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()];
|
||||
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
|
||||
clusters[i] = new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4);
|
||||
}
|
||||
long nanoTime = System.nanoTime();
|
||||
// Can we do a pre-filter by finding the nearest centroids to the original vector centroids?
|
||||
// We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing
|
||||
assignCentroids(centroidAssignmentScorer, floatVectorValues, clusters);
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
|
||||
}
|
||||
// write centroids
|
||||
writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);
|
||||
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
printClusterQualityStatistics(clusters, mergeState.infoStream);
|
||||
}
|
||||
// write the posting lists
|
||||
final long[] offsets = new long[centroidAssignmentScorer.size()];
|
||||
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
|
||||
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
|
||||
float[] centroid = centroidAssignmentScorer.centroid(i);
|
||||
binarizedByteVectorValues.centroid = centroid;
|
||||
// TODO: sort by distance to the centroid
|
||||
IntArrayList cluster = clusters[i];
|
||||
// TODO align???
|
||||
offsets[i] = postingsOutput.getFilePointer();
|
||||
int size = cluster.size();
|
||||
postingsOutput.writeVInt(size);
|
||||
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
|
||||
// TODO we might want to consider putting the docIds in a separate file
|
||||
// to aid with only having to fetch vectors from slower storage when they are required
|
||||
// keeping them in the same file indicates we pull the entire file into cache
|
||||
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
|
||||
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) {
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = Float.MIN_VALUE;
|
||||
float mean = 0;
|
||||
float m2 = 0;
|
||||
// iteratively compute the variance & mean
|
||||
int count = 0;
|
||||
for (IntArrayList cluster : clusters) {
|
||||
count += 1;
|
||||
if (cluster == null) {
|
||||
continue;
|
||||
}
|
||||
float delta = cluster.size() - mean;
|
||||
mean += delta / count;
|
||||
m2 += delta * (cluster.size() - mean);
|
||||
min = Math.min(min, cluster.size());
|
||||
max = Math.max(max, cluster.size());
|
||||
}
|
||||
float variance = m2 / (clusters.length - 1);
|
||||
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
infoStream.message(
|
||||
IVF_VECTOR_COMPONENT,
|
||||
"Centroid count: "
|
||||
+ clusters.length
|
||||
+ " min: "
|
||||
+ min
|
||||
+ " max: "
|
||||
+ max
|
||||
+ " mean: "
|
||||
+ mean
|
||||
+ " stdDev: "
|
||||
+ Math.sqrt(variance)
|
||||
+ " variance: "
|
||||
+ variance
|
||||
"calculate centroids and assign vectors time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)
|
||||
);
|
||||
infoStream.message(IVF_VECTOR_COMPONENT, "final centroid count: " + centroids.length);
|
||||
}
|
||||
|
||||
static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) throws IOException {
|
||||
int numCentroids = scorer.size();
|
||||
// we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible
|
||||
int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO);
|
||||
int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck);
|
||||
NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true);
|
||||
OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1);
|
||||
float[] scratch = new float[vectors.dimension()];
|
||||
for (int docID = 0; docID < vectors.size(); docID++) {
|
||||
float[] vector = vectors.vectorValue(docID);
|
||||
scorer.setScoringVector(vector);
|
||||
int bestCentroid = 0;
|
||||
float bestScore = Float.MAX_VALUE;
|
||||
if (numCentroids > 1) {
|
||||
for (short c = 0; c < numCentroids; c++) {
|
||||
float squareDist = scorer.score(c);
|
||||
neighborsToCheck.insertWithOverflow(c, squareDist);
|
||||
}
|
||||
// pop the best
|
||||
int sz = neighborsToCheck.size();
|
||||
int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores);
|
||||
// Set the size to the number of neighbors we actually found
|
||||
ordScoreIterator.setSize(sz);
|
||||
bestScore = ordScoreIterator.getScore(best);
|
||||
bestCentroid = ordScoreIterator.getOrd(best);
|
||||
}
|
||||
clusters[bestCentroid].add(docID);
|
||||
if (soarClusterCheckCount > 0) {
|
||||
assignCentroidSOAR(
|
||||
ordScoreIterator,
|
||||
docID,
|
||||
bestCentroid,
|
||||
scorer.centroid(bestCentroid),
|
||||
bestScore,
|
||||
scratch,
|
||||
scorer,
|
||||
vector,
|
||||
clusters
|
||||
);
|
||||
}
|
||||
neighborsToCheck.clear();
|
||||
IntArrayList[] assignmentsByCluster = new IntArrayList[centroids.length];
|
||||
for (int c = 0; c < centroids.length; c++) {
|
||||
IntArrayList cluster = new IntArrayList(vectorPerCluster);
|
||||
for (int j = 0; j < assignments.length; j++) {
|
||||
if (assignments[j] == c) {
|
||||
cluster.add(j);
|
||||
}
|
||||
}
|
||||
|
||||
static void assignCentroidSOAR(
|
||||
OrdScoreIterator centroidsToCheck,
|
||||
int vecOrd,
|
||||
int bestCentroidId,
|
||||
float[] bestCentroid,
|
||||
float bestScore,
|
||||
float[] scratch,
|
||||
CentroidAssignmentScorer scorer,
|
||||
float[] vector,
|
||||
IntArrayList[] clusters
|
||||
) throws IOException {
|
||||
ESVectorUtil.subtract(vector, bestCentroid, scratch);
|
||||
int bestSecondaryCentroid = -1;
|
||||
float minDist = Float.MAX_VALUE;
|
||||
for (int i = 0; i < centroidsToCheck.size(); i++) {
|
||||
float score = centroidsToCheck.getScore(i);
|
||||
int centroidOrdinal = centroidsToCheck.getOrd(i);
|
||||
if (centroidOrdinal == bestCentroidId) {
|
||||
continue;
|
||||
}
|
||||
float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch);
|
||||
score += SOAR_LAMBDA * proj * proj / bestScore;
|
||||
if (score < minDist) {
|
||||
bestSecondaryCentroid = centroidOrdinal;
|
||||
minDist = score;
|
||||
}
|
||||
}
|
||||
if (bestSecondaryCentroid != -1) {
|
||||
clusters[bestSecondaryCentroid].add(vecOrd);
|
||||
for (int j = 0; j < soarAssignments.length; j++) {
|
||||
if (soarAssignments[j] == c) {
|
||||
cluster.add(j);
|
||||
}
|
||||
}
|
||||
|
||||
static class OrdScoreIterator {
|
||||
private final int[] ords;
|
||||
private final float[] scores;
|
||||
private int idx = 0;
|
||||
|
||||
OrdScoreIterator(int size) {
|
||||
this.ords = new int[size];
|
||||
this.scores = new float[size];
|
||||
cluster.trimToSize();
|
||||
assignmentsByCluster[c] = cluster;
|
||||
}
|
||||
|
||||
int setSize(int size) {
|
||||
if (size > ords.length) {
|
||||
throw new IllegalArgumentException("size must be <= " + ords.length);
|
||||
}
|
||||
this.idx = size;
|
||||
return size;
|
||||
}
|
||||
|
||||
int getOrd(int idx) {
|
||||
return ords[idx];
|
||||
}
|
||||
|
||||
float getScore(int idx) {
|
||||
return scores[idx];
|
||||
}
|
||||
|
||||
int size() {
|
||||
return idx;
|
||||
if (cacheCentroids) {
|
||||
return new CentroidAssignments(centroids, assignmentsByCluster);
|
||||
} else {
|
||||
return new CentroidAssignments(centroids.length, assignmentsByCluster);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -650,16 +341,25 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
|
||||
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
|
||||
throws IOException {
|
||||
indexOutput.writeBytes(binaryValue, binaryValue.length);
|
||||
indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
|
||||
indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval()));
|
||||
indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
|
||||
assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
|
||||
indexOutput.writeShort((short) corrections.quantizedComponentSum());
|
||||
}
|
||||
|
||||
static class OffHeapCentroidSupplier implements CentroidSupplier {
|
||||
private final IndexInput centroidsInput;
|
||||
private final int numCentroids;
|
||||
private final int dimension;
|
||||
private final float[] scratch;
|
||||
private float[] q;
|
||||
private final long rawCentroidOffset;
|
||||
private int currOrd = -1;
|
||||
|
||||
OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
|
||||
OffHeapCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
|
||||
this.centroidsInput = centroidsInput;
|
||||
this.numCentroids = numCentroids;
|
||||
this.dimension = info.getVectorDimension();
|
||||
|
@ -682,55 +382,5 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
this.currOrd = centroidOrdinal;
|
||||
return scratch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setScoringVector(float[] vector) {
|
||||
q = vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int centroidOrdinal) throws IOException {
|
||||
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO throw away rawCentroids
|
||||
static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
|
||||
private final float[][] centroids;
|
||||
private float[] q;
|
||||
|
||||
OnHeapCentroidAssignmentScorer(float[][] centroids) {
|
||||
this.centroids = centroids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return centroids.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setScoringVector(float[] vector) {
|
||||
q = vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] centroid(int centroidOrdinal) throws IOException {
|
||||
return centroids[centroidOrdinal];
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int centroidOrdinal) throws IOException {
|
||||
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
|
||||
}
|
||||
}
|
||||
|
||||
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
|
||||
throws IOException {
|
||||
indexOutput.writeBytes(binaryValue, binaryValue.length);
|
||||
indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
|
||||
indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval()));
|
||||
indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
|
||||
assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
|
||||
indexOutput.writeShort((short) corrections.quantizedComponentSum());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -97,23 +97,6 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
|
|||
IndexInput clusters
|
||||
) throws IOException;
|
||||
|
||||
protected abstract FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException;
|
||||
|
||||
public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException {
|
||||
FieldEntry entry = fields.get(fieldInfo.number);
|
||||
if (entry == null) {
|
||||
return null;
|
||||
}
|
||||
return getCentroids(entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo);
|
||||
}
|
||||
|
||||
int centroidSize(String fieldName, int centroidOrdinal) throws IOException {
|
||||
FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName);
|
||||
FieldEntry entry = fields.get(fieldInfo.number);
|
||||
ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]);
|
||||
return ivfClusters.readVInt();
|
||||
}
|
||||
|
||||
private static IndexInput openDataInput(
|
||||
SegmentReadState state,
|
||||
int versionMeta,
|
||||
|
|
|
@ -11,11 +11,9 @@ package org.elasticsearch.index.codec.vectors;
|
|||
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
|
@ -25,6 +23,7 @@ import org.apache.lucene.index.SegmentWriteState;
|
|||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntArrayList;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
@ -123,38 +122,32 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
|||
return rawVectorDelegate;
|
||||
}
|
||||
|
||||
protected abstract int calculateAndWriteCentroids(
|
||||
abstract CentroidAssignments calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput temporaryCentroidOutput,
|
||||
IndexOutput centroidOutput,
|
||||
MergeState mergeState,
|
||||
float[] globalCentroid
|
||||
) throws IOException;
|
||||
|
||||
abstract long[] buildAndWritePostingsLists(
|
||||
FieldInfo fieldInfo,
|
||||
CentroidAssignmentScorer scorer,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput,
|
||||
MergeState mergeState
|
||||
) throws IOException;
|
||||
|
||||
abstract CentroidAssignmentScorer calculateAndWriteCentroids(
|
||||
abstract CentroidAssignments calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput centroidOutput,
|
||||
InfoStream infoStream,
|
||||
float[] globalCentroid
|
||||
) throws IOException;
|
||||
|
||||
abstract long[] buildAndWritePostingsLists(
|
||||
FieldInfo fieldInfo,
|
||||
InfoStream infoStream,
|
||||
CentroidAssignmentScorer scorer,
|
||||
CentroidSupplier centroidSupplier,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput
|
||||
IndexOutput postingsOutput,
|
||||
InfoStream infoStream,
|
||||
IntArrayList[] assignmentsByCluster
|
||||
) throws IOException;
|
||||
|
||||
abstract CentroidAssignmentScorer createCentroidScorer(
|
||||
abstract CentroidSupplier createCentroidSupplier(
|
||||
IndexInput centroidsInput,
|
||||
int numCentroids,
|
||||
FieldInfo fieldInfo,
|
||||
|
@ -166,33 +159,31 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
|||
rawVectorDelegate.flush(maxDoc, sortMap);
|
||||
for (FieldWriter fieldWriter : fieldWriters) {
|
||||
float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
|
||||
// calculate global centroid
|
||||
for (var vector : fieldWriter.delegate.getVectors()) {
|
||||
for (int i = 0; i < globalCentroid.length; i++) {
|
||||
globalCentroid[i] += vector[i];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < globalCentroid.length; i++) {
|
||||
globalCentroid[i] /= fieldWriter.delegate.getVectors().size();
|
||||
}
|
||||
// build a float vector values with random access
|
||||
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
|
||||
// build centroids
|
||||
long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
|
||||
final CentroidAssignmentScorer centroidAssignmentScorer = calculateAndWriteCentroids(
|
||||
|
||||
final CentroidAssignments centroidAssignments = calculateAndWriteCentroids(
|
||||
fieldWriter.fieldInfo,
|
||||
floatVectorValues,
|
||||
ivfCentroids,
|
||||
segmentWriteState.infoStream,
|
||||
globalCentroid
|
||||
);
|
||||
|
||||
CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.cachedCentroids());
|
||||
|
||||
long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
|
||||
final long[] offsets = buildAndWritePostingsLists(
|
||||
fieldWriter.fieldInfo,
|
||||
segmentWriteState.infoStream,
|
||||
centroidAssignmentScorer,
|
||||
centroidSupplier,
|
||||
floatVectorValues,
|
||||
ivfClusters
|
||||
ivfClusters,
|
||||
segmentWriteState.infoStream,
|
||||
centroidAssignments.assignmentsByCluster()
|
||||
);
|
||||
// write posting lists
|
||||
writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
|
||||
}
|
||||
}
|
||||
|
@ -240,16 +231,6 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
|||
};
|
||||
}
|
||||
|
||||
static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {
|
||||
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
|
||||
vectorsReader = candidateReader.getFieldReader(fieldName);
|
||||
}
|
||||
if (vectorsReader instanceof IVFVectorsReader reader) {
|
||||
return reader;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
|
||||
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
|
@ -277,22 +258,25 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
|||
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
|
||||
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
|
||||
success = false;
|
||||
CentroidAssignmentScorer centroidAssignmentScorer;
|
||||
long centroidOffset;
|
||||
long centroidLength;
|
||||
String centroidTempName = null;
|
||||
int numCentroids;
|
||||
IndexOutput centroidTemp = null;
|
||||
CentroidAssignments centroidAssignments;
|
||||
try {
|
||||
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
|
||||
centroidTempName = centroidTemp.getName();
|
||||
numCentroids = calculateAndWriteCentroids(
|
||||
|
||||
centroidAssignments = calculateAndWriteCentroids(
|
||||
fieldInfo,
|
||||
floatVectorValues,
|
||||
centroidTemp,
|
||||
mergeState,
|
||||
calculatedGlobalCentroid
|
||||
);
|
||||
numCentroids = centroidAssignments.numCentroids();
|
||||
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false && centroidTempName != null) {
|
||||
|
@ -311,21 +295,28 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
|||
CodecUtil.writeFooter(centroidTemp);
|
||||
IOUtils.close(centroidTemp);
|
||||
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
|
||||
try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
|
||||
ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength());
|
||||
try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
|
||||
ivfCentroids.copyBytes(centroidsInput, centroidsInput.length() - CodecUtil.footerLength());
|
||||
centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
|
||||
centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid);
|
||||
assert centroidAssignmentScorer.size() == numCentroids;
|
||||
|
||||
CentroidSupplier centroidSupplier = createCentroidSupplier(
|
||||
centroidsInput,
|
||||
numCentroids,
|
||||
fieldInfo,
|
||||
calculatedGlobalCentroid
|
||||
);
|
||||
|
||||
// build a float vector values with random access
|
||||
// build centroids
|
||||
final long[] offsets = buildAndWritePostingsLists(
|
||||
fieldInfo,
|
||||
centroidAssignmentScorer,
|
||||
centroidSupplier,
|
||||
floatVectorValues,
|
||||
ivfClusters,
|
||||
mergeState
|
||||
mergeState.infoStream,
|
||||
centroidAssignments.assignmentsByCluster()
|
||||
);
|
||||
assert offsets.length == centroidAssignmentScorer.size();
|
||||
assert offsets.length == centroidSupplier.size();
|
||||
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);
|
||||
}
|
||||
} finally {
|
||||
|
@ -453,8 +444,8 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
|||
|
||||
private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter<float[]> delegate) {}
|
||||
|
||||
interface CentroidAssignmentScorer {
|
||||
CentroidAssignmentScorer EMPTY = new CentroidAssignmentScorer() {
|
||||
interface CentroidSupplier {
|
||||
CentroidSupplier EMPTY = new CentroidSupplier() {
|
||||
@Override
|
||||
public int size() {
|
||||
return 0;
|
||||
|
@ -464,24 +455,29 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
|||
public float[] centroid(int centroidOrdinal) {
|
||||
throw new IllegalStateException("No centroids");
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int centroidOrdinal) {
|
||||
throw new IllegalStateException("No centroids");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setScoringVector(float[] vector) {
|
||||
throw new IllegalStateException("No centroids");
|
||||
}
|
||||
};
|
||||
|
||||
int size();
|
||||
|
||||
float[] centroid(int centroidOrdinal) throws IOException;
|
||||
}
|
||||
|
||||
void setScoringVector(float[] vector);
|
||||
// TODO throw away rawCentroids
|
||||
static class OnHeapCentroidSupplier implements CentroidSupplier {
|
||||
private final float[][] centroids;
|
||||
|
||||
float score(int centroidOrdinal) throws IOException;
|
||||
OnHeapCentroidSupplier(float[][] centroids) {
|
||||
this.centroids = centroids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return centroids.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] centroid(int centroidOrdinal) throws IOException {
|
||||
return centroids[centroidOrdinal];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,494 +0,0 @@
|
|||
/*
|
||||
* @notice
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
* Modifications copyright (C) 2025 Elasticsearch B.V.
|
||||
*/
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntArrayList;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.index.codec.vectors.SampleReader.createSampleReader;
|
||||
|
||||
/** KMeans clustering algorithm for vectors */
|
||||
class KMeans {
|
||||
public static final int DEFAULT_RESTARTS = 1;
|
||||
public static final int DEFAULT_ITRS = 10;
|
||||
public static final int DEFAULT_SAMPLE_VECTORS_PER_CENTROID = 128;
|
||||
|
||||
private static final float EPS = 1f / 1024f;
|
||||
private final FloatVectorValues vectors;
|
||||
private final int numVectors;
|
||||
private final int numCentroids;
|
||||
private final Random random;
|
||||
private final KmeansInitializationMethod initializationMethod;
|
||||
private final float[][] initCentroids;
|
||||
private final int restarts;
|
||||
private final int iters;
|
||||
|
||||
/**
|
||||
* Cluster vectors into a given number of clusters
|
||||
*
|
||||
* @param vectors float vectors
|
||||
* @param similarityFunction vector similarity function. For COSINE similarity, vectors must be
|
||||
* normalized.
|
||||
* @param numClusters number of cluster to cluster vector into
|
||||
* @return results of clustering: produced centroids and for each vector its centroid
|
||||
* @throws IOException when if there is an error accessing vectors
|
||||
*/
|
||||
static Results cluster(FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int numClusters) throws IOException {
|
||||
return cluster(
|
||||
vectors,
|
||||
numClusters,
|
||||
true,
|
||||
42L,
|
||||
KmeansInitializationMethod.PLUS_PLUS,
|
||||
null,
|
||||
similarityFunction == VectorSimilarityFunction.COSINE,
|
||||
DEFAULT_RESTARTS,
|
||||
DEFAULT_ITRS,
|
||||
DEFAULT_SAMPLE_VECTORS_PER_CENTROID * numClusters
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Expert: Cluster vectors into a given number of clusters
|
||||
*
|
||||
* @param vectors float vectors
|
||||
* @param numClusters number of cluster to cluster vector into
|
||||
* @param assignCentroidsToVectors if {@code true} assign centroids for all vectors. Centroids are
|
||||
* computed on a sample of vectors. If this parameter is {@code true}, in results also return
|
||||
* for all vectors what centroids they belong to.
|
||||
* @param seed random seed
|
||||
* @param initializationMethod Kmeans initialization method
|
||||
* @param initCentroids initial centroids, if not {@code null} utilize as initial centroids for
|
||||
* the given initialization method
|
||||
* @param normalizeCenters for cosine distance, set to true, to use spherical k-means where
|
||||
* centers are normalized
|
||||
* @param restarts how many times to run Kmeans algorithm
|
||||
* @param iters how many iterations to do within a single run
|
||||
* @param sampleSize sample size to select from all vectors on which to run Kmeans algorithm
|
||||
* @return results of clustering: produced centroids and if {@code assignCentroidsToVectors ==
|
||||
* true} also for each vector its centroid
|
||||
* @throws IOException if there is error accessing vectors
|
||||
*/
|
||||
static Results cluster(
|
||||
FloatVectorValues vectors,
|
||||
int numClusters,
|
||||
boolean assignCentroidsToVectors,
|
||||
long seed,
|
||||
KmeansInitializationMethod initializationMethod,
|
||||
float[][] initCentroids,
|
||||
boolean normalizeCenters,
|
||||
int restarts,
|
||||
int iters,
|
||||
int sampleSize
|
||||
) throws IOException {
|
||||
if (vectors.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
// adjust sampleSize and numClusters
|
||||
sampleSize = Math.max(sampleSize, 100 * numClusters);
|
||||
if (sampleSize > vectors.size()) {
|
||||
sampleSize = vectors.size();
|
||||
// Decrease the number of clusters if needed
|
||||
int maxNumClusters = Math.max(1, sampleSize / 100);
|
||||
numClusters = Math.min(numClusters, maxNumClusters);
|
||||
}
|
||||
|
||||
Random random = new Random(seed);
|
||||
float[][] centroids;
|
||||
if (numClusters == 1) {
|
||||
centroids = new float[1][vectors.dimension()];
|
||||
for (int i = 0; i < vectors.size(); i++) {
|
||||
float[] vector = vectors.vectorValue(i);
|
||||
for (int dim = 0; dim < vector.length; dim++) {
|
||||
centroids[0][dim] += vector[dim];
|
||||
}
|
||||
}
|
||||
for (int dim = 0; dim < centroids[0].length; dim++) {
|
||||
centroids[0][dim] /= vectors.size();
|
||||
}
|
||||
} else {
|
||||
FloatVectorValues sampleVectors = vectors.size() <= sampleSize ? vectors : createSampleReader(vectors, sampleSize, seed);
|
||||
KMeans kmeans = new KMeans(sampleVectors, numClusters, random, initializationMethod, initCentroids, restarts, iters);
|
||||
centroids = kmeans.computeCentroids(normalizeCenters);
|
||||
}
|
||||
|
||||
int[] vectorCentroids = null;
|
||||
int[] centroidSize = null;
|
||||
// Assign each vector to the nearest centroid and update the centres
|
||||
if (assignCentroidsToVectors) {
|
||||
vectorCentroids = new int[vectors.size()];
|
||||
centroidSize = new int[centroids.length];
|
||||
assignCentroids(random, vectorCentroids, centroidSize, vectors, centroids);
|
||||
}
|
||||
if (normalizeCenters) {
|
||||
for (float[] centroid : centroids) {
|
||||
VectorUtil.l2normalize(centroid, false);
|
||||
}
|
||||
}
|
||||
return new Results(centroids, centroidSize, vectorCentroids);
|
||||
}
|
||||
|
||||
private static void assignCentroids(
|
||||
Random random,
|
||||
int[] docCentroids,
|
||||
int[] centroidSize,
|
||||
FloatVectorValues vectors,
|
||||
float[][] centroids
|
||||
) throws IOException {
|
||||
short numCentroids = (short) centroids.length;
|
||||
assert Arrays.stream(centroidSize).allMatch(size -> size == 0);
|
||||
for (int docID = 0; docID < vectors.size(); docID++) {
|
||||
float[] vector = vectors.vectorValue(docID);
|
||||
short bestCentroid = 0;
|
||||
if (numCentroids > 1) {
|
||||
float minSquaredDist = Float.MAX_VALUE;
|
||||
for (short c = 0; c < numCentroids; c++) {
|
||||
// TODO: replace with RandomVectorScorer::score possible on quantized vectors
|
||||
float squareDist = VectorUtil.squareDistance(centroids[c], vector);
|
||||
if (squareDist < minSquaredDist) {
|
||||
bestCentroid = c;
|
||||
minSquaredDist = squareDist;
|
||||
}
|
||||
}
|
||||
}
|
||||
centroidSize[bestCentroid] += 1;
|
||||
docCentroids[docID] = bestCentroid;
|
||||
}
|
||||
|
||||
IntArrayList unassignedCentroids = new IntArrayList();
|
||||
for (int c = 0; c < numCentroids; c++) {
|
||||
if (centroidSize[c] == 0) {
|
||||
unassignedCentroids.add(c);
|
||||
}
|
||||
}
|
||||
if (unassignedCentroids.size() > 0) {
|
||||
throwAwayAndSplitCentroids(random, vectors, centroids, docCentroids, centroidSize, unassignedCentroids);
|
||||
}
|
||||
assert Arrays.stream(centroidSize).sum() == vectors.size();
|
||||
}
|
||||
|
||||
private final float[] kmeansPlusPlusScratch;
|
||||
|
||||
KMeans(
|
||||
FloatVectorValues vectors,
|
||||
int numCentroids,
|
||||
Random random,
|
||||
KmeansInitializationMethod initializationMethod,
|
||||
float[][] initCentroids,
|
||||
int restarts,
|
||||
int iters
|
||||
) {
|
||||
this.vectors = vectors;
|
||||
this.numVectors = vectors.size();
|
||||
this.numCentroids = numCentroids;
|
||||
this.random = random;
|
||||
this.initializationMethod = initializationMethod;
|
||||
this.restarts = restarts;
|
||||
this.iters = iters;
|
||||
this.initCentroids = initCentroids;
|
||||
this.kmeansPlusPlusScratch = initializationMethod == KmeansInitializationMethod.PLUS_PLUS ? new float[numVectors] : null;
|
||||
}
|
||||
|
||||
float[][] computeCentroids(boolean normalizeCenters) throws IOException {
|
||||
// TODO can we make this off-heap, or reusable? This could be a big array
|
||||
int[] vectorCentroids = new int[numVectors];
|
||||
double minSquaredDist = Double.MAX_VALUE;
|
||||
double squaredDist = 0;
|
||||
float[][] bestCentroids = null;
|
||||
float[][] centroids = new float[numCentroids][vectors.dimension()];
|
||||
int restarts = this.restarts;
|
||||
int numInitializedCentroids = 0;
|
||||
// The user has given us a solid number of centroids to start of with, so skip restarts, fill in
|
||||
// where we can, and refine
|
||||
if (initCentroids != null && initCentroids.length > numCentroids / 2) {
|
||||
int i = 0;
|
||||
for (; i < Math.min(numCentroids, initCentroids.length); i++) {
|
||||
System.arraycopy(initCentroids[i], 0, centroids[i], 0, initCentroids[i].length);
|
||||
}
|
||||
numInitializedCentroids = i;
|
||||
restarts = 1;
|
||||
}
|
||||
|
||||
for (int restart = 0; restart < restarts; restart++) {
|
||||
switch (initializationMethod) {
|
||||
case FORGY -> initializeForgy(centroids, numInitializedCentroids);
|
||||
case RESERVOIR_SAMPLING -> initializeReservoirSampling(centroids, numInitializedCentroids);
|
||||
case PLUS_PLUS -> initializePlusPlus(centroids, numInitializedCentroids);
|
||||
}
|
||||
double prevSquaredDist = Double.MAX_VALUE;
|
||||
int[] centroidSize = new int[centroids.length];
|
||||
for (int iter = 0; iter < iters; iter++) {
|
||||
squaredDist = runKMeansStep(centroids, centroidSize, vectorCentroids, normalizeCenters);
|
||||
// Check for convergence
|
||||
if (prevSquaredDist <= (squaredDist + 1e-6)) {
|
||||
break;
|
||||
}
|
||||
Arrays.fill(centroidSize, 0);
|
||||
prevSquaredDist = squaredDist;
|
||||
}
|
||||
if (squaredDist < minSquaredDist) {
|
||||
minSquaredDist = squaredDist;
|
||||
// Copy out the best centroid as it might be overwritten by the next restart
|
||||
bestCentroids = new float[centroids.length][];
|
||||
for (int i = 0; i < centroids.length; i++) {
|
||||
bestCentroids[i] = ArrayUtil.copyArray(centroids[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return bestCentroids;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize centroids using Forgy method: randomly select numCentroids vectors for initial
|
||||
* centroids
|
||||
*/
|
||||
private void initializeForgy(float[][] initialCentroids, int fromCentroid) throws IOException {
|
||||
if (fromCentroid >= numCentroids) {
|
||||
return;
|
||||
}
|
||||
int numCentroids = this.numCentroids - fromCentroid;
|
||||
Set<Integer> selection = new HashSet<>();
|
||||
while (selection.size() < numCentroids) {
|
||||
selection.add(random.nextInt(numVectors));
|
||||
}
|
||||
int i = 0;
|
||||
for (Integer selectedIdx : selection) {
|
||||
float[] vector = vectors.vectorValue(selectedIdx);
|
||||
System.arraycopy(vector, 0, initialCentroids[fromCentroid + i++], 0, vector.length);
|
||||
}
|
||||
}
|
||||
|
||||
/** Initialize centroids using a reservoir sampling method */
|
||||
private void initializeReservoirSampling(float[][] initialCentroids, int fromCentroid) throws IOException {
|
||||
if (fromCentroid >= numCentroids) {
|
||||
return;
|
||||
}
|
||||
int numCentroids = this.numCentroids - fromCentroid;
|
||||
for (int index = 0; index < numVectors; index++) {
|
||||
float[] vector = vectors.vectorValue(index);
|
||||
if (index < numCentroids) {
|
||||
System.arraycopy(vector, 0, initialCentroids[index + fromCentroid], 0, vector.length);
|
||||
} else if (random.nextDouble() < numCentroids * (1.0 / index)) {
|
||||
int c = random.nextInt(numCentroids);
|
||||
System.arraycopy(vector, 0, initialCentroids[c + fromCentroid], 0, vector.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Initialize centroids using Kmeans++ method */
|
||||
private void initializePlusPlus(float[][] initialCentroids, int fromCentroid) throws IOException {
|
||||
if (fromCentroid >= numCentroids) {
|
||||
return;
|
||||
}
|
||||
// Choose the first centroid uniformly at random
|
||||
int firstIndex = random.nextInt(numVectors);
|
||||
float[] value = vectors.vectorValue(firstIndex);
|
||||
System.arraycopy(value, 0, initialCentroids[fromCentroid], 0, value.length);
|
||||
|
||||
// Store distances of each point to the nearest centroid
|
||||
Arrays.fill(kmeansPlusPlusScratch, Float.MAX_VALUE);
|
||||
|
||||
// Step 2 and 3: Select remaining centroids
|
||||
for (int i = fromCentroid + 1; i < numCentroids; i++) {
|
||||
// Update distances with the new centroid
|
||||
double totalSum = 0;
|
||||
for (int j = 0; j < numVectors; j++) {
|
||||
// TODO: replace with RandomVectorScorer::score possible on quantized vectors
|
||||
float dist = VectorUtil.squareDistance(vectors.vectorValue(j), initialCentroids[i - 1]);
|
||||
if (dist < kmeansPlusPlusScratch[j]) {
|
||||
kmeansPlusPlusScratch[j] = dist;
|
||||
}
|
||||
totalSum += kmeansPlusPlusScratch[j];
|
||||
}
|
||||
|
||||
// Randomly select next centroid
|
||||
double r = totalSum * random.nextDouble();
|
||||
double cumulativeSum = 0;
|
||||
int nextCentroidIndex = 0;
|
||||
for (int j = 0; j < numVectors; j++) {
|
||||
cumulativeSum += kmeansPlusPlusScratch[j];
|
||||
if (cumulativeSum >= r && kmeansPlusPlusScratch[j] > 0) {
|
||||
nextCentroidIndex = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Update centroid
|
||||
value = vectors.vectorValue(nextCentroidIndex);
|
||||
System.arraycopy(value, 0, initialCentroids[i], 0, value.length);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run kmeans step
|
||||
*
|
||||
* @param centroids centroids, new calculated centroids are written here
|
||||
* @param docCentroids for each document which centroid it belongs to, results will be written
|
||||
* here
|
||||
* @param normalizeCentroids if centroids should be normalized; used for cosine similarity only
|
||||
* @throws IOException if there is an error accessing vector values
|
||||
*/
|
||||
private double runKMeansStep(float[][] centroids, int[] centroidSize, int[] docCentroids, boolean normalizeCentroids)
|
||||
throws IOException {
|
||||
short numCentroids = (short) centroids.length;
|
||||
assert Arrays.stream(centroidSize).allMatch(size -> size == 0);
|
||||
float[][] newCentroids = new float[numCentroids][centroids[0].length];
|
||||
|
||||
double sumSquaredDist = 0;
|
||||
for (int docID = 0; docID < vectors.size(); docID++) {
|
||||
float[] vector = vectors.vectorValue(docID);
|
||||
short bestCentroid = 0;
|
||||
if (numCentroids > 1) {
|
||||
float minSquaredDist = Float.MAX_VALUE;
|
||||
for (short c = 0; c < numCentroids; c++) {
|
||||
// TODO: replace with RandomVectorScorer::score possible on quantized vectors
|
||||
float squareDist = VectorUtil.squareDistance(centroids[c], vector);
|
||||
if (squareDist < minSquaredDist) {
|
||||
bestCentroid = c;
|
||||
minSquaredDist = squareDist;
|
||||
}
|
||||
}
|
||||
sumSquaredDist += minSquaredDist;
|
||||
}
|
||||
|
||||
centroidSize[bestCentroid] += 1;
|
||||
for (int dim = 0; dim < vector.length; dim++) {
|
||||
newCentroids[bestCentroid][dim] += vector[dim];
|
||||
}
|
||||
docCentroids[docID] = bestCentroid;
|
||||
}
|
||||
|
||||
IntArrayList unassignedCentroids = new IntArrayList();
|
||||
for (int c = 0; c < numCentroids; c++) {
|
||||
if (centroidSize[c] > 0) {
|
||||
for (int dim = 0; dim < newCentroids[c].length; dim++) {
|
||||
centroids[c][dim] = newCentroids[c][dim] / centroidSize[c];
|
||||
}
|
||||
} else {
|
||||
unassignedCentroids.add(c);
|
||||
}
|
||||
}
|
||||
if (unassignedCentroids.size() > 0) {
|
||||
throwAwayAndSplitCentroids(random, vectors, centroids, docCentroids, centroidSize, unassignedCentroids);
|
||||
}
|
||||
if (normalizeCentroids) {
|
||||
for (float[] centroid : centroids) {
|
||||
VectorUtil.l2normalize(centroid, false);
|
||||
}
|
||||
}
|
||||
assert Arrays.stream(centroidSize).sum() == vectors.size();
|
||||
return sumSquaredDist;
|
||||
}
|
||||
|
||||
static void throwAwayAndSplitCentroids(
|
||||
Random random,
|
||||
FloatVectorValues vectors,
|
||||
float[][] centroids,
|
||||
int[] docCentroids,
|
||||
int[] centroidSize,
|
||||
IntArrayList unassignedCentroidsIdxs
|
||||
) throws IOException {
|
||||
IntObjectHashMap<IntArrayList> splitCentroids = new IntObjectHashMap<>(unassignedCentroidsIdxs.size());
|
||||
// used for splitting logic
|
||||
int[] splitSizes = Arrays.copyOf(centroidSize, centroidSize.length);
|
||||
// FAISS style algorithm for splitting
|
||||
for (int i = 0; i < unassignedCentroidsIdxs.size(); i++) {
|
||||
int toSplit;
|
||||
for (toSplit = 0; true; toSplit = (toSplit + 1) % centroids.length) {
|
||||
/* probability to pick this cluster for split */
|
||||
double p = (splitSizes[toSplit] - 1.0) / (float) (docCentroids.length - centroids.length);
|
||||
float r = random.nextFloat();
|
||||
if (r < p) {
|
||||
break; /* found our cluster to be split */
|
||||
}
|
||||
}
|
||||
int unassignedCentroidIdx = unassignedCentroidsIdxs.get(i);
|
||||
// keep track of those that are split, this way we reassign docCentroids and fix up true size
|
||||
// & centroids
|
||||
splitCentroids.getOrDefault(toSplit, new IntArrayList()).add(unassignedCentroidIdx);
|
||||
System.arraycopy(centroids[toSplit], 0, centroids[unassignedCentroidIdx], 0, centroids[unassignedCentroidIdx].length);
|
||||
for (int dim = 0; dim < centroids[unassignedCentroidIdx].length; dim++) {
|
||||
if (dim % 2 == 0) {
|
||||
centroids[unassignedCentroidIdx][dim] *= (1 + EPS);
|
||||
centroids[toSplit][dim] *= (1 - EPS);
|
||||
} else {
|
||||
centroids[unassignedCentroidIdx][dim] *= (1 - EPS);
|
||||
centroids[toSplit][dim] *= (1 + EPS);
|
||||
}
|
||||
}
|
||||
splitSizes[unassignedCentroidIdx] = splitSizes[toSplit] / 2;
|
||||
splitSizes[toSplit] -= splitSizes[unassignedCentroidIdx];
|
||||
}
|
||||
// now we need to reassign docCentroids and fix up true size & centroids
|
||||
for (int i = 0; i < docCentroids.length; i++) {
|
||||
int docCentroid = docCentroids[i];
|
||||
IntArrayList split = splitCentroids.get(docCentroid);
|
||||
if (split != null) {
|
||||
// we need to reassign this doc
|
||||
int bestCentroid = docCentroid;
|
||||
float bestDist = VectorUtil.squareDistance(centroids[docCentroid], vectors.vectorValue(i));
|
||||
for (int j = 0; j < split.size(); j++) {
|
||||
int newCentroid = split.get(j);
|
||||
float dist = VectorUtil.squareDistance(centroids[newCentroid], vectors.vectorValue(i));
|
||||
if (dist < bestDist) {
|
||||
bestCentroid = newCentroid;
|
||||
bestDist = dist;
|
||||
}
|
||||
}
|
||||
if (bestCentroid != docCentroid) {
|
||||
// we need to update the centroid size
|
||||
centroidSize[docCentroid]--;
|
||||
centroidSize[bestCentroid]++;
|
||||
docCentroids[i] = (short) bestCentroid;
|
||||
// we need to update the old and new centroid accounting for size as well
|
||||
for (int dim = 0; dim < centroids[docCentroid].length; dim++) {
|
||||
centroids[docCentroid][dim] -= vectors.vectorValue(i)[dim] / centroidSize[docCentroid];
|
||||
centroids[bestCentroid][dim] += vectors.vectorValue(i)[dim] / centroidSize[bestCentroid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Kmeans initialization methods */
|
||||
public enum KmeansInitializationMethod {
|
||||
FORGY,
|
||||
RESERVOIR_SAMPLING,
|
||||
PLUS_PLUS
|
||||
}
|
||||
|
||||
/**
|
||||
* Results of KMeans clustering
|
||||
*
|
||||
* @param centroids the produced centroids
|
||||
* @param centroidsSize for each centroid how many vectors belong to it
|
||||
* @param vectorCentroids for each vector which centroid it belongs to
|
||||
*/
|
||||
public record Results(float[][] centroids, int[] centroidsSize, int[] vectorCentroids) {}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors.cluster;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
class FloatVectorValuesSlice extends FloatVectorValues {
|
||||
|
||||
private final FloatVectorValues allValues;
|
||||
private final int[] slice;
|
||||
|
||||
FloatVectorValuesSlice(FloatVectorValues allValues, int[] slice) {
|
||||
assert slice != null;
|
||||
assert slice.length <= allValues.size();
|
||||
this.allValues = allValues;
|
||||
this.slice = slice;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
return this.allValues.vectorValue(this.slice[ord]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return this.allValues.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return slice.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return this.slice[ord];
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues copy() throws IOException {
|
||||
return new FloatVectorValuesSlice(this.allValues.copy(), this.slice);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors.cluster;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means
|
||||
*/
|
||||
public class HierarchicalKMeans {
|
||||
|
||||
static final int MAXK = 128;
|
||||
static final int MAX_ITERATIONS_DEFAULT = 6;
|
||||
static final int SAMPLES_PER_CLUSTER_DEFAULT = 256;
|
||||
static final float DEFAULT_SOAR_LAMBDA = 1.0f;
|
||||
|
||||
final int dimension;
|
||||
final int maxIterations;
|
||||
final int samplesPerCluster;
|
||||
final int clustersPerNeighborhood;
|
||||
final float soarLambda;
|
||||
|
||||
public HierarchicalKMeans(int dimension) {
|
||||
this(dimension, MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, MAXK, DEFAULT_SOAR_LAMBDA);
|
||||
}
|
||||
|
||||
HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) {
|
||||
this.dimension = dimension;
|
||||
this.maxIterations = maxIterations;
|
||||
this.samplesPerCluster = samplesPerCluster;
|
||||
this.clustersPerNeighborhood = clustersPerNeighborhood;
|
||||
this.soarLambda = soarLambda;
|
||||
}
|
||||
|
||||
/**
|
||||
* clusters or moreso partitions the set of vectors by starting with a rough number of partitions and then recursively refining those
|
||||
* lastly a pass is made to adjust nearby neighborhoods and add an extra assignment per vector to nearby neighborhoods
|
||||
*
|
||||
* @param vectors the vectors to cluster
|
||||
* @param targetSize the rough number of vectors that should be attached to a cluster
|
||||
* @return the centroids and the vectors assignments and SOAR (spilled from nearby neighborhoods) assignments
|
||||
* @throws IOException is thrown if vectors is inaccessible
|
||||
*/
|
||||
public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException {
|
||||
|
||||
if (vectors.size() == 0) {
|
||||
return new KMeansIntermediate();
|
||||
}
|
||||
|
||||
// if we have a small number of vectors pick one and output that as the centroid
|
||||
if (vectors.size() <= targetSize) {
|
||||
float[] centroid = new float[dimension];
|
||||
System.arraycopy(vectors.vectorValue(0), 0, centroid, 0, dimension);
|
||||
return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()]);
|
||||
}
|
||||
|
||||
// partition the space
|
||||
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
|
||||
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
|
||||
float f = Math.min((float) samplesPerCluster / targetSize, 1.0f);
|
||||
int localSampleSize = (int) (f * vectors.size());
|
||||
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
|
||||
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
|
||||
}
|
||||
|
||||
return kMeansIntermediate;
|
||||
}
|
||||
|
||||
KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException {
|
||||
if (vectors.size() <= targetSize) {
|
||||
return new KMeansIntermediate();
|
||||
}
|
||||
|
||||
int k = Math.clamp((int) ((vectors.size() + targetSize / 2.0f) / (float) targetSize), 2, MAXK);
|
||||
int m = Math.min(k * samplesPerCluster, vectors.size());
|
||||
|
||||
// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
|
||||
int[] assignments = new int[vectors.size()];
|
||||
|
||||
KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
|
||||
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
|
||||
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
|
||||
kmeans.cluster(vectors, kMeansIntermediate);
|
||||
|
||||
// TODO: consider adding cluster size counts to the kmeans algo
|
||||
// handle assignment here so we can track distance and cluster size
|
||||
int[] centroidVectorCount = new int[centroids.length];
|
||||
float[][] nextCentroids = new float[centroids.length][dimension];
|
||||
for (int i = 0; i < vectors.size(); i++) {
|
||||
float smallest = Float.MAX_VALUE;
|
||||
int centroidIdx = -1;
|
||||
float[] vector = vectors.vectorValue(i);
|
||||
for (int j = 0; j < centroids.length; j++) {
|
||||
float[] centroid = centroids[j];
|
||||
float d = VectorUtil.squareDistance(vector, centroid);
|
||||
if (d < smallest) {
|
||||
smallest = d;
|
||||
centroidIdx = j;
|
||||
}
|
||||
}
|
||||
centroidVectorCount[centroidIdx]++;
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
nextCentroids[centroidIdx][j] += vector[j];
|
||||
}
|
||||
assignments[i] = centroidIdx;
|
||||
}
|
||||
|
||||
// update centroids based on assignments of all vectors
|
||||
for (int i = 0; i < centroids.length; i++) {
|
||||
if (centroidVectorCount[i] > 0) {
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int effectiveK = 0;
|
||||
for (int i = 0; i < centroidVectorCount.length; i++) {
|
||||
if (centroidVectorCount[i] > 0) {
|
||||
effectiveK++;
|
||||
}
|
||||
}
|
||||
|
||||
kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
|
||||
|
||||
if (effectiveK == 1) {
|
||||
return kMeansIntermediate;
|
||||
}
|
||||
|
||||
for (int c = 0; c < centroidVectorCount.length; c++) {
|
||||
// Recurse for each cluster which is larger than targetSize
|
||||
// Give ourselves 30% margin for the target size
|
||||
if (100 * centroidVectorCount[c] > 134 * targetSize) {
|
||||
FloatVectorValues sample = createClusterSlice(centroidVectorCount[c], c, vectors, assignments);
|
||||
|
||||
// TODO: consider iterative here instead of recursive
|
||||
// recursive call to build out the sub partitions around this centroid c
|
||||
// subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
|
||||
updateAssignmentsWithRecursiveSplit(kMeansIntermediate, c, clusterAndSplit(sample, targetSize));
|
||||
}
|
||||
}
|
||||
|
||||
return kMeansIntermediate;
|
||||
}
|
||||
|
||||
static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) {
|
||||
int[] slice = new int[clusterSize];
|
||||
int idx = 0;
|
||||
for (int i = 0; i < assignments.length; i++) {
|
||||
if (assignments[i] == cluster) {
|
||||
slice[idx] = i;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
return new FloatVectorValuesSlice(vectors, slice);
|
||||
}
|
||||
|
||||
void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
|
||||
int orgCentroidsSize = current.centroids().length;
|
||||
int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;
|
||||
|
||||
// update based on the outcomes from the split clusters recursion
|
||||
if (subPartitions.centroids().length > 1) {
|
||||
float[][] newCentroids = new float[newCentroidsSize][dimension];
|
||||
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
|
||||
|
||||
// replace the original cluster
|
||||
int origCentroidOrd = 0;
|
||||
newCentroids[cluster] = subPartitions.centroids()[0];
|
||||
|
||||
// append the remainder
|
||||
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
|
||||
|
||||
current.setCentroids(newCentroids);
|
||||
|
||||
for (int i = 0; i < subPartitions.assignments().length; i++) {
|
||||
// this is a new centroid that was added, and so we'll need to remap it
|
||||
if (subPartitions.assignments()[i] != origCentroidOrd) {
|
||||
int parentOrd = subPartitions.ordToDoc(i);
|
||||
assert current.assignments()[parentOrd] == cluster;
|
||||
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors.cluster;
|
||||
|
||||
import org.apache.lucene.util.hnsw.IntToIntFunction;
|
||||
|
||||
/**
|
||||
* Intermediate object for clustering (partitioning) a set of vectors
|
||||
*/
|
||||
class KMeansIntermediate extends KMeansResult {
|
||||
private final IntToIntFunction assignmentOrds;
|
||||
|
||||
private KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrds, int[] soarAssignments) {
|
||||
super(centroids, assignments, soarAssignments);
|
||||
assert assignmentOrds != null;
|
||||
this.assignmentOrds = assignmentOrds;
|
||||
}
|
||||
|
||||
KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrdinals) {
|
||||
this(centroids, assignments, assignmentOrdinals, new int[0]);
|
||||
}
|
||||
|
||||
KMeansIntermediate() {
|
||||
this(new float[0][0], new int[0], i -> i, new int[0]);
|
||||
}
|
||||
|
||||
KMeansIntermediate(float[][] centroids) {
|
||||
this(centroids, new int[0], i -> i, new int[0]);
|
||||
}
|
||||
|
||||
KMeansIntermediate(float[][] centroids, int[] assignments) {
|
||||
this(centroids, assignments, i -> i, new int[0]);
|
||||
}
|
||||
|
||||
public int ordToDoc(int ord) {
|
||||
return assignmentOrds.apply(ord);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,306 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors.cluster;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* k-means implementation specific to the needs of the {@link HierarchicalKMeans} algorithm that deals specifically
|
||||
* with finalizing nearby pre-established clusters and generate
|
||||
* <a href="https://research.google/blog/soar-new-algorithms-for-even-faster-vector-search-with-scann/">SOAR</a> assignments
|
||||
*/
|
||||
class KMeansLocal {
|
||||
|
||||
final int sampleSize;
|
||||
final int maxIterations;
|
||||
final int clustersPerNeighborhood;
|
||||
final float soarLambda;
|
||||
|
||||
KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood, float soarLambda) {
|
||||
this.sampleSize = sampleSize;
|
||||
this.maxIterations = maxIterations;
|
||||
this.clustersPerNeighborhood = clustersPerNeighborhood;
|
||||
this.soarLambda = soarLambda;
|
||||
}
|
||||
|
||||
KMeansLocal(int sampleSize, int maxIterations) {
|
||||
this(sampleSize, maxIterations, -1, -1f);
|
||||
}
|
||||
|
||||
/**
|
||||
* uses a Reservoir Sampling approach to picking the initial centroids which are subsequently expected
|
||||
* to be used by a clustering algorithm
|
||||
*
|
||||
* @param vectors used to pick an initial set of random centroids
|
||||
* @param centroidCount the total number of centroids to pick
|
||||
* @return randomly selected centroids that are the min of centroidCount and sampleSize
|
||||
* @throws IOException is thrown if vectors is inaccessible
|
||||
*/
|
||||
static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCount) throws IOException {
|
||||
Random random = new Random(42L);
|
||||
int centroidsSize = Math.min(vectors.size(), centroidCount);
|
||||
float[][] centroids = new float[centroidsSize][vectors.dimension()];
|
||||
for (int i = 0; i < vectors.size(); i++) {
|
||||
float[] vector;
|
||||
if (i < centroidCount) {
|
||||
vector = vectors.vectorValue(i);
|
||||
System.arraycopy(vector, 0, centroids[i], 0, vector.length);
|
||||
} else if (random.nextDouble() < centroidCount * (1.0 / i)) {
|
||||
int c = random.nextInt(centroidCount);
|
||||
vector = vectors.vectorValue(i);
|
||||
System.arraycopy(vector, 0, centroids[c], 0, vector.length);
|
||||
}
|
||||
}
|
||||
return centroids;
|
||||
}
|
||||
|
||||
private boolean stepLloyd(
|
||||
FloatVectorValues vectors,
|
||||
float[][] centroids,
|
||||
float[][] nextCentroids,
|
||||
int[] assignments,
|
||||
int sampleSize,
|
||||
List<int[]> neighborhoods
|
||||
) throws IOException {
|
||||
boolean changed = false;
|
||||
int dim = vectors.dimension();
|
||||
int[] centroidCounts = new int[centroids.length];
|
||||
|
||||
for (int i = 0; i < nextCentroids.length; i++) {
|
||||
Arrays.fill(nextCentroids[i], 0.0f);
|
||||
}
|
||||
|
||||
for (int i = 0; i < sampleSize; i++) {
|
||||
float[] vector = vectors.vectorValue(i);
|
||||
int[] neighborOffsets = null;
|
||||
int centroidIdx = -1;
|
||||
if (neighborhoods != null) {
|
||||
neighborOffsets = neighborhoods.get(assignments[i]);
|
||||
centroidIdx = assignments[i];
|
||||
}
|
||||
int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets);
|
||||
if (assignments[i] != bestCentroidOffset) {
|
||||
changed = true;
|
||||
}
|
||||
assignments[i] = bestCentroidOffset;
|
||||
centroidCounts[bestCentroidOffset]++;
|
||||
for (short d = 0; d < dim; d++) {
|
||||
nextCentroids[bestCentroidOffset][d] += vector[d];
|
||||
}
|
||||
}
|
||||
|
||||
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
|
||||
if (centroidCounts[clusterIdx] > 0) {
|
||||
float countF = (float) centroidCounts[clusterIdx];
|
||||
for (short d = 0; d < dim; d++) {
|
||||
centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
|
||||
int bestCentroidOffset = centroidIdx;
|
||||
float minDsq;
|
||||
if (centroidIdx > 0 && centroidIdx < centroids.length) {
|
||||
minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
|
||||
} else {
|
||||
minDsq = Float.MAX_VALUE;
|
||||
}
|
||||
|
||||
int k = 0;
|
||||
for (int j = 0; j < centroids.length; j++) {
|
||||
if (centroidOffsets == null || j == centroidOffsets[k]) {
|
||||
float dsq = VectorUtil.squareDistance(vector, centroids[j]);
|
||||
if (dsq < minDsq) {
|
||||
minDsq = dsq;
|
||||
bestCentroidOffset = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
return bestCentroidOffset;
|
||||
}
|
||||
|
||||
private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods, int clustersPerNeighborhood) {
|
||||
int k = neighborhoods.size();
|
||||
|
||||
if (k == 0 || clustersPerNeighborhood <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<NeighborQueue> neighborQueues = new ArrayList<>(k);
|
||||
for (int i = 0; i < k; i++) {
|
||||
neighborQueues.add(new NeighborQueue(clustersPerNeighborhood, true));
|
||||
}
|
||||
for (int i = 0; i < k - 1; i++) {
|
||||
for (int j = i + 1; j < k; j++) {
|
||||
float dsq = VectorUtil.squareDistance(centers[i], centers[j]);
|
||||
neighborQueues.get(j).insertWithOverflow(i, dsq);
|
||||
neighborQueues.get(i).insertWithOverflow(j, dsq);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < k; i++) {
|
||||
NeighborQueue queue = neighborQueues.get(i);
|
||||
int neighborCount = queue.size();
|
||||
int[] neighbors = new int[neighborCount];
|
||||
queue.consumeNodes(neighbors);
|
||||
neighborhoods.set(i, neighbors);
|
||||
}
|
||||
}
|
||||
|
||||
private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, int[] assignments)
|
||||
throws IOException {
|
||||
// SOAR uses an adjusted distance for assigning spilled documents which is
|
||||
// given by:
|
||||
//
|
||||
// soar(x, c) = ||x - c||^2 + lambda * ((x - c_1)^t (x - c))^2 / ||x - c_1||^2
|
||||
//
|
||||
// Here, x is the document, c is the nearest centroid, and c_1 is the first
|
||||
// centroid the document was assigned to. The document is assigned to the
|
||||
// cluster with the smallest soar(x, c).
|
||||
|
||||
int[] spilledAssignments = new int[assignments.length];
|
||||
|
||||
float[] diffs = new float[vectors.dimension()];
|
||||
for (int i = 0; i < vectors.size(); i++) {
|
||||
float[] vector = vectors.vectorValue(i);
|
||||
|
||||
int currAssignment = assignments[i];
|
||||
float[] currentCentroid = centroids[currAssignment];
|
||||
for (short j = 0; j < vectors.dimension(); j++) {
|
||||
float diff = vector[j] - currentCentroid[j];
|
||||
diffs[j] = diff;
|
||||
}
|
||||
|
||||
// TODO: cache these?
|
||||
// float vectorCentroidDist = assignmentDistances[i];
|
||||
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);
|
||||
|
||||
int bestAssignment = -1;
|
||||
float minSoar = Float.MAX_VALUE;
|
||||
assert neighborhoods.get(currAssignment) != null;
|
||||
for (int neighbor : neighborhoods.get(currAssignment)) {
|
||||
if (neighbor == currAssignment) {
|
||||
continue;
|
||||
}
|
||||
float[] neighborCentroid = centroids[neighbor];
|
||||
float soar = distanceSoar(diffs, vector, neighborCentroid, vectorCentroidDist);
|
||||
if (soar < minSoar) {
|
||||
bestAssignment = neighbor;
|
||||
minSoar = soar;
|
||||
}
|
||||
}
|
||||
|
||||
spilledAssignments[i] = bestAssignment;
|
||||
}
|
||||
|
||||
return spilledAssignments;
|
||||
}
|
||||
|
||||
private float distanceSoar(float[] residual, float[] vector, float[] centroid, float rnorm) {
|
||||
// TODO: combine these to be more efficient
|
||||
float dsq = VectorUtil.squareDistance(vector, centroid);
|
||||
float rproj = ESVectorUtil.soarResidual(vector, centroid, residual);
|
||||
return dsq + soarLambda * rproj * rproj / rnorm;
|
||||
}
|
||||
|
||||
/**
|
||||
* cluster using a lloyd k-means algorithm that is not neighbor aware
|
||||
*
|
||||
* @param vectors the vectors to cluster
|
||||
* @param kMeansIntermediate the output object to populate which minimally includes centroids,
|
||||
* but may include assignments and soar assignments as well; care should be taken in
|
||||
* passing in a valid output object with a centroids array that is the size of centroids expected
|
||||
* @throws IOException is thrown if vectors is inaccessible
|
||||
*/
|
||||
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
|
||||
cluster(vectors, kMeansIntermediate, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* cluster using a lloyd kmeans algorithm that also considers prior clustered neighborhoods when adjusting centroids
|
||||
* this also is used to generate the neighborhood aware additional (SOAR) assignments
|
||||
*
|
||||
* @param vectors the vectors to cluster
|
||||
* @param kMeansIntermediate the output object to populate which minimally includes centroids,
|
||||
* the prior assignments of the given vectors; care should be taken in
|
||||
* passing in a valid output object with a centroids array that is the size of centroids expected
|
||||
* and assignments that are the same size as the vectors. The SOAR assignments are overwritten by this operation.
|
||||
* @param neighborAware whether nearby neighboring centroids and their vectors should be used to update the centroid positions,
|
||||
* implies SOAR assignments
|
||||
* @throws IOException is thrown if vectors is inaccessible
|
||||
*/
|
||||
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException {
|
||||
float[][] centroids = kMeansIntermediate.centroids();
|
||||
|
||||
List<int[]> neighborhoods = null;
|
||||
if (neighborAware) {
|
||||
int k = centroids.length;
|
||||
neighborhoods = new ArrayList<>(k);
|
||||
for (int i = 0; i < k; ++i) {
|
||||
neighborhoods.add(null);
|
||||
}
|
||||
computeNeighborhoods(centroids, neighborhoods, clustersPerNeighborhood);
|
||||
}
|
||||
cluster(vectors, kMeansIntermediate, neighborhoods);
|
||||
if (neighborAware && clustersPerNeighborhood > 0) {
|
||||
int[] assignments = kMeansIntermediate.assignments();
|
||||
assert assignments != null;
|
||||
assert assignments.length == vectors.size();
|
||||
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments));
|
||||
}
|
||||
}
|
||||
|
||||
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
|
||||
float[][] centroids = kMeansIntermediate.centroids();
|
||||
int k = centroids.length;
|
||||
int n = vectors.size();
|
||||
|
||||
if (k == 1 || k >= n) {
|
||||
return;
|
||||
}
|
||||
|
||||
int[] assignments = new int[n];
|
||||
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
|
||||
for (int i = 0; i < maxIterations; i++) {
|
||||
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
stepLloyd(vectors, centroids, nextCentroids, assignments, vectors.size(), neighborhoods);
|
||||
}
|
||||
|
||||
/**
|
||||
* helper that calls {@link KMeansLocal#cluster(FloatVectorValues, KMeansIntermediate)} given a set of initialized centroids,
|
||||
* this call is not neighbor aware
|
||||
*
|
||||
* @param vectors the vectors to cluster
|
||||
* @param centroids the initialized centroids to be shifted using k-means
|
||||
* @param sampleSize the subset of vectors to use when shifting centroids
|
||||
* @param maxIterations the max iterations to shift centroids
|
||||
*/
|
||||
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
|
||||
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
|
||||
KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
|
||||
kMeans.cluster(vectors, kMeansIntermediate);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors.cluster;
|
||||
|
||||
/**
|
||||
* Output object for clustering (partitioning) a set of vectors
|
||||
*/
|
||||
public class KMeansResult {
|
||||
private float[][] centroids;
|
||||
private final int[] assignments;
|
||||
private int[] soarAssignments;
|
||||
|
||||
KMeansResult(float[][] centroids, int[] assignments, int[] soarAssignments) {
|
||||
assert centroids != null;
|
||||
assert assignments != null;
|
||||
assert soarAssignments != null;
|
||||
this.centroids = centroids;
|
||||
this.assignments = assignments;
|
||||
this.soarAssignments = soarAssignments;
|
||||
}
|
||||
|
||||
public float[][] centroids() {
|
||||
return centroids;
|
||||
}
|
||||
|
||||
void setCentroids(float[][] centroids) {
|
||||
this.centroids = centroids;
|
||||
}
|
||||
|
||||
public int[] assignments() {
|
||||
return assignments;
|
||||
}
|
||||
|
||||
void setSoarAssignments(int[] soarAssignments) {
|
||||
this.soarAssignments = soarAssignments;
|
||||
}
|
||||
|
||||
public int[] soarAssignments() {
|
||||
return soarAssignments;
|
||||
}
|
||||
}
|
|
@ -14,10 +14,10 @@
|
|||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
* Modifications copyright (C) 2025 Elasticsearch B.V.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
package org.elasticsearch.index.codec.vectors.cluster;
|
||||
|
||||
import org.apache.lucene.util.LongHeap;
|
||||
import org.apache.lucene.util.NumericUtils;
|
|
@ -1,154 +0,0 @@
|
|||
/*
|
||||
* @notice
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
* Modifications copyright (C) 2025 Elasticsearch B.V.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class KMeansTests extends ESTestCase {
|
||||
|
||||
public void testKMeansAPI() throws IOException {
|
||||
int nClusters = random().nextInt(1, 10);
|
||||
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
|
||||
int dims = random().nextInt(2, 20);
|
||||
int randIdx = random().nextInt(VectorSimilarityFunction.values().length);
|
||||
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx];
|
||||
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
|
||||
|
||||
// default case
|
||||
{
|
||||
KMeans.Results results = KMeans.cluster(vectors, similarityFunction, nClusters);
|
||||
assertResults(results, nClusters, nVectors, true);
|
||||
assertEquals(nClusters, results.centroids().length);
|
||||
assertEquals(nClusters, results.centroidsSize().length);
|
||||
assertEquals(nVectors, results.vectorCentroids().length);
|
||||
}
|
||||
// expert case
|
||||
{
|
||||
boolean assignCentroidsToVectors = random().nextBoolean();
|
||||
int randIdx2 = random().nextInt(KMeans.KmeansInitializationMethod.values().length);
|
||||
KMeans.KmeansInitializationMethod initializationMethod = KMeans.KmeansInitializationMethod.values()[randIdx2];
|
||||
int restarts = random().nextInt(1, 6);
|
||||
int iters = random().nextInt(1, 10);
|
||||
int sampleSize = random().nextInt(10, nVectors * 2);
|
||||
|
||||
KMeans.Results results = KMeans.cluster(
|
||||
vectors,
|
||||
nClusters,
|
||||
assignCentroidsToVectors,
|
||||
random().nextLong(),
|
||||
initializationMethod,
|
||||
null,
|
||||
similarityFunction == VectorSimilarityFunction.COSINE,
|
||||
restarts,
|
||||
iters,
|
||||
sampleSize
|
||||
);
|
||||
assertResults(results, nClusters, nVectors, assignCentroidsToVectors);
|
||||
}
|
||||
}
|
||||
|
||||
private void assertResults(KMeans.Results results, int nClusters, int nVectors, boolean assignCentroidsToVectors) {
|
||||
assertEquals(nClusters, results.centroids().length);
|
||||
if (assignCentroidsToVectors) {
|
||||
assertEquals(nClusters, results.centroidsSize().length);
|
||||
assertEquals(nVectors, results.vectorCentroids().length);
|
||||
int[] centroidsSize = new int[nClusters];
|
||||
for (int i = 0; i < nVectors; i++) {
|
||||
centroidsSize[results.vectorCentroids()[i]]++;
|
||||
}
|
||||
assertArrayEquals(centroidsSize, results.centroidsSize());
|
||||
} else {
|
||||
assertNull(results.vectorCentroids());
|
||||
}
|
||||
}
|
||||
|
||||
public void testKMeansSpecialCases() throws IOException {
|
||||
{
|
||||
// nClusters > nVectors
|
||||
int nClusters = 20;
|
||||
int nVectors = 10;
|
||||
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
|
||||
KMeans.Results results = KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters);
|
||||
// assert that we get 1 centroid, as nClusters will be adjusted
|
||||
assertEquals(1, results.centroids().length);
|
||||
assertEquals(nVectors, results.vectorCentroids().length);
|
||||
}
|
||||
{
|
||||
// small sample size
|
||||
int sampleSize = 2;
|
||||
int nClusters = 2;
|
||||
int nVectors = 300;
|
||||
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
|
||||
KMeans.KmeansInitializationMethod initializationMethod = KMeans.KmeansInitializationMethod.PLUS_PLUS;
|
||||
KMeans.Results results = KMeans.cluster(
|
||||
vectors,
|
||||
nClusters,
|
||||
true,
|
||||
random().nextLong(),
|
||||
initializationMethod,
|
||||
null,
|
||||
false,
|
||||
1,
|
||||
2,
|
||||
sampleSize
|
||||
);
|
||||
assertResults(results, nClusters, nVectors, true);
|
||||
}
|
||||
}
|
||||
|
||||
public void testKMeansSAllZero() throws IOException {
|
||||
int nClusters = 10;
|
||||
List<float[]> vectors = new ArrayList<>();
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
float[] vector = new float[5];
|
||||
vectors.add(vector);
|
||||
}
|
||||
KMeans.Results results = KMeans.cluster(FloatVectorValues.fromFloats(vectors, 5), VectorSimilarityFunction.EUCLIDEAN, nClusters);
|
||||
assertResults(results, nClusters, 1000, true);
|
||||
}
|
||||
|
||||
private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) {
|
||||
List<float[]> vectors = new ArrayList<>(nSamples);
|
||||
float[][] centroids = new float[nClusters][nDims];
|
||||
// Generate random centroids
|
||||
for (int i = 0; i < nClusters; i++) {
|
||||
for (int j = 0; j < nDims; j++) {
|
||||
centroids[i][j] = random().nextFloat() * 100;
|
||||
}
|
||||
}
|
||||
// Generate data points around centroids
|
||||
for (int i = 0; i < nSamples; i++) {
|
||||
int cluster = random().nextInt(nClusters);
|
||||
float[] vector = new float[nDims];
|
||||
for (int j = 0; j < nDims; j++) {
|
||||
vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5;
|
||||
}
|
||||
vectors.add(vector);
|
||||
}
|
||||
return FloatVectorValues.fromFloats(vectors, nDims);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
package org.elasticsearch.index.codec.vectors.cluster;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class HierarchicalKMeansTests extends ESTestCase {
|
||||
|
||||
public void testHKmeans() throws IOException {
|
||||
int nClusters = random().nextInt(1, 10);
|
||||
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
|
||||
int dims = random().nextInt(2, 20);
|
||||
int sampleSize = random().nextInt(100, nVectors + 1);
|
||||
int maxIterations = random().nextInt(0, 100);
|
||||
int clustersPerNeighborhood = random().nextInt(0, 512);
|
||||
float soarLambda = random().nextFloat(0.5f, 1.5f);
|
||||
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
|
||||
|
||||
int targetSize = (int) ((float) nVectors / (float) nClusters);
|
||||
HierarchicalKMeans hkmeans = new HierarchicalKMeans(dims, maxIterations, sampleSize, clustersPerNeighborhood, soarLambda);
|
||||
|
||||
KMeansResult result = hkmeans.cluster(vectors, targetSize);
|
||||
|
||||
float[][] centroids = result.centroids();
|
||||
int[] assignments = result.assignments();
|
||||
int[] soarAssignments = result.soarAssignments();
|
||||
|
||||
assertEquals(nClusters, centroids.length, 6);
|
||||
assertEquals(nVectors, assignments.length);
|
||||
if (centroids.length > 1 && clustersPerNeighborhood > 0) {
|
||||
assertEquals(nVectors, soarAssignments.length);
|
||||
// verify no duplicates exist
|
||||
for (int i = 0; i < assignments.length; i++) {
|
||||
assert assignments[i] != soarAssignments[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) {
|
||||
List<float[]> vectors = new ArrayList<>(nSamples);
|
||||
float[][] centroids = new float[nClusters][nDims];
|
||||
// Generate random centroids
|
||||
for (int i = 0; i < nClusters; i++) {
|
||||
for (int j = 0; j < nDims; j++) {
|
||||
centroids[i][j] = random().nextFloat() * 100;
|
||||
}
|
||||
}
|
||||
// Generate data points around centroids
|
||||
for (int i = 0; i < nSamples; i++) {
|
||||
int cluster = random().nextInt(nClusters);
|
||||
float[] vector = new float[nDims];
|
||||
for (int j = 0; j < nDims; j++) {
|
||||
vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5;
|
||||
}
|
||||
vectors.add(vector);
|
||||
}
|
||||
return FloatVectorValues.fromFloats(vectors, nDims);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors.cluster;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class KMeansLocalTests extends ESTestCase {
|
||||
|
||||
public void testKMeansNeighbors() throws IOException {
|
||||
int nClusters = random().nextInt(1, 10);
|
||||
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
|
||||
int dims = random().nextInt(2, 20);
|
||||
int sampleSize = random().nextInt(100, nVectors + 1);
|
||||
int maxIterations = random().nextInt(0, 100);
|
||||
int clustersPerNeighborhood = random().nextInt(0, 512);
|
||||
float soarLambda = random().nextFloat(0.5f, 1.5f);
|
||||
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
|
||||
|
||||
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, nClusters);
|
||||
KMeansLocal.cluster(vectors, centroids, sampleSize, maxIterations);
|
||||
|
||||
int[] assignments = new int[vectors.size()];
|
||||
int[] assignmentOrdinals = new int[vectors.size()];
|
||||
for (int i = 0; i < vectors.size(); i++) {
|
||||
float minDist = Float.MAX_VALUE;
|
||||
int ord = -1;
|
||||
for (int j = 0; j < centroids.length; j++) {
|
||||
float dist = VectorUtil.squareDistance(vectors.vectorValue(i), centroids[j]);
|
||||
if (dist < minDist) {
|
||||
minDist = dist;
|
||||
ord = j;
|
||||
}
|
||||
}
|
||||
assignments[i] = ord;
|
||||
assignmentOrdinals[i] = i;
|
||||
}
|
||||
|
||||
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
|
||||
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
|
||||
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
|
||||
|
||||
assertEquals(nClusters, centroids.length);
|
||||
assertNotNull(kMeansIntermediate.soarAssignments());
|
||||
}
|
||||
|
||||
public void testKMeansNeighborsAllZero() throws IOException {
|
||||
int nClusters = 10;
|
||||
int maxIterations = 10;
|
||||
int clustersPerNeighborhood = 128;
|
||||
float soarLambda = 1.0f;
|
||||
int nVectors = 1000;
|
||||
List<float[]> vectors = new ArrayList<>();
|
||||
for (int i = 0; i < nVectors; i++) {
|
||||
float[] vector = new float[5];
|
||||
vectors.add(vector);
|
||||
}
|
||||
int sampleSize = vectors.size();
|
||||
FloatVectorValues fvv = FloatVectorValues.fromFloats(vectors, 5);
|
||||
|
||||
float[][] centroids = KMeansLocal.pickInitialCentroids(fvv, nClusters);
|
||||
KMeansLocal.cluster(fvv, centroids, sampleSize, maxIterations);
|
||||
|
||||
int[] assignments = new int[vectors.size()];
|
||||
int[] assignmentOrdinals = new int[vectors.size()];
|
||||
for (int i = 0; i < vectors.size(); i++) {
|
||||
float minDist = Float.MAX_VALUE;
|
||||
int ord = -1;
|
||||
for (int j = 0; j < centroids.length; j++) {
|
||||
float dist = VectorUtil.squareDistance(fvv.vectorValue(i), centroids[j]);
|
||||
if (dist < minDist) {
|
||||
minDist = dist;
|
||||
ord = j;
|
||||
}
|
||||
}
|
||||
assignments[i] = ord;
|
||||
assignmentOrdinals[i] = i;
|
||||
}
|
||||
|
||||
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
|
||||
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
|
||||
kMeansLocal.cluster(fvv, kMeansIntermediate, true);
|
||||
|
||||
assertEquals(nClusters, centroids.length);
|
||||
assertNotNull(kMeansIntermediate.soarAssignments());
|
||||
for (float[] centroid : centroids) {
|
||||
for (float v : centroid) {
|
||||
if (v > 0.0000001f) {
|
||||
assertEquals(0.0f, v, 0.00000001f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) {
|
||||
List<float[]> vectors = new ArrayList<>(nSamples);
|
||||
float[][] centroids = new float[nClusters][nDims];
|
||||
// Generate random centroids
|
||||
for (int i = 0; i < nClusters; i++) {
|
||||
for (int j = 0; j < nDims; j++) {
|
||||
centroids[i][j] = random().nextFloat() * 100;
|
||||
}
|
||||
}
|
||||
// Generate data points around centroids
|
||||
for (int i = 0; i < nSamples; i++) {
|
||||
int cluster = random().nextInt(nClusters);
|
||||
float[] vector = new float[nDims];
|
||||
for (int j = 0; j < nDims; j++) {
|
||||
vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5;
|
||||
}
|
||||
vectors.add(vector);
|
||||
}
|
||||
return FloatVectorValues.fromFloats(vectors, nDims);
|
||||
}
|
||||
}
|
|
@ -14,10 +14,10 @@
|
|||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
* Modifications copyright (C) 2025 Elasticsearch B.V.
|
||||
*
|
||||
* Modifications copyright (C) 2024 Elasticsearch B.V.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
package org.elasticsearch.index.codec.vectors.cluster;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
Loading…
Reference in New Issue