sryza commented on code in PR #51507:
URL: https://github.com/apache/spark/pull/51507#discussion_r2211609713


##########
python/pyspark/pipelines/cli.py:
##########
@@ -217,8 +217,30 @@ def change_dir(path: Path) -> Generator[None, None, None]:
         os.chdir(prev)
 
 
-def run(spec_path: Path) -> None:
-    """Run the pipeline defined with the given spec."""
+def run(
+    spec_path: Path,
+    full_refresh: Optional[Sequence[str]] = None,
+    full_refresh_all: bool = False,
+    refresh: Optional[Sequence[str]] = None,
+) -> None:
+    """Run the pipeline defined with the given spec.
+
+    :param spec_path: Path to the pipeline specification file.
+    :param full_refresh: List of datasets to reset and recompute.
+    :param full_refresh_all: Perform a full graph reset and recompute.
+    :param refresh: List of datasets to update.
+    """
+    # Validate conflicting arguments
+    if full_refresh_all:
+        if full_refresh:

Review Comment:
   Should these be consolidated into `if full_refresh or refresh`?



##########
python/pyspark/pipelines/spark_connect_pipeline.py:
##########
@@ -65,12 +65,26 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) 
-> None:
             log_with_provided_timestamp(event.message, dt)
 
 
-def start_run(spark: SparkSession, dataflow_graph_id: str) -> 
Iterator[Dict[str, Any]]:
+def start_run(
+    spark: SparkSession,
+    dataflow_graph_id: str,
+    full_refresh: Optional[Sequence[str]] = None,

Review Comment:
   Same comment as above about default args.



##########
python/pyspark/pipelines/cli.py:
##########
@@ -217,8 +217,30 @@ def change_dir(path: Path) -> Generator[None, None, None]:
         os.chdir(prev)
 
 
-def run(spec_path: Path) -> None:
-    """Run the pipeline defined with the given spec."""
+def run(
+    spec_path: Path,
+    full_refresh: Optional[Sequence[str]] = None,

Review Comment:
   This never gets invoked without these parameters specified, so no need to 
have default values for them.



##########
sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala:
##########
@@ -0,0 +1,367 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.pipelines
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.connect.proto.{DatasetType, PipelineCommand, 
PipelineEvent}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService}
+import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, 
TestPipelineUpdateContextMixin}
+
+/**
+ * Comprehensive test suite that validates pipeline refresh functionality by 
running actual
+ * pipelines with different refresh parameters and validating the results.
+ */
+class PipelineRefreshFunctionalSuite
+    extends SparkDeclarativePipelinesServerTest
+    with TestPipelineUpdateContextMixin
+    with EventVerificationTestHelpers {
+
+  private val externalSourceTable = TableIdentifier(
+    catalog = Some("spark_catalog"),
+    database = Some("default"),
+    table = "source_data")
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    // Create source table to simulate streaming updates
+    spark.sql(s"CREATE TABLE $externalSourceTable AS SELECT * FROM RANGE(1, 
2)")
+  }
+
+  override def afterEach(): Unit = {
+    super.afterEach()
+    // Clean up the source table after each test
+    spark.sql(s"DROP TABLE IF EXISTS $externalSourceTable")
+  }
+
+  private def createTestPipeline(graphId: String): TestPipelineDefinition = {
+    new TestPipelineDefinition(graphId) {
+      // Create tables that depend on the mv
+      createTable(
+        name = "a",
+        datasetType = DatasetType.TABLE,
+        sql = Some(s"SELECT id FROM STREAM $externalSourceTable"))
+      createTable(
+        name = "b",
+        datasetType = DatasetType.TABLE,
+        sql = Some(s"SELECT id FROM STREAM $externalSourceTable"))
+      createTable(
+        name = "mv",
+        datasetType = DatasetType.MATERIALIZED_VIEW,
+        sql = Some(s"SELECT id FROM a"))
+    }
+  }
+
+  /**
+   * Helper method to run refresh tests with common setup and verification 
logic. This reduces
+   * code duplication across the refresh test cases.
+   */
+  private def runRefreshTest(
+      refreshConfigBuilder: String => Option[PipelineCommand.StartRun] = _ => 
None,
+      expectedContentAfterRefresh: Map[String, Set[Map[String, Any]]],
+      eventValidation: Option[ArrayBuffer[PipelineEvent] => Unit] = None): 
Unit = {
+    withRawBlockingStub { implicit stub =>
+      val graphId = createDataflowGraph
+      val pipeline = createTestPipeline(graphId)
+      registerPipelineDatasets(pipeline)
+
+      // First run to populate tables
+      startPipelineAndWaitForCompletion(graphId)
+
+      // Verify initial data - all tests expect the same initial state
+      verifyMultipleTableContent(
+        tableNames =
+          Set("spark_catalog.default.a", "spark_catalog.default.b", 
"spark_catalog.default.mv"),
+        columnsToVerify = Map(
+          "spark_catalog.default.a" -> Seq("id"),
+          "spark_catalog.default.b" -> Seq("id"),
+          "spark_catalog.default.mv" -> Seq("id")),
+        expectedContent = Map(
+          "spark_catalog.default.a" -> Set(Map("id" -> 1)),
+          "spark_catalog.default.b" -> Set(Map("id" -> 1)),
+          "spark_catalog.default.mv" -> Set(Map("id" -> 1))))
+
+      // Clear cached pipeline execution before starting new run
+      SparkConnectService.sessionManager
+        .getIsolatedSessionIfPresent(SessionKey(defaultUserId, 
defaultSessionId))
+        .foreach(_.removeAllPipelineExecutions())
+
+      // Replace source data to simulate a streaming update
+      spark.sql(
+        "INSERT OVERWRITE TABLE spark_catalog.default.source_data " +
+          "SELECT * FROM VALUES (2), (3) AS t(id)")
+
+      // Run with specified refresh configuration
+      val capturedEvents = refreshConfigBuilder(graphId) match {
+        case Some(startRun) => startPipelineAndWaitForCompletion(graphId, 
Some(startRun))
+        case None => startPipelineAndWaitForCompletion(graphId)
+      }
+
+      // Additional validation if provided
+      eventValidation.foreach(_(capturedEvents))
+
+      // Verify final content
+      verifyMultipleTableContent(
+        tableNames =
+          Set("spark_catalog.default.a", "spark_catalog.default.b", 
"spark_catalog.default.mv"),
+        columnsToVerify = Map(
+          "spark_catalog.default.a" -> Seq("id"),
+          "spark_catalog.default.b" -> Seq("id"),
+          "spark_catalog.default.mv" -> Seq("id")),
+        expectedContent = expectedContentAfterRefresh)
+    }
+  }
+
+  test("pipeline runs selective full_refresh") {
+    runRefreshTest(
+      refreshConfigBuilder = { graphId =>
+        Some(
+          PipelineCommand.StartRun
+            .newBuilder()
+            .setDataflowGraphId(graphId)
+            .addAllFullRefresh(List("a").asJava)
+            .build())
+      },
+      expectedContentAfterRefresh = Map(
+        "spark_catalog.default.a" -> Set(
+          Map("id" -> 2), // a is fully refreshed and only contains the new 
values
+          Map("id" -> 3)),
+        "spark_catalog.default.b" -> Set(
+          Map("id" -> 1) // b is not refreshed, so it retains the old value
+        ),
+        "spark_catalog.default.mv" -> Set(
+          Map("id" -> 1) // mv is not refreshed, so it retains the old value
+        )),
+      eventValidation = Some { capturedEvents =>
+        // assert that table_b is excluded
+        assert(
+          capturedEvents.exists(
+            _.getMessage.contains(s"Flow \'spark_catalog.default.b\' is 
EXCLUDED.")))
+        // assert that table_a ran to completion
+        assert(
+          capturedEvents.exists(
+            _.getMessage.contains(s"Flow spark_catalog.default.a has 
COMPLETED.")))
+        // assert that mv is excluded
+        assert(
+          capturedEvents.exists(
+            _.getMessage.contains(s"Flow \'spark_catalog.default.mv\' is 
EXCLUDED.")))
+        // Verify completion event
+        assert(capturedEvents.exists(_.getMessage.contains("Run is 
COMPLETED")))
+      })
+  }
+
+  test("pipeline runs selective full_refresh and selective refresh") {
+    runRefreshTest(
+      refreshConfigBuilder = { graphId =>
+        Some(
+          PipelineCommand.StartRun
+            .newBuilder()
+            .setDataflowGraphId(graphId)
+            .addAllFullRefresh(Seq("a", "mv").asJava)
+            .addRefresh("b")
+            .build())
+      },
+      expectedContentAfterRefresh = Map(
+        "spark_catalog.default.a" -> Set(
+          Map("id" -> 2), // a is fully refreshed and only contains the new 
values
+          Map("id" -> 3)),
+        "spark_catalog.default.b" -> Set(
+          Map("id" -> 1), // b is refreshed, so it retains the old value and 
adds the new ones
+          Map("id" -> 2),
+          Map("id" -> 3)),
+        "spark_catalog.default.mv" -> Set(
+          Map("id" -> 2), // mv is fully refreshed and only contains the new 
values
+          Map("id" -> 3))))
+  }
+
+  test("pipeline runs refresh by default") {
+    runRefreshTest(expectedContentAfterRefresh =
+      Map(
+        "spark_catalog.default.a" -> Set(
+          Map(
+            "id" -> 1
+          ), // a is refreshed by default, retains the old value and adds the 
new ones
+          Map("id" -> 2),
+          Map("id" -> 3)),
+        "spark_catalog.default.b" -> Set(
+          Map(
+            "id" -> 1
+          ), // b is refreshed by default, retains the old value and adds the 
new ones
+          Map("id" -> 2),
+          Map("id" -> 3)),
+        "spark_catalog.default.mv" -> Set(
+          Map("id" -> 1),
+          Map("id" -> 2), // mv is refreshed from table a, retains all values
+          Map("id" -> 3))))
+  }
+
+  test("pipeline runs full refresh all") {
+    runRefreshTest(
+      refreshConfigBuilder = { graphId =>
+        Some(
+          PipelineCommand.StartRun
+            .newBuilder()
+            .setDataflowGraphId(graphId)
+            .setFullRefreshAll(true)
+            .build())
+      },
+      // full refresh all causes all tables to lose the initial value
+      // and only contain the new values after the source data is updated
+      expectedContentAfterRefresh = Map(
+        "spark_catalog.default.a" -> Set(Map("id" -> 2), Map("id" -> 3)),
+        "spark_catalog.default.b" -> Set(Map("id" -> 2), Map("id" -> 3)),
+        "spark_catalog.default.mv" -> Set(Map("id" -> 2), Map("id" -> 3))))
+  }
+
+  test("validation: cannot specify subset refresh when full_refresh_all is 
true") {
+    withRawBlockingStub { implicit stub =>
+      val graphId = createDataflowGraph
+      val pipeline = createTestPipeline(graphId)
+      registerPipelineDatasets(pipeline)
+
+      val startRun = PipelineCommand.StartRun
+        .newBuilder()
+        .setDataflowGraphId(graphId)
+        .setFullRefreshAll(true)
+        .addRefresh("a")
+        .build()
+
+      val exception = intercept[IllegalArgumentException] {
+        startPipelineAndWaitForCompletion(graphId, Some(startRun))
+      }
+      assert(
+        exception.getMessage.contains(
+          "Cannot specify a subset to full refresh when full refresh all is 
set to true"))
+    }
+  }
+
+  test("validation: cannot specify subset full_refresh when full_refresh_all 
is true") {
+    withRawBlockingStub { implicit stub =>
+      val graphId = createDataflowGraph
+      val pipeline = createTestPipeline(graphId)
+      registerPipelineDatasets(pipeline)
+
+      val startRun = PipelineCommand.StartRun
+        .newBuilder()
+        .setDataflowGraphId(graphId)
+        .setFullRefreshAll(true)
+        .addFullRefresh("a")
+        .build()
+
+      val exception = intercept[IllegalArgumentException] {
+        startPipelineAndWaitForCompletion(graphId, Some(startRun))
+      }
+      assert(
+        exception.getMessage.contains(
+          "Cannot specify a subset to refresh when full refresh all is set to 
true"))
+    }
+  }
+
+  test("validation: refresh and full_refresh cannot overlap") {
+    withRawBlockingStub { implicit stub =>
+      val graphId = createDataflowGraph
+      val pipeline = createTestPipeline(graphId)
+      registerPipelineDatasets(pipeline)
+
+      val startRun = PipelineCommand.StartRun
+        .newBuilder()
+        .setDataflowGraphId(graphId)
+        .addRefresh("a")
+        .addFullRefresh("a")
+        .build()
+
+      val exception = intercept[IllegalArgumentException] {
+        startPipelineAndWaitForCompletion(graphId, Some(startRun))
+      }
+      assert(
+        exception.getMessage.contains(
+          "Datasets specified for refresh and full refresh cannot overlap"))
+      assert(exception.getMessage.contains("a"))
+    }
+  }
+
+  test("validation: multiple overlapping tables in refresh and full_refresh 
not allowed") {
+    withRawBlockingStub { implicit stub =>
+      val graphId = createDataflowGraph
+      val pipeline = createTestPipeline(graphId)
+      registerPipelineDatasets(pipeline)
+
+      val startRun = PipelineCommand.StartRun
+        .newBuilder()
+        .setDataflowGraphId(graphId)
+        .addRefresh("a")
+        .addRefresh("b")
+        .addFullRefresh("a")
+        .build()
+
+      val exception = intercept[IllegalArgumentException] {
+        startPipelineAndWaitForCompletion(graphId, Some(startRun))
+      }
+      assert(
+        exception.getMessage.contains(
+          "Datasets specified for refresh and full refresh cannot overlap"))
+      assert(exception.getMessage.contains("a"))
+    }
+  }
+
+  test("validation: fully qualified table names in validation") {
+    withRawBlockingStub { implicit stub =>
+      val graphId = createDataflowGraph
+      val pipeline = createTestPipeline(graphId)
+      registerPipelineDatasets(pipeline)
+
+      val startRun = PipelineCommand.StartRun
+        .newBuilder()
+        .setDataflowGraphId(graphId)
+        .addRefresh("spark_catalog.default.a")
+        .addFullRefresh("a") // This should be treated as the same table
+        .build()
+
+      val exception = intercept[IllegalArgumentException] {
+        startPipelineAndWaitForCompletion(graphId, Some(startRun))
+      }
+      assert(
+        exception.getMessage.contains(
+          "Datasets specified for refresh and full refresh cannot overlap"))
+    }
+  }
+
+  private def verifyMultipleTableContent(

Review Comment:
   Are we handling each table independently? If so, would it make sense to 
expose a function that operates on a single table and `foreach` over it where 
it's invoked?



##########
sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala:
##########
@@ -125,15 +128,25 @@ class SparkDeclarativePipelinesServerTest extends 
SparkConnectServerTest {
   def createPlanner(): SparkConnectPlanner =
     new 
SparkConnectPlanner(SparkConnectTestUtils.createDummySessionHolder(spark))
 
-  def startPipelineAndWaitForCompletion(graphId: String): Unit = {
+  def startPipelineAndWaitForCompletion(
+      graphId: String,
+      customStartRunCommand: Option[PipelineCommand.StartRun] = None)

Review Comment:
   Supporting these two arguments at the same time allows an undefined 
situation where the custom `StartRun` command has a different ID than the value 
supplied for the `graphId` argument.
   
   Could we instead factor it like this:
   
   ```scala
   def startPipelineAndWaitForCompletion(startRunCommand: 
PipelineCommand.StartRun): Unit = ...
   
   def startPipelineAndWaitForCompletion(graphId: String): Unit = {
     
startPipelineAndWaitForCompletion(PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build())
   }
   ```



##########
python/pyspark/pipelines/tests/test_cli.py:
##########
@@ -359,6 +360,139 @@ def test_python_import_current_directory(self):
                 )
 
 
[email protected](
+    not should_test_connect or not have_yaml,
+    connect_requirement_message or yaml_requirement_message,
+)
+class CLIValidationTests(unittest.TestCase):

Review Comment:
   Is there a meaningful difference between the kinds of tests that are 
included in this class and the kinds of tests that included in the other class 
in this file?



##########
python/pyspark/pipelines/cli.py:
##########
@@ -28,7 +28,7 @@
 import yaml
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Any, Generator, Mapping, Optional, Sequence
+from typing import Any, Generator, Mapping, Optional, Sequence, List

Review Comment:
   Out of alphabetical order: you may need to run `dev/reformat-python` to 
format this.



##########
python/pyspark/pipelines/cli.py:
##########
@@ -217,8 +217,30 @@ def change_dir(path: Path) -> Generator[None, None, None]:
         os.chdir(prev)
 
 
-def run(spec_path: Path) -> None:
-    """Run the pipeline defined with the given spec."""
+def run(

Review Comment:
   If we expect it to vary across run for the same pipeline, it should be a CLI 
arg. If we expect it to be static for a pipeline, it should live in the spec. I 
would expect selections to vary across runs.



##########
sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala:
##########
@@ -0,0 +1,367 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.pipelines
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.connect.proto.{DatasetType, PipelineCommand, 
PipelineEvent}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService}
+import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, 
TestPipelineUpdateContextMixin}
+
+/**
+ * Comprehensive test suite that validates pipeline refresh functionality by 
running actual
+ * pipelines with different refresh parameters and validating the results.
+ */
+class PipelineRefreshFunctionalSuite
+    extends SparkDeclarativePipelinesServerTest
+    with TestPipelineUpdateContextMixin
+    with EventVerificationTestHelpers {
+
+  private val externalSourceTable = TableIdentifier(
+    catalog = Some("spark_catalog"),
+    database = Some("default"),
+    table = "source_data")
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    // Create source table to simulate streaming updates
+    spark.sql(s"CREATE TABLE $externalSourceTable AS SELECT * FROM RANGE(1, 
2)")
+  }
+
+  override def afterEach(): Unit = {
+    super.afterEach()
+    // Clean up the source table after each test
+    spark.sql(s"DROP TABLE IF EXISTS $externalSourceTable")
+  }
+
+  private def createTestPipeline(graphId: String): TestPipelineDefinition = {
+    new TestPipelineDefinition(graphId) {
+      // Create tables that depend on the mv
+      createTable(
+        name = "a",
+        datasetType = DatasetType.TABLE,
+        sql = Some(s"SELECT id FROM STREAM $externalSourceTable"))
+      createTable(
+        name = "b",
+        datasetType = DatasetType.TABLE,
+        sql = Some(s"SELECT id FROM STREAM $externalSourceTable"))
+      createTable(
+        name = "mv",
+        datasetType = DatasetType.MATERIALIZED_VIEW,
+        sql = Some(s"SELECT id FROM a"))
+    }
+  }
+
+  /**
+   * Helper method to run refresh tests with common setup and verification 
logic. This reduces
+   * code duplication across the refresh test cases.
+   */
+  private def runRefreshTest(
+      refreshConfigBuilder: String => Option[PipelineCommand.StartRun] = _ => 
None,
+      expectedContentAfterRefresh: Map[String, Set[Map[String, Any]]],
+      eventValidation: Option[ArrayBuffer[PipelineEvent] => Unit] = None): 
Unit = {
+    withRawBlockingStub { implicit stub =>
+      val graphId = createDataflowGraph
+      val pipeline = createTestPipeline(graphId)
+      registerPipelineDatasets(pipeline)
+
+      // First run to populate tables
+      startPipelineAndWaitForCompletion(graphId)
+
+      // Verify initial data - all tests expect the same initial state
+      verifyMultipleTableContent(
+        tableNames =
+          Set("spark_catalog.default.a", "spark_catalog.default.b", 
"spark_catalog.default.mv"),
+        columnsToVerify = Map(
+          "spark_catalog.default.a" -> Seq("id"),
+          "spark_catalog.default.b" -> Seq("id"),
+          "spark_catalog.default.mv" -> Seq("id")),
+        expectedContent = Map(
+          "spark_catalog.default.a" -> Set(Map("id" -> 1)),
+          "spark_catalog.default.b" -> Set(Map("id" -> 1)),
+          "spark_catalog.default.mv" -> Set(Map("id" -> 1))))
+
+      // Clear cached pipeline execution before starting new run
+      SparkConnectService.sessionManager
+        .getIsolatedSessionIfPresent(SessionKey(defaultUserId, 
defaultSessionId))
+        .foreach(_.removeAllPipelineExecutions())
+
+      // Replace source data to simulate a streaming update
+      spark.sql(
+        "INSERT OVERWRITE TABLE spark_catalog.default.source_data " +
+          "SELECT * FROM VALUES (2), (3) AS t(id)")
+
+      // Run with specified refresh configuration
+      val capturedEvents = refreshConfigBuilder(graphId) match {
+        case Some(startRun) => startPipelineAndWaitForCompletion(graphId, 
Some(startRun))
+        case None => startPipelineAndWaitForCompletion(graphId)
+      }
+
+      // Additional validation if provided
+      eventValidation.foreach(_(capturedEvents))
+
+      // Verify final content
+      verifyMultipleTableContent(

Review Comment:
   Would it make sense to use `QueryTest.checkAnswer` here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to