Refactor bulk quantization writing into a unified class (#130354)

this is a small refactor, laying ground work for more generalized bulk
writing.

I did some benchmarking and there was no significant performance
difference (as expected).
This commit is contained in:
Benjamin Trent 2025-07-01 16:53:45 -04:00 committed by GitHub
parent c1a4f8ae68
commit 044f34bf3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 112 additions and 98 deletions

View File

@ -28,10 +28,6 @@ import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
/**
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
* partition the vector space, and then stores the centroids and posting list in a sequential
@ -58,12 +54,15 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
// write the posting lists
final long[] offsets = new long[centroidSupplier.size()];
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
ES91OSQVectorsScorer.BULK_SIZE,
quantizer,
floatVectorValues,
postingsOutput
);
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
binarizedByteVectorValues.centroid = centroid;
// TODO: add back in sorting vectors by distance to centroid
int[] cluster = assignmentsByCluster[c];
// TODO align???
@ -75,7 +74,7 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
}
if (logger.isDebugEnabled()) {
@ -115,54 +114,6 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
);
}
private void writePostingList(int[] cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
throws IOException {
int limit = cluster.length - ES91OSQVectorsScorer.BULK_SIZE + 1;
int cidx = 0;
OptimizedScalarQuantizer.QuantizationResult[] corrections =
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
int ord = cluster[cidx + j];
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
// write vector
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord);
}
// write corrections
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval()));
}
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval()));
}
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
int targetComponentSum = corrections[j].quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
postingsOutput.writeShort((short) targetComponentSum);
}
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection()));
}
}
// write tail
for (; cidx < cluster.length; cidx++) {
int ord = cluster[cidx];
// write vector
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
writeQuantizedValue(postingsOutput, binaryValue, correction);
binarizedByteVectorValues.getCorrectiveTerms(ord);
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval()));
postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval()));
postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff;
postingsOutput.writeShort((short) correction.quantizedComponentSum());
}
}
@Override
CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
@ -295,47 +246,6 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
}
}
// TODO unify with OSQ format
static class BinarizedFloatVectorValues {
private OptimizedScalarQuantizer.QuantizationResult corrections;
private final byte[] binarized;
private final byte[] initQuantized;
private float[] centroid;
private final FloatVectorValues values;
private final OptimizedScalarQuantizer quantizer;
private int lastOrd = -1;
BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) {
this.values = delegate;
this.quantizer = quantizer;
this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
this.initQuantized = new byte[delegate.dimension()];
}
public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
if (ord != lastOrd) {
throw new IllegalStateException(
"attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
);
}
return corrections;
}
public byte[] vectorValue(int ord) throws IOException {
if (ord != lastOrd) {
binarize(ord);
lastOrd = ord;
}
return binarized;
}
private void binarize(int ord) throws IOException {
corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
packAsBinary(initQuantized, binarized);
}
}
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
throws IOException {
indexOutput.writeBytes(binaryValue, binaryValue.length);

View File

@ -0,0 +1,104 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import java.io.IOException;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
/**
* Base class for bulk writers that write vectors to disk using the BBQ encoding.
* This class provides the structure for writing vectors in bulk, with specific
* implementations for different bit sizes strategies.
*/
public abstract class DiskBBQBulkWriter {
protected final int bulkSize;
protected final OptimizedScalarQuantizer quantizer;
protected final IndexOutput out;
protected final FloatVectorValues fvv;
protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
this.bulkSize = bulkSize;
this.quantizer = quantizer;
this.out = out;
this.fvv = fvv;
}
public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException;
private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
int targetComponentSum = correction.quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
out.writeShort((short) targetComponentSum);
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
}
}
private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction, IndexOutput out) throws IOException {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
int targetComponentSum = correction.quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
out.writeShort((short) targetComponentSum);
}
public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
private final byte[] binarized;
private final byte[] initQuantized;
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
super(bulkSize, quantizer, fvv, out);
this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
this.initQuantized = new byte[fvv.dimension()];
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
}
@Override
public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException {
int limit = count - bulkSize + 1;
int i = 0;
for (; i < limit; i += bulkSize) {
for (int j = 0; j < bulkSize; j++) {
int ord = ords.apply(i + j);
float[] fv = fvv.vectorValue(ord);
corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
packAsBinary(initQuantized, binarized);
out.writeBytes(binarized, binarized.length);
}
writeCorrections(corrections, out);
}
// write tail
for (; i < count; ++i) {
int ord = ords.apply(i);
float[] fv = fvv.vectorValue(ord);
OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
packAsBinary(initQuantized, binarized);
out.writeBytes(binarized, binarized.length);
writeCorrection(correction, out);
}
}
}
}

View File

@ -300,7 +300,7 @@ abstract class AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase extends Luc
);
assertEquals(8, results.scoreDocs.length);
assertIdMatches(reader, "10", results.scoreDocs[0].doc);
assertIdMatches(reader, "8", results.scoreDocs[7].doc);
assertIdMatches(reader, "6", results.scoreDocs[7].doc);
}
}
}