From f81d35536d72a3711502673f3854128da2a8ca4b Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Wed, 2 Jul 2025 14:57:59 +0100 Subject: [PATCH] optimize OptimizedScalarQuantizer#scalarQuantize (#129874) optimize OptimizedScalarQuantizer#scalarQuantize when destination can optimize OptimizedScalarQuantizer#scalarQuantize when destination can be an integer array --- .../OptimizedScalarQuantizerBenchmark.java | 19 +++++-- .../elasticsearch/simdvec/ESVectorUtil.java | 21 ++++++++ .../DefaultESVectorUtilSupport.java | 14 +++++ .../vectorization/ESVectorUtilSupport.java | 2 + .../PanamaESVectorUtilSupport.java | 28 ++++++++++ .../simdvec/ESVectorUtilTests.java | 23 ++++++++ .../index/codec/vectors/BQSpaceUtils.java | 29 ++++++++++ .../index/codec/vectors/BQVectorUtils.java | 2 +- .../vectors/DefaultIVFVectorsReader.java | 12 +++-- .../vectors/DefaultIVFVectorsWriter.java | 8 ++- .../codec/vectors/DiskBBQBulkWriter.java | 4 +- .../vectors/OptimizedScalarQuantizer.java | 53 ++++++++++++++----- .../es818/ES818BinaryFlatVectorsScorer.java | 2 +- .../ES818BinaryQuantizedVectorsWriter.java | 10 ++-- .../codec/vectors/BQVectorUtilsTests.java | 8 +-- .../OptimizedScalarQuantizerTests.java | 41 +++++++++----- ...S818BinaryQuantizedVectorsFormatTests.java | 2 +- 17 files changed, 227 insertions(+), 51 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java index 4fa0a1f95495..ea3309f89a26 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java @@ -43,7 +43,8 @@ public class OptimizedScalarQuantizerBenchmark { float[] vector; float[] centroid; - byte[] destination; + byte[] legacyDestination; + int[] destination; @Param({ "1", "4", "7" }) byte bits; @@ -54,7 +55,8 @@ public class OptimizedScalarQuantizerBenchmark { public void init() { ThreadLocalRandom random = ThreadLocalRandom.current(); // random byte arrays for binary methods - destination = new byte[dims]; + legacyDestination = new byte[dims]; + destination = new int[dims]; vector = new float[dims]; centroid = new float[dims]; for (int i = 0; i < dims; ++i) { @@ -65,13 +67,20 @@ public class OptimizedScalarQuantizerBenchmark { @Benchmark public byte[] scalar() { - osq.scalarQuantize(vector, destination, bits, centroid); - return destination; + osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid); + return legacyDestination; } @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public byte[] vector() { + public byte[] legacyVector() { + osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid); + return legacyDestination; + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public int[] vector() { osq.scalarQuantize(vector, destination, bits, centroid); return destination; } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 6671ed5084a8..b4279d57e6bd 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -258,4 +258,25 @@ public class ESVectorUtil { } return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm); } + + /** + * Optimized-scalar quantization of the provided vector to the provided destination array. + * + * @param vector the vector to quantize + * @param destination the array to store the result + * @param lowInterval the minimum value, lower values in the original array will be replaced by this value + * @param upperInterval the maximum value, bigger values in the original array will be replaced by this value + * @param bit the number of bits to use for quantization, must be between 1 and 8 + * + * @return return the sum of all the elements of the resulting quantized vector. + */ + public static int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bit) { + if (vector.length > destination.length) { + throw new IllegalArgumentException("vector dimensions differ: " + vector.length + "!=" + destination.length); + } + if (bit <= 0 || bit > Byte.SIZE) { + throw new IllegalArgumentException("bit must be between 1 and 8, but was: " + bit); + } + return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index 022f189a2e04..d77cb70c0857 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -269,4 +269,18 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport { } return ret; } + + @Override + public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) { + float nSteps = ((1 << bits) - 1); + float step = (upperInterval - lowInterval) / nSteps; + int sumQuery = 0; + for (int h = 0; h < vector.length; h++) { + float xi = Math.min(Math.max(vector[h], lowInterval), upperInterval); + int assignment = Math.round((xi - lowInterval) / step); + sumQuery += assignment; + destination[h] = assignment; + } + return sumQuery; + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java index dfd324547d84..7c60f07a95d4 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java @@ -39,4 +39,6 @@ public interface ESVectorUtilSupport { float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm); + int quantizeVectorWithIntervals(float[] vector, int[] quantize, float lowInterval, float upperInterval, byte bit); + } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 87e8a39c4842..defdeaf12b0b 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -791,4 +791,32 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { return sum; } + + @Override + public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) { + float nSteps = ((1 << bits) - 1); + float step = (upperInterval - lowInterval) / nSteps; + int sumQuery = 0; + int i = 0; + if (vector.length > 2 * FLOAT_SPECIES.length()) { + int limit = FLOAT_SPECIES.loopBound(vector.length); + FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval); + FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval); + FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step); + for (; i < limit; i += FLOAT_SPECIES.length()) { + FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i); + FloatVector xi = v.max(lowVec).min(upperVec); // clamp + IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round + sumQuery += assignment.reduceLanes(ADD); + assignment.intoArray(destination, i); + } + } + for (; i < vector.length; i++) { + float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval); + int assignment = Math.round((xi - lowInterval) / step); + sumQuery += assignment; + destination[i] = assignment; + } + return sumQuery; + } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index 2a83c0ec6212..f41722ce9474 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -286,6 +286,29 @@ public class ESVectorUtilTests extends BaseVectorizationTests { assertEquals(expected, result, deltaEps); } + public void testQuantizeVectorWithIntervals() { + int vectorSize = randomIntBetween(1, 2048); + float[] vector = new float[vectorSize]; + + byte bits = (byte) randomIntBetween(1, 8); + for (int i = 0; i < vectorSize; ++i) { + vector[i] = random().nextFloat(); + } + float low = random().nextFloat(); + float high = random().nextFloat(); + if (low > high) { + float tmp = low; + low = high; + high = tmp; + } + int[] quantizeExpected = new int[vectorSize]; + int[] quantizeResult = new int[vectorSize]; + var expected = defaultedProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeExpected, low, high, bits); + var result = defOrPanamaProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeResult, low, high, bits); + assertArrayEquals(quantizeExpected, quantizeResult); + assertEquals(expected, result, 0f); + } + void testIpByteBinImpl(ToLongBiFunction ipByteBinFunc) { int iterations = atLeast(50); for (int i = 0; i < iterations; i++) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java index f9fad7483568..99ead8334c21 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java @@ -57,4 +57,33 @@ public class BQSpaceUtils { quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; } } + + /** + * Same as {@link #transposeHalfByte(byte[], byte[])} but the input vector is provided as + * an array of integers. + * + * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15 + * @param quantQueryByte the byte array to store the transposed query vector + * */ + public static void transposeHalfByte(int[] q, byte[] quantQueryByte) { + for (int i = 0; i < q.length;) { + assert q[i] >= 0 && q[i] <= 15; + int lowerByte = 0; + int lowerMiddleByte = 0; + int upperMiddleByte = 0; + int upperByte = 0; + for (int j = 7; j >= 0 && i < q.length; j--) { + lowerByte |= (q[i] & 1) << j; + lowerMiddleByte |= ((q[i] >> 1) & 1) << j; + upperMiddleByte |= ((q[i] >> 2) & 1) << j; + upperByte |= ((q[i] >> 3) & 1) << j; + i++; + } + int index = ((i + 7) / 8) - 1; + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java index 1aff06a17596..75d0d2aa93a4 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java @@ -40,7 +40,7 @@ public class BQVectorUtils { return Math.abs(l1norm - 1.0d) <= EPSILON; } - public static void packAsBinary(byte[] vector, byte[] packed) { + public static void packAsBinary(int[] vector, byte[] packed) { for (int i = 0; i < vector.length;) { byte result = 0; for (int j = 7; j >= 0 && i < vector.length; j--) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index e7b41d005d7e..9ef017796500 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -52,13 +52,17 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap final FieldEntry fieldEntry = fields.get(fieldInfo.number); final float globalCentroidDp = fieldEntry.globalCentroidDp(); final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - final byte[] quantized = new byte[targetQuery.length]; + final int[] scratch = new int[targetQuery.length]; final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize( ArrayUtil.copyArray(targetQuery), - quantized, + scratch, (byte) 4, fieldEntry.globalCentroid() ); + final byte[] quantized = new byte[targetQuery.length]; + for (int i = 0; i < quantized.length; i++) { + quantized[i] = (byte) scratch[i]; + } final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension()); return new CentroidQueryScorer() { int currentCentroid = -1; @@ -182,7 +186,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap DocIdsWriter docIdsWriter = new DocIdsWriter(); final float[] scratch; - final byte[] quantizationScratch; + final int[] quantizationScratch; final byte[] quantizedQueryScratch; final OptimizedScalarQuantizer quantizer; final float[] correctiveValues = new float[3]; @@ -202,7 +206,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap this.needsScoring = needsScoring; scratch = new float[target.length]; - quantizationScratch = new byte[target.length]; + quantizationScratch = new int[target.length]; final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64); quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8]; quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index c9506041584f..df0c277bdd5a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -122,8 +122,9 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter { static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput) throws IOException { final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()]; + int[] quantizedScratch = new int[fieldInfo.getVectorDimension()]; float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; + final byte[] quantized = new byte[fieldInfo.getVectorDimension()]; // TODO do we want to store these distances as well for future use? // TODO: sort centroids by global centroid (was doing so previously here) // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned @@ -135,7 +136,10 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter { (byte) 4, globalCentroid ); - writeQuantizedValue(centroidOutput, quantizedScratch, result); + for (int i = 0; i < quantizedScratch.length; i++) { + quantized[i] = (byte) quantizedScratch[i]; + } + writeQuantizedValue(centroidOutput, quantized, result); } final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); for (float[] centroid : centroids) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java index 07f44e87c1a4..6974cd50d4ab 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java @@ -66,13 +66,13 @@ public abstract class DiskBBQBulkWriter { public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { private final byte[] binarized; - private final byte[] initQuantized; + private final int[] initQuantized; private final OptimizedScalarQuantizer.QuantizationResult[] corrections; public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) { super(bulkSize, quantizer, fvv, out); this.binarized = new byte[discretize(fvv.dimension(), 64) / 8]; - this.initQuantized = new byte[fvv.dimension()]; + this.initQuantized = new int[fvv.dimension()]; this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize]; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java index 565e8116edc2..00e97f92ed81 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java @@ -57,7 +57,7 @@ public class OptimizedScalarQuantizer { public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {} - public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destinations, byte[] bits, float[] centroid) { + public QuantizationResult[] multiScalarQuantize(float[] vector, int[][] destinations, byte[] bits, float[] centroid) { assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); assert bits.length == destinations.length; @@ -79,18 +79,14 @@ public class OptimizedScalarQuantizer { // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch); optimizeIntervals(intervalScratch, vector, norm2, points); - float nSteps = ((1 << bits[i]) - 1); - float a = intervalScratch[0]; - float b = intervalScratch[1]; - float step = (b - a) / nSteps; - int sumQuery = 0; // Now we have the optimized intervals, quantize the vector - for (int h = 0; h < vector.length; h++) { - float xi = (float) clamp(vector[h], a, b); - int assignment = Math.round((xi - a) / step); - sumQuery += assignment; - destinations[i][h] = (byte) assignment; - } + int sumQuery = ESVectorUtil.quantizeVectorWithIntervals( + vector, + destinations[i], + intervalScratch[0], + intervalScratch[1], + bits[i] + ); results[i] = new QuantizationResult( intervalScratch[0], intervalScratch[1], @@ -101,7 +97,8 @@ public class OptimizedScalarQuantizer { return results; } - public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) { + // This method is only used for benchmarking purposes, it is not used in production + public QuantizationResult legacyScalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) { assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); assert vector.length <= destination.length; @@ -141,6 +138,36 @@ public class OptimizedScalarQuantizer { ); } + public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte bits, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert vector.length <= destination.length; + assert bits > 0 && bits <= 8; + int points = 1 << bits; + if (similarityFunction == EUCLIDEAN) { + ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch); + } else { + ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch); + } + float vecMean = statsScratch[0]; + float vecVar = statsScratch[1]; + float norm2 = statsScratch[2]; + float min = statsScratch[3]; + float max = statsScratch[4]; + float vecStd = (float) Math.sqrt(vecVar); + // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds + initInterval(bits, vecStd, vecMean, min, max, intervalScratch); + optimizeIntervals(intervalScratch, vector, norm2, points); + // Now we have the optimized intervals, quantize the vector + int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(vector, destination, intervalScratch[0], intervalScratch[1], bits); + return new QuantizationResult( + intervalScratch[0], + intervalScratch[1], + similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5], + sumQuery + ); + } + /** * Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization * loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index d66a28712ec9..efb098373489 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -77,7 +77,7 @@ public class ES818BinaryFlatVectorsScorer implements FlatVectorsScorer { VectorUtil.l2normalize(copy); } target = copy; - byte[] initial = new byte[target.length]; + int[] initial = new int[target.length]; byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid); BQSpaceUtils.transposeHalfByte(initial, quantized); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java index a4983e234c8d..22520567f295 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -198,7 +198,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) throws IOException { int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); - byte[] quantizationScratch = new byte[discreteDims]; + int[] quantizationScratch = new int[discreteDims]; byte[] vector = new byte[discreteDims / 8]; for (int i = 0; i < fieldData.getVectors().size(); i++) { float[] v = fieldData.getVectors().get(i); @@ -246,7 +246,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { OptimizedScalarQuantizer scalarQuantizer ) throws IOException { int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); - byte[] quantizationScratch = new byte[discreteDims]; + int[] quantizationScratch = new int[discreteDims]; byte[] vector = new byte[discreteDims / 8]; for (int ordinal : ordMap) { float[] v = fieldData.getVectors().get(ordinal); @@ -364,7 +364,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { ) throws IOException { int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64); DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()]; + int[][] quantizationScratch = new int[2][floatVectorValues.dimension()]; byte[] toIndex = new byte[discretizedDimension / 8]; byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY]; KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); @@ -801,7 +801,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { private OptimizedScalarQuantizer.QuantizationResult corrections; private final byte[] binarized; - private final byte[] initQuantized; + private final int[] initQuantized; private final float[] centroid; private final FloatVectorValues values; private final OptimizedScalarQuantizer quantizer; @@ -812,7 +812,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { this.values = delegate; this.quantizer = quantizer; this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8]; - this.initQuantized = new byte[delegate.dimension()]; + this.initQuantized = new int[delegate.dimension()]; this.centroid = centroid; } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java index 270ad54e9a96..35349cd054df 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java @@ -40,25 +40,25 @@ public class BQVectorUtilsTests extends LuceneTestCase { public void testPackAsBinary() { // 5 bits - byte[] toPack = new byte[] { 1, 1, 0, 0, 1 }; + int[] toPack = new int[] { 1, 1, 0, 0, 1 }; byte[] packed = new byte[1]; BQVectorUtils.packAsBinary(toPack, packed); assertArrayEquals(new byte[] { (byte) 0b11001000 }, packed); // 8 bits - toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0 }; + toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0 }; packed = new byte[1]; BQVectorUtils.packAsBinary(toPack, packed); assertArrayEquals(new byte[] { (byte) 0b11001010 }, packed); // 10 bits - toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 }; + toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 }; packed = new byte[2]; BQVectorUtils.packAsBinary(toPack, packed); assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11000000 }, packed); // 16 bits - toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 }; + toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 }; packed = new byte[2]; BQVectorUtils.packAsBinary(toPack, packed); assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11100110 }, packed); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java index 55171f48f373..4b4c4621c19e 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java @@ -19,7 +19,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - static float[] deQuantize(byte[] quantized, byte bits, float[] interval, float[] centroid) { + static float[] deQuantize(int[] quantized, byte bits, float[] interval, float[] centroid) { float[] dequantized = new float[quantized.length]; float a = interval[0]; float b = interval[1]; @@ -52,10 +52,12 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { float[] scratch = new float[dims]; for (byte bit : ALL_BITS) { float eps = (1f / (float) (1 << (bit))); - byte[] destination = new byte[dims]; + byte[] legacyDestination = new byte[dims]; + int[] destination = new int[dims]; for (int i = 0; i < numVectors; ++i) { System.arraycopy(vectors[i], 0, scratch, 0, dims); OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(scratch, destination, bit, centroid); + assertValidResults(result); assertValidQuantizedRange(destination, bit); @@ -71,6 +73,19 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { } mae /= dims; assertTrue("bits: " + bit + " mae: " + mae + " > eps: " + eps, mae <= eps); + + // check we get the same result from the int version + System.arraycopy(vectors[i], 0, scratch, 0, dims); + OptimizedScalarQuantizer.QuantizationResult intResults = osq.legacyScalarQuantize( + scratch, + legacyDestination, + bit, + centroid + ); + assertEquals(result, intResults); + for (int h = 0; h < dims; ++h) { + assertEquals((byte) destination[h], legacyDestination[h]); + } } } } @@ -84,18 +99,18 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { float[] vector = new float[4096]; float[] centroid = new float[4096]; OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); - byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096]; + int[][] destinations = new int[MINIMUM_MSE_GRID.length][4096]; OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); assertEquals(MINIMUM_MSE_GRID.length, results.length); assertValidResults(results); - for (byte[] destination : destinations) { - assertArrayEquals(new byte[4096], destination); + for (int[] destination : destinations) { + assertArrayEquals(new int[4096], destination); } - byte[] destination = new byte[4096]; + int[] destination = new int[4096]; for (byte bit : ALL_BITS) { OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid); assertValidResults(result); - assertArrayEquals(new byte[4096], destination); + assertArrayEquals(new int[4096], destination); } } @@ -108,7 +123,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { VectorUtil.l2normalize(centroid); } OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); - byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1]; + int[][] destinations = new int[MINIMUM_MSE_GRID.length][1]; OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); assertEquals(MINIMUM_MSE_GRID.length, results.length); assertValidResults(results); @@ -122,7 +137,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { VectorUtil.l2normalize(vector); VectorUtil.l2normalize(centroid); } - byte[] destination = new byte[1]; + int[] destination = new int[1]; OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid); assertValidResults(result); assertValidQuantizedRange(destination, bit); @@ -150,7 +165,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { VectorUtil.l2normalize(centroid); } OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); - byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims]; + int[][] destinations = new int[MINIMUM_MSE_GRID.length][dims]; OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid); assertEquals(MINIMUM_MSE_GRID.length, results.length); assertValidResults(results); @@ -158,7 +173,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { assertValidQuantizedRange(destinations[i], ALL_BITS[i]); } for (byte bit : ALL_BITS) { - byte[] destination = new byte[dims]; + int[] destination = new int[dims]; System.arraycopy(vector, 0, copy, 0, dims); if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { VectorUtil.l2normalize(copy); @@ -171,8 +186,8 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { } } - static void assertValidQuantizedRange(byte[] quantized, byte bits) { - for (byte b : quantized) { + static void assertValidQuantizedRange(int[] quantized, byte bits) { + for (int b : quantized) { if (bits < 8) { assertTrue(b >= 0); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java index 84fafde5af7c..9fa10305562c 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java @@ -243,7 +243,7 @@ public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat assertEquals(centroid.length, dims); OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); - byte[] quantizedVector = new byte[dims]; + int[] quantizedVector = new int[dims]; byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8]; if (similarityFunction == VectorSimilarityFunction.COSINE) { vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues);