optimize OptimizedScalarQuantizer#scalarQuantize (#129874)
optimize OptimizedScalarQuantizer#scalarQuantize when destination can optimize OptimizedScalarQuantizer#scalarQuantize when destination can be an integer array
This commit is contained in:
parent
24a5440851
commit
f81d35536d
|
@ -43,7 +43,8 @@ public class OptimizedScalarQuantizerBenchmark {
|
|||
|
||||
float[] vector;
|
||||
float[] centroid;
|
||||
byte[] destination;
|
||||
byte[] legacyDestination;
|
||||
int[] destination;
|
||||
|
||||
@Param({ "1", "4", "7" })
|
||||
byte bits;
|
||||
|
@ -54,7 +55,8 @@ public class OptimizedScalarQuantizerBenchmark {
|
|||
public void init() {
|
||||
ThreadLocalRandom random = ThreadLocalRandom.current();
|
||||
// random byte arrays for binary methods
|
||||
destination = new byte[dims];
|
||||
legacyDestination = new byte[dims];
|
||||
destination = new int[dims];
|
||||
vector = new float[dims];
|
||||
centroid = new float[dims];
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
|
@ -65,13 +67,20 @@ public class OptimizedScalarQuantizerBenchmark {
|
|||
|
||||
@Benchmark
|
||||
public byte[] scalar() {
|
||||
osq.scalarQuantize(vector, destination, bits, centroid);
|
||||
return destination;
|
||||
osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
|
||||
return legacyDestination;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||
public byte[] vector() {
|
||||
public byte[] legacyVector() {
|
||||
osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
|
||||
return legacyDestination;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||
public int[] vector() {
|
||||
osq.scalarQuantize(vector, destination, bits, centroid);
|
||||
return destination;
|
||||
}
|
||||
|
|
|
@ -258,4 +258,25 @@ public class ESVectorUtil {
|
|||
}
|
||||
return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm);
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimized-scalar quantization of the provided vector to the provided destination array.
|
||||
*
|
||||
* @param vector the vector to quantize
|
||||
* @param destination the array to store the result
|
||||
* @param lowInterval the minimum value, lower values in the original array will be replaced by this value
|
||||
* @param upperInterval the maximum value, bigger values in the original array will be replaced by this value
|
||||
* @param bit the number of bits to use for quantization, must be between 1 and 8
|
||||
*
|
||||
* @return return the sum of all the elements of the resulting quantized vector.
|
||||
*/
|
||||
public static int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bit) {
|
||||
if (vector.length > destination.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + vector.length + "!=" + destination.length);
|
||||
}
|
||||
if (bit <= 0 || bit > Byte.SIZE) {
|
||||
throw new IllegalArgumentException("bit must be between 1 and 8, but was: " + bit);
|
||||
}
|
||||
return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -269,4 +269,18 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
|
||||
float nSteps = ((1 << bits) - 1);
|
||||
float step = (upperInterval - lowInterval) / nSteps;
|
||||
int sumQuery = 0;
|
||||
for (int h = 0; h < vector.length; h++) {
|
||||
float xi = Math.min(Math.max(vector[h], lowInterval), upperInterval);
|
||||
int assignment = Math.round((xi - lowInterval) / step);
|
||||
sumQuery += assignment;
|
||||
destination[h] = assignment;
|
||||
}
|
||||
return sumQuery;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,4 +39,6 @@ public interface ESVectorUtilSupport {
|
|||
|
||||
float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm);
|
||||
|
||||
int quantizeVectorWithIntervals(float[] vector, int[] quantize, float lowInterval, float upperInterval, byte bit);
|
||||
|
||||
}
|
||||
|
|
|
@ -791,4 +791,32 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
|
||||
return sum;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
|
||||
float nSteps = ((1 << bits) - 1);
|
||||
float step = (upperInterval - lowInterval) / nSteps;
|
||||
int sumQuery = 0;
|
||||
int i = 0;
|
||||
if (vector.length > 2 * FLOAT_SPECIES.length()) {
|
||||
int limit = FLOAT_SPECIES.loopBound(vector.length);
|
||||
FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval);
|
||||
FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval);
|
||||
FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step);
|
||||
for (; i < limit; i += FLOAT_SPECIES.length()) {
|
||||
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
|
||||
FloatVector xi = v.max(lowVec).min(upperVec); // clamp
|
||||
IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round
|
||||
sumQuery += assignment.reduceLanes(ADD);
|
||||
assignment.intoArray(destination, i);
|
||||
}
|
||||
}
|
||||
for (; i < vector.length; i++) {
|
||||
float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval);
|
||||
int assignment = Math.round((xi - lowInterval) / step);
|
||||
sumQuery += assignment;
|
||||
destination[i] = assignment;
|
||||
}
|
||||
return sumQuery;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -286,6 +286,29 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
|
|||
assertEquals(expected, result, deltaEps);
|
||||
}
|
||||
|
||||
public void testQuantizeVectorWithIntervals() {
|
||||
int vectorSize = randomIntBetween(1, 2048);
|
||||
float[] vector = new float[vectorSize];
|
||||
|
||||
byte bits = (byte) randomIntBetween(1, 8);
|
||||
for (int i = 0; i < vectorSize; ++i) {
|
||||
vector[i] = random().nextFloat();
|
||||
}
|
||||
float low = random().nextFloat();
|
||||
float high = random().nextFloat();
|
||||
if (low > high) {
|
||||
float tmp = low;
|
||||
low = high;
|
||||
high = tmp;
|
||||
}
|
||||
int[] quantizeExpected = new int[vectorSize];
|
||||
int[] quantizeResult = new int[vectorSize];
|
||||
var expected = defaultedProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeExpected, low, high, bits);
|
||||
var result = defOrPanamaProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeResult, low, high, bits);
|
||||
assertArrayEquals(quantizeExpected, quantizeResult);
|
||||
assertEquals(expected, result, 0f);
|
||||
}
|
||||
|
||||
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
|
||||
int iterations = atLeast(50);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
|
|
|
@ -57,4 +57,33 @@ public class BQSpaceUtils {
|
|||
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Same as {@link #transposeHalfByte(byte[], byte[])} but the input vector is provided as
|
||||
* an array of integers.
|
||||
*
|
||||
* @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
|
||||
* @param quantQueryByte the byte array to store the transposed query vector
|
||||
* */
|
||||
public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
|
||||
for (int i = 0; i < q.length;) {
|
||||
assert q[i] >= 0 && q[i] <= 15;
|
||||
int lowerByte = 0;
|
||||
int lowerMiddleByte = 0;
|
||||
int upperMiddleByte = 0;
|
||||
int upperByte = 0;
|
||||
for (int j = 7; j >= 0 && i < q.length; j--) {
|
||||
lowerByte |= (q[i] & 1) << j;
|
||||
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
|
||||
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
|
||||
upperByte |= ((q[i] >> 3) & 1) << j;
|
||||
i++;
|
||||
}
|
||||
int index = ((i + 7) / 8) - 1;
|
||||
quantQueryByte[index] = (byte) lowerByte;
|
||||
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
|
||||
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
|
||||
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ public class BQVectorUtils {
|
|||
return Math.abs(l1norm - 1.0d) <= EPSILON;
|
||||
}
|
||||
|
||||
public static void packAsBinary(byte[] vector, byte[] packed) {
|
||||
public static void packAsBinary(int[] vector, byte[] packed) {
|
||||
for (int i = 0; i < vector.length;) {
|
||||
byte result = 0;
|
||||
for (int j = 7; j >= 0 && i < vector.length; j--) {
|
||||
|
|
|
@ -52,13 +52,17 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
|
|||
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
|
||||
final float globalCentroidDp = fieldEntry.globalCentroidDp();
|
||||
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
final byte[] quantized = new byte[targetQuery.length];
|
||||
final int[] scratch = new int[targetQuery.length];
|
||||
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
|
||||
ArrayUtil.copyArray(targetQuery),
|
||||
quantized,
|
||||
scratch,
|
||||
(byte) 4,
|
||||
fieldEntry.globalCentroid()
|
||||
);
|
||||
final byte[] quantized = new byte[targetQuery.length];
|
||||
for (int i = 0; i < quantized.length; i++) {
|
||||
quantized[i] = (byte) scratch[i];
|
||||
}
|
||||
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
|
||||
return new CentroidQueryScorer() {
|
||||
int currentCentroid = -1;
|
||||
|
@ -182,7 +186,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
|
|||
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||
|
||||
final float[] scratch;
|
||||
final byte[] quantizationScratch;
|
||||
final int[] quantizationScratch;
|
||||
final byte[] quantizedQueryScratch;
|
||||
final OptimizedScalarQuantizer quantizer;
|
||||
final float[] correctiveValues = new float[3];
|
||||
|
@ -202,7 +206,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
|
|||
this.needsScoring = needsScoring;
|
||||
|
||||
scratch = new float[target.length];
|
||||
quantizationScratch = new byte[target.length];
|
||||
quantizationScratch = new int[target.length];
|
||||
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
|
||||
quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8];
|
||||
quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES;
|
||||
|
|
|
@ -122,8 +122,9 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
|
||||
throws IOException {
|
||||
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
|
||||
int[] quantizedScratch = new int[fieldInfo.getVectorDimension()];
|
||||
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
|
||||
final byte[] quantized = new byte[fieldInfo.getVectorDimension()];
|
||||
// TODO do we want to store these distances as well for future use?
|
||||
// TODO: sort centroids by global centroid (was doing so previously here)
|
||||
// TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
|
||||
|
@ -135,7 +136,10 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|||
(byte) 4,
|
||||
globalCentroid
|
||||
);
|
||||
writeQuantizedValue(centroidOutput, quantizedScratch, result);
|
||||
for (int i = 0; i < quantizedScratch.length; i++) {
|
||||
quantized[i] = (byte) quantizedScratch[i];
|
||||
}
|
||||
writeQuantizedValue(centroidOutput, quantized, result);
|
||||
}
|
||||
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (float[] centroid : centroids) {
|
||||
|
|
|
@ -66,13 +66,13 @@ public abstract class DiskBBQBulkWriter {
|
|||
|
||||
public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
|
||||
private final byte[] binarized;
|
||||
private final byte[] initQuantized;
|
||||
private final int[] initQuantized;
|
||||
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
|
||||
|
||||
public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
|
||||
super(bulkSize, quantizer, fvv, out);
|
||||
this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
|
||||
this.initQuantized = new byte[fvv.dimension()];
|
||||
this.initQuantized = new int[fvv.dimension()];
|
||||
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
|
||||
}
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ public class OptimizedScalarQuantizer {
|
|||
|
||||
public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {}
|
||||
|
||||
public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destinations, byte[] bits, float[] centroid) {
|
||||
public QuantizationResult[] multiScalarQuantize(float[] vector, int[][] destinations, byte[] bits, float[] centroid) {
|
||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
|
||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
||||
assert bits.length == destinations.length;
|
||||
|
@ -79,18 +79,14 @@ public class OptimizedScalarQuantizer {
|
|||
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
||||
initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch);
|
||||
optimizeIntervals(intervalScratch, vector, norm2, points);
|
||||
float nSteps = ((1 << bits[i]) - 1);
|
||||
float a = intervalScratch[0];
|
||||
float b = intervalScratch[1];
|
||||
float step = (b - a) / nSteps;
|
||||
int sumQuery = 0;
|
||||
// Now we have the optimized intervals, quantize the vector
|
||||
for (int h = 0; h < vector.length; h++) {
|
||||
float xi = (float) clamp(vector[h], a, b);
|
||||
int assignment = Math.round((xi - a) / step);
|
||||
sumQuery += assignment;
|
||||
destinations[i][h] = (byte) assignment;
|
||||
}
|
||||
int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(
|
||||
vector,
|
||||
destinations[i],
|
||||
intervalScratch[0],
|
||||
intervalScratch[1],
|
||||
bits[i]
|
||||
);
|
||||
results[i] = new QuantizationResult(
|
||||
intervalScratch[0],
|
||||
intervalScratch[1],
|
||||
|
@ -101,7 +97,8 @@ public class OptimizedScalarQuantizer {
|
|||
return results;
|
||||
}
|
||||
|
||||
public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
|
||||
// This method is only used for benchmarking purposes, it is not used in production
|
||||
public QuantizationResult legacyScalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
|
||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
|
||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
||||
assert vector.length <= destination.length;
|
||||
|
@ -141,6 +138,36 @@ public class OptimizedScalarQuantizer {
|
|||
);
|
||||
}
|
||||
|
||||
public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte bits, float[] centroid) {
|
||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
|
||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
||||
assert vector.length <= destination.length;
|
||||
assert bits > 0 && bits <= 8;
|
||||
int points = 1 << bits;
|
||||
if (similarityFunction == EUCLIDEAN) {
|
||||
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
|
||||
} else {
|
||||
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
|
||||
}
|
||||
float vecMean = statsScratch[0];
|
||||
float vecVar = statsScratch[1];
|
||||
float norm2 = statsScratch[2];
|
||||
float min = statsScratch[3];
|
||||
float max = statsScratch[4];
|
||||
float vecStd = (float) Math.sqrt(vecVar);
|
||||
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
||||
initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
|
||||
optimizeIntervals(intervalScratch, vector, norm2, points);
|
||||
// Now we have the optimized intervals, quantize the vector
|
||||
int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(vector, destination, intervalScratch[0], intervalScratch[1], bits);
|
||||
return new QuantizationResult(
|
||||
intervalScratch[0],
|
||||
intervalScratch[1],
|
||||
similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
|
||||
sumQuery
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization
|
||||
* loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the
|
||||
|
|
|
@ -77,7 +77,7 @@ public class ES818BinaryFlatVectorsScorer implements FlatVectorsScorer {
|
|||
VectorUtil.l2normalize(copy);
|
||||
}
|
||||
target = copy;
|
||||
byte[] initial = new byte[target.length];
|
||||
int[] initial = new int[target.length];
|
||||
byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8];
|
||||
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid);
|
||||
BQSpaceUtils.transposeHalfByte(initial, quantized);
|
||||
|
|
|
@ -198,7 +198,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
|
|||
private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer)
|
||||
throws IOException {
|
||||
int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64);
|
||||
byte[] quantizationScratch = new byte[discreteDims];
|
||||
int[] quantizationScratch = new int[discreteDims];
|
||||
byte[] vector = new byte[discreteDims / 8];
|
||||
for (int i = 0; i < fieldData.getVectors().size(); i++) {
|
||||
float[] v = fieldData.getVectors().get(i);
|
||||
|
@ -246,7 +246,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
|
|||
OptimizedScalarQuantizer scalarQuantizer
|
||||
) throws IOException {
|
||||
int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64);
|
||||
byte[] quantizationScratch = new byte[discreteDims];
|
||||
int[] quantizationScratch = new int[discreteDims];
|
||||
byte[] vector = new byte[discreteDims / 8];
|
||||
for (int ordinal : ordMap) {
|
||||
float[] v = fieldData.getVectors().get(ordinal);
|
||||
|
@ -364,7 +364,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
|
|||
) throws IOException {
|
||||
int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64);
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()];
|
||||
int[][] quantizationScratch = new int[2][floatVectorValues.dimension()];
|
||||
byte[] toIndex = new byte[discretizedDimension / 8];
|
||||
byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY];
|
||||
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
|
||||
|
@ -801,7 +801,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
|
|||
static class BinarizedFloatVectorValues extends BinarizedByteVectorValues {
|
||||
private OptimizedScalarQuantizer.QuantizationResult corrections;
|
||||
private final byte[] binarized;
|
||||
private final byte[] initQuantized;
|
||||
private final int[] initQuantized;
|
||||
private final float[] centroid;
|
||||
private final FloatVectorValues values;
|
||||
private final OptimizedScalarQuantizer quantizer;
|
||||
|
@ -812,7 +812,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
|
|||
this.values = delegate;
|
||||
this.quantizer = quantizer;
|
||||
this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8];
|
||||
this.initQuantized = new byte[delegate.dimension()];
|
||||
this.initQuantized = new int[delegate.dimension()];
|
||||
this.centroid = centroid;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,25 +40,25 @@ public class BQVectorUtilsTests extends LuceneTestCase {
|
|||
|
||||
public void testPackAsBinary() {
|
||||
// 5 bits
|
||||
byte[] toPack = new byte[] { 1, 1, 0, 0, 1 };
|
||||
int[] toPack = new int[] { 1, 1, 0, 0, 1 };
|
||||
byte[] packed = new byte[1];
|
||||
BQVectorUtils.packAsBinary(toPack, packed);
|
||||
assertArrayEquals(new byte[] { (byte) 0b11001000 }, packed);
|
||||
|
||||
// 8 bits
|
||||
toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0 };
|
||||
toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0 };
|
||||
packed = new byte[1];
|
||||
BQVectorUtils.packAsBinary(toPack, packed);
|
||||
assertArrayEquals(new byte[] { (byte) 0b11001010 }, packed);
|
||||
|
||||
// 10 bits
|
||||
toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 };
|
||||
toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 };
|
||||
packed = new byte[2];
|
||||
BQVectorUtils.packAsBinary(toPack, packed);
|
||||
assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11000000 }, packed);
|
||||
|
||||
// 16 bits
|
||||
toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 };
|
||||
toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 };
|
||||
packed = new byte[2];
|
||||
BQVectorUtils.packAsBinary(toPack, packed);
|
||||
assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11100110 }, packed);
|
||||
|
|
|
@ -19,7 +19,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|||
|
||||
static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
|
||||
|
||||
static float[] deQuantize(byte[] quantized, byte bits, float[] interval, float[] centroid) {
|
||||
static float[] deQuantize(int[] quantized, byte bits, float[] interval, float[] centroid) {
|
||||
float[] dequantized = new float[quantized.length];
|
||||
float a = interval[0];
|
||||
float b = interval[1];
|
||||
|
@ -52,10 +52,12 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|||
float[] scratch = new float[dims];
|
||||
for (byte bit : ALL_BITS) {
|
||||
float eps = (1f / (float) (1 << (bit)));
|
||||
byte[] destination = new byte[dims];
|
||||
byte[] legacyDestination = new byte[dims];
|
||||
int[] destination = new int[dims];
|
||||
for (int i = 0; i < numVectors; ++i) {
|
||||
System.arraycopy(vectors[i], 0, scratch, 0, dims);
|
||||
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(scratch, destination, bit, centroid);
|
||||
|
||||
assertValidResults(result);
|
||||
assertValidQuantizedRange(destination, bit);
|
||||
|
||||
|
@ -71,6 +73,19 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|||
}
|
||||
mae /= dims;
|
||||
assertTrue("bits: " + bit + " mae: " + mae + " > eps: " + eps, mae <= eps);
|
||||
|
||||
// check we get the same result from the int version
|
||||
System.arraycopy(vectors[i], 0, scratch, 0, dims);
|
||||
OptimizedScalarQuantizer.QuantizationResult intResults = osq.legacyScalarQuantize(
|
||||
scratch,
|
||||
legacyDestination,
|
||||
bit,
|
||||
centroid
|
||||
);
|
||||
assertEquals(result, intResults);
|
||||
for (int h = 0; h < dims; ++h) {
|
||||
assertEquals((byte) destination[h], legacyDestination[h]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,18 +99,18 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|||
float[] vector = new float[4096];
|
||||
float[] centroid = new float[4096];
|
||||
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
|
||||
byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096];
|
||||
int[][] destinations = new int[MINIMUM_MSE_GRID.length][4096];
|
||||
OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
|
||||
assertEquals(MINIMUM_MSE_GRID.length, results.length);
|
||||
assertValidResults(results);
|
||||
for (byte[] destination : destinations) {
|
||||
assertArrayEquals(new byte[4096], destination);
|
||||
for (int[] destination : destinations) {
|
||||
assertArrayEquals(new int[4096], destination);
|
||||
}
|
||||
byte[] destination = new byte[4096];
|
||||
int[] destination = new int[4096];
|
||||
for (byte bit : ALL_BITS) {
|
||||
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
|
||||
assertValidResults(result);
|
||||
assertArrayEquals(new byte[4096], destination);
|
||||
assertArrayEquals(new int[4096], destination);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -108,7 +123,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|||
VectorUtil.l2normalize(centroid);
|
||||
}
|
||||
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
|
||||
byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1];
|
||||
int[][] destinations = new int[MINIMUM_MSE_GRID.length][1];
|
||||
OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
|
||||
assertEquals(MINIMUM_MSE_GRID.length, results.length);
|
||||
assertValidResults(results);
|
||||
|
@ -122,7 +137,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|||
VectorUtil.l2normalize(vector);
|
||||
VectorUtil.l2normalize(centroid);
|
||||
}
|
||||
byte[] destination = new byte[1];
|
||||
int[] destination = new int[1];
|
||||
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
|
||||
assertValidResults(result);
|
||||
assertValidQuantizedRange(destination, bit);
|
||||
|
@ -150,7 +165,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|||
VectorUtil.l2normalize(centroid);
|
||||
}
|
||||
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
|
||||
byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims];
|
||||
int[][] destinations = new int[MINIMUM_MSE_GRID.length][dims];
|
||||
OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid);
|
||||
assertEquals(MINIMUM_MSE_GRID.length, results.length);
|
||||
assertValidResults(results);
|
||||
|
@ -158,7 +173,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|||
assertValidQuantizedRange(destinations[i], ALL_BITS[i]);
|
||||
}
|
||||
for (byte bit : ALL_BITS) {
|
||||
byte[] destination = new byte[dims];
|
||||
int[] destination = new int[dims];
|
||||
System.arraycopy(vector, 0, copy, 0, dims);
|
||||
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
|
||||
VectorUtil.l2normalize(copy);
|
||||
|
@ -171,8 +186,8 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
static void assertValidQuantizedRange(byte[] quantized, byte bits) {
|
||||
for (byte b : quantized) {
|
||||
static void assertValidQuantizedRange(int[] quantized, byte bits) {
|
||||
for (int b : quantized) {
|
||||
if (bits < 8) {
|
||||
assertTrue(b >= 0);
|
||||
}
|
||||
|
|
|
@ -243,7 +243,7 @@ public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat
|
|||
assertEquals(centroid.length, dims);
|
||||
|
||||
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
|
||||
byte[] quantizedVector = new byte[dims];
|
||||
int[] quantizedVector = new int[dims];
|
||||
byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8];
|
||||
if (similarityFunction == VectorSimilarityFunction.COSINE) {
|
||||
vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues);
|
||||
|
|
Loading…
Reference in New Issue