optimize OptimizedScalarQuantizer#scalarQuantize (#129874)

optimize OptimizedScalarQuantizer#scalarQuantize when destination can optimize 
OptimizedScalarQuantizer#scalarQuantize when destination can be an integer array
This commit is contained in:
Ignacio Vera 2025-07-02 14:57:59 +01:00 committed by GitHub
parent 24a5440851
commit f81d35536d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 227 additions and 51 deletions

View File

@ -43,7 +43,8 @@ public class OptimizedScalarQuantizerBenchmark {
float[] vector;
float[] centroid;
byte[] destination;
byte[] legacyDestination;
int[] destination;
@Param({ "1", "4", "7" })
byte bits;
@ -54,7 +55,8 @@ public class OptimizedScalarQuantizerBenchmark {
public void init() {
ThreadLocalRandom random = ThreadLocalRandom.current();
// random byte arrays for binary methods
destination = new byte[dims];
legacyDestination = new byte[dims];
destination = new int[dims];
vector = new float[dims];
centroid = new float[dims];
for (int i = 0; i < dims; ++i) {
@ -65,13 +67,20 @@ public class OptimizedScalarQuantizerBenchmark {
@Benchmark
public byte[] scalar() {
osq.scalarQuantize(vector, destination, bits, centroid);
return destination;
osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
return legacyDestination;
}
@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public byte[] vector() {
public byte[] legacyVector() {
osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
return legacyDestination;
}
@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public int[] vector() {
osq.scalarQuantize(vector, destination, bits, centroid);
return destination;
}

View File

@ -258,4 +258,25 @@ public class ESVectorUtil {
}
return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm);
}
/**
* Optimized-scalar quantization of the provided vector to the provided destination array.
*
* @param vector the vector to quantize
* @param destination the array to store the result
* @param lowInterval the minimum value, lower values in the original array will be replaced by this value
* @param upperInterval the maximum value, bigger values in the original array will be replaced by this value
* @param bit the number of bits to use for quantization, must be between 1 and 8
*
* @return return the sum of all the elements of the resulting quantized vector.
*/
public static int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bit) {
if (vector.length > destination.length) {
throw new IllegalArgumentException("vector dimensions differ: " + vector.length + "!=" + destination.length);
}
if (bit <= 0 || bit > Byte.SIZE) {
throw new IllegalArgumentException("bit must be between 1 and 8, but was: " + bit);
}
return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit);
}
}

View File

@ -269,4 +269,18 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
}
return ret;
}
@Override
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
float nSteps = ((1 << bits) - 1);
float step = (upperInterval - lowInterval) / nSteps;
int sumQuery = 0;
for (int h = 0; h < vector.length; h++) {
float xi = Math.min(Math.max(vector[h], lowInterval), upperInterval);
int assignment = Math.round((xi - lowInterval) / step);
sumQuery += assignment;
destination[h] = assignment;
}
return sumQuery;
}
}

View File

@ -39,4 +39,6 @@ public interface ESVectorUtilSupport {
float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm);
int quantizeVectorWithIntervals(float[] vector, int[] quantize, float lowInterval, float upperInterval, byte bit);
}

View File

@ -791,4 +791,32 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
return sum;
}
@Override
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
float nSteps = ((1 << bits) - 1);
float step = (upperInterval - lowInterval) / nSteps;
int sumQuery = 0;
int i = 0;
if (vector.length > 2 * FLOAT_SPECIES.length()) {
int limit = FLOAT_SPECIES.loopBound(vector.length);
FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval);
FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval);
FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step);
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
FloatVector xi = v.max(lowVec).min(upperVec); // clamp
IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round
sumQuery += assignment.reduceLanes(ADD);
assignment.intoArray(destination, i);
}
}
for (; i < vector.length; i++) {
float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval);
int assignment = Math.round((xi - lowInterval) / step);
sumQuery += assignment;
destination[i] = assignment;
}
return sumQuery;
}
}

View File

