ESQL - Add K mandatory param for KNN function (#129763)

This commit is contained in:
Carlos Delgado 2025-07-02 16:06:37 +02:00 committed by GitHub
parent a9625cec7a
commit 315aba696a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 194 additions and 138 deletions

View File

@ -1 +1 @@
<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="568" height="61" viewbox="0 0 568 61"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m56 0h10m32 0h10m80 0h10m32 0h10m80 0h10m32 0h30m104 0h20m-139 0q5 0 5 5v10q0 5 5 5h114q5 0 5-5v-10q0-5 5-5m5 0h10m32 0h5"/><rect class="s" x="5" y="5" width="56" height="36"/><text class="k" x="15" y="31">KNN</text><rect class="s" x="71" y="5" width="32" height="36" rx="7"/><text class="syn" x="81" y="31">(</text><rect class="s" x="113" y="5" width="80" height="36" rx="7"/><text class="k" x="123" y="31">field</text><rect class="s" x="203" y="5" width="32" height="36" rx="7"/><text class="syn" x="213" y="31">,</text><rect class="s" x="245" y="5" width="80" height="36" rx="7"/><text class="k" x="255" y="31">query</text><rect class="s" x="335" y="5" width="32" height="36" rx="7"/><text class="syn" x="345" y="31">,</text><rect class="s" x="397" y="5" width="104" height="36" rx="7"/><text class="k" x="407" y="31">options</text><rect class="s" x="531" y="5" width="32" height="36" rx="7"/><text class="syn" x="541" y="31">)</text></svg>
<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="652" height="61" viewbox="0 0 652 61"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m56 0h10m32 0h10m80 0h10m32 0h10m80 0h10m32 0h10m32 0h10m32 0h30m104 0h20m-139 0q5 0 5 5v10q0 5 5 5h114q5 0 5-5v-10q0-5 5-5m5 0h10m32 0h5"/><rect class="s" x="5" y="5" width="56" height="36"/><text class="k" x="15" y="31">KNN</text><rect class="s" x="71" y="5" width="32" height="36" rx="7"/><text class="syn" x="81" y="31">(</text><rect class="s" x="113" y="5" width="80" height="36" rx="7"/><text class="k" x="123" y="31">field</text><rect class="s" x="203" y="5" width="32" height="36" rx="7"/><text class="syn" x="213" y="31">,</text><rect class="s" x="245" y="5" width="80" height="36" rx="7"/><text class="k" x="255" y="31">query</text><rect class="s" x="335" y="5" width="32" height="36" rx="7"/><text class="syn" x="345" y="31">,</text><rect class="s" x="377" y="5" width="32" height="36" rx="7"/><text class="k" x="387" y="31">k</text><rect class="s" x="419" y="5" width="32" height="36" rx="7"/><text class="syn" x="429" y="31">,</text><rect class="s" x="481" y="5" width="104" height="36" rx="7"/><text class="k" x="491" y="31">options</text><rect class="s" x="615" y="5" width="32" height="36" rx="7"/><text class="syn" x="625" y="31">)</text></svg>

Before

Width:  |  Height:  |  Size: 1.5 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@ -5,8 +5,7 @@
"description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.",
"signatures" : [ ],
"examples" : [
"from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc, color asc",
"from colors metadata _score\n| where knn(rgb_vector, [0,255,255], {\"k\": 4})\n| sort _score desc, color asc"
"from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0], 10)\n| sort _score desc, color asc"
],
"preview" : true,
"snapshot_only" : true

View File

@ -1,10 +1,10 @@
% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
### KNN
Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.
```esql
from colors metadata _score
| where knn(rgb_vector, [0, 120, 0])
| where knn(rgb_vector, [0, 120, 0], 10)
| sort _score desc, color asc
```

View File

@ -236,6 +236,9 @@ tests:
- class: org.elasticsearch.packaging.test.DockerTests
method: test012SecurityCanBeDisabled
issue: https://github.com/elastic/elasticsearch/issues/116636
- class: org.elasticsearch.index.shard.StoreRecoveryTests
method: testAddIndices
issue: https://github.com/elastic/elasticsearch/issues/124104
- class: org.elasticsearch.smoketest.MlWithSecurityIT
method: test {yaml=ml/data_frame_analytics_crud/Test get stats on newly created config}
issue: https://github.com/elastic/elasticsearch/issues/121726
@ -455,6 +458,12 @@ tests:
- class: org.elasticsearch.packaging.test.DockerTests
method: test073RunEsAsDifferentUserAndGroupWithoutBindMounting
issue: https://github.com/elastic/elasticsearch/issues/128996
- class: org.elasticsearch.upgrades.UpgradeClusterClientYamlTestSuiteIT
method: test {p0=upgraded_cluster/70_ilm/Test Lifecycle Still There And Indices Are Still Managed}
issue: https://github.com/elastic/elasticsearch/issues/129097
- class: org.elasticsearch.upgrades.UpgradeClusterClientYamlTestSuiteIT
method: test {p0=upgraded_cluster/90_ml_data_frame_analytics_crud/Get mixed cluster outlier_detection job}
issue: https://github.com/elastic/elasticsearch/issues/129098
- class: org.elasticsearch.packaging.test.DockerTests
method: test081SymlinksAreFollowedWithEnvironmentVariableFiles
issue: https://github.com/elastic/elasticsearch/issues/128867
@ -473,21 +482,27 @@ tests:
- class: org.elasticsearch.entitlement.runtime.policy.FileAccessTreeTests
method: testWindowsAbsolutPathAccess
issue: https://github.com/elastic/elasticsearch/issues/129168
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
method: test {knn-function.KnnSearchWithKOption ASYNC}
issue: https://github.com/elastic/elasticsearch/issues/129447
- class: org.elasticsearch.xpack.ml.integration.ClassificationIT
method: testWithDatastreams
issue: https://github.com/elastic/elasticsearch/issues/129457
- class: org.elasticsearch.index.engine.ThreadPoolMergeExecutorServiceDiskSpaceTests
method: testMergeTasksAreUnblockedWhenMoreDiskSpaceBecomesAvailable
issue: https://github.com/elastic/elasticsearch/issues/129296
- class: org.elasticsearch.xpack.security.PermissionsIT
method: testCanManageIndexWithNoPermissions
issue: https://github.com/elastic/elasticsearch/issues/129471
- class: org.elasticsearch.xpack.security.PermissionsIT
method: testCanManageIndexAndPolicyDifferentUsers
issue: https://github.com/elastic/elasticsearch/issues/129479
- class: org.elasticsearch.xpack.security.PermissionsIT
method: testCanViewExplainOnUnmanagedIndex
issue: https://github.com/elastic/elasticsearch/issues/129480
- class: org.elasticsearch.xpack.profiling.action.GetStatusActionIT
method: testWaitsUntilResourcesAreCreated
issue: https://github.com/elastic/elasticsearch/issues/129486
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
method: test {knn-function.KnnSearchWithKOption SYNC}
issue: https://github.com/elastic/elasticsearch/issues/129512
- class: org.elasticsearch.xpack.security.PermissionsIT
method: testWhenUserLimitedByOnlyAliasOfIndexCanWriteToIndexWhichWasRolledoverByILMPolicy
issue: https://github.com/elastic/elasticsearch/issues/129481
- class: org.elasticsearch.index.engine.ThreadPoolMergeExecutorServiceTests
method: testIORateIsAdjustedForAllRunningMergeTasks
issue: https://github.com/elastic/elasticsearch/issues/129531
@ -503,15 +518,24 @@ tests:
- class: org.elasticsearch.search.query.VectorIT
method: testFilteredQueryStrategy
issue: https://github.com/elastic/elasticsearch/issues/129517
- class: org.elasticsearch.snapshots.SnapshotShutdownIT
method: testSnapshotShutdownProgressTracker
issue: https://github.com/elastic/elasticsearch/issues/129752
- class: org.elasticsearch.xpack.security.SecurityRolesMultiProjectIT
method: testUpdatingFileBasedRoleAffectsAllProjects
issue: https://github.com/elastic/elasticsearch/issues/129775
- class: org.elasticsearch.qa.verify_version_constants.VerifyVersionConstantsIT
method: testLuceneVersionConstant
issue: https://github.com/elastic/elasticsearch/issues/125638
- class: org.elasticsearch.index.store.FsDirectoryFactoryTests
method: testPreload
issue: https://github.com/elastic/elasticsearch/issues/129852
- class: org.elasticsearch.xpack.rank.rrf.RRFRankClientYamlTestSuiteIT
method: test {yaml=rrf/950_pinned_interaction/rrf with pinned retriever as a sub-retriever}
issue: https://github.com/elastic/elasticsearch/issues/129845
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
method: test {p0=esql/60_usage/Basic ESQL usage output (telemetry) non-snapshot version}
issue: https://github.com/elastic/elasticsearch/issues/129888
- class: org.elasticsearch.gradle.internal.InternalDistributionBwcSetupPluginFuncTest
method: "builds distribution from branches via archives extractedAssemble [bwcDistVersion: 8.2.1, bwcProject: bugfix, expectedAssembleTaskName:
extractedAssemble, #2]"
@ -525,9 +549,14 @@ tests:
- class: org.elasticsearch.xpack.esql.qa.multi_node.GenerativeIT
method: test
issue: https://github.com/elastic/elasticsearch/issues/130067
- class: geoip.GeoIpMultiProjectIT
issue: https://github.com/elastic/elasticsearch/issues/130073
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeIT
method: test
issue: https://github.com/elastic/elasticsearch/issues/130067
- class: org.elasticsearch.xpack.esql.action.EnrichIT
method: testTopN
issue: https://github.com/elastic/elasticsearch/issues/130122
- class: org.elasticsearch.action.support.ThreadedActionListenerTests
method: testRejectionHandling
issue: https://github.com/elastic/elasticsearch/issues/130129

View File

@ -3,11 +3,11 @@
# top-n query at the shard level
knnSearch
required_capability: knn_function
required_capability: knn_function_v2
// tag::knn-function[]
from colors metadata _score
| where knn(rgb_vector, [0, 120, 0])
| where knn(rgb_vector, [0, 120, 0], 10)
| sort _score desc, color asc
// end::knn-function[]
| keep color, rgb_vector
@ -29,31 +29,12 @@ chartreuse | [127.0, 255.0, 0.0]
// end::knn-function-result[]
;
knnSearchWithKOption
required_capability: knn_function
// tag::knn-function-options[]
from colors metadata _score
| where knn(rgb_vector, [0,255,255], {"k": 4})
| sort _score desc, color asc
// end::knn-function-options[]
| keep color, rgb_vector
| limit 4
;
color:text | rgb_vector:dense_vector
cyan | [0.0, 255.0, 255.0]
turquoise | [64.0, 224.0, 208.0]
aqua marine | [127.0, 255.0, 212.0]
teal | [0.0, 128.0, 128.0]
;
# https://github.com/elastic/elasticsearch/issues/129550
# https://github.com/elastic/elasticsearch/issues/129550 - Add as an example to knn function documentation
knnSearchWithSimilarityOption-Ignore
required_capability: knn_function
required_capability: knn_function_v2
from colors metadata _score
| where knn(rgb_vector, [255,192,203], {"k": 140, "similarity": 40})
| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40})
| sort _score desc, color asc
| keep color, rgb_vector
;
@ -63,14 +44,13 @@ pink | [255.0, 192.0, 203.0]
peach puff | [255.0, 218.0, 185.0]
bisque | [255.0, 228.0, 196.0]
wheat | [245.0, 222.0, 179.0]
;
knnHybridSearch
required_capability: knn_function
required_capability: knn_function_v2
from colors metadata _score
| where match(color, "blue") or knn(rgb_vector, [65,105,225], {"k": 140})
| where match(color, "blue") or knn(rgb_vector, [65,105,225], 140)
| where primary == true
| sort _score desc, color asc
| keep color, rgb_vector
@ -90,10 +70,10 @@ yellow | [255.0, 255.0, 0.0]
;
knnWithMultipleFunctions
required_capability: knn_function
required_capability: knn_function_v2
from colors metadata _score
| where knn(rgb_vector, [128,128,0], {"k": 140}) and match(color, "olive")
| where knn(rgb_vector, [128,128,0], 140) and match(color, "olive")
| sort _score desc, color asc
| keep color, rgb_vector
;
@ -103,11 +83,11 @@ olive | [128.0, 128.0, 0.0]
;
knnAfterKeep
required_capability: knn_function
required_capability: knn_function_v2
from colors metadata _score
| keep rgb_vector, color, _score
| where knn(rgb_vector, [128,255,0], {"k": 140})
| where knn(rgb_vector, [128,255,0], 140)
| sort _score desc, color asc
| keep rgb_vector
| limit 5
@ -122,11 +102,11 @@ rgb_vector:dense_vector
;
knnAfterDrop
required_capability: knn_function
required_capability: knn_function_v2
from colors metadata _score
| drop primary
| where knn(rgb_vector, [128,250,0], {"k": 140})
| where knn(rgb_vector, [128,250,0], 140)
| sort _score desc, color asc
| keep color, rgb_vector
| limit 5
@ -141,11 +121,11 @@ lime | [0.0, 255.0, 0.0]
;
knnAfterEval
required_capability: knn_function
required_capability: knn_function_v2
from colors metadata _score
| eval composed_name = locate(color, " ") > 0
| where knn(rgb_vector, [128,128,0], {"k": 140})
| where knn(rgb_vector, [128,128,0], 140)
| sort _score desc, color asc
| keep color, composed_name
| limit 5
@ -160,11 +140,11 @@ golden rod | true
;
knnWithConjunction
required_capability: knn_function
required_capability: knn_function_v2
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
from colors metadata _score
| where knn(rgb_vector, [255,255,238], {"k": 140}) and hex_code like "#FFF*"
| where knn(rgb_vector, [255,255,238], 140) and hex_code like "#FFF*"
| sort _score desc, color asc
| keep color, hex_code, rgb_vector
| limit 10
@ -181,11 +161,11 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0]
;
knnWithDisjunctionAndFiltersConjunction
required_capability: knn_function
required_capability: knn_function_v2
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
from colors metadata _score
| where (knn(rgb_vector, [0,255,255], {"k": 140}) or knn(rgb_vector, [128, 0, 255], {"k": 140})) and primary == true
| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 140)) and primary == true
| keep color, rgb_vector, _score
| sort _score desc, color asc
| drop _score
@ -205,11 +185,11 @@ yellow | [255.0, 255.0, 0.0]
;
knnWithNonPushableConjunction
required_capability: knn_function
required_capability: knn_function_v2
from colors metadata _score
| eval composed_name = locate(color, " ") > 0
| where knn(rgb_vector, [128,128,0], {"k": 140}) and composed_name == false
| where knn(rgb_vector, [128,128,0], 140) and composed_name == false
| sort _score desc, color asc
| keep color, composed_name
| limit 10
@ -230,10 +210,10 @@ maroon | false
# https://github.com/elastic/elasticsearch/issues/129550
testKnnWithNonPushableDisjunctions-Ignore
required_capability: knn_function
required_capability: knn_function_v2
from colors metadata _score
| where knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 30}) or length(color) > 10
| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10
| sort _score desc, color asc
| keep color
;
@ -247,10 +227,10 @@ papaya whip
# https://github.com/elastic/elasticsearch/issues/129550
testKnnWithNonPushableDisjunctionsOnComplexExpressions-Ignore
required_capability: knn_function
required_capability: knn_function_v2
from colors metadata _score
| where (knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], {"k": 140, "similarity": 60}) and primary == false)
| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false)
| sort _score desc, color asc
| keep color, primary
;
@ -262,11 +242,11 @@ indigo | false
;
testKnnInStatsNonPushable
required_capability: knn_function
required_capability: knn_function_v2
from colors
| where length(color) < 10
| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140})
| stats c = count(*) where knn(rgb_vector, [128,128,255], 140)
;
c: long
@ -274,12 +254,12 @@ c: long
;
testKnnInStatsWithGrouping
required_capability: knn_function
required_capability: knn_function_v2
required_capability: full_text_functions_in_stats_where
from colors
| where length(color) < 10
| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) by primary
| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) by primary
;
c: long | primary: boolean

