From 315aba696af254d535e530d707b6674b8f7b31e8 Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:06:37 +0200 Subject: [PATCH] ESQL - Add K mandatory param for KNN function (#129763) --- .../esql/images/functions/knn.svg | 2 +- .../esql/kibana/definition/functions/knn.json | 3 +- .../esql/kibana/docs/functions/knn.md | 4 +- muted-tests.yml | 41 +++++++-- .../src/main/resources/knn-function.csv-spec | 78 ++++++---------- .../xpack/esql/plugin/KnnFunctionIT.java | 8 +- .../xpack/esql/action/EsqlCapabilities.java | 2 +- .../esql/expression/ExpressionWritables.java | 2 +- .../function/EsqlFunctionRegistry.java | 2 +- .../esql/expression/function/vector/Knn.java | 90 ++++++++++++------- .../elasticsearch/xpack/esql/CsvTests.java | 2 +- .../xpack/esql/analysis/AnalyzerTests.java | 2 +- .../xpack/esql/analysis/VerifierTests.java | 58 ++++++------ .../function/fulltext/KnnTests.java | 32 +++++-- .../LocalPhysicalPlanOptimizerTests.java | 6 +- 15 files changed, 194 insertions(+), 138 deletions(-) diff --git a/docs/reference/query-languages/esql/images/functions/knn.svg b/docs/reference/query-languages/esql/images/functions/knn.svg index 75a104a7cdcf..6e20dbc21720 100644 --- a/docs/reference/query-languages/esql/images/functions/knn.svg +++ b/docs/reference/query-languages/esql/images/functions/knn.svg @@ -1 +1 @@ -KNN(field,query,options) \ No newline at end of file +KNN(field,query,k,options) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json index 21addd4b6442..d347891393dc 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json @@ -5,8 +5,7 @@ "description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.", "signatures" : [ ], "examples" : [ - "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc, color asc", - "from colors metadata _score\n| where knn(rgb_vector, [0,255,255], {\"k\": 4})\n| sort _score desc, color asc" + "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0], 10)\n| sort _score desc, color asc" ], "preview" : true, "snapshot_only" : true diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md index bea09b0bf50d..c7af797488ba 100644 --- a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md +++ b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md @@ -1,10 +1,10 @@ -% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. +% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. ### KNN Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors. ```esql from colors metadata _score -| where knn(rgb_vector, [0, 120, 0]) +| where knn(rgb_vector, [0, 120, 0], 10) | sort _score desc, color asc ``` diff --git a/muted-tests.yml b/muted-tests.yml index 5c4e29a83453..d9abae06ba91 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -236,6 +236,9 @@ tests: - class: org.elasticsearch.packaging.test.DockerTests method: test012SecurityCanBeDisabled issue: https://github.com/elastic/elasticsearch/issues/116636 +- class: org.elasticsearch.index.shard.StoreRecoveryTests + method: testAddIndices + issue: https://github.com/elastic/elasticsearch/issues/124104 - class: org.elasticsearch.smoketest.MlWithSecurityIT method: test {yaml=ml/data_frame_analytics_crud/Test get stats on newly created config} issue: https://github.com/elastic/elasticsearch/issues/121726 @@ -455,6 +458,12 @@ tests: - class: org.elasticsearch.packaging.test.DockerTests method: test073RunEsAsDifferentUserAndGroupWithoutBindMounting issue: https://github.com/elastic/elasticsearch/issues/128996 +- class: org.elasticsearch.upgrades.UpgradeClusterClientYamlTestSuiteIT + method: test {p0=upgraded_cluster/70_ilm/Test Lifecycle Still There And Indices Are Still Managed} + issue: https://github.com/elastic/elasticsearch/issues/129097 +- class: org.elasticsearch.upgrades.UpgradeClusterClientYamlTestSuiteIT + method: test {p0=upgraded_cluster/90_ml_data_frame_analytics_crud/Get mixed cluster outlier_detection job} + issue: https://github.com/elastic/elasticsearch/issues/129098 - class: org.elasticsearch.packaging.test.DockerTests method: test081SymlinksAreFollowedWithEnvironmentVariableFiles issue: https://github.com/elastic/elasticsearch/issues/128867 @@ -473,21 +482,27 @@ tests: - class: org.elasticsearch.entitlement.runtime.policy.FileAccessTreeTests method: testWindowsAbsolutPathAccess issue: https://github.com/elastic/elasticsearch/issues/129168 -- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT - method: test {knn-function.KnnSearchWithKOption ASYNC} - issue: https://github.com/elastic/elasticsearch/issues/129447 - class: org.elasticsearch.xpack.ml.integration.ClassificationIT method: testWithDatastreams issue: https://github.com/elastic/elasticsearch/issues/129457 - class: org.elasticsearch.index.engine.ThreadPoolMergeExecutorServiceDiskSpaceTests method: testMergeTasksAreUnblockedWhenMoreDiskSpaceBecomesAvailable issue: https://github.com/elastic/elasticsearch/issues/129296 +- class: org.elasticsearch.xpack.security.PermissionsIT + method: testCanManageIndexWithNoPermissions + issue: https://github.com/elastic/elasticsearch/issues/129471 +- class: org.elasticsearch.xpack.security.PermissionsIT + method: testCanManageIndexAndPolicyDifferentUsers + issue: https://github.com/elastic/elasticsearch/issues/129479 +- class: org.elasticsearch.xpack.security.PermissionsIT + method: testCanViewExplainOnUnmanagedIndex + issue: https://github.com/elastic/elasticsearch/issues/129480 - class: org.elasticsearch.xpack.profiling.action.GetStatusActionIT method: testWaitsUntilResourcesAreCreated issue: https://github.com/elastic/elasticsearch/issues/129486 -- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT - method: test {knn-function.KnnSearchWithKOption SYNC} - issue: https://github.com/elastic/elasticsearch/issues/129512 +- class: org.elasticsearch.xpack.security.PermissionsIT + method: testWhenUserLimitedByOnlyAliasOfIndexCanWriteToIndexWhichWasRolledoverByILMPolicy + issue: https://github.com/elastic/elasticsearch/issues/129481 - class: org.elasticsearch.index.engine.ThreadPoolMergeExecutorServiceTests method: testIORateIsAdjustedForAllRunningMergeTasks issue: https://github.com/elastic/elasticsearch/issues/129531 @@ -503,15 +518,24 @@ tests: - class: org.elasticsearch.search.query.VectorIT method: testFilteredQueryStrategy issue: https://github.com/elastic/elasticsearch/issues/129517 +- class: org.elasticsearch.snapshots.SnapshotShutdownIT + method: testSnapshotShutdownProgressTracker + issue: https://github.com/elastic/elasticsearch/issues/129752 - class: org.elasticsearch.xpack.security.SecurityRolesMultiProjectIT method: testUpdatingFileBasedRoleAffectsAllProjects issue: https://github.com/elastic/elasticsearch/issues/129775 - class: org.elasticsearch.qa.verify_version_constants.VerifyVersionConstantsIT method: testLuceneVersionConstant issue: https://github.com/elastic/elasticsearch/issues/125638 +- class: org.elasticsearch.index.store.FsDirectoryFactoryTests + method: testPreload + issue: https://github.com/elastic/elasticsearch/issues/129852 - class: org.elasticsearch.xpack.rank.rrf.RRFRankClientYamlTestSuiteIT method: test {yaml=rrf/950_pinned_interaction/rrf with pinned retriever as a sub-retriever} issue: https://github.com/elastic/elasticsearch/issues/129845 +- class: org.elasticsearch.xpack.test.rest.XPackRestIT + method: test {p0=esql/60_usage/Basic ESQL usage output (telemetry) non-snapshot version} + issue: https://github.com/elastic/elasticsearch/issues/129888 - class: org.elasticsearch.gradle.internal.InternalDistributionBwcSetupPluginFuncTest method: "builds distribution from branches via archives extractedAssemble [bwcDistVersion: 8.2.1, bwcProject: bugfix, expectedAssembleTaskName: extractedAssemble, #2]" @@ -525,9 +549,14 @@ tests: - class: org.elasticsearch.xpack.esql.qa.multi_node.GenerativeIT method: test issue: https://github.com/elastic/elasticsearch/issues/130067 +- class: geoip.GeoIpMultiProjectIT + issue: https://github.com/elastic/elasticsearch/issues/130073 - class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeIT method: test issue: https://github.com/elastic/elasticsearch/issues/130067 +- class: org.elasticsearch.xpack.esql.action.EnrichIT + method: testTopN + issue: https://github.com/elastic/elasticsearch/issues/130122 - class: org.elasticsearch.action.support.ThreadedActionListenerTests method: testRejectionHandling issue: https://github.com/elastic/elasticsearch/issues/130129 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index ac6c16f35de0..c6105b82f230 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -3,11 +3,11 @@ # top-n query at the shard level knnSearch -required_capability: knn_function +required_capability: knn_function_v2 // tag::knn-function[] from colors metadata _score -| where knn(rgb_vector, [0, 120, 0]) +| where knn(rgb_vector, [0, 120, 0], 10) | sort _score desc, color asc // end::knn-function[] | keep color, rgb_vector @@ -29,31 +29,12 @@ chartreuse | [127.0, 255.0, 0.0] // end::knn-function-result[] ; -knnSearchWithKOption -required_capability: knn_function - -// tag::knn-function-options[] -from colors metadata _score -| where knn(rgb_vector, [0,255,255], {"k": 4}) -| sort _score desc, color asc -// end::knn-function-options[] -| keep color, rgb_vector -| limit 4 -; - -color:text | rgb_vector:dense_vector -cyan | [0.0, 255.0, 255.0] -turquoise | [64.0, 224.0, 208.0] -aqua marine | [127.0, 255.0, 212.0] -teal | [0.0, 128.0, 128.0] -; - -# https://github.com/elastic/elasticsearch/issues/129550 +# https://github.com/elastic/elasticsearch/issues/129550 - Add as an example to knn function documentation knnSearchWithSimilarityOption-Ignore -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where knn(rgb_vector, [255,192,203], {"k": 140, "similarity": 40}) +| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40}) | sort _score desc, color asc | keep color, rgb_vector ; @@ -63,14 +44,13 @@ pink | [255.0, 192.0, 203.0] peach puff | [255.0, 218.0, 185.0] bisque | [255.0, 228.0, 196.0] wheat | [245.0, 222.0, 179.0] - ; knnHybridSearch -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where match(color, "blue") or knn(rgb_vector, [65,105,225], {"k": 140}) +| where match(color, "blue") or knn(rgb_vector, [65,105,225], 140) | where primary == true | sort _score desc, color asc | keep color, rgb_vector @@ -90,10 +70,10 @@ yellow | [255.0, 255.0, 0.0] ; knnWithMultipleFunctions -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where knn(rgb_vector, [128,128,0], {"k": 140}) and match(color, "olive") +| where knn(rgb_vector, [128,128,0], 140) and match(color, "olive") | sort _score desc, color asc | keep color, rgb_vector ; @@ -103,11 +83,11 @@ olive | [128.0, 128.0, 0.0] ; knnAfterKeep -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score | keep rgb_vector, color, _score -| where knn(rgb_vector, [128,255,0], {"k": 140}) +| where knn(rgb_vector, [128,255,0], 140) | sort _score desc, color asc | keep rgb_vector | limit 5 @@ -122,11 +102,11 @@ rgb_vector:dense_vector ; knnAfterDrop -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score | drop primary -| where knn(rgb_vector, [128,250,0], {"k": 140}) +| where knn(rgb_vector, [128,250,0], 140) | sort _score desc, color asc | keep color, rgb_vector | limit 5 @@ -141,11 +121,11 @@ lime | [0.0, 255.0, 0.0] ; knnAfterEval -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score | eval composed_name = locate(color, " ") > 0 -| where knn(rgb_vector, [128,128,0], {"k": 140}) +| where knn(rgb_vector, [128,128,0], 140) | sort _score desc, color asc | keep color, composed_name | limit 5 @@ -160,11 +140,11 @@ golden rod | true ; knnWithConjunction -required_capability: knn_function +required_capability: knn_function_v2 # TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score -| where knn(rgb_vector, [255,255,238], {"k": 140}) and hex_code like "#FFF*" +| where knn(rgb_vector, [255,255,238], 140) and hex_code like "#FFF*" | sort _score desc, color asc | keep color, hex_code, rgb_vector | limit 10 @@ -181,11 +161,11 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0] ; knnWithDisjunctionAndFiltersConjunction -required_capability: knn_function +required_capability: knn_function_v2 # TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score -| where (knn(rgb_vector, [0,255,255], {"k": 140}) or knn(rgb_vector, [128, 0, 255], {"k": 140})) and primary == true +| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 140)) and primary == true | keep color, rgb_vector, _score | sort _score desc, color asc | drop _score @@ -205,11 +185,11 @@ yellow | [255.0, 255.0, 0.0] ; knnWithNonPushableConjunction -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score | eval composed_name = locate(color, " ") > 0 -| where knn(rgb_vector, [128,128,0], {"k": 140}) and composed_name == false +| where knn(rgb_vector, [128,128,0], 140) and composed_name == false | sort _score desc, color asc | keep color, composed_name | limit 10 @@ -230,10 +210,10 @@ maroon | false # https://github.com/elastic/elasticsearch/issues/129550 testKnnWithNonPushableDisjunctions-Ignore -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 30}) or length(color) > 10 +| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 | sort _score desc, color asc | keep color ; @@ -247,10 +227,10 @@ papaya whip # https://github.com/elastic/elasticsearch/issues/129550 testKnnWithNonPushableDisjunctionsOnComplexExpressions-Ignore -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where (knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], {"k": 140, "similarity": 60}) and primary == false) +| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) | sort _score desc, color asc | keep color, primary ; @@ -262,11 +242,11 @@ indigo | false ; testKnnInStatsNonPushable -required_capability: knn_function +required_capability: knn_function_v2 from colors | where length(color) < 10 -| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) +| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) ; c: long @@ -274,12 +254,12 @@ c: long ; testKnnInStatsWithGrouping -required_capability: knn_function +required_capability: knn_function_v2 required_capability: full_text_functions_in_stats_where from colors | where length(color) < 10 -| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) by primary +| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) by primary ; c: long | primary: boolean diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index a26294390993..11f9bd6c5aeb 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -39,7 +39,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase { var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s) + | WHERE knn(vector, %s, 10) | KEEP id, floats, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); @@ -73,7 +73,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase { var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, {"k": 5}) + | WHERE knn(vector, %s, 5) | KEEP id, floats, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); @@ -94,7 +94,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase { // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, {"k": 5}) OR id > 10 + | WHERE knn(vector, %s, 5) OR id > 10 | KEEP id, floats, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); @@ -111,7 +111,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase { @Before public void setup() throws IOException { - assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); var indexName = "test"; var client = client().admin().indices(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 9a8b71e8e5ee..4dcab5f0c927 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1195,7 +1195,7 @@ public class EsqlCapabilities { /** * Support knn function */ - KNN_FUNCTION(Build.current().isSnapshot()), + KNN_FUNCTION_V2(Build.current().isSnapshot()), LIKE_WITH_LIST_OF_PATTERNS, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index 901f364a6004..a3f6d3a089d4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -259,7 +259,7 @@ public class ExpressionWritables { } private static List vector() { - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { return List.of(Knn.ENTRY); } return List.of(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index ede0537d5d3d..630c9c2008a1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -478,7 +478,7 @@ public class EsqlFunctionRegistry { def(LastOverTime.class, uni(LastOverTime::new), "last_over_time"), def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"), def(Term.class, bi(Term::new), "term"), - def(Knn.class, tri(Knn::new), "knn"), + def(Knn.class, Knn::new, "knn"), def(StGeohash.class, StGeohash::new, "st_geohash"), def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"), def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index ecce0b069693..abe63b9b57bf 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -48,6 +49,7 @@ import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMI import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; @@ -62,10 +64,11 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom); private final Expression field; + // k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes + private final transient Expression k; private final Expression options; public static final Map ALLOWED_OPTIONS = Map.ofEntries( - entry(K_FIELD.getPreferredName(), INTEGER), entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER), entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT), entry(BOOST_FIELD.getPreferredName(), FLOAT), @@ -77,9 +80,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun preview = true, description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. " + "knn function finds nearest vectors through approximate search on indexed dense_vectors.", - examples = { - @Example(file = "knn-function", tag = "knn-function"), - @Example(file = "knn-function", tag = "knn-function-options"), }, + examples = { @Example(file = "knn-function", tag = "knn-function") }, appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) } ) public Knn( @@ -90,6 +91,13 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun type = { "dense_vector" }, description = "Vector value to find top nearest neighbours for." ) Expression query, + @Param( + name = "k", + type = { "integer" }, + description = "The number of nearest neighbors to return from each shard. " + + "Elasticsearch collects k results from each shard, then merges them to find the global top results. " + + "This value must be less than or equal to num_candidates." + ) Expression k, @MapParam( name = "options", params = { @@ -100,14 +108,6 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun description = "Floating point number used to decrease or increase the relevance scores of the query." + "Defaults to 1.0." ), - @MapParam.MapParamEntry( - name = "k", - type = "integer", - valueHint = { "10" }, - description = "The number of nearest neighbors to return from each shard. " - + "Elasticsearch collects k results from each shard, then merges them to find the global top results. " - + "This value must be less than or equal to num_candidates. Defaults to 10." - ), @MapParam.MapParamEntry( name = "num_candidates", type = "integer", @@ -136,19 +136,37 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun optional = true ) Expression options ) { - this(source, field, query, options, null); + this(source, field, query, k, options, null); } - private Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) { - super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder); + private Knn(Source source, Expression field, Expression query, Expression k, Expression options, QueryBuilder queryBuilder) { + super(source, query, expressionList(field, query, k, options), queryBuilder); this.field = field; + this.k = k; this.options = options; } + private static List expressionList(Expression field, Expression query, Expression k, Expression options) { + List result = new ArrayList<>(); + result.add(field); + result.add(query); + if (k != null) { + result.add(k); + } + if (options != null) { + result.add(options); + } + return result; + } + public Expression field() { return field; } + public Expression k() { + return k; + } + public Expression options() { return options; } @@ -160,7 +178,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun @Override protected TypeResolution resolveParams() { - return resolveField().and(resolveQuery()).and(resolveOptions()); + return resolveField().and(resolveQuery()).and(resolveK()).and(resolveOptions()); } private TypeResolution resolveField() { @@ -173,14 +191,24 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun ); } + private TypeResolution resolveK() { + if (k == null) { + // Function has already been rewritten and included in QueryBuilder - otherwise parsing would have failed + return TypeResolution.TYPE_RESOLVED; + } + + return isType(k(), dt -> dt == INTEGER, sourceText(), THIRD, "integer").and(isFoldable(k(), sourceText(), THIRD)) + .and(isNotNull(k(), sourceText(), THIRD)); + } + private TypeResolution resolveOptions() { if (options() != null) { - TypeResolution resolution = isNotNull(options(), sourceText(), THIRD); + TypeResolution resolution = isNotNull(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH); if (resolution.unresolved()) { return resolution; } // MapExpression does not have a DataType associated with it - resolution = isMapExpression(options(), sourceText(), THIRD); + resolution = isMapExpression(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH); if (resolution.unresolved()) { return resolution; } @@ -200,7 +228,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun } Map matchOptions = new HashMap<>(); - populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS); + populateOptionsMap((MapExpression) options(), matchOptions, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS); return matchOptions; } @@ -216,22 +244,24 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun for (int i = 0; i < queryFolded.size(); i++) { queryAsFloats[i] = queryFolded.get(i).floatValue(); } + int kValue = ((Number) k().fold(FoldContext.small())).intValue(); - return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions()); + Map opts = queryOptions(); + opts.put(K_FIELD.getPreferredName(), kValue); + + return new KnnQuery(source(), fieldName, queryAsFloats, opts); } @Override public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { - return new Knn(source(), field(), query(), options(), queryBuilder); + return new Knn(source(), field(), query(), k(), options(), queryBuilder); } private Map queryOptions() throws InvalidArgumentException { - if (options() == null) { - return Map.of(); - } - Map options = new HashMap<>(); - populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS); + if (options() != null) { + populateOptionsMap((MapExpression) options(), options, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS); + } return options; } @@ -241,14 +271,15 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun source(), newChildren.get(0), newChildren.get(1), - newChildren.size() > 2 ? newChildren.get(2) : null, + newChildren.get(2), + newChildren.size() > 3 ? newChildren.get(3) : null, queryBuilder() ); } @Override protected NodeInfo info() { - return NodeInfo.create(this, Knn::new, field(), query(), options()); + return NodeInfo.create(this, Knn::new, field(), query(), k(), options()); } @Override @@ -261,8 +292,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun Expression field = in.readNamedWriteable(Expression.class); Expression query = in.readNamedWriteable(Expression.class); QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); - - return new Knn(source, field, query, null, queryBuilder); + return new Knn(source, field, query, null, null, queryBuilder); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 6935a1efac2a..9062bdef62d7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -298,7 +298,7 @@ public class CsvTests extends ESTestCase { ); assumeFalse( "can't use KNN function in csv tests", - testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION.capabilityName()) + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V2.capabilityName()) ); assumeFalse( "lookup join disabled for csv tests", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index b99050b8ef09..34dc741713c0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -2375,7 +2375,7 @@ public class AnalyzerTests extends ESTestCase { Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors")); var plan = analyze(""" - from test | where knn(vector, [0.342, 0.164, 0.234]) + from test | where knn(vector, [0.342, 0.164, 0.234], 10) """, "mapping-dense_vector.json"); var limit = as(plan, Limit.class); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 14e5c0615e2b..8936a02e599d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1235,8 +1235,8 @@ public class VerifierTests extends ESTestCase { checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function"); checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)"); } } @@ -1368,8 +1368,8 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function"); } } @@ -1407,8 +1407,8 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)"); } } @@ -1472,8 +1472,8 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function"); } } @@ -1543,8 +1543,8 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2])"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)"); } } @@ -2071,8 +2071,8 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})"); } } @@ -2159,9 +2159,10 @@ public class VerifierTests extends ESTestCase { checkFullTextFunctionNullArgs("term(null, \"query\")", "first"); checkFullTextFunctionNullArgs("term(title, null)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first"); - checkFullTextFunctionNullArgs("knn(vector, null)", "second"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first"); + checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second"); + checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third"); } } @@ -2172,24 +2173,25 @@ public class VerifierTests extends ESTestCase { ); } - public void testFullTextFunctionsConstantQuery() throws Exception { - checkFullTextFunctionsConstantQuery("match(title, category)", "second"); - checkFullTextFunctionsConstantQuery("qstr(title)", ""); - checkFullTextFunctionsConstantQuery("kql(title)", ""); - checkFullTextFunctionsConstantQuery("match_phrase(title, tags)", "second"); + public void testFullTextFunctionsConstantArg() throws Exception { + checkFullTextFunctionsConstantArg("match(title, category)", "second"); + checkFullTextFunctionsConstantArg("qstr(title)", ""); + checkFullTextFunctionsConstantArg("kql(title)", ""); + checkFullTextFunctionsConstantArg("match_phrase(title, tags)", "second"); if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { - checkFullTextFunctionsConstantQuery("multi_match(category, body)", "first"); - checkFullTextFunctionsConstantQuery("multi_match(concat(title, \"world\"), title)", "first"); + checkFullTextFunctionsConstantArg("multi_match(category, body)", "first"); + checkFullTextFunctionsConstantArg("multi_match(concat(title, \"world\"), title)", "first"); } if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { - checkFullTextFunctionsConstantQuery("term(title, tags)", "second"); + checkFullTextFunctionsConstantArg("term(title, tags)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionsConstantQuery("knn(vector, vector)", "second"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionsConstantArg("knn(vector, vector, 10)", "second"); + checkFullTextFunctionsConstantArg("knn(vector, [0, 1, 2], category)", "third"); } } - private void checkFullTextFunctionsConstantQuery(String functionInvocation, String argOrdinal) throws Exception { + private void checkFullTextFunctionsConstantArg(String functionInvocation, String argOrdinal) throws Exception { assertThat( error("from test | where " + functionInvocation, fullTextAnalyzer), containsString(argOrdinal + " argument of [" + functionInvocation + "] must be a constant") @@ -2214,8 +2216,8 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)"); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java index 76db793c4e77..4a5708b398b1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; @@ -27,6 +28,7 @@ import org.junit.Before; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.function.Supplier; import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize; @@ -49,19 +51,33 @@ public class KnnTests extends AbstractFunctionTestCase { @Before public void checkCapability() { - assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); } private static List testCaseSuppliers() { List suppliers = new ArrayList<>(); suppliers.add( - TestCaseSupplier.testCaseSupplier( - new TestCaseSupplier.TypedDataSupplier("dense_vector field", KnnTests::randomDenseVector, DENSE_VECTOR), - new TestCaseSupplier.TypedDataSupplier("query", KnnTests::randomDenseVector, DENSE_VECTOR, true), - (d1, d2) -> equalTo("string"), - BOOLEAN, - (o1, o2) -> true + new TestCaseSupplier( + List.of(DENSE_VECTOR, DENSE_VECTOR, DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData( + new FieldAttribute( + Source.EMPTY, + randomIdentifier(), + new EsField(randomIdentifier(), DENSE_VECTOR, Map.of(), false) + ), + DENSE_VECTOR, + "dense_vector field" + ), + new TestCaseSupplier.TypedData(randomDenseVector(), DENSE_VECTOR, "query"), + new TestCaseSupplier.TypedData(randomIntBetween(1, 1000), DataType.INTEGER, "k") + ), + equalTo("KnnEvaluator" + KnnTests.class.getSimpleName()), + BOOLEAN, + equalTo(true) + ) ) ); @@ -104,7 +120,7 @@ public class KnnTests extends AbstractFunctionTestCase { @Override protected Expression build(Source source, List args) { - Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null); + Knn knn = new Knn(source, args.get(0), args.get(1), args.get(2), args.size() > 3 ? args.get(3) : null); // We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and // thus test the serialization methods. But we can only do this if the parameters make sense . if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index b7d124375949..66b797afa426 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1363,12 +1363,12 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase { public void testKnnOptionsPushDown() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); String query = """ from test - | where KNN(dense_vector, [0.1, 0.2, 0.3], - { "k": 5, "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 }) + | where KNN(dense_vector, [0.1, 0.2, 0.3], 5, + { "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 }) """; var analyzer = makeAnalyzer("mapping-all-types.json"); var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);