This is an automated email from the ASF dual-hosted git repository.
cloud-fan pushed a commit to branch branch-4.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.2 by this push:
new 9f65c53b5e66 [SPARK-55978][SQL] Add TABLESAMPLE SYSTEM block sampling
with DSv2 pushdown
9f65c53b5e66 is described below
commit 9f65c53b5e6673f2f62c10d9d09da8469104de40
Author: Stanley Yao <[email protected]>
AuthorDate: Tue May 12 09:09:09 2026 +0800
[SPARK-55978][SQL] Add TABLESAMPLE SYSTEM block sampling with DSv2 pushdown
### What changes were proposed in this pull request?
This PR adds support for ANSI SQL `TABLESAMPLE SYSTEM` (block-level
sampling) alongside the existing `TABLESAMPLE BERNOULLI` (row-level sampling).
Key changes:
- **SQL grammar**: Extended `TABLESAMPLE` to accept an optional `SYSTEM` or
`BERNOULLI` qualifier before the sample method. Added both as non-reserved
keywords.
- **Logical plan**: Introduced `SampleMethod` sealed trait
(`Bernoulli`/`System`) and added it to the `Sample` node. Default is
`Bernoulli` for backward compatibility.
- **Parser**: `TABLESAMPLE SYSTEM` only supports `PERCENT` sampling and
does not support `REPEATABLE`. Other methods (ROWS, BUCKET, BYTES) are rejected
with clear error messages.
- **DSv2 pushdown**: `TABLESAMPLE SYSTEM` is pushed down to data sources
via an extended `SupportsPushDownTableSample.pushTableSample()` overload with
`isSystemSampling` flag. Sources that don't override the new method reject
SYSTEM sampling by default.
- **Physical planning**: SYSTEM samples that aren't pushed down to a DSv2
source raise an `AnalysisException` — there is no row-level fallback since
block sampling is data-source dependent.
### Why are the changes needed?
ANSI SQL defines two sampling methods: `BERNOULLI` (row-level) and `SYSTEM`
(implementation-dependent, typically block/split-level). Block sampling is
significantly faster for large tables since it avoids per-row evaluation,
making it useful for approximate queries and data exploration. Many databases
(PostgreSQL, Hive, Trino) support this distinction.
### Does this PR introduce _any_ user-facing change?
Yes. New SQL syntax `TABLESAMPLE SYSTEM (x PERCENT)` and `TABLESAMPLE
BERNOULLI (x PERCENT)`. `BERNOULLI` and `SYSTEM` are added as non-reserved
keywords. Existing queries without these keywords behave identically to before.
### How was this patch tested?
- 9 new test cases in `PlanParserSuite` covering: basic parsing, case
insensitivity, boundary fractions, unsupported methods (ROWS/BUCKET with
SYSTEM), REPEATABLE rejection, fraction validation, identifier preservation,
and subquery contexts.
- Existing `SQLQuerySuite` tests pass.
- Scalastyle passes with 0 errors/warnings.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (thoroughly refined, reviewed, and tested by
human)
Closes #54972 from stanyao/spark-55978-tablesample-system.
Lead-authored-by: Stanley Yao <[email protected]>
Co-authored-by: Stanley Yao <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 5fe5451c3680351ff96811477589617685d06cf6)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 20 ++
docs/sql-ref-ansi-compliance.md | 2 +
.../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 2 +
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 8 +-
.../spark/sql/errors/QueryParsingErrors.scala | 16 ++
...sPushDownTableSample.java => SampleMethod.java} | 20 +-
.../read/SupportsPushDownTableSample.java | 19 +-
.../analysis/UnsupportedOperationChecker.scala | 2 +-
.../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +-
.../spark/sql/catalyst/parser/AstBuilder.scala | 25 +-
.../plans/logical/basicLogicalOperators.scala | 22 +-
.../sql/catalyst/parser/PlanParserSuite.scala | 201 ++++++++++++++++
.../catalog/InMemoryTableWithTableSample.scala | 258 +++++++++++++++++++++
.../InMemoryTableWithTableSampleCatalog.scala | 102 ++++++++
.../jdbc/SparkConnectDatabaseMetaDataSuite.scala | 4 +-
.../explain-results/sample_fraction_seed.explain | 2 +-
.../sample_withReplacement_fraction_seed.explain | 2 +-
.../spark/sql/execution/DataSourceScanExec.scala | 8 +-
.../spark/sql/execution/SparkStrategies.scala | 9 +-
.../execution/datasources/v2/PushDownUtils.scala | 9 +-
.../execution/datasources/v2/TableSampleInfo.scala | 5 +-
.../datasources/v2/V2ScanRelationPushDown.scala | 23 +-
.../analyzer-results/pipe-operators.sql.out | 12 +-
.../sql-tests/results/keywords-enforced.sql.out | 2 +
.../resources/sql-tests/results/keywords.sql.out | 2 +
.../sql-tests/results/nonansi/keywords.sql.out | 2 +
.../connector/DataSourceV2TableSampleSuite.scala | 210 +++++++++++++++++
.../ThriftServerWithSparkContextSuite.scala | 2 +-
28 files changed, 951 insertions(+), 40 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 95f0c303e35f..889ecf9f7b08 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -8095,6 +8095,26 @@
"Store backend <stateStoreProvider> is not supported by
TransformWithState operator. Please use RocksDBStateStoreProvider."
]
},
+ "TABLESAMPLE_SYSTEM" : {
+ "message" : [
+ "TABLESAMPLE SYSTEM is only supported by data sources that implement
block-level sampling."
+ ]
+ },
+ "TABLESAMPLE_SYSTEM_NO_SCAN" : {
+ "message" : [
+ "TABLESAMPLE SYSTEM requires a direct reference to a data source
table that supports block-level sampling. It cannot be applied to subqueries,
views, or tables with intervening operations."
+ ]
+ },
+ "TABLESAMPLE_SYSTEM_REPEATABLE" : {
+ "message" : [
+ "TABLESAMPLE SYSTEM does not support the REPEATABLE clause. Use
TABLESAMPLE BERNOULLI for repeatable sampling with a seed."
+ ]
+ },
+ "TABLESAMPLE_SYSTEM_SAMPLE_METHOD" : {
+ "message" : [
+ "TABLESAMPLE SYSTEM does not support <sampleMethod> sampling. Only
PERCENT sampling is supported."
+ ]
+ },
"TABLE_OPERATION" : {
"message" : [
"Table <tableName> does not support <operation>. Please check the
current catalog and namespace to make sure the qualified table name is
expected, and also check the catalog implementation which is configured by
\"spark.sql.catalog\"."
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index 8542cd3d8986..4f21b7b4b3c7 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -430,6 +430,7 @@ Below is a list of all the keywords in Spark SQL.
|ATOMIC|non-reserved|non-reserved|non-reserved|
|AUTHORIZATION|reserved|non-reserved|reserved|
|BEGIN|non-reserved|non-reserved|non-reserved|
+|BERNOULLI|non-reserved|non-reserved|non-reserved|
|BETWEEN|non-reserved|non-reserved|reserved|
|BIGINT|non-reserved|non-reserved|reserved|
|BINARY|non-reserved|non-reserved|reserved|
@@ -765,6 +766,7 @@ Below is a list of all the keywords in Spark SQL.
|SUBSTR|non-reserved|non-reserved|non-reserved|
|SUBSTRING|non-reserved|non-reserved|non-reserved|
|SYNC|non-reserved|non-reserved|non-reserved|
+|SYSTEM|non-reserved|non-reserved|reserved|
|SYSTEM_PATH|non-reserved|non-reserved|not a keyword|
|SYSTEM_TIME|non-reserved|non-reserved|non-reserved|
|SYSTEM_VERSION|non-reserved|non-reserved|non-reserved|
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index f4834b4ecf62..af71f441012c 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -149,6 +149,7 @@ AT: 'AT';
ATOMIC: 'ATOMIC';
AUTHORIZATION: 'AUTHORIZATION';
BEGIN: 'BEGIN';
+BERNOULLI: 'BERNOULLI';
BETWEEN: 'BETWEEN';
BIGINT: 'BIGINT';
BINARY: 'BINARY';
@@ -483,6 +484,7 @@ STRUCT: 'STRUCT' {incComplexTypeLevelCounter();};
SUBSTR: 'SUBSTR';
SUBSTRING: 'SUBSTRING';
SYNC: 'SYNC';
+SYSTEM: 'SYSTEM';
SYSTEM_TIME: 'SYSTEM_TIME';
SYSTEM_VERSION: 'SYSTEM_VERSION';
SYSTEM_PATH: 'SYSTEM_PATH';
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 735921681cdc..1e3acbc001b3 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -1073,7 +1073,9 @@ nearestByClause
;
sample
- : TABLESAMPLE LEFT_PAREN sampleMethod? RIGHT_PAREN (REPEATABLE LEFT_PAREN
seed=integerValue RIGHT_PAREN)?
+ : TABLESAMPLE (sampleType=(SYSTEM | BERNOULLI))?
+ LEFT_PAREN sampleMethod? RIGHT_PAREN
+ (REPEATABLE LEFT_PAREN seed=integerValue RIGHT_PAREN)?
;
sampleMethod
@@ -1942,6 +1944,7 @@ ansiNonReserved
| AT
| ATOMIC
| BEGIN
+ | BERNOULLI
| BETWEEN
| BIGINT
| BINARY
@@ -2216,6 +2219,7 @@ ansiNonReserved
| SUBSTR
| SUBSTRING
| SYNC
+ | SYSTEM
| SYSTEM_PATH
| SYSTEM_TIME
| SYSTEM_VERSION
@@ -2322,6 +2326,7 @@ nonReserved
| ATOMIC
| AUTHORIZATION
| BEGIN
+ | BERNOULLI
| BETWEEN
| BIGINT
| BINARY
@@ -2645,6 +2650,7 @@ nonReserved
| SUBSTR
| SUBSTRING
| SYNC
+ | SYSTEM
| SYSTEM_PATH
| SYSTEM_TIME
| SYSTEM_VERSION
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index 33d7aaef17b8..eca7342a2d9e 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -509,6 +509,22 @@ private[sql] object QueryParsingErrors extends
DataTypeErrorsBase {
ctx)
}
+ def tableSampleSystemRepeatableError(ctx: ParserRuleContext): Throwable = {
+ new ParseException(
+ errorClass = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_REPEATABLE",
+ messageParameters = Map.empty,
+ ctx)
+ }
+
+ def tableSampleSystemSampleMethodError(
+ sampleMethod: String,
+ ctx: ParserRuleContext): Throwable = {
+ new ParseException(
+ errorClass = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_SAMPLE_METHOD",
+ messageParameters = Map("sampleMethod" -> sampleMethod),
+ ctx)
+ }
+
def invalidStatementError(operation: String, ctx: ParserRuleContext):
Throwable = {
new ParseException(
errorClass = "INVALID_STATEMENT_OR_CLAUSE",
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SampleMethod.java
similarity index 71%
copy from
sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
copy to
sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SampleMethod.java
index 3630feb4680e..b9af8f9d5ac7 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SampleMethod.java
@@ -20,20 +20,14 @@ package org.apache.spark.sql.connector.read;
import org.apache.spark.annotation.Evolving;
/**
- * A mix-in interface for {@link Scan}. Data sources can implement this
interface to
- * push down SAMPLE.
+ * The sampling method for TABLESAMPLE.
*
- * @since 3.3.0
+ * @since 4.2.0
*/
@Evolving
-public interface SupportsPushDownTableSample extends ScanBuilder {
-
- /**
- * Pushes down SAMPLE to the data source.
- */
- boolean pushTableSample(
- double lowerBound,
- double upperBound,
- boolean withReplacement,
- long seed);
+public enum SampleMethod {
+ /** Row-level sampling (BERNOULLI). Each row is independently selected. */
+ BERNOULLI,
+ /** Block-level sampling (SYSTEM). Entire partitions/splits are included or
skipped. */
+ SYSTEM
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
index 3630feb4680e..3ceb7ed2de14 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
@@ -29,11 +29,28 @@ import org.apache.spark.annotation.Evolving;
public interface SupportsPushDownTableSample extends ScanBuilder {
/**
- * Pushes down SAMPLE to the data source.
+ * Pushes down BERNOULLI (row-level) SAMPLE to the data source.
*/
boolean pushTableSample(
double lowerBound,
double upperBound,
boolean withReplacement,
long seed);
+
+ /**
+ * Pushes down SAMPLE to the data source with the specified sampling method.
+ */
+ default boolean pushTableSample(
+ double lowerBound,
+ double upperBound,
+ boolean withReplacement,
+ long seed,
+ SampleMethod sampleMethod) {
+ if (sampleMethod == SampleMethod.SYSTEM) {
+ // If the data source hasn't overridden this method, it must not have
added support
+ // for SYSTEM sampling. Don't apply sample pushdown.
+ return false;
+ }
+ return pushTableSample(lowerBound, upperBound, withReplacement, seed);
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index b925d9da3342..83ad97fdc4fa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -579,7 +579,7 @@ object UnsupportedOperationChecker extends Logging {
throwError("Sorting is not supported on streaming
DataFrames/Datasets, unless it is on " +
"aggregated DataFrame/Dataset in Complete output mode")
- case Sample(_, _, _, _, child) if child.isStreaming =>
+ case Sample(_, _, _, _, child, _) if child.isStreaming =>
throwError("Sampling is not supported on streaming
DataFrames/Datasets")
case Window(windowExpression, _, _, child, _) if child.isStreaming =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e4d53b697af8..ddfe80443d56 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1296,7 +1296,7 @@ object CollapseProject extends Rule[LogicalPlan] with
AliasHelper {
limit.copy(child = p2.copy(projectList = newProjectList))
case Project(l1, r @ Repartition(_, _, p @ Project(l2, _))) if
isRenaming(l1, l2) =>
r.copy(child = p.copy(projectList = buildCleanedProjectList(l1,
p.projectList)))
- case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if
isRenaming(l1, l2) =>
+ case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _), _)) if
isRenaming(l1, l2) =>
s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1,
p2.projectList)))
case o => o
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 929fb2b4ceb1..95b21eb01b4b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -2448,10 +2448,14 @@ class AstBuilder extends DataTypeAstBuilder
* - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
* - TABLESAMPLE(x PERCENT) [REPEATABLE (y)]: Sample the table down to the
given percentage with
* seed 'y'. Note that percentages are defined as a number between 0 and 100.
+ * - TABLESAMPLE SYSTEM(x PERCENT): Sample by data-source-dependent blocks
or file splits.
* - TABLESAMPLE(BUCKET x OUT OF y) [REPEATABLE (z)]: Sample the table down
to a 'x' divided by
* 'y' fraction with seed 'z'.
*/
private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan
= withOrigin(ctx) {
+ val isSystem = ctx.sampleType != null &&
+ ctx.sampleType.getType == SqlBaseParser.SYSTEM
+
// Create a sampled plan if we need one.
def sample(fraction: Double, seed: Option[Long]): Sample = {
// The range of fraction accepted by Sample is [0, 1]. Because Hive's
block sampling
@@ -2461,17 +2465,25 @@ class AstBuilder extends DataTypeAstBuilder
validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
s"Sampling fraction ($fraction) must be on interval [0, 1]",
ctx)
- Sample(0.0, fraction, withReplacement = false, seed, query)
+ val method = if (isSystem) SampleMethod.System else
SampleMethod.Bernoulli
+ Sample(0.0, fraction, withReplacement = false, seed, query, method)
}
if (ctx.sampleMethod() == null) {
throw QueryParsingErrors.emptyInputForTableSampleError(ctx)
}
+ if (isSystem && ctx.seed != null) {
+ throw QueryParsingErrors.tableSampleSystemRepeatableError(ctx)
+ }
+
val seed: Option[Long] = Option(ctx.seed).map(_.getText.toLong)
ctx.sampleMethod() match {
case ctx: SampleByRowsContext =>
+ if (isSystem) {
+ throw QueryParsingErrors.tableSampleSystemSampleMethodError("ROWS",
ctx)
+ }
Limit(expression(ctx.expression), query)
case ctx: SampleByPercentileContext =>
@@ -2483,6 +2495,9 @@ class AstBuilder extends DataTypeAstBuilder
sample(sign * fraction / 100.0d, seed)
case ctx: SampleByBytesContext =>
+ if (isSystem) {
+ throw QueryParsingErrors.tableSampleSystemSampleMethodError("BYTES",
ctx)
+ }
val bytesStr = ctx.bytes.getText
if (bytesStr.matches("[0-9]+[bBkKmMgG]")) {
throw
QueryParsingErrors.tableSampleByBytesUnsupportedError("byteLengthLiteral", ctx)
@@ -2491,6 +2506,9 @@ class AstBuilder extends DataTypeAstBuilder
}
case ctx: SampleByBucketContext if ctx.ON() != null =>
+ if (isSystem) {
+ throw
QueryParsingErrors.tableSampleSystemSampleMethodError("BUCKET", ctx)
+ }
if (ctx.identifier != null) {
throw QueryParsingErrors.tableSampleByBytesUnsupportedError(
"BUCKET x OUT OF y ON colname", ctx)
@@ -2500,6 +2518,9 @@ class AstBuilder extends DataTypeAstBuilder
}
case ctx: SampleByBucketContext =>
+ if (isSystem) {
+ throw
QueryParsingErrors.tableSampleSystemSampleMethodError("BUCKET", ctx)
+ }
sample(ctx.numerator.getText.toDouble /
ctx.denominator.getText.toDouble, seed)
}
}
@@ -2912,7 +2933,7 @@ class AstBuilder extends DataTypeAstBuilder
// inline table comes in two styles:
// style 1: values (1), (2), (3) -- multiple columns are supported
// style 2: values 1, 2, 3 -- only a single column is supported here
- // Strip Alias wrappers from row values — CreateStruct.apply preserves
them for
+ // Strip Alias wrappers from row values - CreateStruct.apply preserves
them for
// expressions like `(1 AS id, 'a' AS name)`, but they are redundant
here since
// column names are determined by the table alias or generated
defaults.
case struct: CreateNamedStruct => struct.valExprs.map {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 8e9f264698ca..6d37aa0f9f6b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1912,6 +1912,14 @@ object SubqueryAlias {
}
}
+sealed trait SampleMethod extends Serializable
+object SampleMethod {
+ /** Row-level sampling (BERNOULLI). Each row independently selected. No I/O
savings. */
+ case object Bernoulli extends SampleMethod
+ /** System-level sampling (SYSTEM). Entire partitions/splits included or
skipped. */
+ case object System extends SampleMethod
+}
+
object Sample {
/**
* Convenience constructor that wraps a concrete seed in [[Some]].
@@ -1926,6 +1934,16 @@ object Sample {
child: LogicalPlan): Sample = {
new Sample(lowerBound, upperBound, withReplacement, Some(seed), child)
}
+
+ def apply(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long,
+ child: LogicalPlan,
+ sampleMethod: SampleMethod): Sample = {
+ new Sample(lowerBound, upperBound, withReplacement, Some(seed), child,
sampleMethod)
+ }
}
/**
@@ -1939,13 +1957,15 @@ object Sample {
* (SQL `REPEATABLE` clause or programmatic API), `None` when no
seed was
* specified and a random seed should be generated at execution
time.
* @param child the LogicalPlan
+ * @param sampleMethod the sampling method (Bernoulli or System)
*/
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Option[Long],
- child: LogicalPlan) extends UnaryNode {
+ child: LogicalPlan,
+ sampleMethod: SampleMethod = SampleMethod.Bernoulli) extends UnaryNode {
val eps = RandomSampler.roundingEpsilon
val fraction = upperBound - lowerBound
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 6124c69fbedd..1ecb7fa539c2 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -1024,6 +1024,207 @@ class PlanParserSuite extends AnalysisTest {
stop = 65))
}
+ test("SPARK-55978: TABLESAMPLE SYSTEM and BERNOULLI - basic parsing") {
+ val sql = "select * from t"
+ // SYSTEM produces SampleMethod.System
+ assertEqual(
+ s"$sql tablesample system (43 percent) as x",
+ Sample(0, .43d, withReplacement = false, None,
+ table("t").as("x"), SampleMethod.System).select(star()))
+ // BERNOULLI produces SampleMethod.Bernoulli
+ assertEqual(
+ s"$sql tablesample bernoulli (43 percent) as x",
+ Sample(0, .43d, withReplacement = false, None,
+ table("t").as("x"), SampleMethod.Bernoulli).select(star()))
+ // No qualifier defaults to Bernoulli (backward compat)
+ assertEqual(
+ s"$sql tablesample(43 percent) as x",
+ Sample(0, .43d, withReplacement = false, None,
+ table("t").as("x")).select(star()))
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM - case insensitivity") {
+ val sql = "select * from t"
+ // Keywords are case-insensitive
+ assertEqual(
+ s"$sql TABLESAMPLE SYSTEM (43 PERCENT) as x",
+ Sample(0, .43d, withReplacement = false, None,
+ table("t").as("x"), SampleMethod.System).select(star()))
+ assertEqual(
+ s"$sql TabLeSaMpLe SyStEm (43 PeRcEnT) as x",
+ Sample(0, .43d, withReplacement = false, None,
+ table("t").as("x"), SampleMethod.System).select(star()))
+ assertEqual(
+ s"$sql TABLESAMPLE BERNOULLI (43 PERCENT) as x",
+ Sample(0, .43d, withReplacement = false, None,
+ table("t").as("x"), SampleMethod.Bernoulli).select(star()))
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM - boundary fractions") {
+ val sql = "select * from t"
+ // 0 PERCENT
+ assertEqual(
+ s"$sql tablesample system (0 percent) as x",
+ Sample(0, 0d, withReplacement = false, None,
+ table("t").as("x"), SampleMethod.System).select(star()))
+ // 100 PERCENT
+ assertEqual(
+ s"$sql tablesample system (100 percent) as x",
+ Sample(0, 1d, withReplacement = false, None,
+ table("t").as("x"), SampleMethod.System).select(star()))
+ // Fractional percent
+ assertEqual(
+ s"$sql tablesample system (0.1 percent) as x",
+ Sample(0, 0.001d, withReplacement = false, None,
+ table("t").as("x"), SampleMethod.System).select(star()))
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM - unsupported sample methods") {
+ val sql = "select * from t"
+ // SYSTEM + ROWS -> error
+ checkError(
+ exception = parseException(s"$sql tablesample system (100 rows)"),
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_SAMPLE_METHOD",
+ sqlState = "0A000",
+ parameters = Map("sampleMethod" -> "ROWS"),
+ context = ExpectedContext(
+ fragment = "tablesample system (100 rows)",
+ start = 16,
+ stop = 44))
+ // SYSTEM + BYTES -> error
+ checkError(
+ exception = parseException(s"$sql tablesample system (300M)"),
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_SAMPLE_METHOD",
+ sqlState = "0A000",
+ parameters = Map("sampleMethod" -> "BYTES"),
+ context = ExpectedContext(
+ fragment = "tablesample system (300M)",
+ start = 16,
+ stop = 40))
+ // SYSTEM + BUCKET -> error
+ checkError(
+ exception = parseException(s"$sql tablesample system (bucket 4 out of
10)"),
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_SAMPLE_METHOD",
+ sqlState = "0A000",
+ parameters = Map("sampleMethod" -> "BUCKET"),
+ context = ExpectedContext(
+ fragment = "tablesample system (bucket 4 out of 10)",
+ start = 16,
+ stop = 54))
+ // SYSTEM + BUCKET ON colname -> error
+ checkError(
+ exception = parseException(s"$sql tablesample system (bucket 4 out of 10
on x)"),
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_SAMPLE_METHOD",
+ sqlState = "0A000",
+ parameters = Map("sampleMethod" -> "BUCKET"),
+ context = ExpectedContext(
+ fragment = "tablesample system (bucket 4 out of 10 on x)",
+ start = 16,
+ stop = 59))
+ // SYSTEM + BUCKET ON function -> error
+ checkError(
+ exception = parseException(s"$sql tablesample system (bucket 3 out of 32
on rand())"),
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_SAMPLE_METHOD",
+ sqlState = "0A000",
+ parameters = Map("sampleMethod" -> "BUCKET"),
+ context = ExpectedContext(
+ fragment = "tablesample system (bucket 3 out of 32 on rand())",
+ start = 16,
+ stop = 64))
+ }
+
+ test("SPARK-55978: TABLESAMPLE BERNOULLI - REPEATABLE is supported") {
+ assertEqual(
+ "select * from t tablesample bernoulli (43 percent) repeatable (123) as
x",
+ Sample(0, .43d, withReplacement = false, 123L,
+ table("t").as("x"), SampleMethod.Bernoulli).select(star()))
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM - REPEATABLE not supported") {
+ val sql = "select * from t"
+ checkError(
+ exception = parseException(s"$sql tablesample system (43 percent)
repeatable (123)"),
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_REPEATABLE",
+ sqlState = "0A000",
+ context = ExpectedContext(
+ fragment = "tablesample system (43 percent) repeatable (123)",
+ start = 16,
+ stop = 63))
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM - fraction out of range") {
+ val sql = "select * from t"
+ // > 100 PERCENT
+ checkError(
+ exception = parseException(s"$sql tablesample system (150 percent) as
x"),
+ condition = "_LEGACY_ERROR_TEMP_0064",
+ parameters = Map("msg" -> "Sampling fraction (1.5) must be on interval
[0, 1]"),
+ context = ExpectedContext(
+ fragment = "tablesample system (150 percent)",
+ start = 16,
+ stop = 47))
+ // Negative PERCENT
+ checkError(
+ exception = parseException(s"$sql tablesample system (-10 percent) as
x"),
+ condition = "_LEGACY_ERROR_TEMP_0064",
+ parameters = Map("msg" -> "Sampling fraction (-0.1) must be on interval
[0, 1]"),
+ context = ExpectedContext(
+ fragment = "tablesample system (-10 percent)",
+ start = 16,
+ stop = 47))
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM and BERNOULLI as identifiers") {
+ // SYSTEM usable as column name (nonReserved)
+ assertEqual("SELECT system FROM t",
+ table("t").select($"system"))
+ // BERNOULLI usable as column name
+ assertEqual("SELECT bernoulli FROM t",
+ table("t").select($"bernoulli"))
+ // Usable as table alias
+ assertEqual("SELECT * FROM t system",
+ table("t").as("system").select(star()))
+ assertEqual("SELECT * FROM t bernoulli",
+ table("t").as("bernoulli").select(star()))
+ // SYSTEM as table name with default (Bernoulli) TABLESAMPLE
+ assertEqual("SELECT * FROM system TABLESAMPLE(10 PERCENT) AS x",
+ Sample(0, .1d, withReplacement = false, None,
+ table("system").as("x")).select(star()))
+ // SYSTEM as table name with TABLESAMPLE SYSTEM qualifier
+ assertEqual("SELECT * FROM system TABLESAMPLE SYSTEM (10 PERCENT) AS x",
+ Sample(0, .1d, withReplacement = false, None,
+ table("system").as("x"), SampleMethod.System).select(star()))
+ // SYSTEM as both table name and alias with TABLESAMPLE
+ assertEqual("SELECT * FROM system TABLESAMPLE(10 PERCENT) system",
+ Sample(0, .1d, withReplacement = false, None,
+ table("system").as("system")).select(star()))
+ // BERNOULLI as table name with TABLESAMPLE BERNOULLI qualifier
+ assertEqual("SELECT * FROM bernoulli TABLESAMPLE BERNOULLI (10 PERCENT) AS
x",
+ Sample(0, .1d, withReplacement = false, None,
+ table("bernoulli").as("x"), SampleMethod.Bernoulli).select(star()))
+ // SYSTEM as table name with TABLESAMPLE BERNOULLI (cross-keyword)
+ assertEqual("SELECT * FROM system TABLESAMPLE BERNOULLI (10 PERCENT) AS x",
+ Sample(0, .1d, withReplacement = false, None,
+ table("system").as("x"), SampleMethod.Bernoulli).select(star()))
+ // BERNOULLI as both table name and alias with TABLESAMPLE
+ assertEqual("SELECT * FROM bernoulli TABLESAMPLE(10 PERCENT) bernoulli",
+ Sample(0, .1d, withReplacement = false, None,
+ table("bernoulli").as("bernoulli")).select(star()))
+ // Schema-qualified SYSTEM table name with TABLESAMPLE SYSTEM
+ assertEqual("SELECT * FROM mydb.system TABLESAMPLE SYSTEM (10 PERCENT) AS
x",
+ Sample(0, .1d, withReplacement = false, None,
+ table("mydb", "system").as("x"), SampleMethod.System).select(star()))
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM - subquery and join contexts") {
+ // SYSTEM sample in subquery
+ assertEqual(
+ "SELECT * FROM (SELECT * FROM t TABLESAMPLE SYSTEM (50 PERCENT)) sub",
+ Sample(0, .5d, withReplacement = false, None,
+ table("t"), SampleMethod.System)
+ .select(star()).as("sub").select(star()))
+ }
+
test("sub-query") {
val plan = table("t0").select($"id")
assertEqual("select id from (t0)", plan)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithTableSample.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithTableSample.scala
new file mode 100644
index 000000000000..514a7f3beda4
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithTableSample.scala
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+import java.util.Locale
+
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.join.JoinType
+import org.apache.spark.sql.connector.read.{InputPartition, SampleMethod,
Scan, ScanBuilder, SupportsPushDownJoin, SupportsPushDownTableSample,
SupportsPushDownV2Filters}
+import org.apache.spark.sql.connector.read.SupportsPushDownJoin.ColumnWithAlias
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.util.ArrayImplicits._
+
+/**
+ * An in-memory table that supports TABLESAMPLE pushdown (both BERNOULLI and
SYSTEM).
+ *
+ * For SYSTEM sampling, entire splits (InputPartitions) are included or
skipped based on
+ * a hash of their index and the seed. For BERNOULLI sampling, the pushdown is
accepted
+ * but rows are not actually filtered (Spark's row-level Sample operator
handles it).
+ */
+class InMemoryTableWithTableSample(
+ name: String,
+ columns: Array[Column],
+ partitioning: Array[Transform],
+ properties: util.Map[String, String])
+ extends InMemoryBaseTable(name, columns, partitioning, properties) {
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ InMemoryBaseTable.maybeSimulateFailedTableWrite(new
CaseInsensitiveStringMap(properties))
+ InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
+ new InMemoryWriterBuilder(info) {
+ override def truncate(): WriteBuilder = {
+ writer = new TruncateAndAppend(this.info)
+ streamingWriter = new StreamingTruncateAndAppend(this.info)
+ this
+ }
+ }
+ }
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder
= {
+ new InMemoryTableSampleScanBuilder(schema, options)
+ }
+
+ class InMemoryTableSampleScanBuilder(
+ tableSchema: StructType,
+ options: CaseInsensitiveStringMap)
+ extends InMemoryScanBuilder(tableSchema, options) with
SupportsPushDownTableSample {
+
+ private var sampleFraction: Double = 1.0
+ private var sampleSeed: Long = 0L
+ private var sampleMethod: SampleMethod = SampleMethod.BERNOULLI
+ private var sampleWithReplacement: Boolean = false
+ private var samplePushed: Boolean = false
+
+ override def pushTableSample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long): Boolean = {
+ this.sampleFraction = upperBound - lowerBound
+ this.sampleSeed = seed
+ this.sampleMethod = SampleMethod.BERNOULLI
+ this.sampleWithReplacement = withReplacement
+ this.samplePushed = true
+ true
+ }
+
+ override def pushTableSample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long,
+ sampleMethod: SampleMethod): Boolean = {
+ this.sampleFraction = upperBound - lowerBound
+ this.sampleSeed = seed
+ this.sampleMethod = sampleMethod
+ this.sampleWithReplacement = withReplacement
+ this.samplePushed = true
+ true
+ }
+
+ override def build: Scan = {
+ val allPartitions =
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq
+ val filteredPartitions = if (samplePushed && sampleMethod ==
SampleMethod.SYSTEM) {
+ // SYSTEM sampling: include/skip entire splits based on hash of index
+ seed
+ allPartitions.zipWithIndex.filter { case (_, idx) =>
+ val hash = ((idx.toLong * 31 + sampleSeed) & Long.MaxValue).toDouble
/ Long.MaxValue
+ hash < sampleFraction
+ }.map(_._1)
+ } else {
+ allPartitions
+ }
+ if (samplePushed) {
+ new InMemoryBatchScanWithSample(
+ filteredPartitions, schema, tableSchema, options,
+ sampleFraction, sampleSeed, sampleMethod, sampleWithReplacement)
+ } else {
+ InMemoryBatchScan(filteredPartitions, schema, tableSchema, options)
+ }
+ }
+ }
+
+ private class InMemoryBatchScanWithSample(
+ data: Seq[InputPartition],
+ readSchema: StructType,
+ tableSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ sampleFraction: Double,
+ sampleSeed: Long,
+ sampleMethod: SampleMethod,
+ sampleWithReplacement: Boolean)
+ extends InMemoryBatchScan(data, readSchema, tableSchema, options) {
+
+ override def description(): String = {
+ val pct = sampleFraction * 100
+ val method = sampleMethod.toString.toUpperCase(Locale.ROOT)
+ s"${super.description()} $method SAMPLE ($pct) $sampleWithReplacement
SEED($sampleSeed)"
+ }
+ }
+}
+
+/**
+ * An in-memory table that supports both TABLESAMPLE pushdown and JOIN
pushdown.
+ * Used to test the guard that prevents join pushdown when a side has a pushed
sample.
+ */
+class InMemoryTableWithJoinAndSample(
+ name: String,
+ columns: Array[Column],
+ partitioning: Array[Transform],
+ properties: util.Map[String, String])
+ extends InMemoryTableWithTableSample(name, columns, partitioning,
properties) {
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder
= {
+ new InMemoryJoinAndSampleScanBuilder(schema, options)
+ }
+
+ class InMemoryJoinAndSampleScanBuilder(
+ tableSchema: StructType,
+ options: CaseInsensitiveStringMap)
+ extends InMemoryTableSampleScanBuilder(tableSchema, options)
+ with SupportsPushDownJoin with SupportsPushDownV2Filters {
+
+ private[catalog] val ownSchema: StructType = tableSchema
+ private var pushed: Array[Predicate] = Array.empty
+ private var joinedSchema: Option[StructType] = None
+
+ override def pushPredicates(predicates: Array[Predicate]):
Array[Predicate] = {
+ pushed = predicates
+ // Return empty - all predicates accepted (not actually filtered, just
cleared
+ // so that the join pushdown pattern's Nil filter requirement is
satisfied).
+ Array.empty
+ }
+
+ // Override V1 pushFilters (inherited from InMemoryScanBuilder) to also
accept all
+ // filters. PushDownUtils.pushFilters matches SupportsPushDownFilters
before
+ // SupportsPushDownV2Filters, so without this override isnotnull
predicates remain
+ // as post-scan Filter nodes and block the join pushdown pattern match.
+ override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ Array.empty
+ }
+
+ override def pushedPredicates(): Array[Predicate] = pushed
+
+ override def isOtherSideCompatibleForJoin(other: SupportsPushDownJoin):
Boolean = true
+
+ override def pushDownJoin(
+ other: SupportsPushDownJoin,
+ joinType: JoinType,
+ leftSideRequiredColumnsWithAliases: Array[ColumnWithAlias],
+ rightSideRequiredColumnsWithAliases: Array[ColumnWithAlias],
+ condition: Predicate): Boolean = {
+ val otherSchema =
other.asInstanceOf[InMemoryJoinAndSampleScanBuilder].ownSchema
+ val leftFields = leftSideRequiredColumnsWithAliases.map { col =>
+ val name = if (col.alias() != null) col.alias() else col.colName()
+ tableSchema(col.colName()).copy(name = name)
+ }
+ val rightFields = rightSideRequiredColumnsWithAliases.map { col =>
+ val name = if (col.alias() != null) col.alias() else col.colName()
+ otherSchema(col.colName()).copy(name = name)
+ }
+ joinedSchema = Some(StructType(leftFields ++ rightFields))
+ true
+ }
+
+ override def build: Scan = {
+ joinedSchema match {
+ case Some(js) =>
+ InMemoryBatchScan(
+ data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq,
+ js, tableSchema, options)
+ case None => super.build
+ }
+ }
+ }
+}
+
+/**
+ * An in-memory table that supports TABLESAMPLE pushdown using only the legacy
4-arg
+ * pushTableSample method (does NOT override the 5-arg default). Used to test
backward
+ * compatibility: BERNOULLI should push down via the default delegation, and
SYSTEM
+ * should fail because the default returns false for SYSTEM.
+ */
+class InMemoryTableWithLegacyTableSample(
+ name: String,
+ columns: Array[Column],
+ partitioning: Array[Transform],
+ properties: util.Map[String, String])
+ extends InMemoryBaseTable(name, columns, partitioning, properties) {
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ InMemoryBaseTable.maybeSimulateFailedTableWrite(new
CaseInsensitiveStringMap(properties))
+ InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
+ new InMemoryWriterBuilder(info) {
+ override def truncate(): WriteBuilder = {
+ writer = new TruncateAndAppend(this.info)
+ streamingWriter = new StreamingTruncateAndAppend(this.info)
+ this
+ }
+ }
+ }
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder
= {
+ new InMemoryLegacySampleScanBuilder(schema, options)
+ }
+
+ class InMemoryLegacySampleScanBuilder(
+ tableSchema: StructType,
+ options: CaseInsensitiveStringMap)
+ extends InMemoryScanBuilder(tableSchema, options) with
SupportsPushDownTableSample {
+
+ // Only the 4-arg method is overridden; the 5-arg default method is
inherited.
+ override def pushTableSample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long): Boolean = true
+ }
+}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithTableSampleCatalog.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithTableSampleCatalog.scala
new file mode 100644
index 000000000000..12da978ea11a
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithTableSampleCatalog.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+
+import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.connector.expressions.Transform
+
+class InMemoryTableWithTableSampleCatalog extends InMemoryTableCatalog {
+ import CatalogV2Implicits._
+
+ override def createTable(
+ ident: Identifier,
+ columns: Array[Column],
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
+ }
+
+ InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
+
+ val tableName = s"$name.${ident.quoted}"
+ val table = new InMemoryTableWithTableSample(tableName, columns,
partitions, properties)
+ tables.put(ident, table)
+ namespaces.putIfAbsent(ident.namespace.toList, Map())
+ table
+ }
+
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ createTable(ident, tableInfo.columns(), tableInfo.partitions(),
tableInfo.properties)
+ }
+}
+
+class InMemoryTableWithJoinAndSampleCatalog extends InMemoryTableCatalog {
+ import CatalogV2Implicits._
+
+ override def createTable(
+ ident: Identifier,
+ columns: Array[Column],
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
+ }
+
+ InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
+
+ val tableName = s"$name.${ident.quoted}"
+ val table = new InMemoryTableWithJoinAndSample(tableName, columns,
partitions, properties)
+ tables.put(ident, table)
+ namespaces.putIfAbsent(ident.namespace.toList, Map())
+ table
+ }
+
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ createTable(ident, tableInfo.columns(), tableInfo.partitions(),
tableInfo.properties)
+ }
+}
+
+class InMemoryTableWithLegacyTableSampleCatalog extends InMemoryTableCatalog {
+ import CatalogV2Implicits._
+
+ override def createTable(
+ ident: Identifier,
+ columns: Array[Column],
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
+ }
+
+ InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
+
+ val tableName = s"$name.${ident.quoted}"
+ val table = new InMemoryTableWithLegacyTableSample(
+ tableName, columns, partitions, properties)
+ tables.put(ident, table)
+ namespaces.putIfAbsent(ident.namespace.toList, Map())
+ table
+ }
+
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ createTable(ident, tableInfo.columns(), tableInfo.partitions(),
tableInfo.properties)
+ }
+}
diff --git
a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala
b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala
index 1cfe05a2b5c1..1f525a541daa 100644
---
a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala
+++
b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala
@@ -209,8 +209,8 @@ class SparkConnectDatabaseMetaDataSuite extends
ConnectFunSuite with RemoteSpark
withConnection { conn =>
val metadata = conn.getMetaData
// scalastyle:off line.size.limit
- // CURRENT_PATH is excluded: getSQLKeywords drops SQL:2003 reserved
words (see companion).
- assert(metadata.getSQLKeywords ===
"ADD,AFTER,AGGREGATE,ALWAYS,ANALYZE,ANTI,ANY_VALUE,APPROX,ARCHIVE,ASC,BINDING,BUCKET,BUCKETS,BYTE,CACHE,CASCADE,CATALOG,CATALOGS,CHANGE,CHANGES,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATION,COLLATIONS,COLLECTION,COLUMNS,COMMENT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONTAINS,CONTINUE,COST,CURRENT_DATABASE,CURRENT_SCHEMA,DATA,DATABASE,DATABASES,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAYOFYEAR,DAYS,DBPROPERTIES,DEFAULT_PATH,DEFINED,DEFINER,DE
[...]
+ // CURRENT_PATH and SYSTEM are excluded: getSQLKeywords drops SQL:2003
reserved words (see companion).
+ assert(metadata.getSQLKeywords ===
"ADD,AFTER,AGGREGATE,ALWAYS,ANALYZE,ANTI,ANY_VALUE,APPROX,ARCHIVE,ASC,BERNOULLI,BINDING,BUCKET,BUCKETS,BYTE,CACHE,CASCADE,CATALOG,CATALOGS,CHANGE,CHANGES,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATION,COLLATIONS,COLLECTION,COLUMNS,COMMENT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONTAINS,CONTINUE,COST,CURRENT_DATABASE,CURRENT_SCHEMA,DATA,DATABASE,DATABASES,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAYOFYEAR,DAYS,DBPROPERTIES,DEFAULT_PATH,DEFINED,
[...]
// scalastyle:on line.size.limit
}
}
diff --git
a/sql/connect/common/src/test/resources/query-tests/explain-results/sample_fraction_seed.explain
b/sql/connect/common/src/test/resources/query-tests/explain-results/sample_fraction_seed.explain
index f94e0a850e40..9bcbf8813539 100644
---
a/sql/connect/common/src/test/resources/query-tests/explain-results/sample_fraction_seed.explain
+++
b/sql/connect/common/src/test/resources/query-tests/explain-results/sample_fraction_seed.explain
@@ -1,2 +1,2 @@
-Sample 0.0, 0.43, false, 9890823
+Sample 0.0, 0.43, false, 9890823, Bernoulli
+- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git
a/sql/connect/common/src/test/resources/query-tests/explain-results/sample_withReplacement_fraction_seed.explain
b/sql/connect/common/src/test/resources/query-tests/explain-results/sample_withReplacement_fraction_seed.explain
index 340c25ab6d01..5af5314e48f9 100644
---
a/sql/connect/common/src/test/resources/query-tests/explain-results/sample_withReplacement_fraction_seed.explain
+++
b/sql/connect/common/src/test/resources/query-tests/explain-results/sample_withReplacement_fraction_seed.explain
@@ -1,2 +1,2 @@
-Sample 0.0, 0.23, true, 898
+Sample 0.0, 0.23, true, 898, Bernoulli
+- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 2488b6aa5115..be7013188f2f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import java.util.Locale
import java.util.concurrent.TimeUnit._
import org.apache.hadoop.fs.Path
@@ -159,8 +160,11 @@ case class RowDataSourceScanExec(
private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
- private def pushedSampleMetadataString(s: TableSampleInfo): String =
- s"SAMPLE (${(s.upperBound - s.lowerBound) * 100}) ${s.withReplacement}
SEED(${s.seed})"
+ private def pushedSampleMetadataString(s: TableSampleInfo): String = {
+ val pct = (s.upperBound - s.lowerBound) * 100
+ val method = s.sampleMethod.toString.toUpperCase(Locale.ROOT)
+ s"$method SAMPLE ($pct) ${s.withReplacement} SEED(${s.seed})"
+ }
override val metadata: Map[String, String] = {
val markedFilters = if (filters.nonEmpty) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 2c060aa3f9a5..92818c12bfa0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -1040,7 +1040,14 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
execution.FilterExec(f.typedCondition(f.deserializer),
planLater(f.child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
- case logical.Sample(lb, ub, withReplacement, seed, child) =>
+ case logical.Sample(lb, ub, withReplacement, seed, child, sampleMethod)
=>
+ if (sampleMethod == logical.SampleMethod.System) {
+ // V2ScanRelationPushDown is non-excludable and always handles
SYSTEM samples
+ // (either pushes down or throws). Reaching here indicates an
internal invariant
+ // violation.
+ throw SparkException.internalError(
+ "TABLESAMPLE SYSTEM node was not properly handled by
V2ScanRelationPushDown.")
+ }
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child))
:: Nil
case logical.LocalRelation(output, data, _, stream) =>
LocalTableScanExec(output, data, stream) :: Nil
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 0d34dfc91c39..e31e81fc1fa9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -22,13 +22,14 @@ import scala.collection.mutable
import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
AttributeSet, DynamicPruning, DynamicPruningExpression, Expression,
ExpressionSet, GetStructField, NamedExpression, PythonUDF, SchemaPruning,
SubqueryExpression, V2ExpressionUtils}
+import org.apache.spark.sql.catalyst.plans.logical.SampleMethod
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.{IdentityTransform,
SortOrder}
import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownOffset,
SupportsPushDownRequiredColumns, SupportsPushDownTableSample,
SupportsPushDownTopN, SupportsPushDownV2Filters, SupportsRuntimeV2Filtering}
+import org.apache.spark.sql.connector.read.{SampleMethod => SampleMethodV2,
Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit,
SupportsPushDownOffset, SupportsPushDownRequiredColumns,
SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters,
SupportsRuntimeV2Filtering}
import org.apache.spark.sql.execution.{ScalarSubquery => ExecScalarSubquery}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy,
DataSourceUtils}
import org.apache.spark.sql.internal.SQLConf
@@ -398,7 +399,11 @@ object PushDownUtils extends Logging {
scanBuilder match {
case s: SupportsPushDownTableSample =>
s.pushTableSample(
- sample.lowerBound, sample.upperBound, sample.withReplacement,
sample.seed)
+ sample.lowerBound, sample.upperBound, sample.withReplacement,
sample.seed,
+ sample.sampleMethod match {
+ case SampleMethod.Bernoulli => SampleMethodV2.BERNOULLI
+ case SampleMethod.System => SampleMethodV2.SYSTEM
+ })
case _ => false
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala
index cb4fb9eb0809..441ed28c813c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala
@@ -17,8 +17,11 @@
package org.apache.spark.sql.execution.datasources.v2
+import org.apache.spark.sql.catalyst.plans.logical.SampleMethod
+
case class TableSampleInfo(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
- seed: Long)
+ seed: Long,
+ sampleMethod: SampleMethod = SampleMethod.Bernoulli)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index c0b72123065f..60a2017e6947 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -23,11 +23,12 @@ import scala.collection.mutable
import org.apache.spark.{SparkException, SparkIllegalArgumentException}
import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES,
GROUP_BY_EXPRS, JOIN_CONDITION, JOIN_TYPE, POST_SCAN_FILTERS, PUSHED_FILTERS,
RELATION_NAME, RELATION_OUTPUT}
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And,
Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression,
ExpressionSet, ExprId, IntegerLiteral, Literal, NamedExpression,
PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation,
ScanOperation}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join,
LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset,
OffsetAndLimit, Project, Sample, Sort}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join,
LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset,
OffsetAndLimit, Project, Sample, SampleMethod, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
@@ -150,6 +151,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper {
rightProjections.forall(_.isInstanceOf[AttributeReference]) &&
// Cross joins are not supported because they increase the amount of
data.
condition.isDefined &&
+ // Do not push down join if either side has a pushed sample, because
+ // the merged scan builder would silently discard it.
+ // TODO(SPARK-56504): Extend SupportsPushDownJoin to accept pushed
+ // samples so sources supporting both can handle the composition.
+ leftHolder.pushedSample.isEmpty && rightHolder.pushedSample.isEmpty &&
lBuilder.isOtherSideCompatibleForJoin(rBuilder) =>
// Process left and right columns in original order
val (leftSideRequiredColumnsWithAliases,
rightSideRequiredColumnsWithAliases) =
@@ -844,15 +850,26 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper {
sample.lowerBound,
sample.upperBound,
sample.withReplacement,
- sample.seed.getOrElse((math.random() * 1000).toLong))
+ // TODO(SPARK-56573): The * 1000 limits the seed to only 1000
distinct values.
+ // Kept here for consistency with SampleExec.resolvedSeed; will be
fixed
+ // across all call sites in SPARK-56573.
+ sample.seed.getOrElse((math.random() * 1000).toLong),
+ sampleMethod = sample.sampleMethod)
val pushed = PushDownUtils.pushTableSample(sHolder.builder,
tableSample)
if (pushed) {
sHolder.pushedSample = Some(tableSample)
sample.child
+ } else if (sample.sampleMethod == SampleMethod.System) {
+ throw new AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM",
+ messageParameters = Map.empty)
} else {
sample
}
-
+ case _ if sample.sampleMethod == SampleMethod.System =>
+ throw new AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_NO_SCAN",
+ messageParameters = Map.empty)
case _ => sample
}
}
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out
index 84ec13334ffd..a6a86f9ebe1d 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out
@@ -1979,7 +1979,7 @@ org.apache.spark.sql.catalyst.parser.ParseException
table t
|> tablesample (100 percent) repeatable (0)
-- !query analysis
-Sample 0.0, 1.0, false, 0
+Sample 0.0, 1.0, false, 0, Bernoulli
+- SubqueryAlias spark_catalog.default.t
+- Relation spark_catalog.default.t[x#x,y#x] csv
@@ -1998,7 +1998,7 @@ GlobalLimit 2
table t
|> tablesample (bucket 1 out of 1) repeatable (0)
-- !query analysis
-Sample 0.0, 1.0, false, 0
+Sample 0.0, 1.0, false, 0, Bernoulli
+- SubqueryAlias spark_catalog.default.t
+- Relation spark_catalog.default.t[x#x,y#x] csv
@@ -2009,10 +2009,10 @@ table t
|> tablesample (5 rows) repeatable (0)
|> tablesample (bucket 1 out of 1) repeatable (0)
-- !query analysis
-Sample 0.0, 1.0, false, 0
+Sample 0.0, 1.0, false, 0, Bernoulli
+- GlobalLimit 5
+- LocalLimit 5
- +- Sample 0.0, 1.0, false, 0
+ +- Sample 0.0, 1.0, false, 0, Bernoulli
+- SubqueryAlias spark_catalog.default.t
+- Relation spark_catalog.default.t[x#x,y#x] csv
@@ -2435,7 +2435,7 @@ Project [a#x]
: +- Project [a#x]
: +- SubqueryAlias grouping
: +- LocalRelation [a#x]
- +- Sample 0.0, 1.0, false, 0
+ +- Sample 0.0, 1.0, false, 0, Bernoulli
+- SubqueryAlias jt2
+- SubqueryAlias join_test_t2
+- View (`join_test_t2`, [a#x])
@@ -2458,7 +2458,7 @@ Project [a#x]
: +- SubqueryAlias grouping
: +- LocalRelation [a#x]
+- SubqueryAlias jt2
- +- Sample 0.0, 1.0, false, 0
+ +- Sample 0.0, 1.0, false, 0, Bernoulli
+- Project [1 AS a#x]
+- OneRowRelation
diff --git
a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out
b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out
index 6f9e8fde5d9f..6bcbdd2840f9 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords-enforced.sql.out
@@ -25,6 +25,7 @@ AT false
ATOMIC false
AUTHORIZATION true
BEGIN false
+BERNOULLI false
BETWEEN false
BIGINT false
BINARY false
@@ -358,6 +359,7 @@ STRUCT false
SUBSTR false
SUBSTRING false
SYNC false
+SYSTEM false
SYSTEM_PATH false
SYSTEM_TIME false
SYSTEM_VERSION false
diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
index 1fdb51507bc1..a01034326446 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
@@ -25,6 +25,7 @@ AT false
ATOMIC false
AUTHORIZATION false
BEGIN false
+BERNOULLI false
BETWEEN false
BIGINT false
BINARY false
@@ -358,6 +359,7 @@ STRUCT false
SUBSTR false
SUBSTRING false
SYNC false
+SYSTEM false
SYSTEM_PATH false
SYSTEM_TIME false
SYSTEM_VERSION false
diff --git
a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out
index 1fdb51507bc1..a01034326446 100644
--- a/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/nonansi/keywords.sql.out
@@ -25,6 +25,7 @@ AT false
ATOMIC false
AUTHORIZATION false
BEGIN false
+BERNOULLI false
BETWEEN false
BIGINT false
BINARY false
@@ -358,6 +359,7 @@ STRUCT false
SUBSTR false
SUBSTRING false
SYNC false
+SYSTEM false
SYSTEM_PATH false
SYSTEM_TIME false
SYSTEM_VERSION false
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2TableSampleSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2TableSampleSuite.scala
new file mode 100644
index 000000000000..76ec2e588eae
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2TableSampleSuite.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.AnalysisException
+import
org.apache.spark.sql.connector.catalog.{InMemoryTableWithJoinAndSampleCatalog,
InMemoryTableWithLegacyTableSampleCatalog, InMemoryTableWithTableSampleCatalog}
+import org.apache.spark.sql.internal.SQLConf
+
+class DataSourceV2TableSampleSuite extends DatasourceV2SQLBase
+ with DataSourcePushdownTestUtils {
+
+ private val sampleCatalog = "testsample"
+
+ private def withSampleTable(testFunc: String => Unit): Unit = {
+ registerCatalog(sampleCatalog,
classOf[InMemoryTableWithTableSampleCatalog])
+ val tableName = s"$sampleCatalog.ns.sample_tbl"
+ sql(s"CREATE TABLE $tableName (id bigint, data string) USING _")
+ try {
+ sql(s"INSERT INTO $tableName VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4,
'd'), (5, 'e')")
+ testFunc(tableName)
+ } finally {
+ sql(s"DROP TABLE IF EXISTS $tableName")
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM pushdown removes Sample node") {
+ withSampleTable { table =>
+ val df = sql(s"SELECT * FROM $table TABLESAMPLE SYSTEM (50 PERCENT)")
+ checkSamplePushed(df, pushed = true)
+ checkPushedInfo(df, "SYSTEM SAMPLE (50.0) false SEED(")
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE BERNOULLI pushdown removes Sample node") {
+ withSampleTable { table =>
+ val df = sql(s"SELECT * FROM $table TABLESAMPLE BERNOULLI (50 PERCENT)")
+ checkSamplePushed(df, pushed = true)
+ checkPushedInfo(df, "BERNOULLI SAMPLE (50.0) false SEED(")
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE default (no qualifier) pushdown removes
Sample node") {
+ withSampleTable { table =>
+ val df = sql(s"SELECT * FROM $table TABLESAMPLE (50 PERCENT)")
+ checkSamplePushed(df, pushed = true)
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM 0 PERCENT returns no rows") {
+ withSampleTable { table =>
+ val df = sql(s"SELECT * FROM $table TABLESAMPLE SYSTEM (0 PERCENT)")
+ checkSamplePushed(df, pushed = true)
+ assert(df.collect().isEmpty)
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM 100 PERCENT returns all rows") {
+ withSampleTable { table =>
+ val df = sql(s"SELECT * FROM $table TABLESAMPLE SYSTEM (100 PERCENT)")
+ checkSamplePushed(df, pushed = true)
+ assert(df.collect().length == 5)
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM composes with projection") {
+ withSampleTable { table =>
+ val df = sql(s"SELECT id FROM $table TABLESAMPLE SYSTEM (100 PERCENT)")
+ checkSamplePushed(df, pushed = true)
+ assert(df.columns.sameElements(Array("id")))
+ assert(df.collect().length == 5)
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE on non-pushdown catalog falls back to Sample
node") {
+ val table = "testcat.ns.no_sample_tbl"
+ sql(s"CREATE TABLE $table (id bigint, data string) USING _")
+ try {
+ sql(s"INSERT INTO $table VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ val df = sql(s"SELECT * FROM $table TABLESAMPLE (50 PERCENT)")
+ // testcat uses InMemoryCatalog which does NOT implement
SupportsPushDownTableSample,
+ // so the Sample node should remain in the plan.
+ checkSamplePushed(df, pushed = false)
+ } finally {
+ sql(s"DROP TABLE IF EXISTS $table")
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM on non-pushdown catalog errors") {
+ val table = "testcat.ns.no_sample_tbl"
+ sql(s"CREATE TABLE $table (id bigint, data string) USING _")
+ try {
+ sql(s"INSERT INTO $table VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ // testcat uses InMemoryCatalog whose ScanBuilder does not implement
+ // SupportsPushDownTableSample, so SYSTEM sampling cannot be pushed down.
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(s"SELECT * FROM $table TABLESAMPLE SYSTEM (50
PERCENT)").collect()
+ },
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM")
+ } finally {
+ sql(s"DROP TABLE IF EXISTS $table")
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM on subquery errors") {
+ withSampleTable { table =>
+ // SYSTEM sampling requires a direct table scan; applying it to a derived
+ // query (here an aggregate) means there is no ScanBuilderHolder to push
into.
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(s"SELECT * FROM (SELECT id, count(*) AS cnt FROM $table GROUP BY
id) " +
+ s"TABLESAMPLE SYSTEM (50 PERCENT)").collect()
+ },
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_NO_SCAN")
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM with WHERE filter errors") {
+ withSampleTable { table =>
+ // A WHERE clause between the Sample and the scan produces a non-empty
filter list
+ // in PhysicalOperation, which falls through to the catch-all error
branch.
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(s"SELECT * FROM (SELECT * FROM $table WHERE id > 1) " +
+ s"TABLESAMPLE SYSTEM (50 PERCENT)").collect()
+ },
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_NO_SCAN")
+ }
+ }
+
+ test("SPARK-55978: TABLESAMPLE SYSTEM on DSv1 table errors") {
+ withTable("dsv1_tbl") {
+ sql("CREATE TABLE dsv1_tbl (id bigint, data string) USING parquet")
+ sql("INSERT INTO dsv1_tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ // DSv1 tables have no ScanBuilderHolder, so SYSTEM sampling cannot be
pushed down.
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("SELECT * FROM dsv1_tbl TABLESAMPLE SYSTEM (50
PERCENT)").collect()
+ },
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM_NO_SCAN")
+ }
+ }
+
+ test("SPARK-55978: join pushdown is skipped when a side has a pushed
sample") {
+ val joinSampleCatalog = "testjoinsample"
+ registerCatalog(joinSampleCatalog,
classOf[InMemoryTableWithJoinAndSampleCatalog])
+ val t1 = s"$joinSampleCatalog.ns.t1"
+ val t2 = s"$joinSampleCatalog.ns.t2"
+ sql(s"CREATE TABLE $t1 (id bigint, data string) USING _")
+ sql(s"CREATE TABLE $t2 (id bigint, data string) USING _")
+ try {
+ sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ sql(s"INSERT INTO $t2 VALUES (2, 'x'), (3, 'y'), (4, 'z')")
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ // Without sample: join should be pushed down
+ val dfNoSample = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.id = $t2.id")
+ checkJoinPushed(dfNoSample)
+
+ // With SYSTEM sample on one side: join pushdown should be skipped
+ val dfWithSample = sql(
+ s"SELECT * FROM $t1 TABLESAMPLE SYSTEM (100 PERCENT) " +
+ s"JOIN $t2 ON $t1.id = $t2.id")
+ checkJoinNotPushed(dfWithSample)
+ // The sample should still be pushed down though
+ checkSamplePushed(dfWithSample, pushed = true)
+ }
+ } finally {
+ sql(s"DROP TABLE IF EXISTS $t1")
+ sql(s"DROP TABLE IF EXISTS $t2")
+ }
+ }
+
+ test("SPARK-55978: legacy connector with only 4-arg pushTableSample -
BERNOULLI pushes down") {
+ val legacyCatalog = "testlegacysample"
+ registerCatalog(legacyCatalog,
classOf[InMemoryTableWithLegacyTableSampleCatalog])
+ val tableName = s"$legacyCatalog.ns.legacy_tbl"
+ sql(s"CREATE TABLE $tableName (id bigint, data string) USING _")
+ try {
+ sql(s"INSERT INTO $tableName VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ // BERNOULLI should push down via the default 5-arg method delegating to
4-arg
+ val dfBernoulli = sql(s"SELECT * FROM $tableName TABLESAMPLE (50
PERCENT)")
+ checkSamplePushed(dfBernoulli, pushed = true)
+
+ // SYSTEM should fail because the default 5-arg method returns false for
SYSTEM,
+ // and SYSTEM requires successful pushdown.
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(s"SELECT * FROM $tableName TABLESAMPLE SYSTEM (50
PERCENT)").collect()
+ },
+ condition = "UNSUPPORTED_FEATURE.TABLESAMPLE_SYSTEM")
+ } finally {
+ sql(s"DROP TABLE IF EXISTS $tableName")
+ }
+ }
+}
diff --git
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
index 5067f7dfbcc5..1ecf5b3dae4a 100644
---
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
+++
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
@@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends
SharedThriftServer {
val sessionHandle = client.openSession(user, "")
val infoValue = client.getInfo(sessionHandle,
GetInfoType.CLI_ODBC_KEYWORDS)
// scalastyle:off line.size.limit
- assert(infoValue.getStringValue ==
"ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,APPROX,ARCHIVE,ARRAY,AS,ASC,ASENSITIVE,AT,ATOMIC,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHANGES,CHAR,CHARACTER,CHECK,CLEAR,CLOSE,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLATIONS,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENA
[...]
+ assert(infoValue.getStringValue ==
"ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,APPROX,ARCHIVE,ARRAY,AS,ASC,ASENSITIVE,AT,ATOMIC,AUTHORIZATION,BEGIN,BERNOULLI,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHANGES,CHAR,CHARACTER,CHECK,CLEAR,CLOSE,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLATIONS,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE
[...]
// scalastyle:on line.size.limit
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]