This is an automated email from the ASF dual-hosted git repository.
cloud-fan pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 39681148eb73 [SPARK-56395][CONNECT][PYTHON] Add NEAREST BY DataFrame
API
39681148eb73 is described below
commit 39681148eb736a76eb6be7d7205fe3185d70c12f
Author: Dilip Biswal <[email protected]>
AuthorDate: Thu May 14 09:18:20 2026 +0800
[SPARK-56395][CONNECT][PYTHON] Add NEAREST BY DataFrame API
### What changes were proposed in this pull request?
Builds on the catalyst-side merged in SPARK-56395
[(link).](https://github.com/apache/spark/pull/55629) Adds the DataFrame
`nearestByJoin` method in Scala / Java / PySpark and wires up Spark Connect:
### Why are the changes needed
API completeness. The prior PR exposed `NEAREST BY` only via SQL; this PR
brings the same capability to DataFrame / PySpark / Spark Connect.
### Does this PR introduce _any_ user-facing change?
// Scala
```
users.nearestByJoin(
products,
-abs(users("score") - products("pscore")),
numResults = 1,
mode = "exact",
direction = "similarity",
joinType = "leftouter")
```
// PySpark
```
users.nearestByJoin(
products,
-sf.abs(users.score - products.pscore),
1,
"exact",
"similarity",
joinType="leftouter",
).select("user_id", "product").show()
```
### How was this patch tested?
DataFrameNearestByJoinSuite,RewriteNearestByJoinSuite, python doctests
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (Opus 4.7), human-reviewed and tested
Closes #55682 from dilipbiswal/SPARK-56395-DF-CONNECT2.
Authored-by: Dilip Biswal <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 13380e780e5e398d4f498f0a97fe8b97257c80bb)
Signed-off-by: Wenchen Fan <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
project/MimaExcludes.scala | 4 +-
.../source/reference/pyspark.sql/dataframe.rst | 1 +
python/pyspark/errors/error-conditions.json | 28 ++
python/pyspark/sql/classic/dataframe.py | 15 +
python/pyspark/sql/connect/dataframe.py | 24 ++
python/pyspark/sql/connect/plan.py | 102 +++++
python/pyspark/sql/connect/proto/relations_pb2.py | 350 ++++++++--------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 85 ++++
python/pyspark/sql/dataframe.py | 67 ++++
.../tests/connect/test_parity_nearest_by_join.py | 30 ++
python/pyspark/sql/tests/test_nearest_by_join.py | 270 +++++++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 70 ++++
.../catalyst/plans/NearestByJoinValidation.scala | 43 ++
.../spark/sql/catalyst/plans/joinTypes.scala | 16 +-
.../sql/catalyst/plans/logical/NearestByJoin.scala | 6 +-
.../spark/sql/DataFrameNearestByJoinSuite.scala | 103 +++++
.../apache/spark/sql/PlanGenerationTestSuite.scala | 23 ++
.../main/protobuf/spark/connect/relations.proto | 31 ++
.../org/apache/spark/sql/connect/Dataset.scala | 91 +++++
.../nearestByJoin_inner_approx_similarity.explain | 5 +
.../nearestByJoin_leftouter_exact_distance.explain | 5 +
.../nearestByJoin_inner_approx_similarity.json | 109 +++++
...nearestByJoin_inner_approx_similarity.proto.bin | Bin 0 -> 708 bytes
.../nearestByJoin_leftouter_exact_distance.json | 109 +++++
...earestByJoin_leftouter_exact_distance.proto.bin | Bin 0 -> 709 bytes
.../sql/connect/planner/SparkConnectPlanner.scala | 24 ++
.../org/apache/spark/sql/classic/Dataset.scala | 60 +++
.../spark/sql/DataFrameNearestByJoinSuite.scala | 444 +++++++++++++++++++++
29 files changed, 1931 insertions(+), 186 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 693d43f10f57..a03732a3554e 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -612,6 +612,7 @@ pyspark_sql = Module(
"pyspark.sql.tests.test_readwriter",
"pyspark.sql.tests.test_serde",
"pyspark.sql.tests.test_session",
+ "pyspark.sql.tests.test_nearest_by_join",
"pyspark.sql.tests.test_subquery",
"pyspark.sql.tests.test_types",
"pyspark.sql.tests.test_geographytype",
@@ -1175,6 +1176,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_parity_observation",
"pyspark.sql.tests.connect.test_parity_repartition",
"pyspark.sql.tests.connect.test_parity_stat",
+ "pyspark.sql.tests.connect.test_parity_nearest_by_join",
"pyspark.sql.tests.connect.test_parity_subquery",
"pyspark.sql.tests.connect.test_parity_types",
"pyspark.sql.tests.connect.test_parity_column",
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index bf2984ba8c6d..25e60a4fecd7 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -57,7 +57,9 @@ object MimaExcludes {
// [SPARK-56330][CORE] Add TaskInterruptListener to TaskContext for
interrupt notifications
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.addTaskInterruptListener"),
// [SPARK-34591][ML] Add pruneTree parameter to Strategy
-
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.this")
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.this"),
+ // [SPARK-56395][SQL] Add NEAREST BY top-K ranking join
+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.nearestByJoin")
)
// Exclude rules for 4.1.x from 4.0.0
diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst
b/python/docs/source/reference/pyspark.sql/dataframe.rst
index 9652eb7c4275..91cf0961318b 100644
--- a/python/docs/source/reference/pyspark.sql/dataframe.rst
+++ b/python/docs/source/reference/pyspark.sql/dataframe.rst
@@ -84,6 +84,7 @@ DataFrame
DataFrame.metadataColumn
DataFrame.melt
DataFrame.na
+ DataFrame.nearestByJoin
DataFrame.observe
DataFrame.offset
DataFrame.orderBy
diff --git a/python/pyspark/errors/error-conditions.json
b/python/pyspark/errors/error-conditions.json
index 485d48b6d7d5..7cc5a73e254b 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -602,6 +602,34 @@
"Multiple pipeline spec files found in the directory `<dir_path>`.
Please remove one or choose a particular one with the --spec argument."
]
},
+ "NEAREST_BY_JOIN": {
+ "message": [
+ "Invalid nearest-by join."
+ ],
+ "sub_class": {
+ "NUM_RESULTS_OUT_OF_RANGE": {
+ "message": [
+ "The number of results <numResults> must be between <min> and <max>.
Update the literal in `APPROX NEAREST <numResults> BY ...` (or `EXACT NEAREST
<numResults> BY ...`) to fall within that range."
+ ]
+ },
+ "UNSUPPORTED_DIRECTION": {
+ "message": [
+ "Unsupported nearest-by join direction '<direction>'. Supported
nearest-by join directions include: <supported>."
+ ]
+ },
+ "UNSUPPORTED_JOIN_TYPE": {
+ "message": [
+ "Unsupported nearest-by join type <joinType>. Supported types:
<supported>."
+ ]
+ },
+ "UNSUPPORTED_MODE": {
+ "message": [
+ "Unsupported nearest-by join mode '<mode>'. Supported modes include:
<supported>."
+ ]
+ }
+ },
+ "sqlState": "42604"
+ },
"NEGATIVE_VALUE": {
"message": [
"Value for `<arg_name>` must be greater than or equal to 0, got
'<arg_value>'."
diff --git a/python/pyspark/sql/classic/dataframe.py
b/python/pyspark/sql/classic/dataframe.py
index 9ea2ee4b86eb..e1c3380416e4 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -820,6 +820,21 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
jdf = self._jdf.lateralJoin(other._jdf, on._jc, how)
return DataFrame(jdf, self.sparkSession)
+ def nearestByJoin(
+ self,
+ other: ParentDataFrame,
+ rankingExpression: Column,
+ numResults: int,
+ mode: str,
+ direction: str,
+ *,
+ joinType: str = "inner",
+ ) -> ParentDataFrame:
+ jdf = self._jdf.nearestByJoin(
+ other._jdf, rankingExpression._jc, int(numResults), mode,
direction, joinType
+ )
+ return DataFrame(jdf, self.sparkSession)
+
# TODO(SPARK-22947): Fix the DataFrame API.
def _joinAsOf(
self,
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index c6602e08fac4..b0a9692f289a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -726,6 +726,30 @@ class DataFrame(ParentDataFrame):
session=self._session,
)
+ def nearestByJoin(
+ self,
+ other: ParentDataFrame,
+ rankingExpression: Column,
+ numResults: int,
+ mode: str,
+ direction: str,
+ *,
+ joinType: str = "inner",
+ ) -> ParentDataFrame:
+ other = self._check_same_session(other)
+ return DataFrame(
+ plan.NearestByJoin(
+ left=self._plan,
+ right=other._plan,
+ ranking_expression=rankingExpression,
+ num_results=int(numResults),
+ join_type=joinType,
+ mode=mode,
+ direction=direction,
+ ),
+ session=self._session,
+ )
+
def _joinAsOf(
self,
other: ParentDataFrame,
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 8e13cf360657..540d81ffc690 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1345,6 +1345,108 @@ class LateralJoin(LogicalPlan):
"""
+# Acceptance lists for `nearestByJoin`. Must stay aligned with
`NearestByJoinValidation` in
+# `sql/api/.../catalyst/plans/NearestByJoinValidation.scala`.
+_NEAREST_BY_JOIN_MAX_NUM_RESULTS = 100000
+_NEAREST_BY_JOIN_SUPPORTED_JOIN_TYPES = frozenset({"inner", "leftouter",
"left"})
+_NEAREST_BY_JOIN_SUPPORTED_JOIN_TYPE_DISPLAY = "'INNER', 'LEFT OUTER'"
+_NEAREST_BY_JOIN_SUPPORTED_MODES = ("approx", "exact")
+_NEAREST_BY_JOIN_SUPPORTED_DIRECTIONS = ("distance", "similarity")
+
+
+class NearestByJoin(LogicalPlan):
+ def __init__(
+ self,
+ left: Optional[LogicalPlan],
+ right: LogicalPlan,
+ ranking_expression: Column,
+ num_results: int,
+ join_type: str,
+ mode: str,
+ direction: str,
+ ) -> None:
+ super().__init__(left, self._collect_references([ranking_expression]))
+ self.left = cast(LogicalPlan, left)
+ self.right = right
+ self.ranking_expression = ranking_expression
+ # Mirror of the Scala `Dataset.validateNearestByJoinArgs` validator --
raises the same
+ # `NEAREST_BY_JOIN.*` error classes the server would, so the user sees
a consistent
+ # error regardless of where the check fires.
+ if num_results < 1 or num_results > _NEAREST_BY_JOIN_MAX_NUM_RESULTS:
+ raise AnalysisException(
+ errorClass="NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE",
+ messageParameters={
+ "numResults": str(num_results),
+ "min": "1",
+ "max": str(_NEAREST_BY_JOIN_MAX_NUM_RESULTS),
+ },
+ )
+ if join_type.lower().replace("_", "") not in
_NEAREST_BY_JOIN_SUPPORTED_JOIN_TYPES:
+ raise AnalysisException(
+ errorClass="NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE",
+ messageParameters={
+ "joinType": join_type,
+ "supported": _NEAREST_BY_JOIN_SUPPORTED_JOIN_TYPE_DISPLAY,
+ },
+ )
+ if mode.lower() not in _NEAREST_BY_JOIN_SUPPORTED_MODES:
+ raise AnalysisException(
+ errorClass="NEAREST_BY_JOIN.UNSUPPORTED_MODE",
+ messageParameters={
+ "mode": mode,
+ "supported": "'" + "',
'".join(_NEAREST_BY_JOIN_SUPPORTED_MODES) + "'",
+ },
+ )
+ if direction.lower() not in _NEAREST_BY_JOIN_SUPPORTED_DIRECTIONS:
+ raise AnalysisException(
+ errorClass="NEAREST_BY_JOIN.UNSUPPORTED_DIRECTION",
+ messageParameters={
+ "direction": direction,
+ "supported": "'" + "',
'".join(_NEAREST_BY_JOIN_SUPPORTED_DIRECTIONS) + "'",
+ },
+ )
+ self.num_results = int(num_results)
+ self.join_type = join_type
+ self.mode = mode
+ self.direction = direction
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ plan = self._create_proto_relation()
+ plan.nearest_by_join.left.CopyFrom(self.left.plan(session))
+ plan.nearest_by_join.right.CopyFrom(self.right.plan(session))
+
plan.nearest_by_join.ranking_expression.CopyFrom(self.ranking_expression.to_plan(session))
+ plan.nearest_by_join.num_results = self.num_results
+ plan.nearest_by_join.join_type = self.join_type
+ plan.nearest_by_join.mode = self.mode
+ plan.nearest_by_join.direction = self.direction
+ return self._with_relations(plan, session)
+
+ @property
+ def observations(self) -> Dict[str, "Observation"]:
+ return {**super().observations, **self.right.observations}
+
+ def print(self, indent: int = 0) -> str:
+ i = " " * indent
+ o = " " * (indent + LogicalPlan.INDENT)
+ n = indent + LogicalPlan.INDENT * 2
+ return (
+ f"{i}<NearestByJoin numResults={self.num_results}
joinType={self.join_type} "
+ f"mode={self.mode} direction={self.direction}>\n{o}"
+ f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}"
+ )
+
+ def _repr_html_(self) -> str:
+ return f"""
+ <ul>
+ <li>
+ <b>NearestByJoin</b><br />
+ Left: {self.left._repr_html_()}
+ Right: {self.right._repr_html_()}
+ </li>
+ </uL>
+ """
+
+
class SetOperation(LogicalPlan):
def __init__(
self,
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index d024c6a07ada..f63b61fc344e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -44,7 +44,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as
spark_dot_connect_dot_ml_
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\xd9\x1e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\xa1\x1f\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x [...]
)
_globals = globals()
@@ -82,177 +82,179 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PARSE_OPTIONSENTRY"]._loaded_options = None
_globals["_PARSE_OPTIONSENTRY"]._serialized_options = b"8\001"
_globals["_RELATION"]._serialized_start = 224
- _globals["_RELATION"]._serialized_end = 4153
- _globals["_MLRELATION"]._serialized_start = 4156
- _globals["_MLRELATION"]._serialized_end = 4640
- _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4368
- _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4603
- _globals["_FETCH"]._serialized_start = 4643
- _globals["_FETCH"]._serialized_end = 4974
- _globals["_FETCH_METHOD"]._serialized_start = 4759
- _globals["_FETCH_METHOD"]._serialized_end = 4974
- _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4847
- _globals["_FETCH_METHOD_ARGS"]._serialized_end = 4974
- _globals["_UNKNOWN"]._serialized_start = 4976
- _globals["_UNKNOWN"]._serialized_end = 4985
- _globals["_RELATIONCOMMON"]._serialized_start = 4988
- _globals["_RELATIONCOMMON"]._serialized_end = 5130
- _globals["_SQL"]._serialized_start = 5133
- _globals["_SQL"]._serialized_end = 5611
- _globals["_SQL_ARGSENTRY"]._serialized_start = 5427
- _globals["_SQL_ARGSENTRY"]._serialized_end = 5517
- _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5519
- _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5611
- _globals["_WITHRELATIONS"]._serialized_start = 5613
- _globals["_WITHRELATIONS"]._serialized_end = 5730
- _globals["_READ"]._serialized_start = 5733
- _globals["_READ"]._serialized_end = 6450
- _globals["_READ_NAMEDTABLE"]._serialized_start = 5911
- _globals["_READ_NAMEDTABLE"]._serialized_end = 6103
- _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 6045
- _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 6103
- _globals["_READ_DATASOURCE"]._serialized_start = 6106
- _globals["_READ_DATASOURCE"]._serialized_end = 6437
- _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 6045
- _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 6103
- _globals["_RELATIONCHANGES"]._serialized_start = 6453
- _globals["_RELATIONCHANGES"]._serialized_end = 6685
- _globals["_RELATIONCHANGES_OPTIONSENTRY"]._serialized_start = 6045
- _globals["_RELATIONCHANGES_OPTIONSENTRY"]._serialized_end = 6103
- _globals["_PROJECT"]._serialized_start = 6687
- _globals["_PROJECT"]._serialized_end = 6804
- _globals["_FILTER"]._serialized_start = 6806
- _globals["_FILTER"]._serialized_end = 6918
- _globals["_JOIN"]._serialized_start = 6921
- _globals["_JOIN"]._serialized_end = 7582
- _globals["_JOIN_JOINDATATYPE"]._serialized_start = 7260
- _globals["_JOIN_JOINDATATYPE"]._serialized_end = 7352
- _globals["_JOIN_JOINTYPE"]._serialized_start = 7355
- _globals["_JOIN_JOINTYPE"]._serialized_end = 7563
- _globals["_SETOPERATION"]._serialized_start = 7585
- _globals["_SETOPERATION"]._serialized_end = 8064
- _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7901
- _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 8015
- _globals["_LIMIT"]._serialized_start = 8066
- _globals["_LIMIT"]._serialized_end = 8142
- _globals["_OFFSET"]._serialized_start = 8144
- _globals["_OFFSET"]._serialized_end = 8223
- _globals["_TAIL"]._serialized_start = 8225
- _globals["_TAIL"]._serialized_end = 8300
- _globals["_AGGREGATE"]._serialized_start = 8303
- _globals["_AGGREGATE"]._serialized_end = 9069
- _globals["_AGGREGATE_PIVOT"]._serialized_start = 8718
- _globals["_AGGREGATE_PIVOT"]._serialized_end = 8829
- _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8831
- _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8907
- _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8910
- _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 9069
- _globals["_SORT"]._serialized_start = 9072
- _globals["_SORT"]._serialized_end = 9232
- _globals["_DROP"]._serialized_start = 9235
- _globals["_DROP"]._serialized_end = 9376
- _globals["_DEDUPLICATE"]._serialized_start = 9379
- _globals["_DEDUPLICATE"]._serialized_end = 9619
- _globals["_LOCALRELATION"]._serialized_start = 9621
- _globals["_LOCALRELATION"]._serialized_end = 9710
- _globals["_CACHEDLOCALRELATION"]._serialized_start = 9712
- _globals["_CACHEDLOCALRELATION"]._serialized_end = 9784
- _globals["_CHUNKEDCACHEDLOCALRELATION"]._serialized_start = 9786
- _globals["_CHUNKEDCACHEDLOCALRELATION"]._serialized_end = 9898
- _globals["_CACHEDREMOTERELATION"]._serialized_start = 9900
- _globals["_CACHEDREMOTERELATION"]._serialized_end = 9955
- _globals["_SAMPLE"]._serialized_start = 9958
- _globals["_SAMPLE"]._serialized_end = 10231
- _globals["_RANGE"]._serialized_start = 10234
- _globals["_RANGE"]._serialized_end = 10379
- _globals["_SUBQUERYALIAS"]._serialized_start = 10381
- _globals["_SUBQUERYALIAS"]._serialized_end = 10495
- _globals["_REPARTITION"]._serialized_start = 10498
- _globals["_REPARTITION"]._serialized_end = 10640
- _globals["_SHOWSTRING"]._serialized_start = 10643
- _globals["_SHOWSTRING"]._serialized_end = 10785
- _globals["_HTMLSTRING"]._serialized_start = 10787
- _globals["_HTMLSTRING"]._serialized_end = 10901
- _globals["_STATSUMMARY"]._serialized_start = 10903
- _globals["_STATSUMMARY"]._serialized_end = 10995
- _globals["_STATDESCRIBE"]._serialized_start = 10997
- _globals["_STATDESCRIBE"]._serialized_end = 11078
- _globals["_STATCROSSTAB"]._serialized_start = 11080
- _globals["_STATCROSSTAB"]._serialized_end = 11181
- _globals["_STATCOV"]._serialized_start = 11183
- _globals["_STATCOV"]._serialized_end = 11279
- _globals["_STATCORR"]._serialized_start = 11282
- _globals["_STATCORR"]._serialized_end = 11419
- _globals["_STATAPPROXQUANTILE"]._serialized_start = 11422
- _globals["_STATAPPROXQUANTILE"]._serialized_end = 11586
- _globals["_STATFREQITEMS"]._serialized_start = 11588
- _globals["_STATFREQITEMS"]._serialized_end = 11713
- _globals["_STATSAMPLEBY"]._serialized_start = 11716
- _globals["_STATSAMPLEBY"]._serialized_end = 12025
- _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11917
- _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 12016
- _globals["_NAFILL"]._serialized_start = 12028
- _globals["_NAFILL"]._serialized_end = 12162
- _globals["_NADROP"]._serialized_start = 12165
- _globals["_NADROP"]._serialized_end = 12299
- _globals["_NAREPLACE"]._serialized_start = 12302
- _globals["_NAREPLACE"]._serialized_end = 12598
- _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 12457
- _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 12598
- _globals["_TODF"]._serialized_start = 12600
- _globals["_TODF"]._serialized_end = 12688
- _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 12691
- _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 13073
- _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start =
12935
- _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end =
13002
- _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 13004
- _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 13073
- _globals["_WITHCOLUMNS"]._serialized_start = 13075
- _globals["_WITHCOLUMNS"]._serialized_end = 13194
- _globals["_WITHWATERMARK"]._serialized_start = 13197
- _globals["_WITHWATERMARK"]._serialized_end = 13331
- _globals["_HINT"]._serialized_start = 13334
- _globals["_HINT"]._serialized_end = 13466
- _globals["_UNPIVOT"]._serialized_start = 13469
- _globals["_UNPIVOT"]._serialized_end = 13796
- _globals["_UNPIVOT_VALUES"]._serialized_start = 13726
- _globals["_UNPIVOT_VALUES"]._serialized_end = 13785
- _globals["_TRANSPOSE"]._serialized_start = 13798
- _globals["_TRANSPOSE"]._serialized_end = 13920
- _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13922
- _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 14047
- _globals["_TOSCHEMA"]._serialized_start = 14049
- _globals["_TOSCHEMA"]._serialized_end = 14155
- _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 14158
- _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 14361
- _globals["_MAPPARTITIONS"]._serialized_start = 14364
- _globals["_MAPPARTITIONS"]._serialized_end = 14596
- _globals["_GROUPMAP"]._serialized_start = 14599
- _globals["_GROUPMAP"]._serialized_end = 15449
- _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 15452
- _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 15675
- _globals["_COGROUPMAP"]._serialized_start = 15678
- _globals["_COGROUPMAP"]._serialized_end = 16204
- _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 16207
- _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 16564
- _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 16567
- _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16811
- _globals["_PYTHONUDTF"]._serialized_start = 16814
- _globals["_PYTHONUDTF"]._serialized_end = 16991
- _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 16994
- _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 17145
- _globals["_PYTHONDATASOURCE"]._serialized_start = 17147
- _globals["_PYTHONDATASOURCE"]._serialized_end = 17222
- _globals["_COLLECTMETRICS"]._serialized_start = 17225
- _globals["_COLLECTMETRICS"]._serialized_end = 17361
- _globals["_PARSE"]._serialized_start = 17364
- _globals["_PARSE"]._serialized_end = 17774
- _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 6045
- _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 6103
- _globals["_PARSE_PARSEFORMAT"]._serialized_start = 17653
- _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17763
- _globals["_ASOFJOIN"]._serialized_start = 17777
- _globals["_ASOFJOIN"]._serialized_end = 18252
- _globals["_LATERALJOIN"]._serialized_start = 18255
- _globals["_LATERALJOIN"]._serialized_end = 18485
+ _globals["_RELATION"]._serialized_end = 4225
+ _globals["_MLRELATION"]._serialized_start = 4228
+ _globals["_MLRELATION"]._serialized_end = 4712
+ _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4440
+ _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4675
+ _globals["_FETCH"]._serialized_start = 4715
+ _globals["_FETCH"]._serialized_end = 5046
+ _globals["_FETCH_METHOD"]._serialized_start = 4831
+ _globals["_FETCH_METHOD"]._serialized_end = 5046
+ _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4919
+ _globals["_FETCH_METHOD_ARGS"]._serialized_end = 5046
+ _globals["_UNKNOWN"]._serialized_start = 5048
+ _globals["_UNKNOWN"]._serialized_end = 5057
+ _globals["_RELATIONCOMMON"]._serialized_start = 5060
+ _globals["_RELATIONCOMMON"]._serialized_end = 5202
+ _globals["_SQL"]._serialized_start = 5205
+ _globals["_SQL"]._serialized_end = 5683
+ _globals["_SQL_ARGSENTRY"]._serialized_start = 5499
+ _globals["_SQL_ARGSENTRY"]._serialized_end = 5589
+ _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5591
+ _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5683
+ _globals["_WITHRELATIONS"]._serialized_start = 5685
+ _globals["_WITHRELATIONS"]._serialized_end = 5802
+ _globals["_READ"]._serialized_start = 5805
+ _globals["_READ"]._serialized_end = 6522
+ _globals["_READ_NAMEDTABLE"]._serialized_start = 5983
+ _globals["_READ_NAMEDTABLE"]._serialized_end = 6175
+ _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 6117
+ _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 6175
+ _globals["_READ_DATASOURCE"]._serialized_start = 6178
+ _globals["_READ_DATASOURCE"]._serialized_end = 6509
+ _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 6117
+ _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 6175
+ _globals["_RELATIONCHANGES"]._serialized_start = 6525
+ _globals["_RELATIONCHANGES"]._serialized_end = 6757
+ _globals["_RELATIONCHANGES_OPTIONSENTRY"]._serialized_start = 6117
+ _globals["_RELATIONCHANGES_OPTIONSENTRY"]._serialized_end = 6175
+ _globals["_PROJECT"]._serialized_start = 6759
+ _globals["_PROJECT"]._serialized_end = 6876
+ _globals["_FILTER"]._serialized_start = 6878
+ _globals["_FILTER"]._serialized_end = 6990
+ _globals["_JOIN"]._serialized_start = 6993
+ _globals["_JOIN"]._serialized_end = 7654
+ _globals["_JOIN_JOINDATATYPE"]._serialized_start = 7332
+ _globals["_JOIN_JOINDATATYPE"]._serialized_end = 7424
+ _globals["_JOIN_JOINTYPE"]._serialized_start = 7427
+ _globals["_JOIN_JOINTYPE"]._serialized_end = 7635
+ _globals["_SETOPERATION"]._serialized_start = 7657
+ _globals["_SETOPERATION"]._serialized_end = 8136
+ _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7973
+ _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 8087
+ _globals["_LIMIT"]._serialized_start = 8138
+ _globals["_LIMIT"]._serialized_end = 8214
+ _globals["_OFFSET"]._serialized_start = 8216
+ _globals["_OFFSET"]._serialized_end = 8295
+ _globals["_TAIL"]._serialized_start = 8297
+ _globals["_TAIL"]._serialized_end = 8372
+ _globals["_AGGREGATE"]._serialized_start = 8375
+ _globals["_AGGREGATE"]._serialized_end = 9141
+ _globals["_AGGREGATE_PIVOT"]._serialized_start = 8790
+ _globals["_AGGREGATE_PIVOT"]._serialized_end = 8901
+ _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8903
+ _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8979
+ _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8982
+ _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 9141
+ _globals["_SORT"]._serialized_start = 9144
+ _globals["_SORT"]._serialized_end = 9304
+ _globals["_DROP"]._serialized_start = 9307
+ _globals["_DROP"]._serialized_end = 9448
+ _globals["_DEDUPLICATE"]._serialized_start = 9451
+ _globals["_DEDUPLICATE"]._serialized_end = 9691
+ _globals["_LOCALRELATION"]._serialized_start = 9693
+ _globals["_LOCALRELATION"]._serialized_end = 9782
+ _globals["_CACHEDLOCALRELATION"]._serialized_start = 9784
+ _globals["_CACHEDLOCALRELATION"]._serialized_end = 9856
+ _globals["_CHUNKEDCACHEDLOCALRELATION"]._serialized_start = 9858
+ _globals["_CHUNKEDCACHEDLOCALRELATION"]._serialized_end = 9970
+ _globals["_CACHEDREMOTERELATION"]._serialized_start = 9972
+ _globals["_CACHEDREMOTERELATION"]._serialized_end = 10027
+ _globals["_SAMPLE"]._serialized_start = 10030
+ _globals["_SAMPLE"]._serialized_end = 10303
+ _globals["_RANGE"]._serialized_start = 10306
+ _globals["_RANGE"]._serialized_end = 10451
+ _globals["_SUBQUERYALIAS"]._serialized_start = 10453
+ _globals["_SUBQUERYALIAS"]._serialized_end = 10567
+ _globals["_REPARTITION"]._serialized_start = 10570
+ _globals["_REPARTITION"]._serialized_end = 10712
+ _globals["_SHOWSTRING"]._serialized_start = 10715
+ _globals["_SHOWSTRING"]._serialized_end = 10857
+ _globals["_HTMLSTRING"]._serialized_start = 10859
+ _globals["_HTMLSTRING"]._serialized_end = 10973
+ _globals["_STATSUMMARY"]._serialized_start = 10975
+ _globals["_STATSUMMARY"]._serialized_end = 11067
+ _globals["_STATDESCRIBE"]._serialized_start = 11069
+ _globals["_STATDESCRIBE"]._serialized_end = 11150
+ _globals["_STATCROSSTAB"]._serialized_start = 11152
+ _globals["_STATCROSSTAB"]._serialized_end = 11253
+ _globals["_STATCOV"]._serialized_start = 11255
+ _globals["_STATCOV"]._serialized_end = 11351
+ _globals["_STATCORR"]._serialized_start = 11354
+ _globals["_STATCORR"]._serialized_end = 11491
+ _globals["_STATAPPROXQUANTILE"]._serialized_start = 11494
+ _globals["_STATAPPROXQUANTILE"]._serialized_end = 11658
+ _globals["_STATFREQITEMS"]._serialized_start = 11660
+ _globals["_STATFREQITEMS"]._serialized_end = 11785
+ _globals["_STATSAMPLEBY"]._serialized_start = 11788
+ _globals["_STATSAMPLEBY"]._serialized_end = 12097
+ _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11989
+ _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 12088
+ _globals["_NAFILL"]._serialized_start = 12100
+ _globals["_NAFILL"]._serialized_end = 12234
+ _globals["_NADROP"]._serialized_start = 12237
+ _globals["_NADROP"]._serialized_end = 12371
+ _globals["_NAREPLACE"]._serialized_start = 12374
+ _globals["_NAREPLACE"]._serialized_end = 12670
+ _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 12529
+ _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 12670
+ _globals["_TODF"]._serialized_start = 12672
+ _globals["_TODF"]._serialized_end = 12760
+ _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 12763
+ _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 13145
+ _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start =
13007
+ _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end =
13074
+ _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 13076
+ _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 13145
+ _globals["_WITHCOLUMNS"]._serialized_start = 13147
+ _globals["_WITHCOLUMNS"]._serialized_end = 13266
+ _globals["_WITHWATERMARK"]._serialized_start = 13269
+ _globals["_WITHWATERMARK"]._serialized_end = 13403
+ _globals["_HINT"]._serialized_start = 13406
+ _globals["_HINT"]._serialized_end = 13538
+ _globals["_UNPIVOT"]._serialized_start = 13541
+ _globals["_UNPIVOT"]._serialized_end = 13868
+ _globals["_UNPIVOT_VALUES"]._serialized_start = 13798
+ _globals["_UNPIVOT_VALUES"]._serialized_end = 13857
+ _globals["_TRANSPOSE"]._serialized_start = 13870
+ _globals["_TRANSPOSE"]._serialized_end = 13992
+ _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13994
+ _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 14119
+ _globals["_TOSCHEMA"]._serialized_start = 14121
+ _globals["_TOSCHEMA"]._serialized_end = 14227
+ _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 14230
+ _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 14433
+ _globals["_MAPPARTITIONS"]._serialized_start = 14436
+ _globals["_MAPPARTITIONS"]._serialized_end = 14668
+ _globals["_GROUPMAP"]._serialized_start = 14671
+ _globals["_GROUPMAP"]._serialized_end = 15521
+ _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 15524
+ _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 15747
+ _globals["_COGROUPMAP"]._serialized_start = 15750
+ _globals["_COGROUPMAP"]._serialized_end = 16276
+ _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 16279
+ _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 16636
+ _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 16639
+ _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16883
+ _globals["_PYTHONUDTF"]._serialized_start = 16886
+ _globals["_PYTHONUDTF"]._serialized_end = 17063
+ _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 17066
+ _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 17217
+ _globals["_PYTHONDATASOURCE"]._serialized_start = 17219
+ _globals["_PYTHONDATASOURCE"]._serialized_end = 17294
+ _globals["_COLLECTMETRICS"]._serialized_start = 17297
+ _globals["_COLLECTMETRICS"]._serialized_end = 17433
+ _globals["_PARSE"]._serialized_start = 17436
+ _globals["_PARSE"]._serialized_end = 17846
+ _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 6117
+ _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 6175
+ _globals["_PARSE_PARSEFORMAT"]._serialized_start = 17725
+ _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17835
+ _globals["_ASOFJOIN"]._serialized_start = 17849
+ _globals["_ASOFJOIN"]._serialized_end = 18324
+ _globals["_LATERALJOIN"]._serialized_start = 18327
+ _globals["_LATERALJOIN"]._serialized_end = 18557
+ _globals["_NEARESTBYJOIN"]._serialized_start = 18560
+ _globals["_NEARESTBYJOIN"]._serialized_end = 18853
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 7b3968545ce0..c99de778db4c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -111,6 +111,7 @@ class Relation(google.protobuf.message.Message):
LATERAL_JOIN_FIELD_NUMBER: builtins.int
CHUNKED_CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
RELATION_CHANGES_FIELD_NUMBER: builtins.int
+ NEAREST_BY_JOIN_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
@@ -223,6 +224,8 @@ class Relation(google.protobuf.message.Message):
@property
def relation_changes(self) -> global___RelationChanges: ...
@property
+ def nearest_by_join(self) -> global___NearestByJoin: ...
+ @property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
@@ -310,6 +313,7 @@ class Relation(google.protobuf.message.Message):
lateral_join: global___LateralJoin | None = ...,
chunked_cached_local_relation: global___ChunkedCachedLocalRelation |
None = ...,
relation_changes: global___RelationChanges | None = ...,
+ nearest_by_join: global___NearestByJoin | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
@@ -395,6 +399,8 @@ class Relation(google.protobuf.message.Message):
b"map_partitions",
"ml_relation",
b"ml_relation",
+ "nearest_by_join",
+ b"nearest_by_join",
"offset",
b"offset",
"parse",
@@ -524,6 +530,8 @@ class Relation(google.protobuf.message.Message):
b"map_partitions",
"ml_relation",
b"ml_relation",
+ "nearest_by_join",
+ b"nearest_by_join",
"offset",
b"offset",
"parse",
@@ -633,6 +641,7 @@ class Relation(google.protobuf.message.Message):
"lateral_join",
"chunked_cached_local_relation",
"relation_changes",
+ "nearest_by_join",
"fill_na",
"drop_na",
"replace",
@@ -4657,3 +4666,79 @@ class LateralJoin(google.protobuf.message.Message):
) -> None: ...
global___LateralJoin = LateralJoin
+
+class NearestByJoin(google.protobuf.message.Message):
+ """Relation of type [[NearestByJoin]].
+
+ For each row on the left side, returns up to `num_results` rows from the
right side ranked
+ by `ranking_expression`.
+ """
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ LEFT_FIELD_NUMBER: builtins.int
+ RIGHT_FIELD_NUMBER: builtins.int
+ RANKING_EXPRESSION_FIELD_NUMBER: builtins.int
+ NUM_RESULTS_FIELD_NUMBER: builtins.int
+ JOIN_TYPE_FIELD_NUMBER: builtins.int
+ MODE_FIELD_NUMBER: builtins.int
+ DIRECTION_FIELD_NUMBER: builtins.int
+ @property
+ def left(self) -> global___Relation:
+ """(Required) Left (query) input relation."""
+ @property
+ def right(self) -> global___Relation:
+ """(Required) Right (base) input relation."""
+ @property
+ def ranking_expression(self) ->
pyspark.sql.connect.proto.expressions_pb2.Expression:
+ """(Required) Scalar expression used to rank candidate rows on the
right side."""
+ num_results: builtins.int
+ """(Required) Maximum number of matches per left row. Must be between 1
and 100000."""
+ join_type: builtins.str
+ """The following three fields use `string` (not typed enums) for parity
with `AsOfJoin`,
+ which models analogous fields the same way. Validation happens server-side
at planning time.
+
+ (Required) The join type. Must be one of: "inner", "leftouter".
+ """
+ mode: builtins.str
+ """(Required) Search algorithm contract. Must be one of: "approx",
"exact"."""
+ direction: builtins.str
+ """(Required) Ranking direction. Must be one of: "distance",
"similarity"."""
+ def __init__(
+ self,
+ *,
+ left: global___Relation | None = ...,
+ right: global___Relation | None = ...,
+ ranking_expression:
pyspark.sql.connect.proto.expressions_pb2.Expression | None = ...,
+ num_results: builtins.int = ...,
+ join_type: builtins.str = ...,
+ mode: builtins.str = ...,
+ direction: builtins.str = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "left", b"left", "ranking_expression", b"ranking_expression",
"right", b"right"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "direction",
+ b"direction",
+ "join_type",
+ b"join_type",
+ "left",
+ b"left",
+ "mode",
+ b"mode",
+ "num_results",
+ b"num_results",
+ "ranking_expression",
+ b"ranking_expression",
+ "right",
+ b"right",
+ ],
+ ) -> None: ...
+
+global___NearestByJoin = NearestByJoin
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d5172afe9bd2..6b17abdc5a9a 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2870,6 +2870,73 @@ class DataFrame:
"""
...
+ def nearestByJoin(
+ self,
+ other: "DataFrame",
+ rankingExpression: Column,
+ numResults: int,
+ mode: str,
+ direction: str,
+ *,
+ joinType: str = "inner",
+ ) -> "DataFrame":
+ """
+ Nearest-by top-K ranking join with another :class:`DataFrame`. For
each row on the
+ left (query side), returns up to ``numResults`` rows from ``other``
(base side), ranked
+ by ``rankingExpression``.
+
+ The current implementation evaluates the full cross-product of left
and right and
+ bounds memory per left row by ``numResults``. Index-backed approximate
strategies
+ (transparent to ``approx`` mode) are planned for a future release;
until then,
+ pre-filter ``other`` when it is large. Tie-breaking among rows with
equal ranking
+ values is unspecified.
+
+ .. versionadded:: 4.2.0
+
+ Parameters
+ ----------
+ other : :class:`DataFrame`
+ Right (base side) of the join - the candidate pool searched for
each row of this
+ DataFrame.
+ rankingExpression : :class:`Column`
+ Scalar expression used to rank candidate rows on the right side.
+ numResults : int
+ Maximum number of matches per query row. Must be between 1 and
100000.
+ mode : str
+ Search algorithm contract. Must be one of: ``approx``, ``exact``.
``approx`` allows
+ the optimizer to use indexed or other approximate strategies when
available;
+ ``exact`` forces brute-force evaluation and requires the ranking
expression to be
+ deterministic.
+ direction : str
+ ``"distance"`` (smallest value first) or ``"similarity"`` (largest
value first).
+ joinType : str, keyword-only, optional
+ Default ``inner``. Must be one of: ``inner``, ``leftouter``.
+
+ Returns
+ -------
+ :class:`DataFrame`
+ Joined DataFrame.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> users = spark.createDataFrame(
+ ... [(1, 10.0), (2, 20.0), (3, 30.0)], ["user_id", "score"])
+ >>> products = spark.createDataFrame(
+ ... [("A", 11.0), ("B", 22.0), ("C", 5.0)], ["product", "pscore"])
+ >>> users.nearestByJoin(
+ ... products, -sf.abs(users.score - products.pscore), 1, "exact",
"similarity"
+ ... ).select("user_id", "product").orderBy("user_id").show()
+ +-------+-------+
+ |user_id|product|
+ +-------+-------+
+ | 1| A|
+ | 2| B|
+ | 3| B|
+ +-------+-------+
+ """
+ ...
+
# TODO(SPARK-22947): Fix the DataFrame API.
@dispatch_df_method
def _joinAsOf(
diff --git a/python/pyspark/sql/tests/connect/test_parity_nearest_by_join.py
b/python/pyspark/sql/tests/connect/test_parity_nearest_by_join.py
new file mode 100644
index 000000000000..1fb0f5b62046
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_nearest_by_join.py
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from pyspark.sql.tests.test_nearest_by_join import NearestByJoinTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class NearestByJoinParityTests(NearestByJoinTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.testing import main
+
+ main()
diff --git a/python/pyspark/sql/tests/test_nearest_by_join.py
b/python/pyspark/sql/tests/test_nearest_by_join.py
new file mode 100644
index 000000000000..fdee3043289e
--- /dev/null
+++ b/python/pyspark/sql/tests/test_nearest_by_join.py
@@ -0,0 +1,270 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from pyspark.errors import AnalysisException
+from pyspark.sql import Row
+from pyspark.sql import functions as sf
+from pyspark.testing import assertDataFrameEqual
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class NearestByJoinTestsMixin:
+ """Mixin run against both classic (`ReusedSQLTestCase`) and Connect
+ (`ReusedConnectTestCase`) to ensure parity between the two paths."""
+
+ @property
+ def users(self):
+ return self.spark.createDataFrame([(1, 10.0), (2, 20.0), (3, 30.0)],
["user_id", "score"])
+
+ @property
+ def products(self):
+ return self.spark.createDataFrame(
+ [("A", 11.0), ("B", 22.0), ("C", 5.0)], ["product", "pscore"]
+ )
+
+ def test_inner_similarity_k1(self):
+ users, products = self.users, self.products
+ result = (
+ users.nearestByJoin(
+ products,
+ -sf.abs(users.score - products.pscore),
+ numResults=1,
+ mode="approx",
+ direction="similarity",
+ )
+ .select("user_id", "product")
+ .orderBy("user_id")
+ )
+ assertDataFrameEqual(
+ result,
+ [Row(user_id=1, product="A"), Row(user_id=2, product="B"),
Row(user_id=3, product="B")],
+ )
+
+ def test_inner_distance_k2(self):
+ users, products = self.users, self.products
+ result = (
+ users.nearestByJoin(
+ products,
+ sf.abs(users.score - products.pscore),
+ numResults=2,
+ mode="approx",
+ direction="distance",
+ )
+ .select("user_id", "product")
+ .orderBy("user_id", "product")
+ )
+ assertDataFrameEqual(
+ result,
+ [
+ Row(user_id=1, product="A"),
+ Row(user_id=1, product="C"),
+ Row(user_id=2, product="A"),
+ Row(user_id=2, product="B"),
+ Row(user_id=3, product="A"),
+ Row(user_id=3, product="B"),
+ ],
+ )
+
+ def test_left_outer_with_empty_right(self):
+ users, products = self.users, self.products
+ empty = products.filter(sf.lit(False))
+ result = (
+ users.nearestByJoin(
+ empty,
+ -sf.abs(users.score - empty.pscore),
+ numResults=1,
+ mode="exact",
+ direction="similarity",
+ joinType="leftouter",
+ )
+ .select("user_id", "product")
+ .orderBy("user_id")
+ )
+ assertDataFrameEqual(
+ result,
+ [
+ Row(user_id=1, product=None),
+ Row(user_id=2, product=None),
+ Row(user_id=3, product=None),
+ ],
+ )
+
+ def test_select_star_schema_has_no_internal_columns(self):
+ users, products = self.users, self.products
+ result = users.nearestByJoin(
+ products,
+ -sf.abs(users.score - products.pscore),
+ numResults=1,
+ mode="exact",
+ direction="similarity",
+ )
+ # No `__qid`, `__nearest_matches__`, or other rewrite-internal columns
leak through.
+ assert sorted(result.columns) == ["product", "pscore", "score",
"user_id"]
+
+ def test_invalid_num_results_low(self):
+ users, products = self.users, self.products
+ with self.assertRaises(AnalysisException) as pe:
+ users.nearestByJoin(
+ products,
+ -sf.abs(users.score - products.pscore),
+ numResults=0,
+ mode="approx",
+ direction="similarity",
+ )
+ self.check_error(
+ exception=pe.exception,
+ errorClass="NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE",
+ messageParameters={"numResults": "0", "min": "1", "max": "100000"},
+ )
+
+ def test_invalid_num_results_high(self):
+ users, products = self.users, self.products
+ with self.assertRaises(AnalysisException) as pe:
+ users.nearestByJoin(
+ products,
+ -sf.abs(users.score - products.pscore),
+ numResults=200000,
+ mode="approx",
+ direction="similarity",
+ )
+ self.check_error(
+ exception=pe.exception,
+ errorClass="NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE",
+ messageParameters={"numResults": "200000", "min": "1", "max":
"100000"},
+ )
+
+ def test_invalid_join_type(self):
+ users, products = self.users, self.products
+ with self.assertRaises(AnalysisException) as pe:
+ users.nearestByJoin(
+ products,
+ -sf.abs(users.score - products.pscore),
+ numResults=1,
+ mode="approx",
+ direction="similarity",
+ joinType="outer",
+ )
+ self.check_error(
+ exception=pe.exception,
+ errorClass="NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE",
+ messageParameters={"joinType": "outer", "supported": "'INNER',
'LEFT OUTER'"},
+ )
+
+ def test_invalid_mode(self):
+ users, products = self.users, self.products
+ with self.assertRaises(AnalysisException) as pe:
+ users.nearestByJoin(
+ products,
+ -sf.abs(users.score - products.pscore),
+ numResults=1,
+ mode="bogus",
+ direction="similarity",
+ )
+ self.check_error(
+ exception=pe.exception,
+ errorClass="NEAREST_BY_JOIN.UNSUPPORTED_MODE",
+ messageParameters={"mode": "bogus", "supported": "'approx',
'exact'"},
+ )
+
+ def test_invalid_direction(self):
+ users, products = self.users, self.products
+ with self.assertRaises(AnalysisException) as pe:
+ users.nearestByJoin(
+ products,
+ -sf.abs(users.score - products.pscore),
+ numResults=1,
+ mode="approx",
+ direction="elsewhere",
+ )
+ self.check_error(
+ exception=pe.exception,
+ errorClass="NEAREST_BY_JOIN.UNSUPPORTED_DIRECTION",
+ messageParameters={
+ "direction": "elsewhere",
+ "supported": "'distance', 'similarity'",
+ },
+ )
+
+ def test_rejected_when_crossjoin_disabled(self):
+ users, products = self.users, self.products
+ with self.sql_conf({"spark.sql.crossJoin.enabled": "false"}):
+ with self.assertRaises(AnalysisException) as pe:
+ users.nearestByJoin(
+ products,
+ -sf.abs(users.score - products.pscore),
+ numResults=1,
+ mode="exact",
+ direction="similarity",
+ ).collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="NEAREST_BY_JOIN.CROSS_JOIN_NOT_ENABLED",
+ messageParameters={},
+ )
+
+ def test_exact_with_nondeterministic_ranking_rejected(self):
+ users, products = self.users, self.products
+ # Use an explicit seed (`rand(0)`) so the rendered expression in the
error message is
+ # byte-stable. Without it, Spark assigns a random seed at analysis and
the message
+ # parameter becomes `"(rand(<random-long>) + pscore)"`, which can't be
asserted on.
+ with self.assertRaises(AnalysisException) as pe:
+ users.nearestByJoin(
+ products,
+ sf.rand(0) + products.pscore,
+ numResults=1,
+ mode="exact",
+ direction="similarity",
+ ).collect()
+ self.check_error(
+ exception=pe.exception,
+
errorClass="NEAREST_BY_JOIN.EXACT_WITH_NONDETERMINISTIC_EXPRESSION",
+ messageParameters={"expression": '"(rand(0) + pscore)"'},
+ )
+
+ def test_streaming_inputs_rejected(self):
+ streaming_users = (
+ self.spark.readStream.format("rate")
+ .option("rowsPerSecond", 1)
+ .load()
+ .selectExpr("CAST(value AS INT) AS user_id", "CAST(value AS
DOUBLE) AS score")
+ )
+ products = self.products
+ with self.assertRaises(AnalysisException) as pe:
+ # `.schema` forces analysis without starting the streaming query.
+ _ = streaming_users.nearestByJoin(
+ products,
+ -sf.abs(streaming_users.score - products.pscore),
+ numResults=1,
+ mode="exact",
+ direction="similarity",
+ ).schema
+ self.check_error(
+ exception=pe.exception,
+ errorClass="NEAREST_BY_JOIN.STREAMING_NOT_SUPPORTED",
+ messageParameters={},
+ )
+
+
+class NearestByJoinTests(NearestByJoinTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.testing import main
+
+ main()
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
index c3c983c17bb0..38765262e1fc 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -912,6 +912,76 @@ abstract class Dataset[T] extends Serializable {
*/
def lateralJoin(right: Dataset[_], joinExprs: Column, joinType: String):
DataFrame
+ /**
+ * Nearest-by top-K ranking join with another `DataFrame`, using the default
`inner` join type.
+ * For each row on the left (query side), returns up to `numResults` rows
from `right` (base
+ * side), ranked by `rankingExpression`.
+ *
+ * Equivalent SQL (with `mode = "exact"` and `direction = "similarity"`):
+ * {{{
+ * left INNER JOIN right EXACT NEAREST numResults BY SIMILARITY
rankingExpression
+ * }}}
+ *
+ * The current implementation evaluates the full cross-product of left and
right and bounds
+ * memory per left row by `numResults`. Index-backed approximate strategies
(transparent to
+ * `approx` mode) are planned for a future release; until then, pre-filter
the right side when
+ * it is large. Tie-breaking among rows with equal ranking values is
unspecified.
+ *
+ * @param right
+ * Right (base side) of the join - the candidate pool searched for each
row of this Dataset.
+ * @param rankingExpression
+ * Scalar expression used to rank candidate rows.
+ * @param numResults
+ * Maximum number of matches per query row. Must be between 1 and 100000.
+ * @param mode
+ * Search algorithm contract. Must be one of: `approx`, `exact`. `approx`
allows the optimizer
+ * to use indexed or other approximate strategies when available; `exact`
forces brute-force
+ * evaluation and requires the ranking expression to be deterministic.
+ * @param direction
+ * `"distance"` (smallest value first) or `"similarity"` (largest value
first).
+ * @group untypedrel
+ * @since 4.2.0
+ */
+ def nearestByJoin(
+ right: Dataset[_],
+ rankingExpression: Column,
+ numResults: Int,
+ mode: String,
+ direction: String): DataFrame
+
+ /**
+ * Nearest-by top-K ranking join with another `DataFrame`.
+ *
+ * The current implementation evaluates the full cross-product of left and
right and bounds
+ * memory per left row by `numResults`. Index-backed approximate strategies
(transparent to
+ * `approx` mode) are planned for a future release; until then, pre-filter
the right side when
+ * it is large. Tie-breaking among rows with equal ranking values is
unspecified.
+ *
+ * @param right
+ * Right (base side) of the join - the candidate pool searched for each
row of this Dataset.
+ * @param rankingExpression
+ * Scalar expression used to rank candidate rows.
+ * @param numResults
+ * Maximum number of matches per query row. Must be between 1 and 100000.
+ * @param mode
+ * Search algorithm contract. Must be one of: `approx`, `exact`. `approx`
allows the optimizer
+ * to use indexed or other approximate strategies when available; `exact`
forces brute-force
+ * evaluation and requires the ranking expression to be deterministic.
+ * @param direction
+ * `"distance"` (smallest value first) or `"similarity"` (largest value
first).
+ * @param joinType
+ * Type of join to perform. Must be one of: `inner`, `leftouter`.
+ * @group untypedrel
+ * @since 4.2.0
+ */
+ def nearestByJoin(
+ right: Dataset[_],
+ rankingExpression: Column,
+ numResults: Int,
+ mode: String,
+ direction: String,
+ joinType: String): DataFrame
+
protected def sortInternal(global: Boolean, sortExprs: Seq[Column]):
Dataset[T]
/**
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/NearestByJoinValidation.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/NearestByJoinValidation.scala
new file mode 100644
index 000000000000..8ebac8e73c67
--- /dev/null
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/NearestByJoinValidation.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+/**
+ * Acceptance lists for the `NEAREST BY` join API.
+ */
+private[sql] object NearestByJoinValidation {
+
+ /** Upper bound on `numResults`. Mirrors the K-overload limit of
`MaxMinByK`. */
+ val MaxNumResults: Int = 100000
+
+ /**
+ * Strings accepted by `joinType` after lower-casing and stripping `_` (so
e.g. `LEFT_OUTER`
+ * canonicalizes to `leftouter`). Every consumer must apply the same
canonicalization before
+ * checking membership.
+ */
+ val SupportedJoinTypes: Seq[String] = Seq("inner", "leftouter", "left")
+
+ /** Display form for `supported` in `NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE`
error messages. */
+ val SupportedJoinTypeDisplay: String = "'INNER', 'LEFT OUTER'"
+
+ /** Strings accepted by `mode`. Lower-cased before membership check. */
+ val SupportedModes: Seq[String] = Seq("approx", "exact")
+
+ /** Strings accepted by `direction`. Lower-cased before membership check. */
+ val SupportedDirections: Seq[String] = Seq("distance", "similarity")
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 569cd05a46ba..790307e44ec9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -184,7 +184,8 @@ object LateralJoinType {
object NearestByDirection {
- val supported = Seq("distance", "similarity")
+ /** @see [[NearestByJoinValidation.SupportedDirections]] */
+ val supported: Seq[String] = NearestByJoinValidation.SupportedDirections
def apply(direction: String): NearestByDirection = {
direction.toLowerCase(Locale.ROOT) match {
@@ -207,13 +208,11 @@ case object NearestBySimilarity extends NearestByDirection
object NearestByJoinType {
- /** Strings accepted by the Dataset API. */
- val supported = Seq("inner", "leftouter", "left", "left_outer")
+ /** @see [[NearestByJoinValidation.SupportedJoinTypes]] */
+ val supported: Seq[String] = NearestByJoinValidation.SupportedJoinTypes
- /** Display string used in `NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE` error
messages. Matches the
- * parser-side wording so the same error class reports the same `supported`
value across the
- * SQL and DataFrame paths. */
- val supportedDisplay = "'INNER', 'LEFT OUTER'"
+ /** @see [[NearestByJoinValidation.SupportedJoinTypeDisplay]] */
+ val supportedDisplay: String =
NearestByJoinValidation.SupportedJoinTypeDisplay
def apply(typ: String): JoinType = typ.toLowerCase(Locale.ROOT).replace("_",
"") match {
case "inner" => Inner
@@ -229,7 +228,8 @@ object NearestByJoinType {
object NearestByJoinMode {
- val supported = Seq("approx", "exact")
+ /** @see [[NearestByJoinValidation.SupportedModes]] */
+ val supported: Seq[String] = NearestByJoinValidation.SupportedModes
/** Returns true for APPROX, false for EXACT. */
def apply(mode: String): Boolean = mode.toLowerCase(Locale.ROOT) match {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala
index 9df79ba128b8..6a5c94d4a1df 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NearestByJoin.scala
@@ -18,12 +18,12 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter,
NearestByDirection}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter,
NearestByDirection, NearestByJoinValidation}
import org.apache.spark.sql.catalyst.trees.TreePattern._
object NearestByJoin {
- /** Upper bound on `numResults`. Mirrors the K-overload limit of
`MaxMinByK`. */
- val MaxNumResults: Int = 100000
+ /** @see [[NearestByJoinValidation.MaxNumResults]] */
+ val MaxNumResults: Int = NearestByJoinValidation.MaxNumResults
}
/**
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNearestByJoinSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNearestByJoinSuite.scala
new file mode 100644
index 000000000000..00d7c4f80b09
--- /dev/null
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNearestByJoinSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
+import org.apache.spark.sql.functions._
+
+/**
+ * End-to-end Connect-side coverage for `Dataset.nearestByJoin`. Mirrors the
+ * `DataFrameNearestByJoinSuite` in `sql/core` for the classic path; this
suite ensures the same
+ * API behaves correctly when invoked through the Connect client (proto
serialization, server-side
+ * proto-to-catalyst translation in
`SparkConnectPlanner.transformNearestByJoin`, and result
+ * roundtrip).
+ */
+class DataFrameNearestByJoinSuite extends QueryTest with RemoteSparkSession {
+ import testImplicits._
+
+ private lazy val users = Seq((1, 10.0), (2, 20.0), (3,
30.0)).toDF("user_id", "score")
+
+ private lazy val products = Seq(("A", 11.0), ("B", 22.0), ("C",
5.0)).toDF("product", "pscore")
+
+ test("inner approx similarity k=1") {
+ checkAnswer(
+ users
+ .nearestByJoin(
+ right = products,
+ rankingExpression = -abs(users("score") - products("pscore")),
+ numResults = 1,
+ mode = "approx",
+ direction = "similarity")
+ .select("user_id", "product")
+ .orderBy("user_id"),
+ Seq(Row(1, "A"), Row(2, "B"), Row(3, "B")))
+ }
+
+ test("inner approx distance k=2") {
+ checkAnswer(
+ users
+ .nearestByJoin(
+ right = products,
+ rankingExpression = abs(users("score") - products("pscore")),
+ numResults = 2,
+ mode = "approx",
+ direction = "distance")
+ .select("user_id", "product")
+ .orderBy("user_id", "product"),
+ Seq(Row(1, "A"), Row(1, "C"), Row(2, "A"), Row(2, "B"), Row(3, "A"),
Row(3, "B")))
+ }
+
+ test("left outer with empty right preserves left rows with NULLs") {
+ val emptyProducts = products.filter(lit(false))
+ checkAnswer(
+ users
+ .nearestByJoin(
+ right = emptyProducts,
+ rankingExpression = -abs(users("score") - emptyProducts("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "similarity",
+ joinType = "leftouter")
+ .select("user_id", "product")
+ .orderBy("user_id"),
+ Seq(Row(1, null), Row(2, null), Row(3, null)))
+ }
+
+ test("output schema has no rewrite-internal columns") {
+ val result = users.nearestByJoin(
+ right = products,
+ rankingExpression = -abs(users("score") - products("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "similarity")
+ // Only the user-visible columns flow through; no `__qid`,
`__nearest_matches__`, etc.
+ assert(result.columns.toSet === Set("user_id", "score", "product",
"pscore"))
+ }
+
+ test("invalid mode is rejected") {
+ val ex = intercept[AnalysisException] {
+ users.nearestByJoin(
+ right = products,
+ rankingExpression = -abs(users("score") - products("pscore")),
+ numResults = 1,
+ mode = "bogus",
+ direction = "similarity")
+ }
+ assert(ex.getCondition === "NEAREST_BY_JOIN.UNSUPPORTED_MODE")
+ }
+}
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 16a2bf85de4a..199736da92ac 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -516,6 +516,29 @@ class PlanGenerationTestSuite extends ConnectFunSuite with
Logging {
left.crossJoin(right)
}
+ test("nearestByJoin inner_approx_similarity") {
+ left
+ .as("l")
+ .nearestByJoin(
+ right = right.as("r"),
+ rankingExpression = fn.col("l.a") + fn.col("r.a"),
+ numResults = 1,
+ mode = "approx",
+ direction = "similarity")
+ }
+
+ test("nearestByJoin leftouter_exact_distance") {
+ left
+ .as("l")
+ .nearestByJoin(
+ right = right.as("r"),
+ rankingExpression = fn.col("l.a") + fn.col("r.a"),
+ numResults = 5,
+ mode = "exact",
+ direction = "distance",
+ joinType = "leftouter")
+ }
+
test("sortWithinPartitions strings") {
simple.sortWithinPartitions("a", "id")
}
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
index 57c4ed7be3c8..95cc9281d8ca 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -82,6 +82,7 @@ message Relation {
LateralJoin lateral_join = 44;
ChunkedCachedLocalRelation chunked_cached_local_relation = 45;
RelationChanges relation_changes = 46;
+ NearestByJoin nearest_by_join = 47;
// NA functions
NAFill fill_na = 90;
@@ -1276,3 +1277,33 @@ message LateralJoin {
// (Required) The join type.
Join.JoinType join_type = 4;
}
+
+// Relation of type [[NearestByJoin]].
+//
+// For each row on the left side, returns up to `num_results` rows from the
right side ranked
+// by `ranking_expression`.
+message NearestByJoin {
+ // (Required) Left (query) input relation.
+ Relation left = 1;
+
+ // (Required) Right (base) input relation.
+ Relation right = 2;
+
+ // (Required) Scalar expression used to rank candidate rows on the right
side.
+ Expression ranking_expression = 3;
+
+ // (Required) Maximum number of matches per left row. Must be between 1 and
100000.
+ int32 num_results = 4;
+
+ // The following three fields use `string` (not typed enums) for parity with
`AsOfJoin`,
+ // which models analogous fields the same way. Validation happens
server-side at planning time.
+
+ // (Required) The join type. Must be one of: "inner", "leftouter".
+ string join_type = 5;
+
+ // (Required) Search algorithm contract. Must be one of: "approx", "exact".
+ string mode = 6;
+
+ // (Required) Ranking direction. Must be one of: "distance", "similarity".
+ string direction = 7;
+}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
index b57ea66bb1f7..34c685213711 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
+import org.apache.spark.sql.catalyst.plans.NearestByJoinValidation
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr,
toLiteral, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkResult
@@ -421,6 +422,52 @@ class Dataset[T] private[sql] (
lateralJoin(right, Some(joinExprs), joinType)
}
+ private def nearestByJoinImpl(
+ right: sql.Dataset[_],
+ rankingExpression: Column,
+ numResults: Int,
+ joinType: String,
+ mode: String,
+ direction: String): DataFrame = {
+ // Validate locally so Connect users see the same errors as the classic
path without a
+ // server round-trip. The validation logic mirrors
`NearestByJoinType.apply` /
+ // `NearestByJoinMode.apply` / `NearestByDirection.apply` in sql/catalyst,
which
+ // `sql/connect/common` cannot import; the acceptance lists themselves are
shared via
+ // `NearestByJoinValidation` in sql-api.
+ Dataset.validateNearestByJoinArgs(numResults, joinType, mode, direction)
+ sparkSession.newDataFrame(Seq(rankingExpression)) { builder =>
+ builder.getNearestByJoinBuilder
+ .setLeft(plan.getRoot)
+ .setRight(right.plan.getRoot)
+ .setRankingExpression(toExpr(rankingExpression))
+ .setNumResults(numResults)
+ .setJoinType(joinType)
+ .setMode(mode)
+ .setDirection(direction)
+ }
+ }
+
+ /** @inheritdoc */
+ def nearestByJoin(
+ right: sql.Dataset[_],
+ rankingExpression: Column,
+ numResults: Int,
+ mode: String,
+ direction: String): DataFrame = {
+ nearestByJoinImpl(right, rankingExpression, numResults, "inner", mode,
direction)
+ }
+
+ /** @inheritdoc */
+ def nearestByJoin(
+ right: sql.Dataset[_],
+ rankingExpression: Column,
+ numResults: Int,
+ mode: String,
+ direction: String,
+ joinType: String): DataFrame = {
+ nearestByJoinImpl(right, rankingExpression, numResults, joinType, mode,
direction)
+ }
+
override protected def sortInternal(global: Boolean, sortCols: Seq[Column]):
Dataset[T] = {
val sortExprs = sortCols.map { c =>
ColumnNodeToProtoConverter(c.sortOrder).getSortOrder
@@ -1569,3 +1616,47 @@ class Dataset[T] private[sql] (
override def queryExecution: QueryExecution =
throw ConnectClientUnsupportedErrors.queryExecution()
}
+
+private[sql] object Dataset {
+
+ private[connect] def validateNearestByJoinArgs(
+ numResults: Int,
+ joinType: String,
+ mode: String,
+ direction: String): Unit = {
+ if (numResults < 1 || numResults > NearestByJoinValidation.MaxNumResults) {
+ throw new AnalysisException(
+ errorClass = "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE",
+ messageParameters = Map(
+ "numResults" -> numResults.toString,
+ "min" -> "1",
+ "max" -> NearestByJoinValidation.MaxNumResults.toString))
+ }
+ val canonicalJoinType =
joinType.toLowerCase(java.util.Locale.ROOT).replace("_", "")
+ if
(!NearestByJoinValidation.SupportedJoinTypes.contains(canonicalJoinType)) {
+ throw new AnalysisException(
+ errorClass = "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE",
+ messageParameters = Map(
+ "joinType" -> joinType,
+ "supported" -> NearestByJoinValidation.SupportedJoinTypeDisplay))
+ }
+ if (!NearestByJoinValidation.SupportedModes.contains(
+ mode.toLowerCase(java.util.Locale.ROOT))) {
+ throw new AnalysisException(
+ errorClass = "NEAREST_BY_JOIN.UNSUPPORTED_MODE",
+ messageParameters = Map(
+ "mode" -> mode,
+ "supported" ->
+ NearestByJoinValidation.SupportedModes.mkString("'", "', '", "'")))
+ }
+ if (!NearestByJoinValidation.SupportedDirections.contains(
+ direction.toLowerCase(java.util.Locale.ROOT))) {
+ throw new AnalysisException(
+ errorClass = "NEAREST_BY_JOIN.UNSUPPORTED_DIRECTION",
+ messageParameters = Map(
+ "direction" -> direction,
+ "supported" ->
+ NearestByJoinValidation.SupportedDirections.mkString("'", "', '",
"'")))
+ }
+ }
+}
diff --git
a/sql/connect/common/src/test/resources/query-tests/explain-results/nearestByJoin_inner_approx_similarity.explain
b/sql/connect/common/src/test/resources/query-tests/explain-results/nearestByJoin_inner_approx_similarity.explain
new file mode 100644
index 000000000000..8e3750b4c4a7
--- /dev/null
+++
b/sql/connect/common/src/test/resources/query-tests/explain-results/nearestByJoin_inner_approx_similarity.explain
@@ -0,0 +1,5 @@
+'NearestByJoin Inner, true, 1, (a#0 + a#0), NearestBySimilarity
+:- SubqueryAlias l
+: +- LocalRelation <empty>, [id#0L, a#0, b#0]
++- SubqueryAlias r
+ +- LocalRelation <empty>, [a#0, id#0L, payload#0]
diff --git
a/sql/connect/common/src/test/resources/query-tests/explain-results/nearestByJoin_leftouter_exact_distance.explain
b/sql/connect/common/src/test/resources/query-tests/explain-results/nearestByJoin_leftouter_exact_distance.explain
new file mode 100644
index 000000000000..67539c3964b1
--- /dev/null
+++
b/sql/connect/common/src/test/resources/query-tests/explain-results/nearestByJoin_leftouter_exact_distance.explain
@@ -0,0 +1,5 @@
+'NearestByJoin LeftOuter, false, 5, (a#0 + a#0), NearestByDistance
+:- SubqueryAlias l
+: +- LocalRelation <empty>, [id#0L, a#0, b#0]
++- SubqueryAlias r
+ +- LocalRelation <empty>, [a#0, id#0L, payload#0]
diff --git
a/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_inner_approx_similarity.json
b/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_inner_approx_similarity.json
new file mode 100644
index 000000000000..ca4f2919e55c
--- /dev/null
+++
b/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_inner_approx_similarity.json
@@ -0,0 +1,109 @@
+{
+ "common": {
+ "planId": "4"
+ },
+ "nearestByJoin": {
+ "left": {
+ "common": {
+ "planId": "1"
+ },
+ "subqueryAlias": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+ }
+ },
+ "alias": "l"
+ }
+ },
+ "right": {
+ "common": {
+ "planId": "3"
+ },
+ "subqueryAlias": {
+ "input": {
+ "common": {
+ "planId": "2"
+ },
+ "localRelation": {
+ "schema": "struct\u003ca:int,id:bigint,payload:binary\u003e"
+ }
+ },
+ "alias": "r"
+ }
+ },
+ "rankingExpression": {
+ "unresolvedFunction": {
+ "functionName": "+",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "l.a"
+ },
+ "common": {
+ "origin": {
+ "jvmOrigin": {
+ "stackTrace": [{
+ "classLoaderName": "app",
+ "declaringClass": "org.apache.spark.sql.functions$",
+ "methodName": "col",
+ "fileName": "functions.scala"
+ }, {
+ "classLoaderName": "app",
+ "declaringClass":
"org.apache.spark.sql.PlanGenerationTestSuite",
+ "methodName": "~~trimmed~anonfun~~",
+ "fileName": "PlanGenerationTestSuite.scala"
+ }]
+ }
+ }
+ }
+ }, {
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "r.a"
+ },
+ "common": {
+ "origin": {
+ "jvmOrigin": {
+ "stackTrace": [{
+ "classLoaderName": "app",
+ "declaringClass": "org.apache.spark.sql.functions$",
+ "methodName": "col",
+ "fileName": "functions.scala"
+ }, {
+ "classLoaderName": "app",
+ "declaringClass":
"org.apache.spark.sql.PlanGenerationTestSuite",
+ "methodName": "~~trimmed~anonfun~~",
+ "fileName": "PlanGenerationTestSuite.scala"
+ }]
+ }
+ }
+ }
+ }],
+ "isInternal": false
+ },
+ "common": {
+ "origin": {
+ "jvmOrigin": {
+ "stackTrace": [{
+ "classLoaderName": "app",
+ "declaringClass": "org.apache.spark.sql.Column",
+ "methodName": "$plus",
+ "fileName": "Column.scala"
+ }, {
+ "classLoaderName": "app",
+ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+ "methodName": "~~trimmed~anonfun~~",
+ "fileName": "PlanGenerationTestSuite.scala"
+ }]
+ }
+ }
+ }
+ },
+ "numResults": 1,
+ "joinType": "inner",
+ "mode": "approx",
+ "direction": "similarity"
+ }
+}
\ No newline at end of file
diff --git
a/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_inner_approx_similarity.proto.bin
b/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_inner_approx_similarity.proto.bin
new file mode 100644
index 000000000000..8dbeb994d8fc
Binary files /dev/null and
b/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_inner_approx_similarity.proto.bin
differ
diff --git
a/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_leftouter_exact_distance.json
b/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_leftouter_exact_distance.json
new file mode 100644
index 000000000000..877bff8f90c8
--- /dev/null
+++
b/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_leftouter_exact_distance.json
@@ -0,0 +1,109 @@
+{
+ "common": {
+ "planId": "4"
+ },
+ "nearestByJoin": {
+ "left": {
+ "common": {
+ "planId": "1"
+ },
+ "subqueryAlias": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+ }
+ },
+ "alias": "l"
+ }
+ },
+ "right": {
+ "common": {
+ "planId": "3"
+ },
+ "subqueryAlias": {
+ "input": {
+ "common": {
+ "planId": "2"
+ },
+ "localRelation": {
+ "schema": "struct\u003ca:int,id:bigint,payload:binary\u003e"
+ }
+ },
+ "alias": "r"
+ }
+ },
+ "rankingExpression": {
+ "unresolvedFunction": {
+ "functionName": "+",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "l.a"
+ },
+ "common": {
+ "origin": {
+ "jvmOrigin": {
+ "stackTrace": [{
+ "classLoaderName": "app",
+ "declaringClass": "org.apache.spark.sql.functions$",
+ "methodName": "col",
+ "fileName": "functions.scala"
+ }, {
+ "classLoaderName": "app",
+ "declaringClass":
"org.apache.spark.sql.PlanGenerationTestSuite",
+ "methodName": "~~trimmed~anonfun~~",
+ "fileName": "PlanGenerationTestSuite.scala"
+ }]
+ }
+ }
+ }
+ }, {
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "r.a"
+ },
+ "common": {
+ "origin": {
+ "jvmOrigin": {
+ "stackTrace": [{
+ "classLoaderName": "app",
+ "declaringClass": "org.apache.spark.sql.functions$",
+ "methodName": "col",
+ "fileName": "functions.scala"
+ }, {
+ "classLoaderName": "app",
+ "declaringClass":
"org.apache.spark.sql.PlanGenerationTestSuite",
+ "methodName": "~~trimmed~anonfun~~",
+ "fileName": "PlanGenerationTestSuite.scala"
+ }]
+ }
+ }
+ }
+ }],
+ "isInternal": false
+ },
+ "common": {
+ "origin": {
+ "jvmOrigin": {
+ "stackTrace": [{
+ "classLoaderName": "app",
+ "declaringClass": "org.apache.spark.sql.Column",
+ "methodName": "$plus",
+ "fileName": "Column.scala"
+ }, {
+ "classLoaderName": "app",
+ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+ "methodName": "~~trimmed~anonfun~~",
+ "fileName": "PlanGenerationTestSuite.scala"
+ }]
+ }
+ }
+ }
+ },
+ "numResults": 5,
+ "joinType": "leftouter",
+ "mode": "exact",
+ "direction": "distance"
+ }
+}
\ No newline at end of file
diff --git
a/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_leftouter_exact_distance.proto.bin
b/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_leftouter_exact_distance.proto.bin
new file mode 100644
index 000000000000..a671071c556e
Binary files /dev/null and
b/sql/connect/common/src/test/resources/query-tests/queries/nearestByJoin_leftouter_exact_distance.proto.bin
differ
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 12d0c1ce12a4..dff80cb24268 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -159,6 +159,8 @@ class SparkConnectPlanner(
case proto.Relation.RelTypeCase.JOIN =>
transformJoinOrJoinWith(rel.getJoin)
case proto.Relation.RelTypeCase.AS_OF_JOIN =>
transformAsOfJoin(rel.getAsOfJoin)
case proto.Relation.RelTypeCase.LATERAL_JOIN =>
transformLateralJoin(rel.getLateralJoin)
+ case proto.Relation.RelTypeCase.NEAREST_BY_JOIN =>
+ transformNearestByJoin(rel.getNearestByJoin)
case proto.Relation.RelTypeCase.DEDUPLICATE =>
transformDeduplicate(rel.getDeduplicate)
case proto.Relation.RelTypeCase.SET_OP =>
transformSetOperation(rel.getSetOp)
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
@@ -2567,6 +2569,28 @@ class SparkConnectPlanner(
condition = joinCondition)
}
+ private def transformNearestByJoin(rel: proto.NearestByJoin): LogicalPlan = {
+ assertPlan(rel.hasLeft && rel.hasRight, "Both join sides must be present")
+ assertPlan(rel.hasRankingExpression, "Ranking expression must be present")
+ // proto3 string fields default to "" when not set; reject the empty case
explicitly so the
+ // user sees a "must be set" error instead of a misleading "unsupported
value" error.
+ assertPlan(rel.getJoinType.nonEmpty, "NearestByJoin.join_type must be set")
+ assertPlan(rel.getMode.nonEmpty, "NearestByJoin.mode must be set")
+ assertPlan(rel.getDirection.nonEmpty, "NearestByJoin.direction must be
set")
+ val left = Dataset.ofRows(session, transformRelation(rel.getLeft))
+ val right = Dataset.ofRows(session, transformRelation(rel.getRight))
+ val rankingExpression =
Column(transformExpression(rel.getRankingExpression))
+ left
+ .nearestByJoin(
+ right,
+ rankingExpression,
+ rel.getNumResults,
+ rel.getMode,
+ rel.getDirection,
+ rel.getJoinType)
+ .logicalPlan
+ }
+
private def transformSort(sort: proto.Sort): LogicalPlan = {
assertPlan(sort.getOrderCount > 0, "'order' must be present and contain
elements.")
logical.Sort(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index 91d51163b319..d83a4df51cd5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -764,6 +764,66 @@ class Dataset[T] private[sql](
lateralJoin(right, Some(joinExprs), LateralJoinType(joinType))
}
+ private[sql] def nearestByJoin(
+ right: sql.Dataset[_],
+ rankingExpression: Column,
+ numResults: Int,
+ joinType: JoinType,
+ approx: Boolean,
+ direction: NearestByDirection): DataFrame = {
+ if (numResults < 1 || numResults > NearestByJoin.MaxNumResults) {
+ throw new AnalysisException(
+ errorClass = "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE",
+ messageParameters = Map(
+ "numResults" -> numResults.toString,
+ "min" -> "1",
+ "max" -> NearestByJoin.MaxNumResults.toString))
+ }
+ withPlan {
+ NearestByJoin(
+ logicalPlan,
+ right.logicalPlan,
+ joinType,
+ approx,
+ numResults,
+ rankingExpression.expr,
+ direction)
+ }
+ }
+
+ /** @inheritdoc */
+ def nearestByJoin(
+ right: sql.Dataset[_],
+ rankingExpression: Column,
+ numResults: Int,
+ mode: String,
+ direction: String): DataFrame = {
+ nearestByJoin(
+ right,
+ rankingExpression,
+ numResults,
+ Inner,
+ NearestByJoinMode(mode),
+ NearestByDirection(direction))
+ }
+
+ /** @inheritdoc */
+ def nearestByJoin(
+ right: sql.Dataset[_],
+ rankingExpression: Column,
+ numResults: Int,
+ mode: String,
+ direction: String,
+ joinType: String): DataFrame = {
+ nearestByJoin(
+ right,
+ rankingExpression,
+ numResults,
+ NearestByJoinType(joinType),
+ NearestByJoinMode(mode),
+ NearestByDirection(direction))
+ }
+
// TODO(SPARK-22947): Fix the DataFrame API.
private[sql] def joinAsOf(
other: Dataset[_],
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNearestByJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNearestByJoinSuite.scala
new file mode 100644
index 000000000000..b34880b71f5b
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNearestByJoinSuite.scala
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.plans.{NearestByDirection,
NearestByJoinMode, NearestByJoinType}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.tags.SlowSQLTest
+
+@SlowSQLTest
+class DataFrameNearestByJoinSuite extends QueryTest with SharedSparkSession {
+
+ private def prepareForNearestByJoin(): (classic.DataFrame,
classic.DataFrame) = {
+ val users = spark.createDataFrame(
+ Seq((1, 10.0), (2, 20.0), (3, 30.0))).toDF("user_id", "score")
+ val products = spark.createDataFrame(
+ Seq(("A", 11.0), ("B", 22.0), ("C", 5.0))).toDF("product", "pscore")
+ (users, products)
+ }
+
+ test("similarity, inner, k=1") {
+ val (users, products) = prepareForNearestByJoin()
+ val result = users.nearestByJoin(
+ products,
+ -abs(users("score") - products("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "similarity")
+
+ checkAnswer(
+ result.select("user_id", "product").orderBy("user_id"),
+ Seq(Row(1, "A"), Row(2, "B"), Row(3, "B"))
+ )
+ }
+
+ test("distance, inner, k=2") {
+ val (users, products) = prepareForNearestByJoin()
+ val result = users.nearestByJoin(
+ products,
+ abs(users("score") - products("pscore")),
+ numResults = 2,
+ mode = "exact",
+ direction = "distance")
+
+ // For each user_id, closest 2 by |score - pscore|:
+ // user 1 (10): A (|10-11|=1), C (|10-5|=5)
+ // user 2 (20): B (|20-22|=2), A (|20-11|=9)
+ // user 3 (30): B (|30-22|=8), A (|30-11|=19)
+ checkAnswer(
+ result.select("user_id", "product").orderBy("user_id", "product"),
+ Seq(
+ Row(1, "A"), Row(1, "C"),
+ Row(2, "A"), Row(2, "B"),
+ Row(3, "A"), Row(3, "B"))
+ )
+ }
+
+ test("left outer when right side is empty") {
+ val (users, products) = prepareForNearestByJoin()
+ val emptyProducts = products.filter(lit(false))
+ val result = users.nearestByJoin(
+ emptyProducts,
+ -abs(users("score") - emptyProducts("pscore")),
+ numResults = 1,
+ joinType = "leftouter",
+ mode = "approx",
+ direction = "similarity")
+
+ checkAnswer(
+ result.select("user_id", "product").orderBy("user_id"),
+ Seq(Row(1, null), Row(2, null), Row(3, null))
+ )
+ }
+
+ test("inner drops left rows with no matches") {
+ val (users, products) = prepareForNearestByJoin()
+ val emptyProducts = products.filter(lit(false))
+ val result = users.nearestByJoin(
+ emptyProducts,
+ -abs(users("score") - emptyProducts("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "similarity")
+
+ assert(result.count() === 0)
+ }
+
+ test("self-join: each row finds nearest other rows in the same DataFrame") {
+ val (users, _) = prepareForNearestByJoin()
+ // We pass `users` as both sides; DeduplicateRelations rewrites the right
side to
+ // generate fresh ExprIds, so the join resolves. Both `users("score")`
references in
+ // the ranking expression bind to the original (left) attribute, so the
rank is
+ // identically 0 for every candidate -- this test exercises self-join
resolution,
+ // not nearest-row selection.
+ val result = users.nearestByJoin(
+ users,
+ -abs(users("score") - users("score")),
+ numResults = 2,
+ mode = "exact",
+ direction = "similarity")
+
+ // 3 users x 2 nearest = 6 rows; output schema has user_id and score from
both sides.
+ assert(result.count() === 6)
+ assert(result.columns.length === 4)
+ }
+
+ test("inner: NULL ranking values for all candidates drops the left row") {
+ // Construct a left side where every comparison yields NULL: a NULL score
on the left makes
+ // `abs(left.score - right.pscore)` evaluate to NULL for every right row,
so MaxMinByK skips
+ // every candidate (its `ord == null` early-return path) and the heap
stays empty. With INNER,
+ // the left row is dropped entirely.
+ val users = spark.createDataFrame(
+ Seq[(Int, java.lang.Double)]((1, null), (2, 20.0d))).toDF("user_id",
"score")
+ val products = spark.createDataFrame(
+ Seq(("A", 11.0), ("B", 22.0))).toDF("product", "pscore")
+
+ val result = users.nearestByJoin(
+ products,
+ abs(users("score") - products("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "distance")
+
+ // Only user 2 should appear; user 1 (NULL score) drops because no
candidate has a
+ // non-null ranking value.
+ checkAnswer(
+ result.select("user_id", "product"),
+ Seq(Row(2, "B"))
+ )
+ }
+
+ test("left outer: NULL ranking values for all candidates preserves left with
NULLs") {
+ // Same shape as the previous test, but LEFT OUTER preserves user 1 with
NULL right-side
+ // columns instead of dropping it.
+ val users = spark.createDataFrame(
+ Seq[(Int, java.lang.Double)]((1, null), (2, 20.0d))).toDF("user_id",
"score")
+ val products = spark.createDataFrame(
+ Seq(("A", 11.0), ("B", 22.0))).toDF("product", "pscore")
+
+ val result = users.nearestByJoin(
+ products,
+ abs(users("score") - products("pscore")),
+ numResults = 1,
+ joinType = "leftouter",
+ mode = "exact",
+ direction = "distance")
+
+ checkAnswer(
+ result.select("user_id", "product").orderBy("user_id"),
+ Seq(Row(1, null), Row(2, "B"))
+ )
+ }
+
+ test("numResults larger than right side returns min(k, available) per left
row") {
+ // Right side has 3 rows; ask for 5. Each left row should get exactly 3
matches, not 5
+ // padded with NULLs.
+ val (users, products) = prepareForNearestByJoin()
+ val result = users.nearestByJoin(
+ products,
+ abs(users("score") - products("pscore")),
+ numResults = 5,
+ mode = "exact",
+ direction = "distance")
+
+ // 3 users x min(5, 3) = 9 rows.
+ assert(result.count() === 9)
+ // No NULL padding: every left row pairs with every product exactly once.
+ val perUser = result.groupBy("user_id").count().collect().map(r =>
r.getInt(0) -> r.getLong(1))
+ assert(perUser.toMap === Map(1 -> 3L, 2 -> 3L, 3 -> 3L))
+ }
+
+ test("duplicate left rows each get an independent top-K") {
+ // Two identical user rows must not be collapsed into a single group: each
must independently
+ // produce its own top-K. This proves the per-row __qid tagging in the
rewrite works.
+ val users = spark.createDataFrame(
+ Seq((1, 10.0), (1, 10.0))).toDF("user_id", "score")
+ val products = spark.createDataFrame(
+ Seq(("A", 11.0), ("B", 22.0), ("C", 5.0))).toDF("product", "pscore")
+
+ val result = users.nearestByJoin(
+ products,
+ abs(users("score") - products("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "distance")
+
+ // Two identical left rows -> two output rows, both pairing with product A
(closest to 10.0).
+ checkAnswer(
+ result.select("user_id", "product"),
+ Seq(Row(1, "A"), Row(1, "A"))
+ )
+ }
+
+ test("conflicting column names between sides resolve via DataFrame
qualifiers") {
+ // Both sides have a column named `score`; the ranking expression
disambiguates via
+ // DataFrame-qualified accessors.
+ val left = spark.createDataFrame(Seq((1, 10.0), (2, 20.0))).toDF("id",
"score")
+ val right = spark.createDataFrame(
+ Seq(("A", 11.0), ("B", 22.0), ("C", 5.0))).toDF("name", "score")
+
+ val result = left.nearestByJoin(
+ right,
+ -abs(left("score") - right("score")),
+ numResults = 1,
+ mode = "exact",
+ direction = "similarity")
+
+ checkAnswer(
+ result.select("id", "name").orderBy("id"),
+ Seq(Row(1, "A"), Row(2, "B"))
+ )
+ // Output schema should carry both `score` columns through (4 columns
total).
+ assert(result.columns.length === 4)
+ }
+
+ test("streaming inputs are rejected at analysis time") {
+ // Build a streaming left side and a static right side; NearestByJoin must
be rejected
+ // at analysis before the optimizer rewrite (an unconditioned
cross-product fed into a
+ // global Aggregate keyed by a per-row identifier) ever runs.
+ import testImplicits._
+ implicit val ctx = spark.sqlContext
+ val streamingUsers = MemoryStream[(Int, Double)].toDF().toDF("user_id",
"score")
+ val products = spark.createDataFrame(
+ Seq(("A", 11.0), ("B", 22.0), ("C", 5.0))).toDF("product", "pscore")
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ streamingUsers.nearestByJoin(
+ products,
+ -abs(streamingUsers("score") - products("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "similarity").queryExecution.analyzed
+ },
+ condition = "NEAREST_BY_JOIN.STREAMING_NOT_SUPPORTED",
+ parameters = Map.empty)
+ }
+
+ test("rejected when spark.sql.crossJoin.enabled is false") {
+ // The rewrite produces an unconditioned cross-product internally, so when
the user has
+ // opted out of cross-products via `spark.sql.crossJoin.enabled = false`,
NEAREST BY
+ // queries are rejected at analysis time with
`NEAREST_BY_JOIN.CROSS_JOIN_NOT_ENABLED` --
+ // a NEAREST BY-specific error class added so the user does not see
internal rewrite
+ // attributes in the error message.
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
+ val (users, products) = prepareForNearestByJoin()
+ checkError(
+ exception = intercept[AnalysisException] {
+ users.nearestByJoin(
+ products,
+ -abs(users("score") - products("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "similarity").queryExecution.analyzed
+ },
+ condition = "NEAREST_BY_JOIN.CROSS_JOIN_NOT_ENABLED",
+ parameters = Map.empty)
+ }
+ }
+
+ test("exact + left outer: empty right side preserves all left rows with
NULLs") {
+ // Exercises the EXACT + LEFT OUTER combination, which no other test
covers together.
+ val (users, products) = prepareForNearestByJoin()
+ val emptyProducts = products.filter(lit(false))
+ val result = users.nearestByJoin(
+ emptyProducts,
+ -abs(users("score") - emptyProducts("pscore")),
+ numResults = 1,
+ joinType = "leftouter",
+ mode = "exact",
+ direction = "similarity")
+
+ checkAnswer(
+ result.select("user_id", "product").orderBy("user_id"),
+ Seq(Row(1, null), Row(2, null), Row(3, null))
+ )
+ }
+
+ test("SQL: APPROX NEAREST SIMILARITY") {
+ val (users, products) = prepareForNearestByJoin()
+ users.createOrReplaceTempView("t_users")
+ products.createOrReplaceTempView("t_products")
+ try {
+ val result = spark.sql(
+ """
+ |SELECT u.user_id, p.product
+ |FROM t_users u JOIN t_products p
+ | APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore)
+ |""".stripMargin)
+ checkAnswer(
+ result.orderBy("user_id"),
+ Seq(Row(1, "A"), Row(2, "B"), Row(3, "B"))
+ )
+ } finally {
+ spark.catalog.dropTempView("t_users")
+ spark.catalog.dropTempView("t_products")
+ }
+ }
+
+ test("SQL: EXACT NEAREST DISTANCE") {
+ val (users, products) = prepareForNearestByJoin()
+ users.createOrReplaceTempView("t_users")
+ products.createOrReplaceTempView("t_products")
+ try {
+ val result = spark.sql(
+ """
+ |SELECT u.user_id, p.product
+ |FROM t_users u JOIN t_products p
+ | EXACT NEAREST 1 BY DISTANCE abs(u.score - p.pscore)
+ |""".stripMargin)
+ checkAnswer(
+ result.orderBy("user_id"),
+ Seq(Row(1, "A"), Row(2, "B"), Row(3, "B"))
+ )
+ } finally {
+ spark.catalog.dropTempView("t_users")
+ spark.catalog.dropTempView("t_products")
+ }
+ }
+
+ test("invalid numResults is rejected") {
+ val (users, products) = prepareForNearestByJoin()
+ Seq(0, 100001).foreach { k =>
+ checkError(
+ exception = intercept[AnalysisException] {
+ users.nearestByJoin(
+ products,
+ -abs(users("score") - products("pscore")),
+ numResults = k,
+ mode = "exact",
+ direction = "similarity")
+ },
+ condition = "NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE",
+ parameters = Map(
+ "numResults" -> k.toString,
+ "min" -> "1",
+ "max" -> "100000"))
+ }
+ }
+
+ test("invalid joinType is rejected") {
+ val (users, products) = prepareForNearestByJoin()
+ checkError(
+ exception = intercept[AnalysisException] {
+ users.nearestByJoin(
+ products,
+ -abs(users("score") - products("pscore")),
+ numResults = 1,
+ joinType = "rightouter",
+ mode = "approx",
+ direction = "similarity")
+ },
+ condition = "NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE",
+ parameters = Map(
+ "joinType" -> "rightouter",
+ "supported" -> NearestByJoinType.supportedDisplay))
+ }
+
+ test("invalid mode is rejected") {
+ val (users, products) = prepareForNearestByJoin()
+ checkError(
+ exception = intercept[AnalysisException] {
+ users.nearestByJoin(
+ products,
+ -abs(users("score") - products("pscore")),
+ numResults = 1,
+ joinType = "inner",
+ mode = "bogus",
+ direction = "similarity")
+ },
+ condition = "NEAREST_BY_JOIN.UNSUPPORTED_MODE",
+ parameters = Map(
+ "mode" -> "bogus",
+ "supported" -> NearestByJoinMode.supported.mkString("'", "', '", "'")))
+ }
+
+ test("invalid direction is rejected") {
+ val (users, products) = prepareForNearestByJoin()
+ checkError(
+ exception = intercept[AnalysisException] {
+ users.nearestByJoin(
+ products,
+ -abs(users("score") - products("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "bogus")
+ },
+ condition = "NEAREST_BY_JOIN.UNSUPPORTED_DIRECTION",
+ parameters = Map(
+ "direction" -> "bogus",
+ "supported" -> NearestByDirection.supported.mkString("'", "', '",
"'")))
+ }
+
+ test("non-orderable ranking expression is rejected") {
+ val (users, products) = prepareForNearestByJoin()
+ checkError(
+ exception = intercept[AnalysisException] {
+ users.nearestByJoin(
+ products,
+ map(users("score"), products("pscore")),
+ numResults = 1,
+ mode = "exact",
+ direction = "similarity")
+ },
+ condition = "NEAREST_BY_JOIN.NON_ORDERABLE_RANKING_EXPRESSION",
+ parameters = Map(
+ "expression" -> "\"map(score, pscore)\"",
+ "type" -> "\"MAP<DOUBLE, DOUBLE>\""))
+ }
+
+ test("EXACT mode rejects nondeterministic ranking expression") {
+ val (users, products) = prepareForNearestByJoin()
+ checkError(
+ exception = intercept[AnalysisException] {
+ users.nearestByJoin(
+ products,
+ rand() + products("pscore"),
+ numResults = 1,
+ joinType = "inner",
+ mode = "exact",
+ direction = "similarity")
+ },
+ condition = "NEAREST_BY_JOIN.EXACT_WITH_NONDETERMINISTIC_EXPRESSION",
+ matchPVals = true,
+ parameters = Map("expression" -> ".*rand.*pscore.*"))
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]