Refactor bulk quantization writing into a unified class (#130354)
this is a small refactor, laying ground work for more generalized bulk writing. I did some benchmarking and there was no significant performance difference (as expected).
This commit is contained in:
parent
c1a4f8ae68
commit
044f34bf3e
|
@ -28,10 +28,6 @@ import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
|
|
||||||
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
|
|
||||||
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
|
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
|
||||||
* partition the vector space, and then stores the centroids and posting list in a sequential
|
* partition the vector space, and then stores the centroids and posting list in a sequential
|
||||||
|
@ -58,12 +54,15 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
||||||
// write the posting lists
|
// write the posting lists
|
||||||
final long[] offsets = new long[centroidSupplier.size()];
|
final long[] offsets = new long[centroidSupplier.size()];
|
||||||
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||||
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
|
|
||||||
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||||
|
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
|
||||||
|
ES91OSQVectorsScorer.BULK_SIZE,
|
||||||
|
quantizer,
|
||||||
|
floatVectorValues,
|
||||||
|
postingsOutput
|
||||||
|
);
|
||||||
for (int c = 0; c < centroidSupplier.size(); c++) {
|
for (int c = 0; c < centroidSupplier.size(); c++) {
|
||||||
float[] centroid = centroidSupplier.centroid(c);
|
float[] centroid = centroidSupplier.centroid(c);
|
||||||
binarizedByteVectorValues.centroid = centroid;
|
|
||||||
// TODO: add back in sorting vectors by distance to centroid
|
// TODO: add back in sorting vectors by distance to centroid
|
||||||
int[] cluster = assignmentsByCluster[c];
|
int[] cluster = assignmentsByCluster[c];
|
||||||
// TODO align???
|
// TODO align???
|
||||||
|
@ -75,7 +74,7 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
||||||
// to aid with only having to fetch vectors from slower storage when they are required
|
// to aid with only having to fetch vectors from slower storage when they are required
|
||||||
// keeping them in the same file indicates we pull the entire file into cache
|
// keeping them in the same file indicates we pull the entire file into cache
|
||||||
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
|
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
|
||||||
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
|
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (logger.isDebugEnabled()) {
|
if (logger.isDebugEnabled()) {
|
||||||
|
@ -115,54 +114,6 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writePostingList(int[] cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
|
|
||||||
throws IOException {
|
|
||||||
int limit = cluster.length - ES91OSQVectorsScorer.BULK_SIZE + 1;
|
|
||||||
int cidx = 0;
|
|
||||||
OptimizedScalarQuantizer.QuantizationResult[] corrections =
|
|
||||||
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
|
|
||||||
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
|
|
||||||
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
|
|
||||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
|
||||||
int ord = cluster[cidx + j];
|
|
||||||
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
|
||||||
// write vector
|
|
||||||
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
|
|
||||||
corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord);
|
|
||||||
}
|
|
||||||
// write corrections
|
|
||||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
|
||||||
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval()));
|
|
||||||
}
|
|
||||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
|
||||||
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval()));
|
|
||||||
}
|
|
||||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
|
||||||
int targetComponentSum = corrections[j].quantizedComponentSum();
|
|
||||||
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
|
|
||||||
postingsOutput.writeShort((short) targetComponentSum);
|
|
||||||
}
|
|
||||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
|
||||||
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// write tail
|
|
||||||
for (; cidx < cluster.length; cidx++) {
|
|
||||||
int ord = cluster[cidx];
|
|
||||||
// write vector
|
|
||||||
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
|
||||||
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
|
|
||||||
writeQuantizedValue(postingsOutput, binaryValue, correction);
|
|
||||||
binarizedByteVectorValues.getCorrectiveTerms(ord);
|
|
||||||
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
|
|
||||||
postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval()));
|
|
||||||
postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval()));
|
|
||||||
postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
|
|
||||||
assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff;
|
|
||||||
postingsOutput.writeShort((short) correction.quantizedComponentSum());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
|
CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
|
||||||
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
|
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
|
||||||
|
@ -295,47 +246,6 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO unify with OSQ format
|
|
||||||
static class BinarizedFloatVectorValues {
|
|
||||||
private OptimizedScalarQuantizer.QuantizationResult corrections;
|
|
||||||
private final byte[] binarized;
|
|
||||||
private final byte[] initQuantized;
|
|
||||||
private float[] centroid;
|
|
||||||
private final FloatVectorValues values;
|
|
||||||
private final OptimizedScalarQuantizer quantizer;
|
|
||||||
|
|
||||||
private int lastOrd = -1;
|
|
||||||
|
|
||||||
BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) {
|
|
||||||
this.values = delegate;
|
|
||||||
this.quantizer = quantizer;
|
|
||||||
this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
|
|
||||||
this.initQuantized = new byte[delegate.dimension()];
|
|
||||||
}
|
|
||||||
|
|
||||||
public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
|
|
||||||
if (ord != lastOrd) {
|
|
||||||
throw new IllegalStateException(
|
|
||||||
"attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return corrections;
|
|
||||||
}
|
|
||||||
|
|
||||||
public byte[] vectorValue(int ord) throws IOException {
|
|
||||||
if (ord != lastOrd) {
|
|
||||||
binarize(ord);
|
|
||||||
lastOrd = ord;
|
|
||||||
}
|
|
||||||
return binarized;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void binarize(int ord) throws IOException {
|
|
||||||
corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
|
|
||||||
packAsBinary(initQuantized, binarized);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
|
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
indexOutput.writeBytes(binaryValue, binaryValue.length);
|
indexOutput.writeBytes(binaryValue, binaryValue.length);
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.store.IndexOutput;
|
||||||
|
import org.apache.lucene.util.hnsw.IntToIntFunction;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
|
||||||
|
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base class for bulk writers that write vectors to disk using the BBQ encoding.
|
||||||
|
* This class provides the structure for writing vectors in bulk, with specific
|
||||||
|
* implementations for different bit sizes strategies.
|
||||||
|
*/
|
||||||
|
public abstract class DiskBBQBulkWriter {
|
||||||
|
protected final int bulkSize;
|
||||||
|
protected final OptimizedScalarQuantizer quantizer;
|
||||||
|
protected final IndexOutput out;
|
||||||
|
protected final FloatVectorValues fvv;
|
||||||
|
|
||||||
|
protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
|
||||||
|
this.bulkSize = bulkSize;
|
||||||
|
this.quantizer = quantizer;
|
||||||
|
this.out = out;
|
||||||
|
this.fvv = fvv;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException;
|
||||||
|
|
||||||
|
private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
|
||||||
|
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
|
||||||
|
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
|
||||||
|
}
|
||||||
|
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
|
||||||
|
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
|
||||||
|
}
|
||||||
|
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
|
||||||
|
int targetComponentSum = correction.quantizedComponentSum();
|
||||||
|
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
|
||||||
|
out.writeShort((short) targetComponentSum);
|
||||||
|
}
|
||||||
|
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
|
||||||
|
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction, IndexOutput out) throws IOException {
|
||||||
|
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
|
||||||
|
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
|
||||||
|
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
|
||||||
|
int targetComponentSum = correction.quantizedComponentSum();
|
||||||
|
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
|
||||||
|
out.writeShort((short) targetComponentSum);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
|
||||||
|
private final byte[] binarized;
|
||||||
|
private final byte[] initQuantized;
|
||||||
|
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
|
||||||
|
|
||||||
|
public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
|
||||||
|
super(bulkSize, quantizer, fvv, out);
|
||||||
|
this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
|
||||||
|
this.initQuantized = new byte[fvv.dimension()];
|
||||||
|
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException {
|
||||||
|
int limit = count - bulkSize + 1;
|
||||||
|
int i = 0;
|
||||||
|
for (; i < limit; i += bulkSize) {
|
||||||
|
for (int j = 0; j < bulkSize; j++) {
|
||||||
|
int ord = ords.apply(i + j);
|
||||||
|
float[] fv = fvv.vectorValue(ord);
|
||||||
|
corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
|
||||||
|
packAsBinary(initQuantized, binarized);
|
||||||
|
out.writeBytes(binarized, binarized.length);
|
||||||
|
}
|
||||||
|
writeCorrections(corrections, out);
|
||||||
|
}
|
||||||
|
// write tail
|
||||||
|
for (; i < count; ++i) {
|
||||||
|
int ord = ords.apply(i);
|
||||||
|
float[] fv = fvv.vectorValue(ord);
|
||||||
|
OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
|
||||||
|
packAsBinary(initQuantized, binarized);
|
||||||
|
out.writeBytes(binarized, binarized.length);
|
||||||
|
writeCorrection(correction, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -300,7 +300,7 @@ abstract class AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase extends Luc
|
||||||
);
|
);
|
||||||
assertEquals(8, results.scoreDocs.length);
|
assertEquals(8, results.scoreDocs.length);
|
||||||
assertIdMatches(reader, "10", results.scoreDocs[0].doc);
|
assertIdMatches(reader, "10", results.scoreDocs[0].doc);
|
||||||
assertIdMatches(reader, "8", results.scoreDocs[7].doc);
|
assertIdMatches(reader, "6", results.scoreDocs[7].doc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue