Refactor bulk quantization writing into a unified class (#130354)
this is a small refactor, laying ground work for more generalized bulk writing. I did some benchmarking and there was no significant performance difference (as expected).
This commit is contained in:
parent
c1a4f8ae68
commit
044f34bf3e
|
@ -28,10 +28,6 @@ import java.nio.ByteBuffer;
|
|||
import java.nio.ByteOrder;
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
|
||||
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
|
||||
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
|
||||
|
||||
/**
|
||||
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
|
||||
* partition the vector space, and then stores the centroids and posting list in a sequential
|
||||
|
@ -58,12 +54,15 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
// write the posting lists
|
||||
final long[] offsets = new long[centroidSupplier.size()];
|
||||
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
|
||||
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||
|
||||
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
|
||||
ES91OSQVectorsScorer.BULK_SIZE,
|
||||
quantizer,
|
||||
floatVectorValues,
|
||||
postingsOutput
|
||||
);
|
||||
for (int c = 0; c < centroidSupplier.size(); c++) {
|
||||
float[] centroid = centroidSupplier.centroid(c);
|
||||
binarizedByteVectorValues.centroid = centroid;
|
||||
// TODO: add back in sorting vectors by distance to centroid
|
||||
int[] cluster = assignmentsByCluster[c];
|
||||
// TODO align???
|
||||
|
@ -75,7 +74,7 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
// to aid with only having to fetch vectors from slower storage when they are required
|
||||
// keeping them in the same file indicates we pull the entire file into cache
|
||||
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
|
||||
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
|
||||
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
|
||||
}
|
||||
|
||||
if (logger.isDebugEnabled()) {
|
||||
|
@ -115,54 +114,6 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
);
|
||||
}
|
||||
|
||||
private void writePostingList(int[] cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
|
||||
throws IOException {
|
||||
int limit = cluster.length - ES91OSQVectorsScorer.BULK_SIZE + 1;
|
||||
int cidx = 0;
|
||||
OptimizedScalarQuantizer.QuantizationResult[] corrections =
|
||||
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
|
||||
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
|
||||
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
int ord = cluster[cidx + j];
|
||||
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
||||
// write vector
|
||||
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
|
||||
corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||
}
|
||||
// write corrections
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval()));
|
||||
}
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval()));
|
||||
}
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
int targetComponentSum = corrections[j].quantizedComponentSum();
|
||||
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
|
||||
postingsOutput.writeShort((short) targetComponentSum);
|
||||
}
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection()));
|
||||
}
|
||||
}
|
||||
// write tail
|
||||
for (; cidx < cluster.length; cidx++) {
|
||||
int ord = cluster[cidx];
|
||||
// write vector
|
||||
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
||||
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||
writeQuantizedValue(postingsOutput, binaryValue, correction);
|
||||
binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
|
||||
postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval()));
|
||||
postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval()));
|
||||
postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
|
||||
assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff;
|
||||
postingsOutput.writeShort((short) correction.quantizedComponentSum());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
|
||||
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
|
||||
|
@ -295,47 +246,6 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO unify with OSQ format
|
||||
static class BinarizedFloatVectorValues {
|
||||
private OptimizedScalarQuantizer.QuantizationResult corrections;
|
||||
private final byte[] binarized;
|
||||
private final byte[] initQuantized;
|
||||
private float[] centroid;
|
||||
private final FloatVectorValues values;
|
||||
private final OptimizedScalarQuantizer quantizer;
|
||||
|
||||
private int lastOrd = -1;
|
||||
|
||||
BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) {
|
||||
this.values = delegate;
|
||||
this.quantizer = quantizer;
|
||||
this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
|
||||
this.initQuantized = new byte[delegate.dimension()];
|
||||
}
|
||||
|
||||
public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
|
||||
if (ord != lastOrd) {
|
||||
throw new IllegalStateException(
|
||||
"attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
|
||||
);
|
||||
}
|
||||
return corrections;
|
||||
}
|
||||
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
if (ord != lastOrd) {
|
||||
binarize(ord);
|
||||
lastOrd = ord;
|
||||
}
|
||||
return binarized;
|
||||
}
|
||||
|
||||
private void binarize(int ord) throws IOException {
|
||||
corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
|
||||
packAsBinary(initQuantized, binarized);
|
||||
}
|
||||
}
|
||||
|
||||
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
|
||||
throws IOException {
|
||||
indexOutput.writeBytes(binaryValue, binaryValue.length);
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.hnsw.IntToIntFunction;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
|
||||
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
|
||||
|
||||
/**
|
||||
* Base class for bulk writers that write vectors to disk using the BBQ encoding.
|
||||
* This class provides the structure for writing vectors in bulk, with specific
|
||||
* implementations for different bit sizes strategies.
|
||||
*/
|
||||
public abstract class DiskBBQBulkWriter {
|
||||
protected final int bulkSize;
|
||||
protected final OptimizedScalarQuantizer quantizer;
|
||||
protected final IndexOutput out;
|
||||
protected final FloatVectorValues fvv;
|
||||
|
||||
protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
|
||||
this.bulkSize = bulkSize;
|
||||
this.quantizer = quantizer;
|
||||
this.out = out;
|
||||
this.fvv = fvv;
|
||||
}
|
||||
|
||||
public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException;
|
||||
|
||||
private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
|
||||
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
|
||||
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
|
||||
}
|
||||
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
|
||||
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
|
||||
}
|
||||
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
|
||||
int targetComponentSum = correction.quantizedComponentSum();
|
||||
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
|
||||
out.writeShort((short) targetComponentSum);
|
||||
}
|
||||
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
|
||||
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
|
||||
}
|
||||
}
|
||||
|
||||
private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction, IndexOutput out) throws IOException {
|
||||
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
|
||||
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
|
||||
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
|
||||
int targetComponentSum = correction.quantizedComponentSum();
|
||||
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
|
||||
out.writeShort((short) targetComponentSum);
|
||||
}
|
||||
|
||||
public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
|
||||
private final byte[] binarized;
|
||||
private final byte[] initQuantized;
|
||||
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
|
||||
|
||||
public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
|
||||
super(bulkSize, quantizer, fvv, out);
|
||||
this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
|
||||
this.initQuantized = new byte[fvv.dimension()];
|
||||
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException {
|
||||
int limit = count - bulkSize + 1;
|
||||
int i = 0;
|
||||
for (; i < limit; i += bulkSize) {
|
||||
for (int j = 0; j < bulkSize; j++) {
|
||||
int ord = ords.apply(i + j);
|
||||
float[] fv = fvv.vectorValue(ord);
|
||||
corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
|
||||
packAsBinary(initQuantized, binarized);
|
||||
out.writeBytes(binarized, binarized.length);
|
||||
}
|
||||
writeCorrections(corrections, out);
|
||||
}
|
||||
// write tail
|
||||
for (; i < count; ++i) {
|
||||
int ord = ords.apply(i);
|
||||
float[] fv = fvv.vectorValue(ord);
|
||||
OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
|
||||
packAsBinary(initQuantized, binarized);
|
||||
out.writeBytes(binarized, binarized.length);
|
||||
writeCorrection(correction, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -300,7 +300,7 @@ abstract class AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase extends Luc
|
|||
);
|
||||
assertEquals(8, results.scoreDocs.length);
|
||||
assertIdMatches(reader, "10", results.scoreDocs[0].doc);
|
||||
assertIdMatches(reader, "8", results.scoreDocs[7].doc);
|
||||
assertIdMatches(reader, "6", results.scoreDocs[7].doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue