hvanhovell commented on code in PR #45150:
URL: https://github.com/apache/spark/pull/45150#discussion_r1494986285


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala:
##########
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.execution
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, 
SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd}
+
+/**
+ * A listener that tracks the execution of jobs and stages for a given set of 
tags. This is used
+ * to track the progress of a job that is being executed through the connect 
API.
+ *
+ * The listener is instantiated once for the SparkConnectService and then used 
to track all the
+ * current query executions.
+ */
+private[connect] class ConnectProgressExecutionListener extends SparkListener 
with Logging {
+
+  /**
+   * A tracker for a given tag. This is used to track the progress of an 
operation is being
+   * executed through the connect API.
+   */
+  class ExecutionTracker(var tag: String) {
+    private[ConnectProgressExecutionListener] var jobs: Set[Int] = Set()
+    private[ConnectProgressExecutionListener] var stages: Set[Int] = Set()
+    private[ConnectProgressExecutionListener] var totalTasks = 0
+    private[ConnectProgressExecutionListener] var completedTasks = 0
+    private[ConnectProgressExecutionListener] var completedStages = 0
+    private[ConnectProgressExecutionListener] var inputBytesRead = 0L
+    // The tracker is marked as dirty if it has new progress to report. This 
variable does
+    // not need to be protected by a mutex even if multiple threads would read 
the same dirty
+    // state the output is expected to be identical.
+    @volatile private[ConnectProgressExecutionListener] var dirty = false
+
+    /**
+     * Yield the current state of the tracker if it is dirty. A consumer of 
the tracker can
+     * provide a callback that will be called with the current state of the 
tracker if the tracker
+     * has new progress to report.
+     *
+     * If the tracker was marked as dirty, the state is reset after.
+     */
+    def yieldWhenDirty(thunk: (Int, Int, Int, Int, Long) => Unit): Unit = {
+      if (dirty) {
+        thunk(totalTasks, completedTasks, stages.size, completedStages, 
inputBytesRead)
+        dirty = false
+      }
+    }
+
+    /**
+     * Add a job to the tracker. This will add the job to the list of jobs 
that are being tracked
+     */
+    def addJob(job: SparkListenerJobStart): Unit = {
+      jobs = jobs + job.jobId
+      stages = stages ++ job.stageIds
+      totalTasks += job.stageInfos.map(_.numTasks).sum
+      dirty = true
+    }
+
+    def jobCount(): Int = {
+      jobs.size
+    }
+
+    def stageCount(): Int = {
+      stages.size
+    }
+  }
+
+  val trackedTags = collection.mutable.Map[String, ExecutionTracker]()
+
+  override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+    val tags = jobStart.properties.getProperty("spark.job.tags")
+    if (tags != null) {
+      val thisJobTags = tags.split(",").map(_.trim).toSet
+      thisJobTags.foreach { tag =>
+        if (trackedTags.contains(tag)) {
+          trackedTags(tag).addJob(jobStart)
+        }
+      }
+    }
+  }
+
+  override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+    // Check if the task belongs to a job that we are tracking.
+    trackedTags.foreach({ case (tag, tracker) =>
+      if (tracker.stages.contains(taskEnd.stageId)) {
+        tracker.completedTasks += 1
+        tracker.inputBytesRead += taskEnd.taskMetrics.inputMetrics.bytesRead
+        tracker.dirty = true
+      }
+    })
+  }
+
+  override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): 
Unit = {
+    trackedTags.foreach({ case (tag, tracker) =>
+      if (tracker.stages.contains(stageCompleted.stageInfo.stageId)) {
+        tracker.completedStages += 1
+        tracker.dirty = true
+      }
+    })
+  }
+
+  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {

Review Comment:
   Should we include cancelled jobs/stages/tasks?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to