View File

@ -39,7 +39,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s)
| WHERE knn(vector, %s, 10)
| KEEP id, floats, _score, vector
| SORT _score DESC
""", Arrays.toString(queryVector));
@ -73,7 +73,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s, {"k": 5})
| WHERE knn(vector, %s, 5)
| KEEP id, floats, _score, vector
| SORT _score DESC
""", Arrays.toString(queryVector));
@ -94,7 +94,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
// TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s, {"k": 5}) OR id > 10
| WHERE knn(vector, %s, 5) OR id > 10
| KEEP id, floats, _score, vector
| SORT _score DESC
""", Arrays.toString(queryVector));
@ -111,7 +111,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
@Before
public void setup() throws IOException {
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled());
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled());
var indexName = "test";
var client = client().admin().indices();

View File

@ -1195,7 +1195,7 @@ public class EsqlCapabilities {
/**
* Support knn function
*/
KNN_FUNCTION(Build.current().isSnapshot()),
KNN_FUNCTION_V2(Build.current().isSnapshot()),
LIKE_WITH_LIST_OF_PATTERNS,

View File

@ -259,7 +259,7 @@ public class ExpressionWritables {
}
private static List<NamedWriteableRegistry.Entry> vector() {
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
return List.of(Knn.ENTRY);
}
return List.of();

View File

@ -478,7 +478,7 @@ public class EsqlFunctionRegistry {
def(LastOverTime.class, uni(LastOverTime::new), "last_over_time"),
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
def(Term.class, bi(Term::new), "term"),
def(Knn.class, tri(Knn::new), "knn"),
def(Knn.class, Knn::new, "knn"),
def(StGeohash.class, StGeohash::new, "st_geohash"),
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),

View File

@ -35,6 +35,7 @@ import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -48,6 +49,7 @@ import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMI
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
@ -62,10 +64,11 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
private final Expression field;
// k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
private final transient Expression k;
private final Expression options;
public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
entry(K_FIELD.getPreferredName(), INTEGER),
entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER),
entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
entry(BOOST_FIELD.getPreferredName(), FLOAT),
@ -77,9 +80,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
preview = true,
description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. "
+ "knn function finds nearest vectors through approximate search on indexed dense_vectors.",
examples = {
@Example(file = "knn-function", tag = "knn-function"),
@Example(file = "knn-function", tag = "knn-function-options"), },
examples = { @Example(file = "knn-function", tag = "knn-function") },
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
)
public Knn(
@ -90,6 +91,13 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
type = { "dense_vector" },
description = "Vector value to find top nearest neighbours for."
) Expression query,
@Param(
name = "k",
type = { "integer" },
description = "The number of nearest neighbors to return from each shard. "
+ "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
+ "This value must be less than or equal to num_candidates."
) Expression k,
@MapParam(
name = "options",
params = {
@ -100,14 +108,6 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
description = "Floating point number used to decrease or increase the relevance scores of the query."
+ "Defaults to 1.0."
),
@MapParam.MapParamEntry(
name = "k",
type = "integer",
valueHint = { "10" },
description = "The number of nearest neighbors to return from each shard. "
+ "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
+ "This value must be less than or equal to num_candidates. Defaults to 10."
),
@MapParam.MapParamEntry(
name = "num_candidates",
type = "integer",
@ -136,19 +136,37 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
optional = true
) Expression options
) {
this(source, field, query, options, null);
this(source, field, query, k, options, null);
}
private Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) {
super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder);
private Knn(Source source, Expression field, Expression query, Expression k, Expression options, QueryBuilder queryBuilder) {
super(source, query, expressionList(field, query, k, options), queryBuilder);
this.field = field;
this.k = k;
this.options = options;
}
private static List<Expression> expressionList(Expression field, Expression query, Expression k, Expression options) {
List<Expression> result = new ArrayList<>();
result.add(field);
result.add(query);
if (k != null) {
result.add(k);
}
if (options != null) {
result.add(options);
}
return result;
}
public Expression field() {
return field;
}
public Expression k() {
return k;
}
public Expression options() {
return options;
}
@ -160,7 +178,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
@Override
protected TypeResolution resolveParams() {
return resolveField().and(resolveQuery()).and(resolveOptions());
return resolveField().and(resolveQuery()).and(resolveK()).and(resolveOptions());
}
private TypeResolution resolveField() {
@ -173,14 +191,24 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
);
}
private TypeResolution resolveK() {
if (k == null) {
// Function has already been rewritten and included in QueryBuilder - otherwise parsing would have failed
return TypeResolution.TYPE_RESOLVED;
}
return isType(k(), dt -> dt == INTEGER, sourceText(), THIRD, "integer").and(isFoldable(k(), sourceText(), THIRD))
.and(isNotNull(k(), sourceText(), THIRD));
}
private TypeResolution resolveOptions() {
if (options() != null) {
TypeResolution resolution = isNotNull(options(), sourceText(), THIRD);
TypeResolution resolution = isNotNull(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH);
if (resolution.unresolved()) {
return resolution;
}
// MapExpression does not have a DataType associated with it
resolution = isMapExpression(options(), sourceText(), THIRD);
resolution = isMapExpression(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH);
if (resolution.unresolved()) {
return resolution;
}
@ -200,7 +228,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
}
Map<String, Object> matchOptions = new HashMap<>();
populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS);
populateOptionsMap((MapExpression) options(), matchOptions, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS);
return matchOptions;
}
@ -216,22 +244,24 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
for (int i = 0; i < queryFolded.size(); i++) {
queryAsFloats[i] = queryFolded.get(i).floatValue();
}
int kValue = ((Number) k().fold(FoldContext.small())).intValue();
return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions());
Map<String, Object> opts = queryOptions();
opts.put(K_FIELD.getPreferredName(), kValue);
return new KnnQuery(source(), fieldName, queryAsFloats, opts);
}
@Override
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
return new Knn(source(), field(), query(), options(), queryBuilder);
return new Knn(source(), field(), query(), k(), options(), queryBuilder);
}
private Map<String, Object> queryOptions() throws InvalidArgumentException {
if (options() == null) {
return Map.of();
}
Map<String, Object> options = new HashMap<>();
populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS);
if (options() != null) {
populateOptionsMap((MapExpression) options(), options, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS);
}
return options;
}
@ -241,14 +271,15 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
source(),
newChildren.get(0),
newChildren.get(1),
newChildren.size() > 2 ? newChildren.get(2) : null,
newChildren.get(2),
newChildren.size() > 3 ? newChildren.get(3) : null,
queryBuilder()
);
}
@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, Knn::new, field(), query(), options());
return NodeInfo.create(this, Knn::new, field(), query(), k(), options());
}
@Override
@ -261,8 +292,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
Expression field = in.readNamedWriteable(Expression.class);
Expression query = in.readNamedWriteable(Expression.class);
QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
return new Knn(source, field, query, null, queryBuilder);
return new Knn(source, field, query, null, null, queryBuilder);
}
@Override

View File

@ -298,7 +298,7 @@ public class CsvTests extends ESTestCase {
);
assumeFalse(
"can't use KNN function in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION.capabilityName())
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V2.capabilityName())
);
assumeFalse(
"lookup join disabled for csv tests",

View File

@ -2375,7 +2375,7 @@ public class AnalyzerTests extends ESTestCase {
Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors"));
var plan = analyze("""
from test | where knn(vector, [0.342, 0.164, 0.234])
from test | where knn(vector, [0.342, 0.164, 0.234], 10)
""", "mapping-dense_vector.json");
var limit = as(plan, Limit.class);

View File

@ -1235,8 +1235,8 @@ public class VerifierTests extends ESTestCase {
checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function");
checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])");
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)");
}
}
@ -1368,8 +1368,8 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function");
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function");
}
}
@ -1407,8 +1407,8 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])");
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)");
}
}
@ -1472,8 +1472,8 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function");
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function");
}
}
@ -1543,8 +1543,8 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2])");
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)");
}
}
@ -2071,8 +2071,8 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})");
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})");
}
}
@ -2159,9 +2159,10 @@ public class VerifierTests extends ESTestCase {
checkFullTextFunctionNullArgs("term(null, \"query\")", "first");
checkFullTextFunctionNullArgs("term(title, null)", "second");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first");
checkFullTextFunctionNullArgs("knn(vector, null)", "second");
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first");
checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second");
checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third");
}
}
@ -2172,24 +2173,25 @@ public class VerifierTests extends ESTestCase {
);
}
public void testFullTextFunctionsConstantQuery() throws Exception {
checkFullTextFunctionsConstantQuery("match(title, category)", "second");
checkFullTextFunctionsConstantQuery("qstr(title)", "");
checkFullTextFunctionsConstantQuery("kql(title)", "");
checkFullTextFunctionsConstantQuery("match_phrase(title, tags)", "second");
public void testFullTextFunctionsConstantArg() throws Exception {
checkFullTextFunctionsConstantArg("match(title, category)", "second");
checkFullTextFunctionsConstantArg("qstr(title)", "");
checkFullTextFunctionsConstantArg("kql(title)", "");
checkFullTextFunctionsConstantArg("match_phrase(title, tags)", "second");
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkFullTextFunctionsConstantQuery("multi_match(category, body)", "first");
checkFullTextFunctionsConstantQuery("multi_match(concat(title, \"world\"), title)", "first");
checkFullTextFunctionsConstantArg("multi_match(category, body)", "first");
checkFullTextFunctionsConstantArg("multi_match(concat(title, \"world\"), title)", "first");
}
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkFullTextFunctionsConstantQuery("term(title, tags)", "second");
checkFullTextFunctionsConstantArg("term(title, tags)", "second");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionsConstantQuery("knn(vector, vector)", "second");
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
checkFullTextFunctionsConstantArg("knn(vector, vector, 10)", "second");
checkFullTextFunctionsConstantArg("knn(vector, [0, 1, 2], category)", "third");
}
}
private void checkFullTextFunctionsConstantQuery(String functionInvocation, String argOrdinal) throws Exception {
private void checkFullTextFunctionsConstantArg(String functionInvocation, String argOrdinal) throws Exception {
assertThat(
error("from test | where " + functionInvocation, fullTextAnalyzer),
containsString(argOrdinal + " argument of [" + functionInvocation + "] must be a constant")
@ -2214,8 +2216,8 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])");
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)");
}
}

View File

@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
@ -27,6 +28,7 @@ import org.junit.Before;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize;
@ -49,19 +51,33 @@ public class KnnTests extends AbstractFunctionTestCase {
@Before
public void checkCapability() {
assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled());
assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled());
}
private static List<TestCaseSupplier> testCaseSuppliers() {
List<TestCaseSupplier> suppliers = new ArrayList<>();
suppliers.add(
TestCaseSupplier.testCaseSupplier(
new TestCaseSupplier.TypedDataSupplier("dense_vector field", KnnTests::randomDenseVector, DENSE_VECTOR),
new TestCaseSupplier.TypedDataSupplier("query", KnnTests::randomDenseVector, DENSE_VECTOR, true),
(d1, d2) -> equalTo("string"),
new TestCaseSupplier(
List.of(DENSE_VECTOR, DENSE_VECTOR, DataType.INTEGER),
() -> new TestCaseSupplier.TestCase(
List.of(
new TestCaseSupplier.TypedData(
new FieldAttribute(
Source.EMPTY,
randomIdentifier(),
new EsField(randomIdentifier(), DENSE_VECTOR, Map.of(), false)
),
DENSE_VECTOR,
"dense_vector field"
),
new TestCaseSupplier.TypedData(randomDenseVector(), DENSE_VECTOR, "query"),
new TestCaseSupplier.TypedData(randomIntBetween(1, 1000), DataType.INTEGER, "k")
),
equalTo("KnnEvaluator" + KnnTests.class.getSimpleName()),
BOOLEAN,
(o1, o2) -> true
equalTo(true)
)
)
);
@ -104,7 +120,7 @@ public class KnnTests extends AbstractFunctionTestCase {
@Override
protected Expression build(Source source, List<Expression> args) {
Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null);
Knn knn = new Knn(source, args.get(0), args.get(1), args.get(2), args.size() > 3 ? args.get(3) : null);
// We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and
// thus test the serialization methods. But we can only do this if the parameters make sense .
if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) {

View File

@ -1363,12 +1363,12 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
public void testKnnOptionsPushDown() {
assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled());
assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled());
String query = """
from test
| where KNN(dense_vector, [0.1, 0.2, 0.3],
{ "k": 5, "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 })
| where KNN(dense_vector, [0.1, 0.2, 0.3], 5,
{ "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 })
""";
var analyzer = makeAnalyzer("mapping-all-types.json");
var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);