@ -286,6 +286,29 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
assertEquals(expected, result, deltaEps);
}
public void testQuantizeVectorWithIntervals() {
int vectorSize = randomIntBetween(1, 2048);
float[] vector = new float[vectorSize];
byte bits = (byte) randomIntBetween(1, 8);
for (int i = 0; i < vectorSize; ++i) {
vector[i] = random().nextFloat();
}
float low = random().nextFloat();
float high = random().nextFloat();
if (low > high) {
float tmp = low;
low = high;
high = tmp;
}
int[] quantizeExpected = new int[vectorSize];
int[] quantizeResult = new int[vectorSize];
var expected = defaultedProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeExpected, low, high, bits);
var result = defOrPanamaProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeResult, low, high, bits);
assertArrayEquals(quantizeExpected, quantizeResult);
assertEquals(expected, result, 0f);
}
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
int iterations = atLeast(50);
for (int i = 0; i < iterations; i++) {

View File

@ -57,4 +57,33 @@ public class BQSpaceUtils {
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
}
}
/**
* Same as {@link #transposeHalfByte(byte[], byte[])} but the input vector is provided as
* an array of integers.
*
* @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
* @param quantQueryByte the byte array to store the transposed query vector
* */
public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
for (int i = 0; i < q.length;) {
assert q[i] >= 0 && q[i] <= 15;
int lowerByte = 0;
int lowerMiddleByte = 0;
int upperMiddleByte = 0;
int upperByte = 0;
for (int j = 7; j >= 0 && i < q.length; j--) {
lowerByte |= (q[i] & 1) << j;
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
upperByte |= ((q[i] >> 3) & 1) << j;
i++;
}
int index = ((i + 7) / 8) - 1;
quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
}
}
}

View File

