ESQL: Speed up grouping by bytes (#114021) (#114652)

This speeds up grouping by bytes valued fields (keyword, text, ip, and
wildcard) when the input is an ordinal block:
```
    bytes_refs 22.213 ± 0.322 -> 19.848 ± 0.205 ns/op (*maybe* real, maybe noise. still good)
       ordinal didn't exist   ->  2.988 ± 0.011 ns/op
```
I see this as 20ns -> 3ns, an 85% speed up. We never hard the ordinals
branch before so I'm expecting the same performance there - about 20ns
per op.

This also speeds up grouping by a pair of byte valued fields:
```
two_bytes_refs 83.112 ± 42.348  -> 46.521 ± 0.386 ns/op
  two_ordinals 83.531 ± 23.473  ->  8.617 ± 0.105 ns/op
```
The speed up is much better when the fields are ordinals because hashing
bytes is comparatively slow.

I believe the ordinals case is quite common. I've run into it in quite a
few profiles.
This commit is contained in:
Nik Everett 2024-10-11 16:22:18 -04:00 committed by GitHub
parent 0e2f832516
commit 1212dee8b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 632 additions and 66 deletions

View File

@ -30,10 +30,13 @@ import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AggregationOperator;
import org.elasticsearch.compute.operator.DriverContext;
@ -78,7 +81,10 @@ public class AggregatorBenchmark {
private static final String DOUBLES = "doubles";
private static final String BOOLEANS = "booleans";
private static final String BYTES_REFS = "bytes_refs";
private static final String ORDINALS = "ordinals";
private static final String TWO_LONGS = "two_" + LONGS;
private static final String TWO_BYTES_REFS = "two_" + BYTES_REFS;
private static final String TWO_ORDINALS = "two_" + ORDINALS;
private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS;
private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS;
@ -119,7 +125,21 @@ public class AggregatorBenchmark {
}
}
@Param({ NONE, LONGS, INTS, DOUBLES, BOOLEANS, BYTES_REFS, TWO_LONGS, LONGS_AND_BYTES_REFS, TWO_LONGS_AND_BYTES_REFS })
@Param(
{
NONE,
LONGS,
INTS,
DOUBLES,
BOOLEANS,
BYTES_REFS,
ORDINALS,
TWO_LONGS,
TWO_BYTES_REFS,
TWO_ORDINALS,
LONGS_AND_BYTES_REFS,
TWO_LONGS_AND_BYTES_REFS }
)
public String grouping;
@Param({ COUNT, COUNT_DISTINCT, MIN, MAX, SUM })
@ -144,8 +164,12 @@ public class AggregatorBenchmark {
case INTS -> List.of(new BlockHash.GroupSpec(0, ElementType.INT));
case DOUBLES -> List.of(new BlockHash.GroupSpec(0, ElementType.DOUBLE));
case BOOLEANS -> List.of(new BlockHash.GroupSpec(0, ElementType.BOOLEAN));
case BYTES_REFS -> List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF));
case BYTES_REFS, ORDINALS -> List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF));
case TWO_LONGS -> List.of(new BlockHash.GroupSpec(0, ElementType.LONG), new BlockHash.GroupSpec(1, ElementType.LONG));
case TWO_BYTES_REFS, TWO_ORDINALS -> List.of(
new BlockHash.GroupSpec(0, ElementType.BYTES_REF),
new BlockHash.GroupSpec(1, ElementType.BYTES_REF)
);
case LONGS_AND_BYTES_REFS -> List.of(
new BlockHash.GroupSpec(0, ElementType.LONG),
new BlockHash.GroupSpec(1, ElementType.BYTES_REF)
@ -218,6 +242,10 @@ public class AggregatorBenchmark {
checkGroupingBlock(prefix, LONGS, page.getBlock(0));
checkGroupingBlock(prefix, LONGS, page.getBlock(1));
}
case TWO_BYTES_REFS, TWO_ORDINALS -> {
checkGroupingBlock(prefix, BYTES_REFS, page.getBlock(0));
checkGroupingBlock(prefix, BYTES_REFS, page.getBlock(1));
}
case LONGS_AND_BYTES_REFS -> {
checkGroupingBlock(prefix, LONGS, page.getBlock(0));
checkGroupingBlock(prefix, BYTES_REFS, page.getBlock(1));
@ -379,7 +407,7 @@ public class AggregatorBenchmark {
throw new AssertionError(prefix + "bad group expected [true] but was [" + groups.getBoolean(1) + "]");
}
}
case BYTES_REFS -> {
case BYTES_REFS, ORDINALS -> {
BytesRefBlock groups = (BytesRefBlock) block;
for (int g = 0; g < GROUPS; g++) {
if (false == groups.getBytesRef(g, new BytesRef()).equals(bytesGroup(g))) {
@ -508,6 +536,8 @@ public class AggregatorBenchmark {
private static List<Block> groupingBlocks(String grouping, String blockType) {
return switch (grouping) {
case TWO_LONGS -> List.of(groupingBlock(LONGS, blockType), groupingBlock(LONGS, blockType));
case TWO_BYTES_REFS -> List.of(groupingBlock(BYTES_REFS, blockType), groupingBlock(BYTES_REFS, blockType));
case TWO_ORDINALS -> List.of(groupingBlock(ORDINALS, blockType), groupingBlock(ORDINALS, blockType));
case LONGS_AND_BYTES_REFS -> List.of(groupingBlock(LONGS, blockType), groupingBlock(BYTES_REFS, blockType));
case TWO_LONGS_AND_BYTES_REFS -> List.of(
groupingBlock(LONGS, blockType),
@ -570,6 +600,19 @@ public class AggregatorBenchmark {
}
yield builder.build();
}
case ORDINALS -> {
IntVector.Builder ordinals = blockFactory.newIntVectorBuilder(BLOCK_LENGTH * valuesPerGroup);
for (int i = 0; i < BLOCK_LENGTH; i++) {
for (int v = 0; v < valuesPerGroup; v++) {
ordinals.appendInt(i % GROUPS);
}
}
BytesRefVector.Builder bytes = blockFactory.newBytesRefVectorBuilder(BLOCK_LENGTH * valuesPerGroup);
for (int i = 0; i < GROUPS; i++) {
bytes.appendBytesRef(bytesGroup(i));
}
yield new OrdinalBytesRefVector(ordinals.build(), bytes.build()).asBlock();
}
default -> throw new UnsupportedOperationException("unsupported grouping [" + grouping + "]");
};
}

View File

@ -0,0 +1,5 @@
pr: 114021
summary: "ESQL: Speed up grouping by bytes"
area: ES|QL
type: enhancement
issues: []

View File

@ -23,15 +23,18 @@ import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe;
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBytesRef;
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeInt;
import org.elasticsearch.core.ReleasableIterator;
import java.io.IOException;
/**
* Maps a {@link BytesRefBlock} column to group ids.
* This class is generated. Do not edit it.
*/
final class BytesRefBlockHash extends BlockHash {
private final int channel;
@ -54,6 +57,7 @@ final class BytesRefBlockHash extends BlockHash {
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
// TODO track raw counts and which implementation we pick for the profiler - #114008
var block = page.getBlock(channel);
if (block.areAllValuesNull()) {
seenNull = true;
@ -76,6 +80,10 @@ final class BytesRefBlockHash extends BlockHash {
}
IntVector add(BytesRefVector vector) {
var ordinals = vector.asOrdinals();
if (ordinals != null) {
return addOrdinalsVector(ordinals);
}
BytesRef scratch = new BytesRef();
int positions = vector.getPositionCount();
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
@ -113,15 +121,29 @@ final class BytesRefBlockHash extends BlockHash {
return ReleasableIterator.single(lookup(vector));
}
private IntBlock addOrdinalsBlock(OrdinalBytesRefBlock inputBlock) {
var inputOrds = inputBlock.getOrdinalsBlock();
private IntVector addOrdinalsVector(OrdinalBytesRefVector inputBlock) {
IntVector inputOrds = inputBlock.getOrdinalsVector();
try (
var builder = blockFactory.newIntBlockBuilder(inputOrds.getPositionCount());
var builder = blockFactory.newIntVectorBuilder(inputOrds.getPositionCount());
var hashOrds = add(inputBlock.getDictionaryVector())
) {
for (int i = 0; i < inputOrds.getPositionCount(); i++) {
int valueCount = inputOrds.getValueCount(i);
int firstIndex = inputOrds.getFirstValueIndex(i);
for (int p = 0; p < inputOrds.getPositionCount(); p++) {
int ord = hashOrds.getInt(inputOrds.getInt(p));
builder.appendInt(ord);
}
return builder.build();
}
}
private IntBlock addOrdinalsBlock(OrdinalBytesRefBlock inputBlock) {
try (
IntBlock inputOrds = new MultivalueDedupeInt(inputBlock.getOrdinalsBlock()).dedupeToBlockAdaptive(blockFactory);
IntBlock.Builder builder = blockFactory.newIntBlockBuilder(inputOrds.getPositionCount());
IntVector hashOrds = add(inputBlock.getDictionaryVector())
) {
for (int p = 0; p < inputOrds.getPositionCount(); p++) {
int valueCount = inputOrds.getValueCount(p);
int firstIndex = inputOrds.getFirstValueIndex(p);
switch (valueCount) {
case 0 -> {
builder.appendInt(0);
@ -132,9 +154,11 @@ final class BytesRefBlockHash extends BlockHash {
builder.appendInt(ord);
}
default -> {
int start = firstIndex;
int end = firstIndex + valueCount;
builder.beginPositionEntry();
for (int v = 0; v < valueCount; v++) {
int ord = hashOrds.getInt(inputOrds.getInt(firstIndex + i));
for (int i = start; i < end; i++) {
int ord = hashOrds.getInt(inputOrds.getInt(i));
builder.appendInt(ord);
}
builder.endPositionEntry();

View File

@ -28,6 +28,7 @@ import java.util.BitSet;
/**
* Maps a {@link DoubleBlock} column to group ids.
* This class is generated. Do not edit it.
*/
final class DoubleBlockHash extends BlockHash {
private final int channel;
@ -50,6 +51,7 @@ final class DoubleBlockHash extends BlockHash {
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
// TODO track raw counts and which implementation we pick for the profiler - #114008
var block = page.getBlock(channel);
if (block.areAllValuesNull()) {
seenNull = true;

View File

@ -26,6 +26,7 @@ import java.util.BitSet;
/**
* Maps a {@link IntBlock} column to group ids.
* This class is generated. Do not edit it.
*/
final class IntBlockHash extends BlockHash {
private final int channel;
@ -48,6 +49,7 @@ final class IntBlockHash extends BlockHash {
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
// TODO track raw counts and which implementation we pick for the profiler - #114008
var block = page.getBlock(channel);
if (block.areAllValuesNull()) {
seenNull = true;

View File

@ -28,6 +28,7 @@ import java.util.BitSet;
/**
* Maps a {@link LongBlock} column to group ids.
* This class is generated. Do not edit it.
*/
final class LongBlockHash extends BlockHash {
private final int channel;
@ -50,6 +51,7 @@ final class LongBlockHash extends BlockHash {
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
// TODO track raw counts and which implementation we pick for the profiler - #114008
var block = page.getBlock(channel);
if (block.areAllValuesNull()) {
seenNull = true;

View File

@ -11,6 +11,7 @@ import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.common.util.Int3Hash;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.common.util.LongLongHash;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
@ -28,14 +29,37 @@ import java.util.Iterator;
import java.util.List;
/**
* A specialized hash table implementation maps values of a {@link Block} to ids (in longs).
* This class delegates to {@link LongHash} or {@link BytesRefHash}.
*
* @see LongHash
* @see BytesRefHash
* Specialized hash table implementations that map rows to a <strong>set</strong>
* of bucket IDs to which they belong to implement {@code GROUP BY} expressions.
* <p>
* A row is always in at least one bucket so the results are never {@code null}.
* {@code null} valued key columns will map to some integer bucket id.
* If none of key columns are multivalued then the output is always an
* {@link IntVector}. If any of the key are multivalued then a row is
* in a bucket for each value. If more than one key is multivalued then
* the row is in the combinatorial explosion of all value combinations.
* Luckily for the number of values rows can only be in each bucket once.
* Unluckily, it's the responsibility of {@link BlockHash} to remove those
* duplicates.
* </p>
* <p>
* These classes typically delegate to some combination of {@link BytesRefHash},
* {@link LongHash}, {@link LongLongHash}, {@link Int3Hash}. They don't
* <strong>technically</strong> have to be hash tables, so long as they
* implement the deduplication semantics above and vend integer ids.
* </p>
* <p>
* The integer ids are assigned to offsets into arrays of aggregation states
* so its permissible to have gaps in the ints. But large gaps are a bad
* idea because they'll waste space in the aggregations that use these
* positions. For example, {@link BooleanBlockHash} assigns {@code 0} to
* {@code null}, {@code 1} to {@code false}, and {@code 1} to {@code true}
* and that's <strong>fine</strong> and simple and good because it'll never
* leave a big gap, even if we never see {@code null}.
* </p>
*/
public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef3BlockHash, //
permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef2BlockHash, BytesRef3BlockHash, //
NullBlockHash, PackedValuesBlockHash, BytesRefLongBlockHash, LongLongBlockHash, TimeSeriesBlockHash {
protected final BlockFactory blockFactory;
@ -98,8 +122,19 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
if (groups.size() == 1) {
return newForElementType(groups.get(0).channel(), groups.get(0).elementType(), blockFactory);
}
if (groups.size() == 3 && groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) {
return new BytesRef3BlockHash(blockFactory, groups.get(0).channel, groups.get(1).channel, groups.get(2).channel, emitBatchSize);
if (groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) {
switch (groups.size()) {
case 2:
return new BytesRef2BlockHash(blockFactory, groups.get(0).channel, groups.get(1).channel, emitBatchSize);
case 3:
return new BytesRef3BlockHash(
blockFactory,
groups.get(0).channel,
groups.get(1).channel,
groups.get(2).channel,
emitBatchSize
);
}
}
if (allowBrokenOptimizations && groups.size() == 2) {
var g1 = groups.get(0);

View File

@ -25,8 +25,9 @@ import static org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBoolea
import static org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBoolean.TRUE_ORD;
/**
* Maps a {@link BooleanBlock} column to group ids. Assigns group
* {@code 0} to {@code false} and group {@code 1} to {@code true}.
* Maps a {@link BooleanBlock} column to group ids. Assigns
* {@code 0} to {@code null}, {@code 1} to {@code false}, and
* {@code 2} to {@code true}.
*/
final class BooleanBlockHash extends BlockHash {
private final int channel;

View File

@ -0,0 +1,196 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation.blockhash;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
import java.util.Locale;
/**
* Maps two {@link BytesRefBlock}s to group ids.
*/
final class BytesRef2BlockHash extends BlockHash {
private final int emitBatchSize;
private final int channel1;
private final int channel2;
private final BytesRefBlockHash hash1;
private final BytesRefBlockHash hash2;
private final LongHash finalHash;
BytesRef2BlockHash(BlockFactory blockFactory, int channel1, int channel2, int emitBatchSize) {
super(blockFactory);
this.emitBatchSize = emitBatchSize;
this.channel1 = channel1;
this.channel2 = channel2;
boolean success = false;
try {
this.hash1 = new BytesRefBlockHash(channel1, blockFactory);
this.hash2 = new BytesRefBlockHash(channel2, blockFactory);
this.finalHash = new LongHash(1, blockFactory.bigArrays());
success = true;
} finally {
if (success == false) {
close();
}
}
}
@Override
public void close() {
Releasables.close(hash1, hash2, finalHash);
}
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
BytesRefBlock b1 = page.getBlock(channel1);
BytesRefBlock b2 = page.getBlock(channel2);
BytesRefVector v1 = b1.asVector();
BytesRefVector v2 = b2.asVector();
if (v1 != null && v2 != null) {
addVectors(v1, v2, addInput);
} else {
try (IntBlock k1 = hash1.add(b1); IntBlock k2 = hash2.add(b2)) {
try (AddWork work = new AddWork(k1, k2, addInput)) {
work.add();
}
}
}
}
private void addVectors(BytesRefVector v1, BytesRefVector v2, GroupingAggregatorFunction.AddInput addInput) {
final int positionCount = v1.getPositionCount();
try (IntVector.FixedBuilder ordsBuilder = blockFactory.newIntVectorFixedBuilder(positionCount)) {
try (IntVector k1 = hash1.add(v1); IntVector k2 = hash2.add(v2)) {
for (int p = 0; p < positionCount; p++) {
long ord = ord(k1.getInt(p), k2.getInt(p));
ordsBuilder.appendInt(p, Math.toIntExact(ord));
}
}
try (IntVector ords = ordsBuilder.build()) {
addInput.add(0, ords);
}
}
}
private class AddWork extends AddPage {
final IntBlock b1;
final IntBlock b2;
AddWork(IntBlock b1, IntBlock b2, GroupingAggregatorFunction.AddInput addInput) {
super(blockFactory, emitBatchSize, addInput);
this.b1 = b1;
this.b2 = b2;
}
void add() {
int positionCount = b1.getPositionCount();
for (int i = 0; i < positionCount; i++) {
int v1 = b1.getValueCount(i);
int v2 = b2.getValueCount(i);
int first1 = b1.getFirstValueIndex(i);
int first2 = b2.getFirstValueIndex(i);
if (v1 == 1 && v2 == 1) {
long ord = ord(b1.getInt(first1), b2.getInt(first2));
appendOrdSv(i, Math.toIntExact(ord));
continue;
}
for (int i1 = 0; i1 < v1; i1++) {
int k1 = b1.getInt(first1 + i1);
for (int i2 = 0; i2 < v2; i2++) {
int k2 = b2.getInt(first2 + i2);
long ord = ord(k1, k2);
appendOrdInMv(i, Math.toIntExact(ord));
}
}
finishMv();
}
flushRemaining();
}
}
private long ord(int k1, int k2) {
return hashOrdToGroup(finalHash.add((long) k2 << 32 | k1));
}
@Override
public ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
throw new UnsupportedOperationException("TODO");
}
@Override
public Block[] getKeys() {
// TODO Build Ordinals blocks #114010
final int positions = (int) finalHash.size();
final BytesRef scratch = new BytesRef();
final BytesRefBlock[] outputBlocks = new BytesRefBlock[2];
try {
try (BytesRefBlock.Builder b1 = blockFactory.newBytesRefBlockBuilder(positions)) {
for (int i = 0; i < positions; i++) {
int k1 = (int) (finalHash.get(i) & 0xffffL);
if (k1 == 0) {
b1.appendNull();
} else {
b1.appendBytesRef(hash1.hash.get(k1 - 1, scratch));
}
}
outputBlocks[0] = b1.build();
}
try (BytesRefBlock.Builder b2 = blockFactory.newBytesRefBlockBuilder(positions)) {
for (int i = 0; i < positions; i++) {
int k2 = (int) (finalHash.get(i) >>> 32);
if (k2 == 0) {
b2.appendNull();
} else {
b2.appendBytesRef(hash2.hash.get(k2 - 1, scratch));
}
}
outputBlocks[1] = b2.build();
}
return outputBlocks;
} finally {
if (outputBlocks[outputBlocks.length - 1] == null) {
Releasables.close(outputBlocks);
}
}
}
@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
return new Range(0, Math.toIntExact(finalHash.size())).seenGroupIds(bigArrays);
}
@Override
public IntVector nonEmpty() {
return IntVector.range(0, Math.toIntExact(finalHash.size()), blockFactory);
}
@Override
public String toString() {
return String.format(
Locale.ROOT,
"BytesRef2BlockHash{keys=[channel1=%d, channel2=%d], entries=%d}",
channel1,
channel2,
finalHash.size()
);
}
}

View File

@ -85,7 +85,6 @@ final class BytesRef3BlockHash extends BlockHash {
private void addVectors(BytesRefVector v1, BytesRefVector v2, BytesRefVector v3, GroupingAggregatorFunction.AddInput addInput) {
final int positionCount = v1.getPositionCount();
try (IntVector.FixedBuilder ordsBuilder = blockFactory.newIntVectorFixedBuilder(positionCount)) {
// TODO: enable ordinal vectors in BytesRefBlockHash
try (IntVector k1 = hash1.add(v1); IntVector k2 = hash2.add(v2); IntVector k3 = hash3.add(v3)) {
for (int p = 0; p < positionCount; p++) {
long ord = hashOrdToGroup(finalHash.add(k1.getInt(p), k2.getInt(p), k3.getInt(p)));
@ -148,6 +147,7 @@ final class BytesRef3BlockHash extends BlockHash {
@Override
public Block[] getKeys() {
// TODO Build Ordinals blocks #114010
final int positions = (int) finalHash.size();
final BytesRef scratch = new BytesRef();
final BytesRefBlock[] outputBlocks = new BytesRefBlock[3];

View File

@ -28,6 +28,7 @@ import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
$elseif(double)$
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
@ -51,6 +52,9 @@ $endif$
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe;
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe$Type$;
$if(BytesRef)$
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeInt;
$endif$
import org.elasticsearch.core.ReleasableIterator;
$if(BytesRef)$
@ -62,6 +66,7 @@ import java.util.BitSet;
$endif$
/**
* Maps a {@link $Type$Block} column to group ids.
* This class is generated. Do not edit it.
*/
final class $Type$BlockHash extends BlockHash {
private final int channel;
@ -84,6 +89,7 @@ final class $Type$BlockHash extends BlockHash {
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
// TODO track raw counts and which implementation we pick for the profiler - #114008
var block = page.getBlock(channel);
if (block.areAllValuesNull()) {
seenNull = true;
@ -107,6 +113,10 @@ final class $Type$BlockHash extends BlockHash {
IntVector add($Type$Vector vector) {
$if(BytesRef)$
var ordinals = vector.asOrdinals();
if (ordinals != null) {
return addOrdinalsVector(ordinals);
}
BytesRef scratch = new BytesRef();
$endif$
int positions = vector.getPositionCount();
@ -154,15 +164,29 @@ $endif$
}
$if(BytesRef)$
private IntBlock addOrdinalsBlock(OrdinalBytesRefBlock inputBlock) {
var inputOrds = inputBlock.getOrdinalsBlock();
private IntVector addOrdinalsVector(OrdinalBytesRefVector inputBlock) {
IntVector inputOrds = inputBlock.getOrdinalsVector();
try (
var builder = blockFactory.newIntBlockBuilder(inputOrds.getPositionCount());
var builder = blockFactory.newIntVectorBuilder(inputOrds.getPositionCount());
var hashOrds = add(inputBlock.getDictionaryVector())
) {
for (int i = 0; i < inputOrds.getPositionCount(); i++) {
int valueCount = inputOrds.getValueCount(i);
int firstIndex = inputOrds.getFirstValueIndex(i);
for (int p = 0; p < inputOrds.getPositionCount(); p++) {
int ord = hashOrds.getInt(inputOrds.getInt(p));
builder.appendInt(ord);
}
return builder.build();
}
}
private IntBlock addOrdinalsBlock(OrdinalBytesRefBlock inputBlock) {
try (
IntBlock inputOrds = new MultivalueDedupeInt(inputBlock.getOrdinalsBlock()).dedupeToBlockAdaptive(blockFactory);
IntBlock.Builder builder = blockFactory.newIntBlockBuilder(inputOrds.getPositionCount());
IntVector hashOrds = add(inputBlock.getDictionaryVector())
) {
for (int p = 0; p < inputOrds.getPositionCount(); p++) {
int valueCount = inputOrds.getValueCount(p);
int firstIndex = inputOrds.getFirstValueIndex(p);
switch (valueCount) {
case 0 -> {
builder.appendInt(0);
@ -173,9 +197,11 @@ $if(BytesRef)$
builder.appendInt(ord);
}
default -> {
int start = firstIndex;
int end = firstIndex + valueCount;
builder.beginPositionEntry();
for (int v = 0; v < valueCount; v++) {
int ord = hashOrds.getInt(inputOrds.getInt(firstIndex + i));
for (int i = start; i < end; i++) {
int ord = hashOrds.getInt(inputOrds.getInt(i));
builder.appendInt(ord);
}
builder.endPositionEntry();

View File

@ -21,10 +21,13 @@ import org.elasticsearch.compute.data.BasicBlockTests;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockTestUtils;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.MockBlockFactory;
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.TestBlockFactory;
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeTests;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
@ -38,11 +41,13 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Stream;
import static org.elasticsearch.test.ListMatcher.matchesList;
import static org.elasticsearch.test.MapMatcher.assertMap;
@ -58,26 +63,40 @@ import static org.mockito.Mockito.when;
public class BlockHashRandomizedTests extends ESTestCase {
@ParametersFactory
public static List<Object[]> params() {
List<Object[]> params = new ArrayList<>();
List<List<? extends Type>> allowedTypesChoices = List.of(
/*
* Run with only `LONG` elements because we have some
* optimizations that hit if you only have those.
*/
List.of(new Basic(ElementType.LONG)),
/*
* Run with only `BYTES_REF` elements because we have some
* optimizations that hit if you only have those.
*/
List.of(new Basic(ElementType.BYTES_REF)),
/*
* Run with only `BYTES_REF` elements in an OrdinalBytesRefBlock
* because we have a few optimizations that use it.
*/
List.of(new Ordinals(10)),
/*
* Run with only `LONG` and `BYTES_REF` elements because
* we have some optimizations that hit if you only have
* those.
*/
List.of(new Basic(ElementType.LONG), new Basic(ElementType.BYTES_REF)),
/*
* Any random source.
*/
Stream.concat(Stream.of(new Ordinals(10)), MultivalueDedupeTests.supportedTypes().stream().map(Basic::new)).toList()
);
List<Object[]> params = new ArrayList<>();
for (boolean forcePackedHash : new boolean[] { false, true }) {
for (int groups : new int[] { 1, 2, 3, 4, 5, 10 }) {
for (int maxValuesPerPosition : new int[] { 1, 3 }) {
for (int dups : new int[] { 0, 2 }) {
for (List<ElementType> allowedTypes : List.of(
/*
* Run with only `LONG` elements because we have some
* optimizations that hit if you only have those.
*/
List.of(ElementType.LONG),
/*
* Run with only `LONG` and `BYTES_REF` elements because
* we have some optimizations that hit if you only have
* those.
*/
List.of(ElementType.LONG, ElementType.BYTES_REF),
MultivalueDedupeTests.supportedTypes()
)) {
for (List<? extends Type> allowedTypes : allowedTypesChoices) {
params.add(new Object[] { forcePackedHash, groups, maxValuesPerPosition, dups, allowedTypes });
}
}
@ -87,18 +106,33 @@ public class BlockHashRandomizedTests extends ESTestCase {
return params;
}
/**
* The type of {@link Block} being tested.
*/
interface Type {
/**
* The type of the {@link ElementType elements} in the {@link Block}.
*/
ElementType elementType();
/**
* Build a random {@link Block}.
*/
BasicBlockTests.RandomBlock randomBlock(int positionCount, int maxValuesPerPosition, int dups);
}
private final boolean forcePackedHash;
private final int groups;
private final int maxValuesPerPosition;
private final int dups;
private final List<ElementType> allowedTypes;
private final List<? extends Type> allowedTypes;
public BlockHashRandomizedTests(
@Name("forcePackedHash") boolean forcePackedHash,
@Name("groups") int groups,
@Name("maxValuesPerPosition") int maxValuesPerPosition,
@Name("dups") int dups,
@Name("allowedTypes") List<ElementType> allowedTypes
@Name("allowedTypes") List<Type> allowedTypes
) {
this.forcePackedHash = forcePackedHash;
this.groups = groups;
@ -127,21 +161,22 @@ public class BlockHashRandomizedTests extends ESTestCase {
}
private void test(MockBlockFactory blockFactory) {
List<ElementType> types = randomList(groups, groups, () -> randomFrom(allowedTypes));
List<Type> types = randomList(groups, groups, () -> randomFrom(allowedTypes));
List<ElementType> elementTypes = types.stream().map(Type::elementType).toList();
BasicBlockTests.RandomBlock[] randomBlocks = new BasicBlockTests.RandomBlock[types.size()];
Block[] blocks = new Block[types.size()];
int pageCount = between(1, 10);
int pageCount = between(1, groups < 10 ? 10 : 5);
int positionCount = 100;
int emitBatchSize = 100;
try (BlockHash blockHash = newBlockHash(blockFactory, emitBatchSize, types)) {
try (BlockHash blockHash = newBlockHash(blockFactory, emitBatchSize, elementTypes)) {
/*
* Only the long/long, long/bytes_ref, and bytes_ref/long implementations don't collect nulls.
*/
Oracle oracle = new Oracle(
forcePackedHash
|| false == (types.equals(List.of(ElementType.LONG, ElementType.LONG))
|| types.equals(List.of(ElementType.LONG, ElementType.BYTES_REF))
|| types.equals(List.of(ElementType.BYTES_REF, ElementType.LONG)))
|| false == (elementTypes.equals(List.of(ElementType.LONG, ElementType.LONG))
|| elementTypes.equals(List.of(ElementType.LONG, ElementType.BYTES_REF))
|| elementTypes.equals(List.of(ElementType.BYTES_REF, ElementType.LONG)))
);
/*
* Expected ordinals for checking lookup. Skipped if we have more than 5 groups because
@ -151,15 +186,7 @@ public class BlockHashRandomizedTests extends ESTestCase {
for (int p = 0; p < pageCount; p++) {
for (int g = 0; g < blocks.length; g++) {
randomBlocks[g] = BasicBlockTests.randomBlock(
types.get(g),
positionCount,
types.get(g) == ElementType.NULL ? true : randomBoolean(),
1,
maxValuesPerPosition,
0,
dups
);
randomBlocks[g] = types.get(g).randomBlock(positionCount, maxValuesPerPosition, dups);
blocks[g] = randomBlocks[g].block();
}
oracle.add(randomBlocks);
@ -209,6 +236,7 @@ public class BlockHashRandomizedTests extends ESTestCase {
if (blockHash instanceof LongLongBlockHash == false
&& blockHash instanceof BytesRefLongBlockHash == false
&& blockHash instanceof BytesRef2BlockHash == false
&& blockHash instanceof BytesRef3BlockHash == false) {
assertLookup(blockFactory, expectedOrds, types, blockHash, oracle);
}
@ -235,14 +263,14 @@ public class BlockHashRandomizedTests extends ESTestCase {
private void assertLookup(
BlockFactory blockFactory,
Map<List<Object>, Set<Integer>> expectedOrds,
List<ElementType> types,
List<Type> types,
BlockHash blockHash,
Oracle oracle
) {
Block.Builder[] builders = new Block.Builder[types.size()];
try {
for (int b = 0; b < builders.length; b++) {
builders[b] = types.get(b).newBlockBuilder(LOOKUP_POSITIONS, blockFactory);
builders[b] = types.get(b).elementType().newBlockBuilder(LOOKUP_POSITIONS, blockFactory);
}
for (int p = 0; p < LOOKUP_POSITIONS; p++) {
/*
@ -408,8 +436,8 @@ public class BlockHashRandomizedTests extends ESTestCase {
return breakerService;
}
private static List<Object> randomKey(List<ElementType> types) {
return types.stream().map(BlockHashRandomizedTests::randomKeyElement).toList();
private static List<Object> randomKey(List<Type> types) {
return types.stream().map(t -> randomKeyElement(t.elementType())).toList();
}
public static Object randomKeyElement(ElementType type) {
@ -423,4 +451,75 @@ public class BlockHashRandomizedTests extends ESTestCase {
default -> throw new IllegalArgumentException("unsupported element type [" + type + "]");
};
}
private record Basic(ElementType elementType) implements Type {
@Override
public BasicBlockTests.RandomBlock randomBlock(int positionCount, int maxValuesPerPosition, int dups) {
return BasicBlockTests.randomBlock(
elementType,
positionCount,
elementType == ElementType.NULL | randomBoolean(),
1,
maxValuesPerPosition,
0,
dups
);
}
}
private record Ordinals(int dictionarySize) implements Type {
@Override
public ElementType elementType() {
return ElementType.BYTES_REF;
}
@Override
public BasicBlockTests.RandomBlock randomBlock(int positionCount, int maxValuesPerPosition, int dups) {
List<Map.Entry<String, Integer>> dictionary = new ArrayList<>();
List<List<Object>> values = new ArrayList<>(positionCount);
try (
IntBlock.Builder ordinals = TestBlockFactory.getNonBreakingInstance()
.newIntBlockBuilder(positionCount * maxValuesPerPosition);
BytesRefVector.Builder bytes = TestBlockFactory.getNonBreakingInstance().newBytesRefVectorBuilder(maxValuesPerPosition);
) {
for (String value : dictionary(maxValuesPerPosition)) {
bytes.appendBytesRef(new BytesRef(value));
dictionary.add(Map.entry(value, dictionary.size()));
}
for (int p = 0; p < positionCount; p++) {
int valueCount = between(1, maxValuesPerPosition);
int dupCount = between(0, dups);
List<Integer> ordsAtPosition = new ArrayList<>();
List<Object> valuesAtPosition = new ArrayList<>();
values.add(valuesAtPosition);
if (valueCount != 1 || dupCount != 0) {
ordinals.beginPositionEntry();
}
for (int v = 0; v < valueCount; v++) {
Map.Entry<String, Integer> value = randomFrom(dictionary);
valuesAtPosition.add(new BytesRef(value.getKey()));
ordinals.appendInt(value.getValue());
ordsAtPosition.add(value.getValue());
}
for (int v = 0; v < dupCount; v++) {
ordinals.appendInt(randomFrom(ordsAtPosition));
}
if (valueCount != 1 || dupCount != 0) {
ordinals.endPositionEntry();
}
}
return new BasicBlockTests.RandomBlock(values, new OrdinalBytesRefBlock(ordinals.build(), bytes.build()));
}
}
private Set<String> dictionary(int maxValuesPerPosition) {
int count = Math.max(dictionarySize, maxValuesPerPosition);
Set<String> values = new HashSet<>();
while (values.size() < count) {
values.add(randomAlphaOfLength(5));
}
return values;
}
}
}

View File

@ -20,12 +20,15 @@ import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.MockBlockFactory;
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.TestBlockFactory;
import org.elasticsearch.core.Releasable;
@ -460,6 +463,133 @@ public class BlockHashTests extends ESTestCase {
}
}
public void testBasicOrdinals() {
try (
IntVector.Builder ords = blockFactory.newIntVectorFixedBuilder(8);
BytesRefVector.Builder bytes = blockFactory.newBytesRefVectorBuilder(8)
) {
ords.appendInt(1);
ords.appendInt(0);
ords.appendInt(3);
ords.appendInt(1);
ords.appendInt(3);
ords.appendInt(0);
ords.appendInt(2);
ords.appendInt(3);
bytes.appendBytesRef(new BytesRef("item-1"));
bytes.appendBytesRef(new BytesRef("item-2"));
bytes.appendBytesRef(new BytesRef("item-3"));
bytes.appendBytesRef(new BytesRef("item-4"));
hash(ordsAndKeys -> {
if (forcePackedHash) {
assertThat(ordsAndKeys.description, startsWith("PackedValuesBlockHash{groups=[0:BYTES_REF], entries=4, size="));
assertThat(ordsAndKeys.description, endsWith("b}"));
assertOrds(ordsAndKeys.ords, 0, 1, 2, 0, 2, 1, 3, 2);
assertThat(ordsAndKeys.nonEmpty, equalTo(intRange(0, 4)));
assertKeys(ordsAndKeys.keys, "item-2", "item-1", "item-4", "item-3");
} else {
assertThat(ordsAndKeys.description, startsWith("BytesRefBlockHash{channel=0, entries=4, size="));
assertThat(ordsAndKeys.description, endsWith("b, seenNull=false}"));
assertOrds(ordsAndKeys.ords, 2, 1, 4, 2, 4, 1, 3, 4);
assertThat(ordsAndKeys.nonEmpty, equalTo(intRange(1, 5)));
assertKeys(ordsAndKeys.keys, "item-1", "item-2", "item-3", "item-4");
}
}, new OrdinalBytesRefVector(ords.build(), bytes.build()).asBlock());
}
}
public void testOrdinalsWithNulls() {
try (
IntBlock.Builder ords = blockFactory.newIntBlockBuilder(4);
BytesRefVector.Builder bytes = blockFactory.newBytesRefVectorBuilder(2)
) {
ords.appendInt(0);
ords.appendNull();
ords.appendInt(1);
ords.appendNull();
bytes.appendBytesRef(new BytesRef("cat"));
bytes.appendBytesRef(new BytesRef("dog"));
hash(ordsAndKeys -> {
if (forcePackedHash) {
assertThat(ordsAndKeys.description, startsWith("PackedValuesBlockHash{groups=[0:BYTES_REF], entries=3, size="));
assertThat(ordsAndKeys.description, endsWith("b}"));
assertOrds(ordsAndKeys.ords, 0, 1, 2, 1);
assertKeys(ordsAndKeys.keys, "cat", null, "dog");
} else {
assertThat(ordsAndKeys.description, startsWith("BytesRefBlockHash{channel=0, entries=2, size="));
assertThat(ordsAndKeys.description, endsWith("b, seenNull=true}"));
assertOrds(ordsAndKeys.ords, 1, 0, 2, 0);
assertKeys(ordsAndKeys.keys, null, "cat", "dog");
}
assertThat(ordsAndKeys.nonEmpty, equalTo(intRange(0, 3)));
}, new OrdinalBytesRefBlock(ords.build(), bytes.build()));
}
}
public void testOrdinalsWithMultiValuedFields() {
try (
IntBlock.Builder ords = blockFactory.newIntBlockBuilder(4);
BytesRefVector.Builder bytes = blockFactory.newBytesRefVectorBuilder(2)
) {
ords.appendInt(0);
ords.beginPositionEntry();
ords.appendInt(0);
ords.appendInt(1);
ords.endPositionEntry();
ords.beginPositionEntry();
ords.appendInt(1);
ords.appendInt(2);
ords.endPositionEntry();
ords.beginPositionEntry();
ords.appendInt(2);
ords.appendInt(1);
ords.endPositionEntry();
ords.appendNull();
ords.beginPositionEntry();
ords.appendInt(2);
ords.appendInt(2);
ords.appendInt(1);
ords.endPositionEntry();
bytes.appendBytesRef(new BytesRef("foo"));
bytes.appendBytesRef(new BytesRef("bar"));
bytes.appendBytesRef(new BytesRef("bort"));
hash(ordsAndKeys -> {
if (forcePackedHash) {
assertThat(ordsAndKeys.description, startsWith("PackedValuesBlockHash{groups=[0:BYTES_REF], entries=4, size="));
assertThat(ordsAndKeys.description, endsWith("b}"));
assertOrds(
ordsAndKeys.ords,
new int[] { 0 },
new int[] { 0, 1 },
new int[] { 1, 2 },
new int[] { 2, 1 },
new int[] { 3 },
new int[] { 2, 1 }
);
assertKeys(ordsAndKeys.keys, "foo", "bar", "bort", null);
} else {
assertThat(ordsAndKeys.description, startsWith("BytesRefBlockHash{channel=0, entries=3, size="));
assertThat(ordsAndKeys.description, endsWith("b, seenNull=true}"));
assertOrds(
ordsAndKeys.ords,
new int[] { 1 },
new int[] { 1, 2 },
new int[] { 2, 3 },
new int[] { 3, 2 },
new int[] { 0 },
new int[] { 3, 2 }
);
assertKeys(ordsAndKeys.keys, null, "foo", "bar", "bort");
}
assertThat(ordsAndKeys.nonEmpty, equalTo(intRange(0, 4)));
}, new OrdinalBytesRefBlock(ords.build(), bytes.build()));
}
}
public void testBooleanHashFalseFirst() {
boolean[] values = new boolean[] { false, true, true, true, true };
hash(ordsAndKeys -> {
@ -1315,6 +1445,7 @@ public class BlockHashTests extends ESTestCase {
});
if (blockHash instanceof LongLongBlockHash == false
&& blockHash instanceof BytesRefLongBlockHash == false
&& blockHash instanceof BytesRef2BlockHash == false
&& blockHash instanceof BytesRef3BlockHash == false) {
Block[] keys = blockHash.getKeys();
try (ReleasableIterator<IntBlock> lookup = blockHash.lookup(new Page(keys), ByteSizeValue.ofKb(between(1, 100)))) {