Repository: spark
Updated Branches:
  refs/heads/master 6afe6f32c -> 18b6ec147


[SPARK-24748][SS] Support for reporting custom metrics via StreamingQuery 
Progress

## What changes were proposed in this pull request?

Currently the Structured Streaming sources and sinks does not have a way to 
report custom metrics. Providing an option to report custom metrics and making 
it available via Streaming Query progress can enable sources and sinks to 
report custom progress information (E.g. the lag metrics for Kafka source).

Similar metrics can be reported for Sinks as well, but would like to get 
initial feedback before proceeding further.

## How was this patch tested?

New and existing unit tests.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Closes #21721 from arunmahadevan/SPARK-24748.

Authored-by: Arun Mahadevan <[email protected]>
Signed-off-by: hyukjinkwon <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/18b6ec14
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/18b6ec14
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/18b6ec14

Branch: refs/heads/master
Commit: 18b6ec14716bfafc25ae281b190547ea58b59af1
Parents: 6afe6f3
Author: Arun Mahadevan <[email protected]>
Authored: Tue Aug 7 10:28:26 2018 +0800
Committer: hyukjinkwon <[email protected]>
Committed: Tue Aug 7 10:28:26 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/sources/v2/CustomMetrics.java     | 33 ++++++++++
 .../streaming/SupportsCustomReaderMetrics.java  | 47 +++++++++++++++
 .../streaming/SupportsCustomWriterMetrics.java  | 47 +++++++++++++++
 .../execution/streaming/ProgressReporter.scala  | 63 ++++++++++++++++++--
 .../streaming/sources/MicroBatchWriter.scala    |  2 +-
 .../execution/streaming/sources/memoryV2.scala  | 32 ++++++++--
 .../apache/spark/sql/streaming/progress.scala   | 46 ++++++++++++--
 .../execution/streaming/MemorySinkV2Suite.scala | 22 +++++++
 .../sql/streaming/StreamingQuerySuite.scala     | 28 +++++++++
 9 files changed, 306 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/18b6ec14/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java
