Fix incorrect accounting of semantic text indexing memory pressure (#130221)

This commit is contained in:
Mike Pellegrini 2025-06-27 14:29:54 -04:00 committed by GitHub
parent ddef899837
commit 52495aa5fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 20 deletions

View File

@ -0,0 +1,5 @@
pr: 130221
summary: Fix incorrect accounting of semantic text indexing memory pressure
area: Distributed
type: bug
issues: []

View File

@ -159,7 +159,10 @@ public interface BytesReference extends Comparable<BytesReference>, ToXContentFr
BytesReference slice(int from, int length); BytesReference slice(int from, int length);
/** /**
* The amount of memory used by this BytesReference * The amount of memory used by this BytesReference.
* <p>
* Note that this is not always the same as length and can vary by implementation.
* </p>
*/ */
long ramBytesUsed(); long ramBytesUsed();

View File

@ -631,7 +631,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
if (indexRequest.isIndexingPressureIncremented() == false) { if (indexRequest.isIndexingPressureIncremented() == false) {
try { try {
// Track operation count as one operation per document source update // Track operation count as one operation per document source update
coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed()); coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().length());
indexRequest.setIndexingPressureIncremented(); indexRequest.setIndexingPressureIncremented();
} catch (EsRejectedExecutionException e) { } catch (EsRejectedExecutionException e) {
addInferenceResponseFailure( addInferenceResponseFailure(
@ -737,13 +737,13 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
indexRequest.source(builder); indexRequest.source(builder);
} }
} }
long modifiedSourceSize = indexRequest.source().ramBytesUsed(); long modifiedSourceSize = indexRequest.source().length();
// Add the indexing pressure from the source modifications. // Add the indexing pressure from the source modifications.
// Don't increment operation count because we count one source update as one operation, and we already accounted for those // Don't increment operation count because we count one source update as one operation, and we already accounted for those
// in addFieldInferenceRequests. // in addFieldInferenceRequests.
try { try {
coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed()); coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.length());
} catch (EsRejectedExecutionException e) { } catch (EsRejectedExecutionException e) {
indexRequest.source(originalSource, indexRequest.getContentType()); indexRequest.source(originalSource, indexRequest.getContentType());
item.abort( item.abort(

View File

@ -616,14 +616,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
assertThat(coordinatingIndexingPressure, notNullValue()); assertThat(coordinatingIndexingPressure, notNullValue());
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0Source)); verify(coordinatingIndexingPressure).increment(1, length(doc0Source));
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); verify(coordinatingIndexingPressure).increment(1, length(doc1Source));
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source)); verify(coordinatingIndexingPressure).increment(1, length(doc2Source));
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc3Source)); verify(coordinatingIndexingPressure).increment(1, length(doc3Source));
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc4Source)); verify(coordinatingIndexingPressure).increment(1, length(doc4Source));
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0UpdateSource)); verify(coordinatingIndexingPressure).increment(1, length(doc0UpdateSource));
if (useLegacyFormat == false) { if (useLegacyFormat == false) {
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1UpdateSource)); verify(coordinatingIndexingPressure).increment(1, length(doc1UpdateSource));
} }
verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0)); verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0));
@ -720,7 +720,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
assertThat(coordinatingIndexingPressure, notNullValue()); assertThat(coordinatingIndexingPressure, notNullValue());
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); verify(coordinatingIndexingPressure).increment(1, length(doc1Source));
verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong()); verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong());
// Verify that the coordinating indexing pressure is maintained through downstream action filters // Verify that the coordinating indexing pressure is maintained through downstream action filters
@ -759,7 +759,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception { public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception {
final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build() Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (length(doc1Source) + 1) + "b").build()
); );
final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
@ -802,7 +802,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
assertThat(coordinatingIndexingPressure, notNullValue()); assertThat(coordinatingIndexingPressure, notNullValue());
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); verify(coordinatingIndexingPressure).increment(1, length(doc1Source));
verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0)); verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0));
verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong()); verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong());
@ -862,14 +862,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
); );
XContentBuilder builder = XContentFactory.jsonBuilder(); XContentBuilder builder = XContentFactory.jsonBuilder();
semanticTextField.toXContent(builder, EMPTY_PARAMS); semanticTextField.toXContent(builder, EMPTY_PARAMS);
return bytesUsed(builder); return length(builder);
}; };
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
Settings.builder() Settings.builder()
.put( .put(
MAX_COORDINATING_BYTES.getKey(), MAX_COORDINATING_BYTES.getKey(),
(bytesUsed(doc1Source) + bytesUsed(doc2Source) + estimateInferenceResultsBytes.apply(List.of("bar"), barEmbedding) (length(doc1Source) + length(doc2Source) + estimateInferenceResultsBytes.apply(List.of("bar"), barEmbedding)
+ (estimateInferenceResultsBytes.apply(List.of("bazzz"), bazzzEmbedding) / 2)) + "b" + (estimateInferenceResultsBytes.apply(List.of("bazzz"), bazzzEmbedding) / 2)) + "b"
) )
.build() .build()
@ -913,8 +913,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
assertThat(coordinatingIndexingPressure, notNullValue()); assertThat(coordinatingIndexingPressure, notNullValue());
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); verify(coordinatingIndexingPressure).increment(1, length(doc1Source));
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source)); verify(coordinatingIndexingPressure).increment(1, length(doc2Source));
verify(coordinatingIndexingPressure, times(2)).increment(eq(0), longThat(l -> l > 0)); verify(coordinatingIndexingPressure, times(2)).increment(eq(0), longThat(l -> l > 0));
verify(coordinatingIndexingPressure, times(4)).increment(anyInt(), anyLong()); verify(coordinatingIndexingPressure, times(4)).increment(anyInt(), anyLong());
@ -1124,8 +1124,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) }; new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) };
} }
private static long bytesUsed(XContentBuilder builder) { private static long length(XContentBuilder builder) {
return BytesReference.bytes(builder).ramBytesUsed(); return BytesReference.bytes(builder).length();
} }
@SuppressWarnings({ "unchecked" }) @SuppressWarnings({ "unchecked" })