This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 79a92831cc69 [SPARK-53372][SDP] SDP End to End Testing Suite 79a92831cc69 is described below commit 79a92831cc69559c823c06f5335d5badcb07df1c Author: Jacky Wang <jacky.w...@databricks.com> AuthorDate: Thu Sep 18 01:23:04 2025 +0800 [SPARK-53372][SDP] SDP End to End Testing Suite ### What changes were proposed in this pull request? End to end testing for SDP that simulates a user using the pipelines CLI to execute each test case end to end. The test harness starts the spark connect server, writes source code and pipeline spec as a real SDP project and issues `spark-pipelines run` command. ### Why are the changes needed? There lacks a test that tests the simulate user experience of interaction of the python CLI with the pipeline backend logic. We only had python side or scala side tests before. ### Does this PR introduce _any_ user-facing change? Test only ### How was this patch tested? Test only change ### Was this patch authored or co-authored using generative AI tooling? No Closes #52120 from JiaqiWang18/end-to-end-api-suite. Lead-authored-by: Jacky Wang <jacky.w...@databricks.com> Co-authored-by: Jacky Wang <jacky200...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .github/workflows/build_and_test.yml | 2 +- .../sql/connect/pipelines/EndToEndAPISuite.scala | 173 ++++++ .../apache/spark/sql/pipelines/utils/APITest.scala | 599 +++++++++++++++++++++ 3 files changed, 773 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index d2fac2ffac5a..4fad15d87283 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -362,7 +362,7 @@ jobs: - name: Install Python packages (Python 3.11) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') || contains(matrix.modules, 'yarn') run: | - python3.11 -m pip install 'numpy>=1.22' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.29.1' + python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.29.1' python3.11 -m pip list # Run the tests. - name: Run tests diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala new file mode 100644 index 000000000000..0901c7ef21c9 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.pipelines + +import java.nio.file.{Files, Path, Paths} +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.Duration + +import org.apache.spark.api.python.PythonUtils +import org.apache.spark.sql.connect.SparkConnectServerTest +import org.apache.spark.sql.pipelines.utils.{APITest, PipelineReference, PipelineSourceFile, PipelineTest, TestPipelineConfiguration, TestPipelineSpec} + +case class PipelineReferenceImpl(executionProcess: Process) extends PipelineReference + +/** + * End-to-end test suite for the Spark Declarative Pipelines API using the CLI. + * + * This suite creates a temporary directory for each test case, writes the necessary pipeline + * specification and source files, and invokes the CLI as a separate process. + */ +class EndToEndAPISuite extends PipelineTest with APITest with SparkConnectServerTest { + + // Directory where the pipeline files will be created + private var projectDir: Path = _ + + override def test(testName: String, testTags: org.scalatest.Tag*)(testFun: => Any)(implicit + pos: org.scalactic.source.Position): Unit = { + super.test(testName, testTags: _*) { + withTempDir { dir => + projectDir = dir.toPath + testFun + } + } + } + + override def createAndRunPipeline( + config: TestPipelineConfiguration, + sources: Seq[PipelineSourceFile]): PipelineReference = { + // Create each source file in the temporary directory + sources.foreach { file => + val filePath = Paths.get(file.name) + val fileName = filePath.getFileName.toString + val tempFilePath = projectDir.resolve(fileName) + + // Create the file with the specified contents + Files.write(tempFilePath, file.contents.getBytes("UTF-8")) + logInfo(s"Created file: ${tempFilePath.toAbsolutePath}") + } + + val specFilePath = writePipelineSpecFile(config.pipelineSpec) + val cliCommand: Seq[String] = generateCliCommand(config, specFilePath) + val sourcePath = Paths.get(sparkHome, "python").toAbsolutePath + val py4jPath = Paths.get(sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath + val pythonPath = PythonUtils.mergePythonPaths( + sourcePath.toString, + py4jPath.toString, + sys.env.getOrElse("PYTHONPATH", "")) + + val processBuilder = new ProcessBuilder(cliCommand: _*) + processBuilder.environment().put("PYTHONPATH", pythonPath) + val process = processBuilder.start() + + PipelineReferenceImpl(process) + } + + private def generateCliCommand( + config: TestPipelineConfiguration, + specFilePath: Path): Seq[String] = { + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python3")) + + var cliCommand = Seq( + pythonExec, + Paths.get(sparkHome, "python", "pyspark", "pipelines", "cli.py").toAbsolutePath.toString, + if (config.dryRun) "dry-run" else "run", + "--spec", + specFilePath.toString) + if (config.fullRefreshAll) { + cliCommand :+= "--full-refresh-all" + } + if (config.refreshSelection.nonEmpty) { + cliCommand :+= "--refresh" + cliCommand :+= config.refreshSelection.mkString(",") + } + if (config.fullRefreshSelection.nonEmpty) { + cliCommand :+= "--full-refresh" + cliCommand :+= config.fullRefreshSelection.mkString(",") + } + cliCommand + } + + override def awaitPipelineTermination(pipeline: PipelineReference, duration: Duration): Unit = { + pipeline match { + case ref: PipelineReferenceImpl => + val process = ref.executionProcess + process.waitFor(duration.toSeconds, TimeUnit.SECONDS) + val exitCode = process.exitValue() + if (exitCode != 0) { + throw new RuntimeException(s"""Pipeline update process failed with exit code $exitCode. + |Output: ${new String(process.getInputStream.readAllBytes(), "UTF-8")} + |Error: ${new String( + process.getErrorStream.readAllBytes(), + "UTF-8")}""".stripMargin) + } else { + logInfo("Pipeline update process completed successfully") + logDebug(s"""Output: ${new String(process.getInputStream.readAllBytes(), "UTF-8")} + |Error: ${new String( + process.getErrorStream.readAllBytes(), + "UTF-8")}""".stripMargin) + } + case _ => throw new IllegalArgumentException("Invalid UpdateReference type") + } + } + + override def stopPipeline(pipeline: PipelineReference): Unit = { + pipeline match { + case ref: PipelineReferenceImpl => + val process = ref.executionProcess + if (process.isAlive) { + process.destroy() + logInfo("Pipeline update process has been stopped") + } else { + logInfo("Pipeline update process was not running") + } + case _ => throw new IllegalArgumentException("Invalid UpdateReference type") + } + } + + private def writePipelineSpecFile(spec: TestPipelineSpec): Path = { + val libraries = spec.include + .map { includePattern => + s""" - glob: + | include: "$includePattern" + |""".stripMargin + } + .mkString("\n") + + val pipelineSpec = s""" + |name: test-pipeline + |${spec.catalog.map(catalog => s"""catalog: "$catalog"""").getOrElse("")} + |${spec.database.map(database => s"""database: "$database"""").getOrElse("")} + |configuration: + | "spark.remote": "sc://localhost:$serverPort" + |libraries: + |$libraries + |""".stripMargin + logInfo(""" + |Generated pipeline spec: + | + |$pipelineSpec + |""".stripMargin) + val specFilePath = projectDir.resolve("pipeline.yaml") + Files.write(specFilePath, pipelineSpec.getBytes("UTF-8")) + logDebug(s"Created pipeline spec: ${specFilePath.toAbsolutePath}") + specFilePath + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala new file mode 100644 index 000000000000..211deacd9830 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala @@ -0,0 +1,599 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.utils + +import scala.concurrent.duration._ +import scala.concurrent.duration.Duration + +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +// scalastyle:off +import org.scalatest.funsuite.AnyFunSuite +// scalastyle:on +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.sql.QueryTest.checkAnswer +import org.apache.spark.sql.Row +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * Representation of a pipeline specification + * @param catalog + * the catalog to publish data from the pipeline + * @param database + * the database to publish data from the pipeline + * @param include + * the list of source files to include in the pipeline spec + */ +case class TestPipelineSpec( + catalog: Option[String] = None, + database: Option[String] = None, + include: Seq[String]) + +/** + * Available configurations for running a test pipeline. + * + * @param pipelineSpec + * the pipeline specification to use Below are CLI options that affect execution, default is to + * update all datasets incrementally + * @param dryRun + * if true, the pipeline will be validated but not executed + * @param fullRefreshAll + * if true, perform a full graph reset and recompute + * @param fullRefreshSelection + * if non-empty, only reset and recompute the subset + * @param refreshSelection + * if non-empty, only update the specified subset of datasets + */ +case class TestPipelineConfiguration( + pipelineSpec: TestPipelineSpec, + dryRun: Boolean = false, + fullRefreshAll: Boolean = false, + fullRefreshSelection: Seq[String] = Seq.empty, + refreshSelection: Seq[String] = Seq.empty) + +/** + * Logical representation of a source file to be included in the pipeline spec. + */ +case class PipelineSourceFile(name: String, contents: String) + +/** + * Extendable traits for PipelineReference and UpdateReference to allow different level of + * implementations which stores pipeline execution and update execution specific information. + */ +trait PipelineReference {} + +trait APITest + extends AnyFunSuite // scalastyle:ignore funsuite + with BeforeAndAfterAll + with BeforeAndAfterEach + with Matchers { + + protected def spark: SparkSession + + def createAndRunPipeline( + config: TestPipelineConfiguration, + sources: Seq[PipelineSourceFile]): PipelineReference + def awaitPipelineTermination(pipeline: PipelineReference, timeout: Duration = 60.seconds): Unit + def stopPipeline(pipeline: PipelineReference): Unit + + /* SQL Language Tests */ + test("SQL Pipeline with mv, st, and flows") { + val pipelineSpec = + TestPipelineSpec(include = Seq("mv.sql", "st.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "st.sql", + contents = s""" + |CREATE STREAMING TABLE st; + |CREATE FLOW f AS INSERT INTO st BY NAME SELECT * FROM STREAM mv WHERE id > 2; + |""".stripMargin), + PipelineSourceFile( + name = "mv.sql", + contents = s""" + |CREATE MATERIALIZED VIEW mv + |AS SELECT * FROM RANGE(5); + |""".stripMargin)) + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + + checkAnswer(spark.sql(s"SELECT * FROM st"), Seq(Row(3), Row(4))) + } + + test("SQL Pipeline with CTE") { + val pipelineSpec = + TestPipelineSpec(include = Seq("*.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE MATERIALIZED VIEW a AS SELECT 1; + |CREATE MATERIALIZED VIEW d AS + |WITH c AS ( + | WITH b AS ( + | SELECT * FROM a + | ) + | SELECT * FROM b + |) + |SELECT * FROM c; + |""".stripMargin)) + + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + + checkAnswer(spark.sql(s"SELECT * FROM d"), Seq(Row(1))) + } + + test("SQL Pipeline with subquery") { + val pipelineSpec = + TestPipelineSpec(include = Seq("definition.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE MATERIALIZED VIEW a AS SELECT * FROM RANGE(5); + |CREATE MATERIALIZED VIEW b AS SELECT * FROM RANGE(5) + |WHERE id = (SELECT max(id) FROM a); + |""".stripMargin)) + + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + checkAnswer(spark.sql(s"SELECT * FROM b"), Seq(Row(4))) + } + + test("SQL Pipeline with join") { + val pipelineSpec = + TestPipelineSpec(include = Seq("definition.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE TEMPORARY VIEW a AS SELECT id FROM range(1,3); + |CREATE TEMPORARY VIEW b AS SELECT id FROM range(1,3); + |CREATE MATERIALIZED VIEW c AS SELECT a.id AS id1, b.id AS id2 + |FROM a JOIN b ON a.id=b.id + |""".stripMargin)) + + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + checkAnswer(spark.sql(s"SELECT * FROM c"), Seq(Row(1, 1), Row(2, 2))) + } + + test("SQL Pipeline with aggregation") { + val pipelineSpec = + TestPipelineSpec(include = Seq("definition.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE MATERIALIZED VIEW a AS SELECT id AS value, (id % 2) AS isOdd FROM range(1,10); + |CREATE MATERIALIZED VIEW b AS SELECT isOdd, max(value) AS + |maximum FROM a GROUP BY isOdd LIMIT 2; + |""".stripMargin)) + + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + checkAnswer(spark.sql(s"SELECT * FROM b"), Seq(Row(0, 8), Row(1, 9))) + } + + test("SQL Pipeline with table properties") { + val pipelineSpec = + TestPipelineSpec(include = Seq("definition.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE MATERIALIZED VIEW mv TBLPROPERTIES ('prop1'='foo1', 'prop2'='bar2') + |AS SELECT 1; + |CREATE STREAMING TABLE st TBLPROPERTIES ('prop3'='foo3', 'prop4'='bar4') + |AS SELECT * FROM STREAM(mv); + |""".stripMargin)) + + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + // verify table properties + val mv = spark.sessionState.catalog + .getTableMetadata(TableIdentifier("mv")) + assert(mv.properties.get("prop1").contains("foo1")) + assert(mv.properties.get("prop2").contains("bar2")) + + val st = spark.sessionState.catalog + .getTableMetadata(TableIdentifier("st")) + assert(st.properties.get("prop3").contains("foo3")) + assert(st.properties.get("prop4").contains("bar4")) + } + + test("SQL Pipeline with schema") { + val pipelineSpec = + TestPipelineSpec(include = Seq("definition.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE MATERIALIZED VIEW a (id LONG COMMENT 'comment') AS SELECT * FROM RANGE(5); + |CREATE STREAMING TABLE b (id LONG COMMENT 'comment') AS SELECT * FROM STREAM a; + |""".stripMargin)) + + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + + val a = spark.sessionState.catalog + .getTableMetadata(TableIdentifier("a")) + assert(a.schema.fields.length == 1) + assert(a.schema.fields(0).name == "id") + + val b = spark.sessionState.catalog + .getTableMetadata(TableIdentifier("b")) + assert(b.schema.fields.length == 1) + assert(b.schema.fields(0).name == "id") + } + + /* Mixed Language Tests */ + test("Pipeline with Python and SQL") { + val pipelineSpec = + TestPipelineSpec(include = Seq("definition.sql", "definition.py")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE STREAMING TABLE c; + |CREATE MATERIALIZED VIEW a AS SELECT * FROM RANGE(5); + |""".stripMargin), + PipelineSourceFile( + name = "definition.py", + contents = """ + |from pyspark import pipelines as dp + |from pyspark.sql import DataFrame, SparkSession + | + |spark = SparkSession.active() + | + |@dp.append_flow(target = "c", name = "append_to_c") + |def flow(): + | return spark.readStream.table("b").filter("id >= 3") + | + |@dp.materialized_view + |def b(): + | return spark.read.table("a") + |""".stripMargin)) + + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + checkAnswer(spark.sql(s"SELECT * FROM b"), Seq(Row(0), Row(1), Row(2), Row(3), Row(4))) + checkAnswer(spark.sql(s"SELECT * FROM c"), Seq(Row(3), Row(4))) + } + + test("Pipeline referencing internal datasets") { + val pipelineSpec = + TestPipelineSpec(include = Seq("mv.py", "st.py", "definition.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "mv.py", + contents = """ + |from pyspark import pipelines as dp + |from pyspark.sql import DataFrame, SparkSession + | + |spark = SparkSession.active() + | + |@dp.materialized_view + |def src(): + | return spark.range(5) + |""".stripMargin), + PipelineSourceFile( + name = "st.py", + contents = """ + |from pyspark import pipelines as dp + |from pyspark.sql import DataFrame, SparkSession + | + |spark = SparkSession.active() + | + |@dp.materialized_view + |def a(): + | return spark.read.table("src") + | + |@dp.table + |def b(): + | return spark.readStream.table("src") + |""".stripMargin), + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE STREAMING TABLE c; + |CREATE FLOW f AS INSERT INTO c BY NAME SELECT * FROM STREAM b WHERE id > 2; + |""".stripMargin)) + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + + checkAnswer(spark.sql(s"SELECT * FROM a"), Seq(Row(0), Row(1), Row(2), Row(3), Row(4))) + checkAnswer(spark.sql(s"SELECT * FROM b"), Seq(Row(0), Row(1), Row(2), Row(3), Row(4))) + checkAnswer(spark.sql(s"SELECT * FROM c"), Seq(Row(3), Row(4))) + } + + test("Pipeline referencing external datasets") { + val pipelineSpec = + TestPipelineSpec(include = Seq("definition.py", "definition.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + spark.sql( + s"CREATE TABLE src " + + s"AS SELECT * FROM RANGE(5)") + val sources = Seq( + PipelineSourceFile( + name = "definition.py", + contents = """ + |from pyspark import pipelines as dp + |from pyspark.sql import DataFrame, SparkSession + | + |spark = SparkSession.active() + | + |@dp.materialized_view + |def a(): + | return spark.read.table("src") + | + |@dp.table + |def b(): + | return spark.readStream.table("src") + |""".stripMargin), + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE STREAMING TABLE c; + |CREATE FLOW f AS INSERT INTO c BY NAME SELECT * FROM STREAM b WHERE id > 2; + |""".stripMargin)) + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + + checkAnswer(spark.sql(s"SELECT * FROM a"), Seq(Row(0), Row(1), Row(2), Row(3), Row(4))) + checkAnswer(spark.sql(s"SELECT * FROM b"), Seq(Row(0), Row(1), Row(2), Row(3), Row(4))) + checkAnswer(spark.sql(s"SELECT * FROM c"), Seq(Row(3), Row(4))) + } + + /* Python Language Tests */ + test("Python Pipeline with materialized_view, create_streaming_table, and append_flow") { + val pipelineSpec = + TestPipelineSpec(include = Seq("st.py", "mv.py")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "st.py", + contents = s""" + |from pyspark import pipelines as dp + |from pyspark.sql import DataFrame, SparkSession + | + |spark = SparkSession.active() + | + |dp.create_streaming_table( + | name = "a", + | schema = "id LONG", + | comment = "streaming table a", + |) + | + |@dp.append_flow(target = "a", name = "append_to_a") + |def flow(): + | return spark.readStream.table("src") + |""".stripMargin), + PipelineSourceFile( + name = "mv.py", + contents = s""" + |from pyspark import pipelines as dp + |from pyspark.sql import DataFrame, SparkSession + | + |spark = SparkSession.active() + | + |@dp.materialized_view( + | name = "src", + | comment = "source table", + | schema = "id LONG" + |) + |def irrelevant(): + | return spark.range(5) + |""".stripMargin)) + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + + checkAnswer(spark.sql(s"SELECT * FROM a"), Seq(Row(0), Row(1), Row(2), Row(3), Row(4))) + } + + test("Python Pipeline with temporary_view") { + val pipelineSpec = + TestPipelineSpec(include = Seq("definition.py")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + spark.sql( + s"CREATE TABLE src " + + s"AS SELECT * FROM RANGE(5)") + val sources = Seq( + PipelineSourceFile( + name = "definition.py", + contents = """ + |from pyspark import pipelines as dp + |from pyspark.sql import DataFrame, SparkSession + | + |spark = SparkSession.active() + | + |@dp.temporary_view( + | name = "view_1", + | comment = "temporary view 1" + |) + |def irrelevant(): + | return spark.range(5) + | + |@dp.materialized_view( + | name = "mv_1", + |) + |def irrelevant_1(): + | return spark.read.table("view_1") + |""".stripMargin)) + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + + // query the mv that depends on the temporary view + checkAnswer(spark.sql(s"SELECT * FROM mv_1"), Seq(Row(0), Row(1), Row(2), Row(3), Row(4))) + } + + test("Python Pipeline with partition columns") { + val pipelineSpec = + TestPipelineSpec(include = Seq("*.py")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + val sources = Seq( + PipelineSourceFile( + name = "definition.py", + contents = """ + |from pyspark import pipelines as dp + |from pyspark.sql import DataFrame, SparkSession + |from pyspark.sql.functions import col + | + |spark = SparkSession.active() + | + |@dp.materialized_view(partition_cols = ["id_mod"]) + |def mv(): + | return spark.range(5).withColumn("id_mod", col("id") % 2) + | + |@dp.table(partition_cols = ["id_mod"]) + |def st(): + | return spark.readStream.table("mv") + |""".stripMargin)) + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + + Seq("mv", "st").foreach { tbl => + val fullName = s"$tbl" + checkAnswer( + spark.sql(s"SELECT * FROM $fullName"), + Seq(Row(0, 0), Row(1, 1), Row(2, 0), Row(3, 1), Row(4, 0))) + } + } + + /* Below tests pipeline execution configurations */ + + test("Pipeline with dry run") { + val pipelineSpec = + TestPipelineSpec(include = Seq("definition.sql")) + val pipelineConfig = TestPipelineConfiguration(pipelineSpec, dryRun = true) + val sources = Seq( + PipelineSourceFile( + name = "definition.sql", + contents = """ + |CREATE MATERIALIZED VIEW a AS SELECT * FROM RANGE(5); + |CREATE MATERIALIZED VIEW b AS SELECT * FROM a WHERE id > 2; + |""".stripMargin)) + + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + // ensure the table did not get created in dry run mode + assert(!spark.catalog.tableExists(s"a"), "Table a should not exist in dry run mode") + assert(!spark.catalog.tableExists(s"b"), "Table b should not exist in dry run mode") + } + + Seq( + SelectiveRefreshTestCase( + name = "Pipeline with refresh all by default", + fullRefreshAll = false, + // all tables retain old + new data + expectedA = Seq(Row(1), Row(2), Row(3)), + expectedB = Seq(Row(1), Row(2), Row(3)), + expectedMV = Seq(Row(1), Row(2), Row(3))), + SelectiveRefreshTestCase( + name = "Pipeline with full refresh all", + fullRefreshAll = true, + expectedA = Seq(Row(2), Row(3)), + expectedB = Seq(Row(2), Row(3)), + expectedMV = Seq(Row(2), Row(3))), + SelectiveRefreshTestCase( + name = "Pipeline with selective full refresh and refresh", + fullRefreshAll = false, + refreshSelection = Seq("b"), + fullRefreshSelection = Seq("mv", "a"), + expectedA = Seq(Row(2), Row(3)), + expectedB = Seq(Row(1), Row(2), Row(3)), // b keeps old + new + expectedMV = Seq(Row(2), Row(3))), + SelectiveRefreshTestCase( + name = "Pipeline with selective full_refresh", + fullRefreshAll = false, + fullRefreshSelection = Seq("a"), + expectedA = Seq(Row(2), Row(3)), + expectedB = Seq(Row(1)), // b not refreshed + expectedMV = Seq(Row(1)) // mv not refreshed + )).foreach(runSelectiveRefreshTest) + + private case class SelectiveRefreshTestCase( + name: String, + fullRefreshAll: Boolean, + refreshSelection: Seq[String] = Seq.empty, + fullRefreshSelection: Seq[String] = Seq.empty, + expectedA: Seq[Row], + expectedB: Seq[Row], + expectedMV: Seq[Row]) + + private def runSelectiveRefreshTest(tc: SelectiveRefreshTestCase): Unit = { + test(tc.name) { + val pipelineSpec = TestPipelineSpec(include = Seq("st.sql", "mv.sql")) + val externalTable = s"source_data" + // create initial source table + spark.sql(s"DROP TABLE IF EXISTS $externalTable") + spark.sql(s"CREATE TABLE $externalTable AS SELECT * FROM RANGE(1, 2)") + + val sources = Seq( + PipelineSourceFile( + name = "st.sql", + contents = s""" + |CREATE STREAMING TABLE a AS SELECT * FROM STREAM $externalTable; + |CREATE STREAMING TABLE b AS SELECT * FROM STREAM $externalTable; + |""".stripMargin), + PipelineSourceFile( + name = "mv.sql", + contents = """ + |CREATE MATERIALIZED VIEW mv AS SELECT * FROM a; + |""".stripMargin)) + + val pipelineConfig = TestPipelineConfiguration(pipelineSpec) + + // run pipeline with possible refresh/full refresh + val pipeline = createAndRunPipeline(pipelineConfig, sources) + awaitPipelineTermination(pipeline) + + // Replace source data to simulate a streaming update + spark.sql( + s"INSERT OVERWRITE TABLE $externalTable " + + "SELECT * FROM VALUES (2), (3) AS t(id)") + + val refreshConfig = pipelineConfig.copy( + refreshSelection = tc.refreshSelection, + fullRefreshSelection = tc.fullRefreshSelection, + fullRefreshAll = tc.fullRefreshAll) + val secondUpdate = createAndRunPipeline(refreshConfig, sources) + awaitPipelineTermination(secondUpdate) + + // clear caches to force reload + Seq("a", "b", "mv").foreach { t => + spark.catalog.refreshTable(s"$t") + } + + // verify results + checkAnswer(spark.sql(s"SELECT * FROM a"), tc.expectedA) + checkAnswer(spark.sql(s"SELECT * FROM b"), tc.expectedB) + checkAnswer(spark.sql(s"SELECT * FROM mv"), tc.expectedMV) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org