new file mode 100644
index 0000000..7011a70
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * An interface for reporting custom metrics from streaming sources and sinks
+ */
[email protected]
+public interface CustomMetrics {
+  /**
+   * Returns a JSON serialized representation of custom metrics
+   *
+   * @return JSON serialized representation of custom metrics
+   */
+  String json();
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/18b6ec14/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java
 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java
new file mode 100644
index 0000000..3b293d9
--- /dev/null
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.sources.v2.reader.streaming;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.CustomMetrics;
+import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
+
+/**
+ * A mix in interface for {@link DataSourceReader}. Data source readers can 
implement this
+ * interface to report custom metrics that gets reported under the
+ * {@link org.apache.spark.sql.streaming.SourceProgress}
+ *
+ */
[email protected]
+public interface SupportsCustomReaderMetrics extends DataSourceReader {
+  /**
+   * Returns custom metrics specific to this data source.
+   */
+  CustomMetrics getCustomMetrics();
+
+  /**
+   * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is 
invalid
+   * (e.g. Invalid data that cannot be parsed). Throwing an error here would 
ensure that
+   * your custom metrics work right and correct values are reported always. 
The default action
+   * on invalid metrics is to ignore it.
+   *
+   * @param ex the exception
+   */
+  default void onInvalidMetrics(Exception ex) {
+    // default is to ignore invalid custom metrics
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/18b6ec14/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java
 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java
new file mode 100644
index 0000000..0cd3650
--- /dev/null
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.sources.v2.writer.streaming;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.CustomMetrics;
+import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
+
+/**
+ * A mix in interface for {@link DataSourceWriter}. Data source writers can 
implement this
+ * interface to report custom metrics that gets reported under the
+ * {@link org.apache.spark.sql.streaming.SinkProgress}
+ *
+ */
[email protected]
+public interface SupportsCustomWriterMetrics extends DataSourceWriter {
+  /**
+   * Returns custom metrics specific to this data source.
+   */
+  CustomMetrics getCustomMetrics();
+
+  /**
+   * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is 
invalid
+   * (e.g. Invalid data that cannot be parsed). Throwing an error here would 
ensure that
+   * your custom metrics work right and correct values are reported always. 
The default action
+   * on invalid metrics is to ignore it.
+   *
+   * @param ex the exception
+   */
+  default void onInvalidMetrics(Exception ex) {
+    // default is to ignore invalid custom metrics
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/18b6ec14/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 47f4b52..1e15832 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -22,14 +22,22 @@ import java.util.{Date, UUID}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.util.control.NonFatal
+
+import org.json4s.JsonAST.JValue
+import org.json4s.jackson.JsonMethods.parse
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, 
LogicalPlan}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
-import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, 
WriteToDataSourceV2Exec}
+import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter
+import org.apache.spark.sql.sources.v2.CustomMetrics
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, 
SupportsCustomReaderMetrics}
+import org.apache.spark.sql.sources.v2.writer.DataSourceWriter
+import 
org.apache.spark.sql.sources.v2.writer.streaming.SupportsCustomWriterMetrics
 import org.apache.spark.sql.streaming._
 import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
 import org.apache.spark.util.Clock
@@ -156,7 +164,31 @@ trait ProgressReporter extends Logging {
     }
     logDebug(s"Execution stats: $executionStats")
 
+    // extracts and validates custom metrics from readers and writers
+    def extractMetrics(
+        getMetrics: () => Option[CustomMetrics],
+        onInvalidMetrics: (Exception) => Unit): Option[String] = {
+      try {
+        getMetrics().map(m => {
+          val json = m.json()
+          parse(json)
+          json
+        })
+      } catch {
+        case ex: Exception if NonFatal(ex) =>
+          onInvalidMetrics(ex)
+          None
+      }
+    }
+
     val sourceProgress = sources.distinct.map { source =>
+      val customReaderMetrics = source match {
+        case s: SupportsCustomReaderMetrics =>
+          extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics)
+
+        case _ => None
+      }
+
       val numRecords = executionStats.inputRows.getOrElse(source, 0L)
       new SourceProgress(
         description = source.toString,
@@ -164,10 +196,19 @@ trait ProgressReporter extends Logging {
         endOffset = currentTriggerEndOffsets.get(source).orNull,
         numInputRows = numRecords,
         inputRowsPerSecond = numRecords / inputTimeSec,
-        processedRowsPerSecond = numRecords / processingTimeSec
+        processedRowsPerSecond = numRecords / processingTimeSec,
+        customReaderMetrics.orNull
       )
     }
-    val sinkProgress = new SinkProgress(sink.toString)
+
+    val customWriterMetrics = dataSourceWriter match {
+      case Some(s: SupportsCustomWriterMetrics) =>
+        extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics)
+
+      case _ => None
+    }
+
+    val sinkProgress = new SinkProgress(sink.toString, 
customWriterMetrics.orNull)
 
     val newProgress = new StreamingQueryProgress(
       id = id,
@@ -196,6 +237,18 @@ trait ProgressReporter extends Logging {
     currentStatus = currentStatus.copy(isTriggerActive = false)
   }
 
+  /** Extract writer from the executed query plan. */
+  private def dataSourceWriter: Option[DataSourceWriter] = {
+    if (lastExecution == null) return None
+    lastExecution.executedPlan.collect {
+      case p if p.isInstanceOf[WriteToDataSourceV2Exec] =>
+        p.asInstanceOf[WriteToDataSourceV2Exec].writer
+    }.headOption match {
+      case Some(w: MicroBatchWriter) => Some(w.writer)
+      case _ => None
+    }
+  }
+
   /** Extract statistics about stateful operators from the executed query 
plan. */
   private def extractStateOperatorMetrics(hasNewData: Boolean): 
Seq[StateOperatorProgress] = {
     if (lastExecution == null) return Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/18b6ec14/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
index d023a35..2d43a7b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
@@ -26,7 +26,7 @@ import 
org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
  * the non-streaming interface, forwarding the batch ID determined at 
construction to a wrapped
  * streaming writer.
  */
-class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends 
DataSourceWriter {
+class MicroBatchWriter(batchId: Long, val writer: StreamWriter) extends 
DataSourceWriter {
   override def commit(messages: Array[WriterCommitMessage]): Unit = {
     writer.commit(batchId, messages)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/18b6ec14/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
index afacb2f..2a5d21f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
@@ -23,6 +23,9 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
 
+import org.json4s.NoTypeHints
+import org.json4s.jackson.Serialization
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
@@ -32,9 +35,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, 
Statistics}
 import 
org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, 
Complete, Update}
 import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink}
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, 
StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, 
DataSourceV2, StreamWriteSupport}
 import org.apache.spark.sql.sources.v2.writer._
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.sources.v2.writer.streaming.{StreamWriter, 
SupportsCustomWriterMetrics}
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
 
@@ -114,14 +117,25 @@ class MemorySinkV2 extends DataSourceV2 with 
StreamWriteSupport with MemorySinkB
     batches.clear()
   }
 
+  def numRows: Int = synchronized {
+    batches.foldLeft(0)(_ + _.data.length)
+  }
+
   override def toString(): String = "MemorySinkV2"
 }
 
 case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row])
   extends WriterCommitMessage {}
 
