This is an automated email from the ASF dual-hosted git repository.
szehon-ho pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 206356c63689 [SPARK-53890][SDP] Test (and fix) read/readstream options
are respected for pipelines
206356c63689 is described below
commit 206356c6368910327c3915f8a12fed4910871405
Author: AnishMahto <[email protected]>
AuthorDate: Thu May 28 11:20:06 2026 -0700
[SPARK-53890][SDP] Test (and fix) read/readstream options are respected for
pipelines
### What changes were proposed in this pull request?
Today, read options attached to any `UnresolvedRelation` that is analyzed
by the pipelines flow analyzer are dropped. This PR fixes that bug, and in
doing so also makes the following micro refactors:
- Get rid of `StreamingReadOptions`/`BatchReadOptions`. Previously neither
of the fields of either classes were ever populated, and the classes were
instead used to determine whether a streaming read or batch read was being
executed.
- Propagate the streaming or batch dataframe reader as the sole source of
truth for options to execute reads with, rather than passing in both a reader
and read options side-by-side.
- Correct the `Table` class hierarchy. `Table` _is_ a `GraphElement` but it
is _not_ an Input. Because it was previously inheriting `Input` it had a `load`
override, but that was dead code; logically a `Table` could never be passed
into the polymorphic call sites of `Input.load`.
- Get rid of `AnalysisWarning`, whose exceptions were also dead code
- Refactor `State.findElementsToReset` -> `State.findFlowsToReset`, as
previously the table elements returned by this function had no side effects nor
references - they were dead return values.
### Why are the changes needed?
Prior to these changes, any options specified in
`UnresolvedRelation.options` would be dropped when analyzed via
`FlowAnalysis.analyze`. To my knowledge, in a vanilla installation of Spark
(ex. without Delta io) today there are no options that could be dropped that
would've otherwise actually been respected by the creation of an
`UnresolvedRelation` (ex. via `spark.read.table`), but at the very least this
is future proofing a definite bug.
### How was this patch tested?
`org.apache.spark.sql.pipelines.analysis.ReadOptionsPropagationOnAnalysisSuite`
Closes #53073 from AnishMahto/sdp-fix-read-options-propagation.
Lead-authored-by: AnishMahto <[email protected]>
Co-authored-by: anishm-db <[email protected]>
Signed-off-by: Szehon Ho <[email protected]>
(cherry picked from commit 02f8e3f05c84f56a27de45b9316b093510a0754e)
Signed-off-by: Szehon Ho <[email protected]>
---
.../spark/sql/pipelines/AnalysisWarning.scala | 33 ---
.../apache/spark/sql/pipelines/graph/Flow.scala | 68 +++---
.../spark/sql/pipelines/graph/FlowAnalysis.scala | 91 ++++----
.../sql/pipelines/graph/FlowAnalysisContext.scala | 3 -
.../sql/pipelines/graph/GraphValidations.scala | 2 +-
.../sql/pipelines/graph/PipelinesErrors.scala | 14 --
.../apache/spark/sql/pipelines/graph/State.scala | 26 +--
.../spark/sql/pipelines/graph/elements.scala | 98 ++++----
.../spark/sql/pipelines/util/InputReadInfo.scala | 48 ----
.../ReadOptionsPropagationOnAnalysisSuite.scala | 259 +++++++++++++++++++++
.../sql/pipelines/autocdc/AutoCdcFlowSuite.scala | 59 ++---
.../graph/ConnectValidPipelineSuite.scala | 2 +-
.../pipelines/graph/MaterializeTablesSuite.scala | 1 +
.../utils/TestGraphRegistrationContext.scala | 38 +--
14 files changed, 438 insertions(+), 304 deletions(-)
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala
deleted file mode 100644
index 35b8185c255e..000000000000
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.pipelines
-
-/** Represents a warning generated as part of graph analysis. */
-sealed trait AnalysisWarning
-
-object AnalysisWarning {
-
- /**
- * Warning that some streaming reader options are being dropped
- *
- * @param sourceName Source for which reader options are being dropped.
- * @param droppedOptions Set of reader options that are being dropped for a
specific source.
- */
- case class StreamingReaderOptionsDropped(sourceName: String, droppedOptions:
Seq[String])
- extends AnalysisWarning
-}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
index 9f357ef026b0..740533d7504e 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.pipelines.graph
import scala.util.Try
+import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{functions => F, AnalysisException, Column}
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.classic.DataFrame
-import org.apache.spark.sql.pipelines.AnalysisWarning
import org.apache.spark.sql.pipelines.autocdc.{
AutoCdcReservedNames,
CaseSensitivityLabels,
@@ -32,7 +32,6 @@ import org.apache.spark.sql.pipelines.autocdc.{
Scd1BatchProcessor,
ScdType
}
-import org.apache.spark.sql.pipelines.util.InputReadOptions
import org.apache.spark.sql.types.{DataType, StructField, StructType}
/**
@@ -108,8 +107,7 @@ case class FlowFunctionResult(
streamingInputs: Set[ResolvedInput],
usedExternalInputs: Set[TableIdentifier],
dataFrame: Try[DataFrame],
- sqlConf: Map[String, String],
- analysisWarnings: Seq[AnalysisWarning] = Nil) {
+ sqlConf: Map[String, String]) {
/**
* Returns the names of all of the [[Input]]s used when resolving this
[[Flow]]. If the
@@ -217,7 +215,8 @@ trait ResolvedFlow extends ResolutionCompletedFlow with
Input {
/** Returns the schema of the output of this [[Flow]]. */
def schema: StructType = df.schema
- override def load(readOptions: InputReadOptions): DataFrame = df
+ override def load(asStreaming: Boolean): DataFrame = df
+
def inputs: Set[TableIdentifier] = funcResult.inputs
}
@@ -303,31 +302,48 @@ class AutoCdcMergeFlow(
}
/**
- * Returns an empty dataframe whose schema matches
[[AutoCdcMergeFlow.schema]].
+ * Returns an empty dataframe whose schema matches
[[AutoCdcMergeFlow.schema]]. By construction,
+ * the returned dataframe will be a streaming dataframe.
*
- * Today, [[AutoCdcMergeFlow.load]] is not actually ever called during graph
analysis or
- * execution. An AutoCdcMergeFlow can only be an input to a streaming table
(not an MV or
- * persisted/temp view), and streaming tables take a [[VirtualTableInput]]
as input, not
- * the producing [[Flow]] directly. [[VirtualTableInput]] overrides its own
[[load]] to do
- * schema inference on its input flows, rather than a transitive
[[Flow.load]].
+ * In practice, [[AutoCdcMergeFlow.load]] is not invoked during graph
analysis or execution.
+ * An AutoCdcMergeFlow can only be an input to a streaming table (not an MV
or
+ * persisted/temp view), and streaming tables consume a
[[VirtualTableInput]] rather than the
+ * producing [[Flow]] directly. [[VirtualTableInput]] overrides its own
[[load]] to do schema
+ * inference on its input flows, rather than a transitive [[Flow.load]].
*
- * The [[AutoCdcMergeFlow.load]] implementation exists solely for API
consistency.
+ * The implementation exists for API consistency and throws an internal
error if invoked with
+ * `asStreaming = false`, or if the underlying source dataframe is not
streaming, to surface
+ * a misuse loudly rather than silently producing a non-streaming dataframe.
*/
- override def load(readOptions: InputReadOptions): DataFrame =
changeArgs.storedAsScdType match {
- case ScdType.Type1 =>
- val userSelectedCols: Seq[Column] =
userSelectedSchema.fieldNames.toSeq.map(F.col)
- val emptyCdcMetadataCol: Column =
Scd1BatchProcessor.constructCdcMetadataCol(
- deleteSequence = F.lit(null),
- upsertSequence = F.lit(null),
- sequencingType = sequencingType
- ).as(Scd1BatchProcessor.cdcMetadataColName)
-
- df.select(userSelectedCols :+ emptyCdcMetadataCol: _*)
- case ScdType.Type2 =>
- throw new AnalysisException(
- errorClass = "AUTOCDC_SCD2_NOT_SUPPORTED",
- messageParameters = Map.empty
+ override def load(asStreaming: Boolean): DataFrame = {
+ if (!asStreaming) {
+ throw SparkException.internalError(
+ "Attempted to load AutoCDC flow as a batch flow. AutoCDC flows are
strictly streaming " +
+ "flows, and must be loaded as such."
)
+ }
+ if (!df.isStreaming) {
+ throw SparkException.internalError(
+ "AutoCDC source dataframe is not streaming. AutoCDC flows are strictly
streaming flows, " +
+ "and must be backed by a streaming source."
+ )
+ }
+ changeArgs.storedAsScdType match {
+ case ScdType.Type1 =>
+ val userSelectedCols: Seq[Column] =
userSelectedSchema.fieldNames.toSeq.map(F.col)
+ val emptyCdcMetadataCol: Column =
Scd1BatchProcessor.constructCdcMetadataCol(
+ deleteSequence = F.lit(null),
+ upsertSequence = F.lit(null),
+ sequencingType = sequencingType
+ ).as(Scd1BatchProcessor.cdcMetadataColName)
+
+ df.select(userSelectedCols :+ emptyCdcMetadataCol: _*)
+ case ScdType.Type2 =>
+ throw new AnalysisException(
+ errorClass = "AUTOCDC_SCD2_NOT_SUPPORTED",
+ messageParameters = Map.empty
+ )
+ }
}
/**
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
index 1a00a6339c4b..7e174f2b3107 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
@@ -23,10 +23,8 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{CTESubstitution,
UnresolvedRelation}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
-import org.apache.spark.sql.classic.{DataFrame, Dataset, DataStreamReader,
SparkSession}
-import org.apache.spark.sql.pipelines.AnalysisWarning
+import org.apache.spark.sql.classic.{DataFrame, DataFrameReader, Dataset,
DataStreamReader, SparkSession}
import
org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ExternalDatasetIdentifier,
InternalDatasetIdentifier}
-import org.apache.spark.sql.pipelines.util.{BatchReadOptions,
InputReadOptions, StreamingReadOptions}
object FlowAnalysis {
@@ -64,8 +62,7 @@ object FlowAnalysis {
streamingInputs = ctx.streamingInputs.toSet,
usedExternalInputs = ctx.externalInputs.toSet,
dataFrame = df,
- sqlConf = confs,
- analysisWarnings = ctx.analysisWarnings.toList
+ sqlConf = confs
)
}
}
@@ -112,8 +109,7 @@ object FlowAnalysis {
val resolved = readStreamInput(
context,
name = IdentifierHelper.toQuotedString(u.multipartIdentifier),
- spark.readStream,
- streamingReadOptions = StreamingReadOptions()
+ streamReader = spark.readStream.options(u.options)
).queryExecution.analyzed
// Spark Connect requires the PLAN_ID_TAG to be propagated to the
resolved plan
// to allow correct analysis of the parent plan that contains this
subquery
@@ -124,7 +120,7 @@ object FlowAnalysis {
val resolved = readBatchInput(
context,
name = IdentifierHelper.toQuotedString(u.multipartIdentifier),
- batchReadOptions = BatchReadOptions()
+ batchReader = spark.read.options(u.options)
).queryExecution.analyzed
// Spark Connect requires the PLAN_ID_TAG to be propagated to the
resolved plan
// to allow correct analysis of the parent plan that contains this
subquery
@@ -143,23 +139,25 @@ object FlowAnalysis {
* All the public APIs that read from a dataset should call this function to
read the dataset.
*
* @param name the name of the Dataset to be read.
- * @param batchReadOptions Options for this batch read
+ * @param batchReader the batch dataframe reader, possibly with options, to
execute the read
+ * with.
* @return batch DataFrame that represents data from the specified Dataset.
*/
final private def readBatchInput(
context: FlowAnalysisContext,
name: String,
- batchReadOptions: BatchReadOptions
+ batchReader: DataFrameReader
): DataFrame = {
GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match
{
case inputIdentifier: InternalDatasetIdentifier =>
- readGraphInput(context, inputIdentifier, batchReadOptions)
+ readGraphInput(context, inputIdentifier, isStreamingRead = false)
case inputIdentifier: ExternalDatasetIdentifier =>
readExternalBatchInput(
context,
inputIdentifier = inputIdentifier,
- name = name
+ name = name,
+ batchReader = batchReader
)
}
}
@@ -173,21 +171,19 @@ object FlowAnalysis {
*
* @param name the name of the Dataset to be read.
* @param streamReader The [[DataStreamReader]] that may hold read options
specified by the user.
- * @param streamingReadOptions Options for this streaming read.
* @return streaming DataFrame that represents data from the specified
Dataset.
*/
final private def readStreamInput(
context: FlowAnalysisContext,
name: String,
- streamReader: DataStreamReader,
- streamingReadOptions: StreamingReadOptions
+ streamReader: DataStreamReader
): DataFrame = {
GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match
{
case inputIdentifier: InternalDatasetIdentifier =>
readGraphInput(
context,
inputIdentifier,
- streamingReadOptions
+ isStreamingRead = true
)
case inputIdentifier: ExternalDatasetIdentifier =>
@@ -204,13 +200,13 @@ object FlowAnalysis {
* Internal helper to reference dataset defined in the same
[[DataflowGraph]].
*
* @param inputIdentifier The identifier of the Dataset to be read.
- * @param readOptions Options for this read (may be either streaming or
batch options)
+ * @param isStreamingRead Whether this is a streaming read or batch read.
* @return streaming or batch DataFrame that represents data from the
specified Dataset.
*/
final private def readGraphInput(
ctx: FlowAnalysisContext,
inputIdentifier: InternalDatasetIdentifier,
- readOptions: InputReadOptions
+ isStreamingRead: Boolean
): DataFrame = {
val datasetIdentifier = inputIdentifier.identifier
@@ -227,7 +223,27 @@ object FlowAnalysis {
ctx.availableInput(datasetIdentifier)
}
- val inputDF = input.load(readOptions)
+ val inputDF = input.load(asStreaming = isStreamingRead)
+
+ // Validate that the loaded DataFrame's streaming-ness matches the
requested read mode. Tables
+ // pass through trivially as their [[VirtualTableInput.load]] honors
`asStreaming` by
+ // construction. The check only ever fires for flows.
+ val incompatibleViewReadCheck =
+ ctx.spark.conf.get("pipelines.incompatibleViewCheck.enabled",
"true").toBoolean
+
+ if (incompatibleViewReadCheck && isStreamingRead && !inputDF.isStreaming) {
+ throw new AnalysisException(
+ "INCOMPATIBLE_BATCH_VIEW_READ",
+ Map("datasetIdentifier" -> datasetIdentifier.toString)
+ )
+ }
+ if (incompatibleViewReadCheck && !isStreamingRead && inputDF.isStreaming) {
+ throw new AnalysisException(
+ "INCOMPATIBLE_STREAMING_VIEW_READ",
+ Map("datasetIdentifier" -> datasetIdentifier.toString)
+ )
+ }
+
input match {
// If the referenced input is a [[Flow]], because the query plans will
be fused
// together, we also need to fuse their confs.
@@ -235,9 +251,6 @@ object FlowAnalysis {
case _ =>
}
- val incompatibleViewReadCheck =
- ctx.spark.conf.get("pipelines.incompatibleViewCheck.enabled",
"true").toBoolean
-
// Wrap the DF in an alias so that columns in the DF can be referenced with
// the following in the query:
// - <catalog>.<schema>.<dataset>.<column>
@@ -248,30 +261,10 @@ object FlowAnalysis {
qualifier = Seq(datasetIdentifier.catalog,
datasetIdentifier.database).flatten
)
- readOptions match {
- case sro: StreamingReadOptions =>
- if (!inputDF.isStreaming && incompatibleViewReadCheck) {
- throw new AnalysisException(
- "INCOMPATIBLE_BATCH_VIEW_READ",
- Map("datasetIdentifier" -> datasetIdentifier.toString)
- )
- }
-
- if (sro.droppedUserOptions.nonEmpty) {
- ctx.analysisWarnings +=
AnalysisWarning.StreamingReaderOptionsDropped(
- sourceName = datasetIdentifier.unquotedString,
- droppedOptions = sro.droppedUserOptions.keys.toSeq
- )
- }
- ctx.streamingInputs += ResolvedInput(input, aliasIdentifier)
- case _ =>
- if (inputDF.isStreaming && incompatibleViewReadCheck) {
- throw new AnalysisException(
- "INCOMPATIBLE_STREAMING_VIEW_READ",
- Map("datasetIdentifier" -> datasetIdentifier.toString)
- )
- }
- ctx.batchInputs += ResolvedInput(input, aliasIdentifier)
+ if (isStreamingRead) {
+ ctx.streamingInputs += ResolvedInput(input, aliasIdentifier)
+ } else {
+ ctx.batchInputs += ResolvedInput(input, aliasIdentifier)
}
Dataset.ofRows(
ctx.spark,
@@ -289,11 +282,11 @@ object FlowAnalysis {
final private def readExternalBatchInput(
context: FlowAnalysisContext,
inputIdentifier: ExternalDatasetIdentifier,
- name: String): DataFrame = {
+ name: String,
+ batchReader: DataFrameReader): DataFrame = {
- val spark = context.spark
context.externalInputs += inputIdentifier.identifier
- spark.read.table(inputIdentifier.identifier.quotedString)
+ batchReader.table(inputIdentifier.identifier.quotedString)
}
/**
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala
index 1139946df59a..e5f7cddc4d32 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala
@@ -18,11 +18,9 @@
package org.apache.spark.sql.pipelines.graph
import scala.collection.mutable
-import scala.collection.mutable.ListBuffer
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.classic.SparkSession
-import org.apache.spark.sql.pipelines.AnalysisWarning
/**
* A context used when evaluating a `Flow`'s query into a concrete DataFrame.
@@ -44,7 +42,6 @@ private[pipelines] case class FlowAnalysisContext(
streamingInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty,
requestedInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty,
shouldLowerCaseNames: Boolean = false,
- analysisWarnings: mutable.Buffer[AnalysisWarning] = new
ListBuffer[AnalysisWarning],
spark: SparkSession,
externalInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty
) {
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
index e5ad3de44a8a..a80fdafd1c18 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
@@ -252,7 +252,7 @@ trait GraphValidations extends Logging {
}
protected def validateUserSpecifiedSchemas(): Unit = {
- flows.flatMap(f => table.get(f.identifier)).foreach { t: TableInput =>
+ flows.flatMap(f => table.get(f.identifier)).foreach { t: TableElement =>
// The output inferred schema of a table is the declared schema merged
with the
// schema of all incoming flows. This must be equivalent to the declared
schema.
val inferredSchema = SchemaInferenceUtils
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala
index 7116f5fbcf06..b194e9c235fb 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.pipelines.graph
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
@@ -35,19 +34,6 @@ case class UnresolvedDatasetException(identifier:
TableIdentifier)
s"pipeline but could not be resolved."
)
-/**
- * Exception raised when a flow fails to read from a table defined within the
pipeline
- *
- * @param name The name of the table
- * @param cause The cause of the failure
- */
-case class LoadTableException(name: String, cause: Option[Throwable])
- extends SparkException(
- errorClass = "INTERNAL_ERROR",
- messageParameters = Map("message" -> s"Failed to load table '$name'"),
- cause = cause.orNull
- )
-
object PipelinesErrors extends Logging {
/**
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
index efe5849d1cbd..90c9030cf75e 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
@@ -25,11 +25,13 @@ import org.apache.spark.sql.AnalysisException
object State extends Logging {
/**
- * Find the graph elements to reset given the current update context.
+ * Find the flows to reset given the current update context.
* @param graph The graph to reset.
* @param env The current update context.
*/
- private def findElementsToReset(graph: DataflowGraph, env:
PipelineUpdateContext): Seq[Input] = {
+ private def findFlowsToReset(
+ graph: DataflowGraph,
+ env: PipelineUpdateContext): Seq[ResolvedFlow] = {
// If tableFilter is an instance of SomeTables, this is a refresh
selection and all tables
// to reset should be resettable; Otherwise, this is a full graph update,
and we reset all
// tables that are resettable.
@@ -62,25 +64,17 @@ object State extends Logging {
}
}
- specifiedTablesToReset.flatMap(t => t +:
graph.resolvedFlowsTo(t.identifier)) ++
+ specifiedTablesToReset.flatMap(t => graph.resolvedFlowsTo(t.identifier)) ++
specifiedSinksToReset.flatMap(s => graph.resolvedFlowsTo(s.identifier))
}
/**
- * Performs the following on targets selected for full refresh:
- * - Clearing checkpoint data
- * - Truncating table data
+ * Rolls the streaming checkpoint directory of every flow selected for full
refresh. Table
+ * truncation is handled in [[DatasetManager.materializeTables]] since the
Hive metastore does
+ * not support removing all columns from a table.
*/
- def reset(resolvedGraph: DataflowGraph, env: PipelineUpdateContext):
Seq[Input] = {
- val elementsToReset: Seq[Input] = findElementsToReset(resolvedGraph, env)
-
- elementsToReset.foreach {
- case f: ResolvedFlow => reset(f, env, resolvedGraph)
- case _ => // tables is handled in materializeTables since hive metastore
does not support
- // removing all columns from a table.
- }
-
- elementsToReset
+ def reset(resolvedGraph: DataflowGraph, env: PipelineUpdateContext): Unit = {
+ findFlowsToReset(resolvedGraph, env).foreach(reset(_, env, resolvedGraph))
}
/**
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
index ce3a63de6a33..885755fd78ec 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.pipelines.graph
import java.util
-import scala.util.control.NonFatal
-
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
@@ -29,12 +27,7 @@ import
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.classic.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
import org.apache.spark.sql.pipelines.common.DatasetType
-import org.apache.spark.sql.pipelines.util.{
- BatchReadOptions,
- InputReadOptions,
- SchemaInferenceUtils,
- StreamingReadOptions
-}
+import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils
import org.apache.spark.sql.types.StructType
/** An element in a [[DataflowGraph]]. */
@@ -68,10 +61,10 @@ trait Input extends GraphElement {
/**
* Returns a DataFrame that is a result of loading data from this [[Input]].
- * @param readOptions Type of input. Used to determine streaming/batch
+ * @param asStreaming whether to try loading this input as a streaming or
batch input.
* @return Streaming or batch DataFrame of this Input's data.
*/
- def load(readOptions: InputReadOptions): DataFrame
+ def load(asStreaming: Boolean): DataFrame
}
/**
@@ -101,8 +94,8 @@ sealed trait Dataset extends Output {
def path: String
}
-/** A type of [[Input]] where data is loaded from a table. */
-sealed trait TableInput extends Input {
+/** A graph element backed by a table: either a concrete [[Table]] or a
[[VirtualTableInput]]. */
+sealed trait TableElement extends GraphElement {
/** The user-specified schema for this table. */
def specifiedSchema: Option[StructType]
@@ -132,29 +125,9 @@ case class Table(
override val origin: QueryOrigin,
isStreamingTable: Boolean,
format: Option[String]
-) extends TableInput
+) extends TableElement
with Dataset {
- // Load this table's data from underlying storage.
- override def load(readOptions: InputReadOptions): DataFrame = {
- try {
- lazy val tableName = identifier.quotedString
-
- val df = readOptions match {
- case sro: StreamingReadOptions =>
- spark.readStream.options(sro.userOptions).table(tableName)
- case _: BatchReadOptions =>
- spark.read.table(tableName)
- case _ =>
- throw new IllegalArgumentException("Unhandled `InputReadOptions`
type when loading table")
- }
-
- df
- } catch {
- case NonFatal(e) => throw LoadTableException(displayName, Option(e))
- }
- }
-
/** Returns the normalized storage location to this [[Table]]. */
override def path: String = {
if (!normalized) {
@@ -176,42 +149,55 @@ case class Table(
}
/**
- * A type of [[TableInput]] that returns data from a specified schema or from
the inferred
- * [[Flow]]s that write to the table.
+ * A virtual table is a representation of a pipeline table used during
analysis. During analysis we
+ * only care about the schemas of declared tables, and its possible the
declared tables do not yet
+ * exist in the catalog. Hence we represent all tables in the graph with their
"virtual"
+ * counterparts, which are simply empty dataframes but with the same schemas.
+ *
+ * We refer to the declared table that the virtual counterpart represents as
the "parent" table
+ * below.
+ *
+ * @param identifier The identifier of the parent table.
+ * @param specifiedSchema The user-specified schema for the parent table.
+ * @param incomingFlowIdentifiers The identifiers of all flows that write to
the parent table.
+ * @param availableFlows All resolved flows that write to the parent table.
*/
case class VirtualTableInput(
identifier: TableIdentifier,
specifiedSchema: Option[StructType],
incomingFlowIdentifiers: Set[TableIdentifier],
availableFlows: Seq[ResolvedFlow] = Nil
-) extends TableInput
+) extends TableElement with Input
with Logging {
override def origin: QueryOrigin = QueryOrigin()
assert(availableFlows.forall(_.destinationIdentifier == identifier))
- override def load(readOptions: InputReadOptions): DataFrame = {
- // Infer the schema for this virtual table
- def getFinalSchema: StructType = {
- specifiedSchema match {
- // This is not a backing table, and we have a user-specified schema,
so use it directly.
- case Some(ss) => ss
- // Otherwise infer the schema from a combination of the incoming flows
and the
- // user-specified schema, if provided.
- case _ =>
- SchemaInferenceUtils.inferSchemaFromFlows(availableFlows,
specifiedSchema)
- }
- }
- // create empty streaming/batch df based on input type.
- def createEmptyDF(schema: StructType): DataFrame = readOptions match {
- case _: StreamingReadOptions =>
- MemoryStream[Row](ExpressionEncoder(schema, lenient = false), spark)
- .toDF()
- case _ => spark.createDataFrame(new util.ArrayList[Row](), schema)
+ /**
+ * Loads this virtual table as a dataframe
+ *
+ * @param asStreaming whether to load as a streaming DF or batch DF. There
are cases where we may
+ * want to batch read from a streaming table, for example.
+ */
+ def load(asStreaming: Boolean): DataFrame = {
+ val deducedSchema = specifiedSchema match {
+ // If the user specified a schema, use it directly.
+ case Some(ss) => ss
+ // Otherwise infer the schema from a combination of the incoming flows
and the
+ // user-specified schema, if provided.
+ case _ =>
+ SchemaInferenceUtils.inferSchemaFromFlows(availableFlows,
specifiedSchema)
}
- val df = createEmptyDF(getFinalSchema)
- df
+ // Produce either a streaming or batch dataframe, depending on whether
this is a virtual
+ // representation of a streaming or non-streaming table. Return the
[empty] dataframe with the
+ // deduced schema.
+ if (asStreaming) {
+ MemoryStream[Row](ExpressionEncoder(deducedSchema, lenient = false),
spark)
+ .toDF()
+ } else {
+ spark.createDataFrame(new util.ArrayList[Row](), deducedSchema)
+ }
}
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala
deleted file mode 100644
index 070927aea295..000000000000
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.pipelines.util
-
-import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-import
org.apache.spark.sql.pipelines.util.StreamingReadOptions.EmptyUserOptions
-
-/**
- * Generic options for a read of an input.
- */
-sealed trait InputReadOptions
-
-/**
- * Options for a batch read of an input.
- */
-final case class BatchReadOptions() extends InputReadOptions
-
-/**
- * Options for a streaming read of an input.
- *
- * @param userOptions Holds the user defined read options.
- * @param droppedUserOptions Holds the options that were specified by the user
but
- * not actually used. This is a bug but we are
preserving this behavior
- * for now to avoid making a backwards incompatible
change.
- */
-final case class StreamingReadOptions(
- userOptions: CaseInsensitiveMap[String] = EmptyUserOptions,
- droppedUserOptions: CaseInsensitiveMap[String] = EmptyUserOptions
-) extends InputReadOptions
-
-object StreamingReadOptions {
- val EmptyUserOptions: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map())
-}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/analysis/ReadOptionsPropagationOnAnalysisSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/analysis/ReadOptionsPropagationOnAnalysisSuite.scala
new file mode 100644
index 000000000000..763a6f500fdd
--- /dev/null
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/analysis/ReadOptionsPropagationOnAnalysisSuite.scala
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.analysis
+
+import scala.collection.concurrent.TrieMap
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias
+import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation,
LogicalRelation}
+import org.apache.spark.sql.execution.streaming.runtime.StreamingRelation
+import org.apache.spark.sql.pipelines.graph.{FlowFunction, FlowFunctionResult,
Input, QueryContext, QueryOrigin}
+import org.apache.spark.sql.pipelines.utils.{ExecutionTest,
TestGraphRegistrationContext}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * Tracker for flow function results.
+ *
+ * SDP analyzes flows in parallel (see
[[DataflowGraphTransformer.transformDownNodes]]), so the
+ * backing map must tolerate concurrent writes from worker threads. A
[[TrieMap]] gives lock-free
+ * per-key atomicity while remaining a [[scala.collection.mutable.Map]] for
read sites.
+ *
+ * @param flowFunctionResults Concurrent map storing the latest
FlowFunctionResult per flow
+ * function.
+ */
+case class FlowFunctionResultTracker(
+ flowFunctionResults: scala.collection.concurrent.Map[String,
FlowFunctionResult] =
+ TrieMap.empty
+)
+
+/**
+ * Instrumented FlowFunction implementation, used to track flow function
results.
+ * @param flowName The name of the flow function being tracked
+ * @param flowFunction The flow function being tracked
+ * @param flowFunctionResultTracker The flow function results tracker instance
+ */
+class InstrumentedFlowFunction(
+ flowName: String,
+ flowFunction: FlowFunction,
+ flowFunctionResultTracker: FlowFunctionResultTracker
+)
+ extends FlowFunction {
+ override def call(
+ allInputs: Set[TableIdentifier],
+ availableInputs: Seq[Input],
+ configuration: Map[String, String],
+ queryContext: QueryContext,
+ queryOrigin: QueryOrigin
+ ): FlowFunctionResult = {
+ val flowFunctionResult = flowFunction.call(
+ allInputs,
+ availableInputs,
+ configuration,
+ queryContext,
+ queryOrigin
+ )
+ flowFunctionResultTracker.flowFunctionResults.put(flowName,
flowFunctionResult)
+ flowFunctionResult
+ }
+}
+
+class InstrumentedTestGraphRegistrationContext(
+ spark: SparkSession,
+ flowFunctionResultTracker: FlowFunctionResultTracker
+)
+ extends TestGraphRegistrationContext(spark) {
+
+ def readFlowFunc(
+ flowNameForTracking: String,
+ tableName: String,
+ extraOptions: CaseInsensitiveStringMap
+ ): FlowFunction =
+ new InstrumentedFlowFunction(
+ flowName = flowNameForTracking,
+ flowFunction = readFlowFunc(tableName, extraOptions),
+ flowFunctionResultTracker = flowFunctionResultTracker
+ )
+
+ def readStreamFlowFunc(
+ flowNameForTracking: String,
+ tableName: String,
+ extraOptions: CaseInsensitiveStringMap
+ ): FlowFunction =
+ new InstrumentedFlowFunction(
+ flowName = flowNameForTracking,
+ flowFunction = readStreamFlowFunc(tableName, extraOptions),
+ flowFunctionResultTracker = flowFunctionResultTracker
+ )
+}
+
+/**
+ * Test suite for verifying propagation of read options during pipelines
analysis.
+ */
+class ReadOptionsPropagationOnAnalysisSuite extends ExecutionTest with
SharedSparkSession {
+ test("internal pipeline batch read options are propagated during flow
function analysis") {
+ val session = spark
+ import session.implicits._
+
+ val flowFunctionResultTracker = FlowFunctionResultTracker()
+
+ withTable("spark_catalog.test_db.a", "spark_catalog.test_db.b") {
+ val graphRegistrationContext =
+ new InstrumentedTestGraphRegistrationContext(spark,
flowFunctionResultTracker) {
+ registerMaterializedView(name = "a", query = dfFlowFunc(Seq(1,
2).toDF("id")))
+ registerMaterializedView(
+ name = "b",
+ query = readFlowFunc(
+ flowNameForTracking = "bFlow",
+ tableName = "a",
+ extraOptions = new CaseInsensitiveStringMap(Map("x" ->
"y").asJava)
+ )
+ )
+ }
+ val unresolvedGraph = graphRegistrationContext.toDataflowGraph
+
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph,
storageRoot)
+ updateContext.pipelineExecution.runPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+
+ val bFlow =
flowFunctionResultTracker.flowFunctionResults.get("bFlow").get
+
+ // Verify the flow function's analyzed DF logical plan contains
specified options.
+ assert(bFlow.dataFrame.get.logicalPlan
+ .asInstanceOf[SubqueryAlias].child
+ .asInstanceOf[LogicalRelation].relation
+ .asInstanceOf[HadoopFsRelation].options.get("x").contains("y"))
+ }
+ }
+
+ test("internal pipeline stream read options are propagated during flow
function analysis") {
+ val flowFunctionResultTracker = FlowFunctionResultTracker()
+
+ withTable("spark_catalog.default.a", "spark_catalog.test_db.b",
"spark_catalog.test_db.c") {
+ // Create a regular external table that ST "b" can stream from, then
have ST "c" stream from
+ // "b".
+ spark.range(10).write.saveAsTable("spark_catalog.default.a")
+
+ val graphRegistrationContext =
+ new InstrumentedTestGraphRegistrationContext(spark,
flowFunctionResultTracker) {
+ registerTable(
+ name = "b",
+ query = Option(
+ readStreamFlowFunc(
+ name = "spark_catalog.default.a"
+ )
+ )
+ )
+ registerTable(
+ name = "c",
+ query = Option(
+ readStreamFlowFunc(
+ flowNameForTracking = "cFlow",
+ tableName = "b",
+ extraOptions = new CaseInsensitiveStringMap(Map("x" ->
"y").asJava)
+ )
+ )
+ )
+ }
+ val unresolvedGraph = graphRegistrationContext.toDataflowGraph
+
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph,
storageRoot)
+ updateContext.pipelineExecution.runPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+
+ val cFlow =
flowFunctionResultTracker.flowFunctionResults.get("cFlow").get
+
+ // Verify the flow function's analyzed DF logical plan contains
specified options.
+ assert(cFlow.dataFrame.get.logicalPlan
+ .asInstanceOf[SubqueryAlias].child
+
.asInstanceOf[StreamingRelation].dataSource.options.get("x").contains("y"))
+ }
+ }
+
+ test("external pipeline batch read options are propagated during flow
function analysis") {
+ val flowFunctionResultTracker = FlowFunctionResultTracker()
+
+ withTable("spark_catalog.default.a", "spark_catalog.test_db.b") {
+ // Create regular external table to batch read from with options.
+ spark.range(10).write.saveAsTable("spark_catalog.default.a")
+
+ val graphRegistrationContext =
+ new InstrumentedTestGraphRegistrationContext(spark,
flowFunctionResultTracker) {
+ registerMaterializedView(
+ name = "b",
+ query = readFlowFunc(
+ flowNameForTracking = "bFlow",
+ tableName = "spark_catalog.default.a",
+ extraOptions = new CaseInsensitiveStringMap(Map("x" ->
"y").asJava)
+ )
+ )
+ }
+ val unresolvedGraph = graphRegistrationContext.toDataflowGraph
+
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph,
storageRoot)
+ updateContext.pipelineExecution.runPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+
+ val bFlow =
flowFunctionResultTracker.flowFunctionResults.get("bFlow").get
+
+ // Verify the flow function's analyzed DF logical plan contains
specified options.
+ assert(bFlow.dataFrame.get.logicalPlan
+ .asInstanceOf[SubqueryAlias].child
+ .asInstanceOf[LogicalRelation].relation
+ .asInstanceOf[HadoopFsRelation].options.get("x").contains("y"))
+ }
+ }
+
+ test("external pipeline stream read options are propagated during flow
function analysis") {
+ val flowFunctionResultTracker = FlowFunctionResultTracker()
+
+ withTable("spark_catalog.default.a", "spark_catalog.test_db.b") {
+ // Create regular external table to stream from with read options.
+ spark.range(10).write.saveAsTable("spark_catalog.default.a")
+
+ val graphRegistrationContext =
+ new InstrumentedTestGraphRegistrationContext(spark,
flowFunctionResultTracker) {
+ registerTable(
+ name = "b",
+ query = Option(
+ readStreamFlowFunc(
+ flowNameForTracking = "bFlow",
+ tableName = "spark_catalog.default.a",
+ extraOptions = new CaseInsensitiveStringMap(Map("x" ->
"y").asJava)
+ )
+ )
+ )
+ }
+ val unresolvedGraph = graphRegistrationContext.toDataflowGraph
+
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph,
storageRoot)
+ updateContext.pipelineExecution.runPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+
+ val bFlow =
flowFunctionResultTracker.flowFunctionResults.get("bFlow").get
+
+ // Verify the flow function's analyzed DF logical plan contains
specified options.
+ assert(bFlow.dataFrame.get.logicalPlan
+ .asInstanceOf[SubqueryAlias].child
+
.asInstanceOf[StreamingRelation].dataSource.options.get("x").contains("y"))
+ }
+ }
+}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/autocdc/AutoCdcFlowSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/autocdc/AutoCdcFlowSuite.scala
index 65eafd6c7dcc..932110b94afd 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/autocdc/AutoCdcFlowSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/autocdc/AutoCdcFlowSuite.scala
@@ -21,9 +21,10 @@ import java.util.Locale
import scala.util.Success
-import org.apache.spark.sql.{functions => F, AnalysisException, Column,
QueryTest}
+import org.apache.spark.sql.{functions => F, AnalysisException, Column,
QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.classic.DataFrame
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.pipelines.graph.{
AutoCdcFlow,
@@ -182,13 +183,11 @@ class AutoCdcFlowSuite extends QueryTest with
SharedSparkSession {
new AutoCdcMergeFlow(flow, successfulFuncResult(sourceDf))
}
- /** A stable 3-column source CDF schema used across most schema tests. */
+ /** A stable 3-column source streaming dataframe used across most schema
tests. */
private def threeColumnSourceDf(): DataFrame = {
- val schema = new StructType()
- .add("id", IntegerType, nullable = false)
- .add("name", StringType)
- .add("seq", LongType)
-
spark.createDataFrame(spark.sparkContext.emptyRDD[org.apache.spark.sql.Row],
schema)
+ val session = spark
+ import session.implicits._
+ MemoryStream[(Int, String, Option[Long])].toDS().toDF("id", "name", "seq")
}
/** Convenience to extract the [[StructType]] of the projected
`_cdc_metadata` column. */
@@ -319,7 +318,7 @@ class AutoCdcFlowSuite extends QueryTest with
SharedSparkSession {
test("AutoCdcMergeFlow.load() schema matches AutoCdcMergeFlow.schema") {
val resolvedFlow = newAutoCdcMergeFlow(threeColumnSourceDf())
- val loadedDf = resolvedFlow.load(readOptions = null)
+ val loadedDf = resolvedFlow.load(asStreaming = true)
assert(loadedDf.schema == resolvedFlow.schema)
}
@@ -332,7 +331,7 @@ class AutoCdcFlowSuite extends QueryTest with
SharedSparkSession {
)
)
)
- val loadedDf = resolvedFlow.load(readOptions = null)
+ val loadedDf = resolvedFlow.load(asStreaming = true)
assert(loadedDf.schema == resolvedFlow.schema)
// The user-selected portion drops `name`; the trailing column is the SCD1
metadata.
assert(
@@ -348,7 +347,7 @@ class AutoCdcFlowSuite extends QueryTest with
SharedSparkSession {
ColumnSelection.ExcludeColumns(Seq(UnqualifiedColumnName("name")))
)
)
- val loadedDf = resolvedFlow.load(readOptions = null)
+ val loadedDf = resolvedFlow.load(asStreaming = true)
assert(loadedDf.schema == resolvedFlow.schema)
assert(
loadedDf.schema.fieldNames.toSeq ==
@@ -356,34 +355,6 @@ class AutoCdcFlowSuite extends QueryTest with
SharedSparkSession {
)
}
- test("AutoCdcMergeFlow.load() materializes the CDC metadata column as
null-filled") {
- // The merge engine fills in the metadata at execution time; at planning
time we synthesize
- // a typed null placeholder so that load().schema matches schema. This
test pins down the
- // placeholder shape: outer struct non-null, inner fields null-filled.
- val schema = new StructType()
- .add("id", IntegerType, nullable = false)
- .add("name", StringType)
- .add("seq", LongType)
- val sourceRows = java.util.Arrays.asList(
- org.apache.spark.sql.Row(1, "a", 100L),
- org.apache.spark.sql.Row(2, "b", 200L)
- )
- val sourceDf = spark.createDataFrame(sourceRows, schema)
- val resolvedFlow = newAutoCdcMergeFlow(sourceDf)
-
- val loadedDf = resolvedFlow.load(readOptions = null)
- val collected = loadedDf.collect()
- assert(collected.length == 2)
-
- val metaIdx =
loadedDf.schema.fieldIndex(Scd1BatchProcessor.cdcMetadataColName)
- collected.foreach { row =>
- assert(!row.isNullAt(metaIdx), "_cdc_metadata struct itself should be
non-null")
- val metaRow = row.getStruct(metaIdx)
- assert(metaRow.isNullAt(0), "deleteSequence placeholder should be null")
- assert(metaRow.isNullAt(1), "upsertSequence placeholder should be null")
- }
- }
-
//
===========================================================================================
// AutoCdcMergeFlow reserved-prefix validation tests
//
@@ -399,10 +370,12 @@ class AutoCdcFlowSuite extends QueryTest with
SharedSparkSession {
/** Builds an empty source df with `id` + `seq` + the supplied extra
columns. */
private def sourceDfWithExtraColumns(extraColumns: (String, DataType)*):
DataFrame = {
- val schema = extraColumns.foldLeft(
- new StructType().add("id", IntegerType, nullable = false).add("seq",
LongType)
- ) { case (acc, (name, dt)) => acc.add(name, dt) }
-
spark.createDataFrame(spark.sparkContext.emptyRDD[org.apache.spark.sql.Row],
schema)
+ val session = spark
+ import session.implicits._
+ val baseStream = MemoryStream[(Int, Option[Long])].toDS().toDF("id", "seq")
+ extraColumns.foldLeft(baseStream) { case (acc, (name, dt)) =>
+ acc.withColumn(name, F.lit(null).cast(dt))
+ }
}
test(
@@ -544,7 +517,7 @@ class AutoCdcFlowSuite extends QueryTest with
SharedSparkSession {
.add("name", StringType)
.add("seq", LongType)
val sourceDf =
-
spark.createDataFrame(spark.sparkContext.emptyRDD[org.apache.spark.sql.Row],
schema)
+ spark.createDataFrame(spark.sparkContext.emptyRDD[Row], schema)
checkError(
exception = intercept[AnalysisException] {
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
index 3c7db2cca889..58a6dff709c7 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
@@ -407,7 +407,7 @@ class ConnectValidPipelineSuite extends PipelineTest with
SharedSparkSession {
mem.addData(1, 2)
registerPersistedView("complete-view", query = dfFlowFunc(Seq(1,
2).toDF("x")))
registerPersistedView("incremental-view", query = dfFlowFunc(mem.toDF()))
- registerTable("`complete-table`", query =
Option(readFlowFunc("complete-view")))
+ registerTable("`complete-table`", query =
Option(readFlowFunc("`complete-view`")))
registerTable("`incremental-table`")
registerFlow(
"`incremental-table`",
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
index ecb810dec291..29d85e9b4439 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
@@ -327,6 +327,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
}
test("specified schema incompatible with existing table") {
+ implicit val sqlCtx: SQLContext = spark.sqlContext
sql(s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t6(x
BOOLEAN)")
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
index f5bdf87a6cc6..068171a46aa1 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.pipelines.utils
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{LocalTempView, PersistedView =>
PersistedViewType, UnresolvedRelation, ViewType}
import org.apache.spark.sql.classic.{DataFrame, SparkSession}
-import org.apache.spark.sql.pipelines.graph.{DataflowGraph, FlowAnalysis,
FlowFunction, GraphIdentifierManager, GraphRegistrationContext, PersistedView,
QueryContext, QueryOrigin, QueryOriginType, Sink, SinkImpl, Table,
TemporaryView, UntypedFlow}
+import org.apache.spark.sql.pipelines.graph.{DataflowGraph, FlowAnalysis,
FlowFunction, GraphIdentifierManager, GraphRegistrationContext, PersistedView,
QueryContext, QueryOrigin, QueryOriginType, Sink, SinkImpl, Table,
TemporaryView, UnresolvedFlow, UntypedFlow}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -29,7 +28,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* A test class to simplify the creation of pipelines and datasets for unit
testing.
*/
class TestGraphRegistrationContext(
- val _spark: SparkSession,
+ val spark: SparkSession,
val sqlConf: Map[String, String] = Map.empty)
extends GraphRegistrationContext(
defaultCatalog = TestGraphRegistrationContext.DEFAULT_CATALOG,
@@ -37,9 +36,8 @@ class TestGraphRegistrationContext(
defaultSqlConf = sqlConf
) {
- /** Re-expose as implicit so nested anonymous classes can use it without
shadowing issues */
- implicit def spark: SparkSession = _spark
- implicit def sqlContext: SQLContext = _spark.sqlContext
+ /** Expose all registered flows for tests */
+ def getFlows: List[UnresolvedFlow] = flows.toList
// scalastyle:off
// Disable scalastyle to ignore argument count.
@@ -150,7 +148,7 @@ class TestGraphRegistrationContext(
val qualifiedIdentifier = GraphIdentifierManager
.parseAndQualifyTableIdentifier(
rawTableIdentifier = GraphIdentifierManager
- .parseTableIdentifier(name, _spark),
+ .parseTableIdentifier(name, spark),
currentCatalog = catalog.orElse(Some(defaultCatalog)),
currentDatabase = database.orElse(Some(defaultDatabase)))
.identifier
@@ -309,9 +307,9 @@ class TestGraphRegistrationContext(
catalog: Option[String] = None,
database: Option[String] = None
): Unit = {
- val rawFlowIdentifier = GraphIdentifierManager.parseTableIdentifier(name,
_spark)
+ val rawFlowIdentifier = GraphIdentifierManager.parseTableIdentifier(name,
spark)
val rawDestinationIdentifier =
- GraphIdentifierManager.parseTableIdentifier(destinationName, _spark)
+ GraphIdentifierManager.parseTableIdentifier(destinationName, spark)
val flowWritesToView = getViews
.filter(_.isInstanceOf[TemporaryView])
@@ -360,19 +358,31 @@ class TestGraphRegistrationContext(
/**
* Creates a flow function from a logical plan that reads from a table with
the given name.
*/
- def readFlowFunc(name: String): FlowFunction = {
-
FlowAnalysis.createFlowFunctionFromLogicalPlan(UnresolvedRelation(TableIdentifier(name)))
+ def readFlowFunc(
+ name: String,
+ extraOptions: CaseInsensitiveStringMap =
CaseInsensitiveStringMap.empty()
+ ): FlowFunction = {
+ FlowAnalysis.createFlowFunctionFromLogicalPlan(
+ UnresolvedRelation(
+ tableIdentifier = GraphIdentifierManager.parseTableIdentifier(name,
spark),
+ extraOptions = extraOptions,
+ isStreaming = false
+ )
+ )
}
/**
* Creates a flow function from a logical plan that reads a stream from a
table with the given
* name.
*/
- def readStreamFlowFunc(name: String): FlowFunction = {
+ def readStreamFlowFunc(
+ name: String,
+ extraOptions: CaseInsensitiveStringMap =
CaseInsensitiveStringMap.empty()
+ ): FlowFunction = {
FlowAnalysis.createFlowFunctionFromLogicalPlan(
UnresolvedRelation(
- TableIdentifier(name),
- extraOptions = CaseInsensitiveStringMap.empty(),
+ tableIdentifier = GraphIdentifierManager.parseTableIdentifier(name,
spark),
+ extraOptions = extraOptions,
isStreaming = true
)
)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]