From 044f34bf3ee07d2dc3dcdf2012cb5a2d516d5008 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 1 Jul 2025 16:53:45 -0400 Subject: [PATCH] Refactor bulk quantization writing into a unified class (#130354) this is a small refactor, laying ground work for more generalized bulk writing. I did some benchmarking and there was no significant performance difference (as expected). --- .../vectors/DefaultIVFVectorsWriter.java | 104 ++---------------- .../codec/vectors/DiskBBQBulkWriter.java | 104 ++++++++++++++++++ ...yingChildrenIVFKnnVectorQueryTestCase.java | 2 +- 3 files changed, 112 insertions(+), 98 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 37081e42a65f..c9506041584f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -28,10 +28,6 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; -import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS; -import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; -import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary; - /** * Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to * partition the vector space, and then stores the centroids and posting list in a sequential @@ -58,12 +54,15 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter { // write the posting lists final long[] offsets = new long[centroidSupplier.size()]; OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer); DocIdsWriter docIdsWriter = new DocIdsWriter(); - + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter( + ES91OSQVectorsScorer.BULK_SIZE, + quantizer, + floatVectorValues, + postingsOutput + ); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); - binarizedByteVectorValues.centroid = centroid; // TODO: add back in sorting vectors by distance to centroid int[] cluster = assignmentsByCluster[c]; // TODO align??? @@ -75,7 +74,7 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter { // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); - writePostingList(cluster, postingsOutput, binarizedByteVectorValues); + bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid); } if (logger.isDebugEnabled()) { @@ -115,54 +114,6 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter { ); } - private void writePostingList(int[] cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues) - throws IOException { - int limit = cluster.length - ES91OSQVectorsScorer.BULK_SIZE + 1; - int cidx = 0; - OptimizedScalarQuantizer.QuantizationResult[] corrections = - new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE]; - // Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE. - for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) { - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - int ord = cluster[cidx + j]; - byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); - // write vector - postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); - corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord); - } - // write corrections - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval())); - } - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval())); - } - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - int targetComponentSum = corrections[j].quantizedComponentSum(); - assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; - postingsOutput.writeShort((short) targetComponentSum); - } - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection())); - } - } - // write tail - for (; cidx < cluster.length; cidx++) { - int ord = cluster[cidx]; - // write vector - byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); - OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord); - writeQuantizedValue(postingsOutput, binaryValue, correction); - binarizedByteVectorValues.getCorrectiveTerms(ord); - postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); - postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval())); - postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval())); - postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection())); - assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff; - postingsOutput.writeShort((short) correction.quantizedComponentSum()); - } - } - @Override CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) { return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo); @@ -295,47 +246,6 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter { } } - // TODO unify with OSQ format - static class BinarizedFloatVectorValues { - private OptimizedScalarQuantizer.QuantizationResult corrections; - private final byte[] binarized; - private final byte[] initQuantized; - private float[] centroid; - private final FloatVectorValues values; - private final OptimizedScalarQuantizer quantizer; - - private int lastOrd = -1; - - BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) { - this.values = delegate; - this.quantizer = quantizer; - this.binarized = new byte[discretize(delegate.dimension(), 64) / 8]; - this.initQuantized = new byte[delegate.dimension()]; - } - - public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { - if (ord != lastOrd) { - throw new IllegalStateException( - "attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd - ); - } - return corrections; - } - - public byte[] vectorValue(int ord) throws IOException { - if (ord != lastOrd) { - binarize(ord); - lastOrd = ord; - } - return binarized; - } - - private void binarize(int ord) throws IOException { - corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid); - packAsBinary(initQuantized, binarized); - } - } - static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) throws IOException { indexOutput.writeBytes(binaryValue, binaryValue.length); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java new file mode 100644 index 000000000000..07f44e87c1a4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +import java.io.IOException; + +import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; +import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary; + +/** + * Base class for bulk writers that write vectors to disk using the BBQ encoding. + * This class provides the structure for writing vectors in bulk, with specific + * implementations for different bit sizes strategies. + */ +public abstract class DiskBBQBulkWriter { + protected final int bulkSize; + protected final OptimizedScalarQuantizer quantizer; + protected final IndexOutput out; + protected final FloatVectorValues fvv; + + protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) { + this.bulkSize = bulkSize; + this.quantizer = quantizer; + this.out = out; + this.fvv = fvv; + } + + public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException; + + private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException { + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + int targetComponentSum = correction.quantizedComponentSum(); + assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; + out.writeShort((short) targetComponentSum); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + } + } + + private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction, IndexOutput out) throws IOException { + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + int targetComponentSum = correction.quantizedComponentSum(); + assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; + out.writeShort((short) targetComponentSum); + } + + public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { + private final byte[] binarized; + private final byte[] initQuantized; + private final OptimizedScalarQuantizer.QuantizationResult[] corrections; + + public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) { + super(bulkSize, quantizer, fvv, out); + this.binarized = new byte[discretize(fvv.dimension(), 64) / 8]; + this.initQuantized = new byte[fvv.dimension()]; + this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize]; + } + + @Override + public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException { + int limit = count - bulkSize + 1; + int i = 0; + for (; i < limit; i += bulkSize) { + for (int j = 0; j < bulkSize; j++) { + int ord = ords.apply(i + j); + float[] fv = fvv.vectorValue(ord); + corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid); + packAsBinary(initQuantized, binarized); + out.writeBytes(binarized, binarized.length); + } + writeCorrections(corrections, out); + } + // write tail + for (; i < count; ++i) { + int ord = ords.apply(i); + float[] fv = fvv.vectorValue(ord); + OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid); + packAsBinary(initQuantized, binarized); + out.writeBytes(binarized, binarized.length); + writeCorrection(correction, out); + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase.java index bf3f7e761e94..f73d1e5a3199 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase.java @@ -300,7 +300,7 @@ abstract class AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase extends Luc ); assertEquals(8, results.scoreDocs.length); assertIdMatches(reader, "10", results.scoreDocs[0].doc); - assertIdMatches(reader, "8", results.scoreDocs[7].doc); + assertIdMatches(reader, "6", results.scoreDocs[7].doc); } } }