+class MemoryV2CustomMetrics(sink: MemorySinkV2) extends CustomMetrics {
+  private implicit val formats = Serialization.formats(NoTypeHints)
+  override def json(): String = Serialization.write(Map("numRows" -> 
sink.numRows))
+}
+
 class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, 
schema: StructType)
-  extends DataSourceWriter with Logging {
+  extends DataSourceWriter with SupportsCustomWriterMetrics with Logging {
+
+  private val memoryV2CustomMetrics = new MemoryV2CustomMetrics(sink)
 
   override def createWriterFactory: MemoryWriterFactory = 
MemoryWriterFactory(outputMode, schema)
 
@@ -135,10 +149,16 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, 
outputMode: OutputMode, sc
   override def abort(messages: Array[WriterCommitMessage]): Unit = {
     // Don't accept any of the new input.
   }
+
+  override def getCustomMetrics: CustomMetrics = {
+    memoryV2CustomMetrics
+  }
 }
 
 class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, 
schema: StructType)
-  extends StreamWriter {
+  extends StreamWriter with SupportsCustomWriterMetrics {
+
+  private val customMemoryV2Metrics = new MemoryV2CustomMetrics(sink)
 
   override def createWriterFactory: MemoryWriterFactory = 
MemoryWriterFactory(outputMode, schema)
 
@@ -152,6 +172,10 @@ class MemoryStreamWriter(val sink: MemorySinkV2, 
outputMode: OutputMode, schema:
   override def abort(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {
     // Don't accept any of the new input.
   }
+
+  override def getCustomMetrics: CustomMetrics = {
+    customMemoryV2Metrics
+  }
 }
 
 case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType)

http://git-wip-us.apache.org/repos/asf/spark/blob/18b6ec14/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index 0dcb666..2fb8796 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -163,7 +163,27 @@ class SourceProgress protected[sql](
   val endOffset: String,
   val numInputRows: Long,
   val inputRowsPerSecond: Double,
-  val processedRowsPerSecond: Double) extends Serializable {
+  val processedRowsPerSecond: Double,
+  val customMetrics: String) extends Serializable {
+
+  /** SourceProgress without custom metrics. */
+  protected[sql] def this(
+      description: String,
+      startOffset: String,
+      endOffset: String,
+      numInputRows: Long,
+      inputRowsPerSecond: Double,
+      processedRowsPerSecond: Double) {
+
+    this(
+      description,
+      startOffset,
+      endOffset,
+      numInputRows,
+      inputRowsPerSecond,
+      processedRowsPerSecond,
+      null)
+  }
 
   /** The compact JSON representation of this progress. */
   def json: String = compact(render(jsonValue))
@@ -178,12 +198,18 @@ class SourceProgress protected[sql](
       if (value.isNaN || value.isInfinity) JNothing else JDouble(value)
     }
 
-    ("description" -> JString(description)) ~
+    val jsonVal = ("description" -> JString(description)) ~
       ("startOffset" -> tryParse(startOffset)) ~
       ("endOffset" -> tryParse(endOffset)) ~
       ("numInputRows" -> JInt(numInputRows)) ~
       ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~
       ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond))