@ -40,7 +40,7 @@ public class BQVectorUtils {
return Math.abs(l1norm - 1.0d) <= EPSILON;
}
public static void packAsBinary(byte[] vector, byte[] packed) {
public static void packAsBinary(int[] vector, byte[] packed) {
for (int i = 0; i < vector.length;) {
byte result = 0;
for (int j = 7; j >= 0 && i < vector.length; j--) {

View File

@ -52,13 +52,17 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
final float globalCentroidDp = fieldEntry.globalCentroidDp();
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
final byte[] quantized = new byte[targetQuery.length];
final int[] scratch = new int[targetQuery.length];
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
ArrayUtil.copyArray(targetQuery),
quantized,
scratch,
(byte) 4,
fieldEntry.globalCentroid()
);
final byte[] quantized = new byte[targetQuery.length];
for (int i = 0; i < quantized.length; i++) {
quantized[i] = (byte) scratch[i];
}
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
return new CentroidQueryScorer() {
int currentCentroid = -1;
@ -182,7 +186,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
DocIdsWriter docIdsWriter = new DocIdsWriter();
final float[] scratch;
final byte[] quantizationScratch;
final int[] quantizationScratch;
final byte[] quantizedQueryScratch;
final OptimizedScalarQuantizer quantizer;
final float[] correctiveValues = new float[3];
@ -202,7 +206,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
this.needsScoring = needsScoring;
scratch = new float[target.length];
quantizationScratch = new byte[target.length];
quantizationScratch = new int[target.length];
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8];
quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES;

View File

@ -122,8 +122,9 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
throws IOException {
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
int[] quantizedScratch = new int[fieldInfo.getVectorDimension()];
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
final byte[] quantized = new byte[fieldInfo.getVectorDimension()];
// TODO do we want to store these distances as well for future use?
// TODO: sort centroids by global centroid (was doing so previously here)
// TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
@ -135,7 +136,10 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
(byte) 4,
globalCentroid
);
writeQuantizedValue(centroidOutput, quantizedScratch, result);
for (int i = 0; i < quantizedScratch.length; i++) {
quantized[i] = (byte) quantizedScratch[i];
}
writeQuantizedValue(centroidOutput, quantized, result);
}
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (float[] centroid : centroids) {

View File

@ -66,13 +66,13 @@ public abstract class DiskBBQBulkWriter {
public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
private final byte[] binarized;
private final byte[] initQuantized;
private final int[] initQuantized;
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
super(bulkSize, quantizer, fvv, out);
this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
this.initQuantized = new byte[fvv.dimension()];
this.initQuantized = new int[fvv.dimension()];
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
}

View File

@ -57,7 +57,7 @@ public class OptimizedScalarQuantizer {
public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {}
public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destinations, byte[] bits, float[] centroid) {
public QuantizationResult[] multiScalarQuantize(float[] vector, int[][] destinations, byte[] bits, float[] centroid) {
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
assert bits.length == destinations.length;
@ -79,18 +79,14 @@ public class OptimizedScalarQuantizer {
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch);
optimizeIntervals(intervalScratch, vector, norm2, points);
float nSteps = ((1 << bits[i]) - 1);
float a = intervalScratch[0];
float b = intervalScratch[1];
float step = (b - a) / nSteps;
int sumQuery = 0;
// Now we have the optimized intervals, quantize the vector
for (int h = 0; h < vector.length; h++) {
float xi = (float) clamp(vector[h], a, b);
int assignment = Math.round((xi - a) / step);
sumQuery += assignment;
destinations[i][h] = (byte) assignment;
}
int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(
vector,
destinations[i],
intervalScratch[0],
intervalScratch[1],
bits[i]
);
results[i] = new QuantizationResult(
intervalScratch[0],
intervalScratch[1],
@ -101,7 +97,8 @@ public class OptimizedScalarQuantizer {
return results;
}
public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
// This method is only used for benchmarking purposes, it is not used in production
public QuantizationResult legacyScalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
assert vector.length <= destination.length;
@ -141,6 +138,36 @@ public class OptimizedScalarQuantizer {
);
}
public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte bits, float[] centroid) {
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
assert vector.length <= destination.length;
assert bits > 0 && bits <= 8;
int points = 1 << bits;
if (similarityFunction == EUCLIDEAN) {
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
} else {
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
}
float vecMean = statsScratch[0];
float vecVar = statsScratch[1];
float norm2 = statsScratch[2];
float min = statsScratch[3];
float max = statsScratch[4];
float vecStd = (float) Math.sqrt(vecVar);
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
optimizeIntervals(intervalScratch, vector, norm2, points);
// Now we have the optimized intervals, quantize the vector
int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(vector, destination, intervalScratch[0], intervalScratch[1], bits);
return new QuantizationResult(
intervalScratch[0],
intervalScratch[1],
similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
sumQuery
);
}
/**
* Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization
* loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the

View File

@ -77,7 +77,7 @@ public class ES818BinaryFlatVectorsScorer implements FlatVectorsScorer {
VectorUtil.l2normalize(copy);
}
target = copy;
byte[] initial = new byte[target.length];
int[] initial = new int[target.length];
byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8];
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid);
BQSpaceUtils.transposeHalfByte(initial, quantized);

View File

@ -198,7 +198,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer)
throws IOException {
int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64);
byte[] quantizationScratch = new byte[discreteDims];
int[] quantizationScratch = new int[discreteDims];
byte[] vector = new byte[discreteDims / 8];
for (int i = 0; i < fieldData.getVectors().size(); i++) {
float[] v = fieldData.getVectors().get(i);
@ -246,7 +246,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
OptimizedScalarQuantizer scalarQuantizer
) throws IOException {
int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64);
byte[] quantizationScratch = new byte[discreteDims];
int[] quantizationScratch = new int[discreteDims];
byte[] vector = new byte[discreteDims / 8];
for (int ordinal : ordMap) {
float[] v = fieldData.getVectors().get(ordinal);
@ -364,7 +364,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
) throws IOException {
int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64);
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()];
int[][] quantizationScratch = new int[2][floatVectorValues.dimension()];
byte[] toIndex = new byte[discretizedDimension / 8];
byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY];
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
@ -801,7 +801,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
static class BinarizedFloatVectorValues extends BinarizedByteVectorValues {
private OptimizedScalarQuantizer.QuantizationResult corrections;
private final byte[] binarized;
private final byte[] initQuantized;
private final int[] initQuantized;
private final float[] centroid;
private final FloatVectorValues values;
private final OptimizedScalarQuantizer quantizer;
@ -812,7 +812,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
this.values = delegate;
this.quantizer = quantizer;
this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8];
this.initQuantized = new byte[delegate.dimension()];
this.initQuantized = new int[delegate.dimension()];
this.centroid = centroid;
}

View File

@ -40,25 +40,25 @@ public class BQVectorUtilsTests extends LuceneTestCase {
public void testPackAsBinary() {
// 5 bits
byte[] toPack = new byte[] { 1, 1, 0, 0, 1 };
int[] toPack = new int[] { 1, 1, 0, 0, 1 };
byte[] packed = new byte[1];
BQVectorUtils.packAsBinary(toPack, packed);
assertArrayEquals(new byte[] { (byte) 0b11001000 }, packed);
// 8 bits
toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0 };
toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0 };
packed = new byte[1];
BQVectorUtils.packAsBinary(toPack, packed);
assertArrayEquals(new byte[] { (byte) 0b11001010 }, packed);
// 10 bits
toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 };
toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 };
packed = new byte[2];
BQVectorUtils.packAsBinary(toPack, packed);
assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11000000 }, packed);
// 16 bits
toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 };
toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 };
packed = new byte[2];
BQVectorUtils.packAsBinary(toPack, packed);
assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11100110 }, packed);

View File

@ -19,7 +19,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
static float[] deQuantize(byte[] quantized, byte bits, float[] interval, float[] centroid) {
static float[] deQuantize(int[] quantized, byte bits, float[] interval, float[] centroid) {
float[] dequantized = new float[quantized.length];
float a = interval[0];
float b = interval[1];
@ -52,10 +52,12 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
float[] scratch = new float[dims];
for (byte bit : ALL_BITS) {
float eps = (1f / (float) (1 << (bit)));
byte[] destination = new byte[dims];
byte[] legacyDestination = new byte[dims];
int[] destination = new int[dims];
for (int i = 0; i < numVectors; ++i) {
System.arraycopy(vectors[i], 0, scratch, 0, dims);
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(scratch, destination, bit, centroid);
assertValidResults(result);
assertValidQuantizedRange(destination, bit);
@ -71,6 +73,19 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
}
mae /= dims;
assertTrue("bits: " + bit + " mae: " + mae + " > eps: " + eps, mae <= eps);
// check we get the same result from the int version
System.arraycopy(vectors[i], 0, scratch, 0, dims);
OptimizedScalarQuantizer.QuantizationResult intResults = osq.legacyScalarQuantize(
scratch,
legacyDestination,
bit,
centroid
);
assertEquals(result, intResults);
for (int h = 0; h < dims; ++h) {
assertEquals((byte) destination[h], legacyDestination[h]);
}
}
}
}
@ -84,18 +99,18 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
float[] vector = new float[4096];
float[] centroid = new float[4096];
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096];
int[][] destinations = new int[MINIMUM_MSE_GRID.length][4096];
OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
assertEquals(MINIMUM_MSE_GRID.length, results.length);
assertValidResults(results);
for (byte[] destination : destinations) {
assertArrayEquals(new byte[4096], destination);
for (int[] destination : destinations) {
assertArrayEquals(new int[4096], destination);
}
byte[] destination = new byte[4096];
int[] destination = new int[4096];
for (byte bit : ALL_BITS) {
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
assertValidResults(result);
assertArrayEquals(new byte[4096], destination);
assertArrayEquals(new int[4096], destination);
}
}
@ -108,7 +123,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
VectorUtil.l2normalize(centroid);
}
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1];
int[][] destinations = new int[MINIMUM_MSE_GRID.length][1];
OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
assertEquals(MINIMUM_MSE_GRID.length, results.length);
assertValidResults(results);
@ -122,7 +137,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
VectorUtil.l2normalize(vector);
VectorUtil.l2normalize(centroid);
}
byte[] destination = new byte[1];
int[] destination = new int[1];
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
assertValidResults(result);
assertValidQuantizedRange(destination, bit);
@ -150,7 +165,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
VectorUtil.l2normalize(centroid);
}
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims];
int[][] destinations = new int[MINIMUM_MSE_GRID.length][dims];
OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid);
assertEquals(MINIMUM_MSE_GRID.length, results.length);
assertValidResults(results);
@ -158,7 +173,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
assertValidQuantizedRange(destinations[i], ALL_BITS[i]);
}
for (byte bit : ALL_BITS) {
byte[] destination = new byte[dims];
int[] destination = new int[dims];
System.arraycopy(vector, 0, copy, 0, dims);
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
VectorUtil.l2normalize(copy);
@ -171,8 +186,8 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
}
}
static void assertValidQuantizedRange(byte[] quantized, byte bits) {
for (byte b : quantized) {
static void assertValidQuantizedRange(int[] quantized, byte bits) {
for (int b : quantized) {
if (bits < 8) {
assertTrue(b >= 0);
}

View File

@ -243,7 +243,7 @@ public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat
assertEquals(centroid.length, dims);
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
byte[] quantizedVector = new byte[dims];
int[] quantizedVector = new int[dims];
byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8];
if (similarityFunction == VectorSimilarityFunction.COSINE) {
vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues);