aakash-db commented on code in PR #51003:
URL: https://github.com/apache/spark/pull/51003#discussion_r2116701769


##########
sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala:
##########
@@ -0,0 +1,275 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.pipelines.graph.DataflowGraph.mapUnique
+import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils
+
+/** Validations performed on a [[DataflowGraph]]. */
+trait GraphValidations extends Logging {
+  this: DataflowGraph =>
+
+  /**
+   * Validate multi query table correctness. Exposed for Python unit testing, 
which currently cannot
+   * run anything which invokes the flow function as there's no persistent 
Python to run it.
+   *
+   * @return the multi-query tables by destination
+   */
+  protected[pipelines] def validateMultiQueryTables(): Map[TableIdentifier, 
Seq[Flow]] = {
+    val multiQueryTables = flowsTo.filter(_._2.size > 1)
+    // Non-streaming tables do not support multiflow.
+    multiQueryTables
+      .find {
+        case (dest, flows) =>
+          flows.exists(f => !resolvedFlow(f.identifier).df.isStreaming) &&
+          table.contains(dest)
+      }
+      .foreach {
+        case (dest, flows) =>
+          throw new AnalysisException(
+            "MATERIALIZED_VIEW_WITH_MULTIPLE_QUERIES",
+            Map(
+              "tableName" -> dest.unquotedString,
+              "queries" -> flows.map(_.identifier).mkString(",")
+            )
+          )
+      }
+
+    multiQueryTables
+  }
+
+  /** Throws an exception if the flows in this graph are not topologically 
sorted. */
+  protected[graph] def validateGraphIsTopologicallySorted(): Unit = {
+    val visitedNodes = mutable.Set.empty[TableIdentifier] // Set of visited 
nodes
+    val visitedEdges = mutable.Set.empty[TableIdentifier] // Set of visited 
edges
+    flows.foreach { f =>
+      // Unvisited inputs of the current flow
+      val unvisitedInputNodes =
+        resolvedFlow(f.identifier).inputs -- visitedNodes
+      unvisitedInputNodes.headOption match {
+        case None =>
+          visitedEdges.add(f.identifier)
+          if 
(flowsTo(f.destinationIdentifier).map(_.identifier).forall(visitedEdges.contains))
 {
+            // A node is marked visited if all its inputs are visited
+            visitedNodes.add(f.destinationIdentifier)
+          }
+        case Some(unvisitedInput) =>
+          throw new AnalysisException(
+            "PIPELINE_GRAPH_NOT_TOPOLOGICALLY_SORTED",
+            Map(
+              "flowName" -> f.identifier.unquotedString,
+              "inputName" -> unvisitedInput.unquotedString
+            )
+          )
+      }
+    }
+  }
+
+  /**
+   * Validate that all tables are resettable. This is a best-effort check that 
will only catch
+   * upstream tables that are resettable but have a non-resettable downstream 
dependency.
+   */
+  protected def validateTablesAreResettable(): Seq[GraphValidationWarning] = {
+    validateTablesAreResettable(tables)
+  }
+
+  /** Validate that all specified tables are resettable. */
+  protected def validateTablesAreResettable(tables: Seq[Table]): 
Seq[GraphValidationWarning] = {
+    val tableLookup = mapUnique(tables, "table")(_.identifier)
+    val nonResettableTables =
+      tables.filter(t => 
!PipelinesTableProperties.resetAllowed.fromMap(t.properties))
+    val upstreamResettableTables = 
upstreamDatasets(nonResettableTables.map(_.identifier))
+      .collect {
+        // Filter for upstream datasets that are tables with downstream 
streaming tables
+        case (upstreamDataset, nonResettableDownstreams) if 
table.contains(upstreamDataset) =>
+          nonResettableDownstreams
+            .filter(
+              t => flowsTo(t).exists(f => 
resolvedFlow(f.identifier).df.isStreaming)
+            )
+            .map(id => (tableLookup(upstreamDataset), 
tableLookup(id).displayName))
+      }
+      .flatten
+      .toSeq
+      .filter {
+        case (t, _) => 
PipelinesTableProperties.resetAllowed.fromMap(t.properties)
+      } // Filter for resettable
+
+    upstreamResettableTables
+      .groupBy(_._2) // Group-by non-resettable downstream tables
+      .view
+      .mapValues(_.map(_._1))
+      .toSeq
+      .sortBy(_._2.size) // Output errors from largest to smallest
+      .reverse
+      .map {
+        case (nameForEvent, tables) =>
+          InvalidResettableDependencyException(nameForEvent, tables)
+      }
+  }
+
+  /**
+   * Validate if we have any append only flows writing into a streaming table 
but was created
+   * from a batch query.
+   */
+  protected def validateAppendOnceFlows(): Seq[GraphValidationWarning] = {
+    flows
+      .filter {
+        case af: AppendOnceFlow => !af.definedAsOnce
+        case _ => false
+      }
+      .groupBy(_.destinationIdentifier)
+      .flatMap {
+        case (destination, flows) =>
+          table
+            .get(destination)
+            .map(t => AppendOnceFlowCreatedFromBatchQueryException(t, 
flows.map(_.identifier)))
+      }
+      .toSeq
+  }
+
+  protected def validateUserSpecifiedSchemas(): Unit = {
+    flows.flatMap(f => tableInput(f.identifier)).foreach { t: TableInput =>
+      // The output inferred schema of a table is the declared schema merged 
with the
+      // schema of all incoming flows. This must be equivalent to the declared 
schema.
+      val inferredSchema = SchemaInferenceUtils
+        .inferSchemaFromFlows(
+          flowsTo(t.identifier).map(f => resolvedFlow(f.identifier)),
+          userSpecifiedSchema = t.specifiedSchema
+        )
+
+      t.specifiedSchema.foreach { ss =>
+        // Check the inferred schema matches the specified schema. Used to 
catch errors where the
+        // inferred user-facing schema has columns that are not in the 
specified one.
+        if (inferredSchema != ss) {
+          val datasetType = GraphElementTypeUtils
+            .getDatasetTypeForMaterializedViewOrStreamingTable(
+              flowsTo(t.identifier).map(f => resolvedFlow(f.identifier))
+            )
+          throw GraphErrors.incompatibleUserSpecifiedAndInferredSchemasError(
+            t.identifier,
+            datasetType,
+            ss,
+            inferredSchema
+          )
+        }
+      }
+    }
+  }
+
+  /**
+   * Validates that all flows are resolved. If there are unresolved flows,
+   * detects a possible cyclic dependency and throw the appropriate execption.
+   */
+  protected def validateSuccessfulFlowAnalysis(): Unit = {
+    // all failed flows with their errors
+    val flowAnalysisFailures = resolutionFailedFlows.flatMap(
+      f => f.failure.headOption.map(err => (f.identifier, err))
+    )
+    // only proceed if there are unresolved flows
+    if (flowAnalysisFailures.nonEmpty) {
+      val failedFlowIdentifiers = flowAnalysisFailures.map(_._1).toSet
+      // used to collect the subgraph of only the unresolved flows
+      // maps every unresolved flow to the set of unresolved flows writing to 
one if its inputs
+      val failedFlowsSubgraph = mutable.Map[TableIdentifier, 
Seq[TableIdentifier]]()
+      val (downstreamFailures, directFailures) = 
flowAnalysisFailures.partition {
+        case (flowIdentifier, _) =>
+          // If a failed flow writes to any of the requested datasets, we mark 
this flow as a
+          // downstream failure
+          val failedFlowsWritingToRequestedDatasets =
+            resolutionFailedFlow(flowIdentifier).funcResult.requestedInputs
+              .flatMap(d => flowsTo.getOrElse(d, Seq()))
+              .map(_.identifier)
+              .intersect(failedFlowIdentifiers)
+              .toSeq
+          failedFlowsSubgraph += (flowIdentifier -> 
failedFlowsWritingToRequestedDatasets)
+          failedFlowsWritingToRequestedDatasets.nonEmpty
+      }
+      // if there are flow that failed due to unresolved upstream flows, check 
for a cycle
+      if (failedFlowsSubgraph.nonEmpty) {
+        detectCycle(failedFlowsSubgraph.toMap).foreach {
+          case (upstream, downstream) =>
+            val upstreamDataset = flow(upstream).destinationIdentifier
+            val downstreamDataset = flow(downstream).destinationIdentifier
+            throw CircularDependencyException(
+              downstreamDataset,
+              upstreamDataset
+            )
+        }
+      }
+      // otherwise report what flows failed directly vs. depending on a failed 
flow
+      throw UnresolvedPipelineException(
+        this,
+        directFailures.map { case (id, value) => (id, value) }.toMap,
+        downstreamFailures.map { case (id, value) => (id, value) }.toMap
+      )
+    }
+  }
+
+  /**
+   * Generic method to detect a cycle in directed graph via DFS traversal.
+   * The graph is given as a reverse adjacency map, that is, a map from
+   * each node to its ancestors.
+   * @return the start and end node of a cycle if found, None otherwise
+   */
+  private def detectCycle(ancestors: Map[TableIdentifier, 
Seq[TableIdentifier]])
+      : Option[(TableIdentifier, TableIdentifier)] = {
+    var cycle: Option[(TableIdentifier, TableIdentifier)] = None
+    val visited = mutable.Set[TableIdentifier]()
+    def visit(f: TableIdentifier, currentPath: List[TableIdentifier]): Unit = {
+      if (cycle.isEmpty && !visited.contains(f)) {
+        if (currentPath.contains(f)) {
+          cycle = Option((currentPath.head, f))
+        } else {
+          ancestors(f).foreach(visit(_, f :: currentPath))
+          visited += f
+        }
+      }
+    }
+    ancestors.keys.foreach(visit(_, Nil))
+    cycle
+  }
+
+  /** Validates that persisted views don't read from invalid sources */
+  protected[graph] def validatePersistedViewSources(): Unit = {
+    val viewToFlowMap = ViewHelpers.persistedViewIdentifierToFlow(graph = this)
+
+    persistedViews
+      .foreach { persistedView =>
+        val flow = viewToFlowMap(persistedView.identifier)
+        val funcResult = resolvedFlow(flow.identifier).funcResult
+        val inputIdentifiers = (funcResult.batchInputs ++ 
funcResult.streamingInputs)
+          .map(_.input.identifier)
+
+        inputIdentifiers
+          .flatMap(view.get)
+          .foreach {
+            case tempView: TemporaryView =>

Review Comment:
   Because temp views are not registered in the catalog and thus cannot be 
accessed outside of pipeline scope.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to