+
+    if (customMetrics != null) {
+      jsonVal ~ ("customMetrics" -> parse(customMetrics))
+    } else {
+      jsonVal
+    }
   }
 
   private def tryParse(json: String) = try {
@@ -202,7 +228,13 @@ class SourceProgress protected[sql](
  */
 @InterfaceStability.Evolving
 class SinkProgress protected[sql](
-    val description: String) extends Serializable {
+    val description: String,
+    val customMetrics: String) extends Serializable {
+
+  /** SinkProgress without custom metrics. */
+  protected[sql] def this(description: String) {
+    this(description, null)
+  }
 
   /** The compact JSON representation of this progress. */
   def json: String = compact(render(jsonValue))
@@ -213,6 +245,12 @@ class SinkProgress protected[sql](
   override def toString: String = prettyJson
 
   private[sql] def jsonValue: JValue = {
-    ("description" -> JString(description))
+    val jsonVal = ("description" -> JString(description))
+
+    if (customMetrics != null) {
+      jsonVal ~ ("customMetrics" -> parse(customMetrics))
+    } else {
+      jsonVal
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/18b6ec14/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
index b4d9b68..1efaead 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
@@ -84,4 +84,26 @@ class MemorySinkV2Suite extends StreamTest with 
BeforeAndAfter {
 
     assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 
22, 33))
   }
+
+  test("writer metrics") {
+    val sink = new MemorySinkV2
+    val schema = new StructType().add("i", "int")
+    // batch 0
+    var writer = new MemoryWriter(sink, 0, OutputMode.Append(), schema)
+    writer.commit(
+      Array(
+        MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
+        MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
+        MemoryWriterCommitMessage(2, Seq(Row(5), Row(6)))
+      ))
+    assert(writer.getCustomMetrics.json() == "{\"numRows\":6}")
+    // batch 1
+    writer = new MemoryWriter(sink, 1, OutputMode.Append(), schema
+    )
+    writer.commit(
+      Array(
+        MemoryWriterCommitMessage(0, Seq(Row(7), Row(8)))
+      ))
+    assert(writer.getCustomMetrics.json() == "{\"numRows\":8}")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/18b6ec14/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 9cceec9..a379569 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -24,6 +24,9 @@ import java.util.concurrent.CountDownLatch
 import scala.collection.mutable
 
 import org.apache.commons.lang3.RandomStringUtils
+import org.json4s.NoTypeHints
+import org.json4s.jackson.JsonMethods._
+import org.json4s.jackson.Serialization
 import org.scalactic.TolerantNumerics
 import org.scalatest.BeforeAndAfter
 import org.scalatest.concurrent.PatienceConfiguration.Timeout
@@ -475,6 +478,31 @@ class StreamingQuerySuite extends StreamTest with 
BeforeAndAfter with Logging wi
     }
   }
 
+  test("Check if custom metrics are reported") {
+    val streamInput = MemoryStream[Int]
+    implicit val formats = Serialization.formats(NoTypeHints)
+    testStream(streamInput.toDF(), useV2Sink = true)(
+      AddData(streamInput, 1, 2, 3),
+      CheckAnswer(1, 2, 3),
+      AssertOnQuery { q =>
+        val lastProgress = getLastProgressWithData(q)
+        assert(lastProgress.nonEmpty)
+        assert(lastProgress.get.numInputRows == 3)
+        assert(lastProgress.get.sink.customMetrics == "{\"numRows\":3}")
+        true
+      },
+      AddData(streamInput, 4, 5, 6, 7),
+      CheckAnswer(1, 2, 3, 4, 5, 6, 7),
+      AssertOnQuery { q =>
+        val lastProgress = getLastProgressWithData(q)
+        assert(lastProgress.nonEmpty)
+        assert(lastProgress.get.numInputRows == 4)
+        assert(lastProgress.get.sink.customMetrics == "{\"numRows\":7}")
+        true
+      }
+    )
+  }
+
   test("input row calculation with same V1 source used twice in self-join") {
     val streamingTriggerDF = spark.createDataset(1 to 10).toDF
     val streamingInputDF = 
createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to