This is an automated email from the ASF dual-hosted git repository.

granthenke pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kudu.git

commit 61546c02517663bd4ac5879a9a4189e0a81e963b
Author: Grant Henke <[email protected]>
AuthorDate: Fri Feb 8 10:52:42 2019 -0600

    [spark-tools] DistributedDataGenerator repartition support
    
    This patch adds support to the DistributedDataGenerator
    to repartition the data to match the Kudu partitioning.
    
    Because data generation is now decoupled from data
    loading, this patch changes the collision handling
    behavior. Instead of generating new data on collision,
    now the collision is only tracked in the metrics.
    
    Additionally this patch changes the default generation
    type from random to sequential given that has been
    shown to be the more common option and the type
    of workload Kudu is better suited for.
    
    Change-Id: I57bcc68d645c52b429ac6cf8bcdf0551a8244995
    Reviewed-on: http://gerrit.cloudera.org:8080/12411
    Tested-by: Kudu Jenkins
    Reviewed-by: Adar Dembo <[email protected]>
---
 .../spark/tools/DistributedDataGenerator.scala     | 204 ++++++++++++++-------
 .../spark/tools/DistributedDataGeneratorTest.scala |  74 +++++++-
 .../org/apache/kudu/spark/kudu/KuduContext.scala   |   2 +-
 3 files changed, 209 insertions(+), 71 deletions(-)

diff --git 
a/java/kudu-spark-tools/src/main/scala/org/apache/kudu/spark/tools/DistributedDataGenerator.scala
 
b/java/kudu-spark-tools/src/main/scala/org/apache/kudu/spark/tools/DistributedDataGenerator.scala
index ee0954f..01af3ce 100644
--- 
a/java/kudu-spark-tools/src/main/scala/org/apache/kudu/spark/tools/DistributedDataGenerator.scala
+++ 
b/java/kudu-spark-tools/src/main/scala/org/apache/kudu/spark/tools/DistributedDataGenerator.scala
@@ -20,12 +20,17 @@ import java.math.BigDecimal
 import java.math.BigInteger
 import java.nio.charset.StandardCharsets
 
+import org.apache.kudu.Schema
 import org.apache.kudu.Type
 import org.apache.kudu.client.PartialRow
 import org.apache.kudu.client.SessionConfiguration
 import org.apache.kudu.spark.kudu.KuduContext
+import org.apache.kudu.spark.kudu.KuduWriteOptions
+import org.apache.kudu.spark.kudu.RowConverter
+import org.apache.kudu.spark.kudu.SparkUtil
 import org.apache.kudu.spark.tools.DistributedDataGeneratorOptions._
 import org.apache.kudu.util.DataGenerator
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.util.LongAccumulator
 import org.apache.spark.SparkConf
@@ -50,56 +55,138 @@ object GeneratorMetrics {
 object DistributedDataGenerator {
   val log: Logger = LoggerFactory.getLogger(getClass)
 
-  def generateRows(
-      context: KuduContext,
-      options: DistributedDataGeneratorOptions,
-      taskNum: Int,
-      metrics: GeneratorMetrics) {
-
-    val kuduClient = context.syncClient
-    val session = kuduClient.newSession()
-    session.setFlushMode(SessionConfiguration.FlushMode.AUTO_FLUSH_BACKGROUND)
-    val kuduTable = kuduClient.openTable(options.tableName)
-
-    val generator = new DataGenerator.DataGeneratorBuilder()
-    // Add taskNum to the seed otherwise each task will try to generate the 
same rows.
-      .random(new java.util.Random(options.seed + taskNum))
-      .stringLength(options.stringLength)
-      .binaryLength(options.binaryLength)
-      .build()
-
-    val rowsToWrite = options.numRows / options.numTasks
-    var currentRow: Long = rowsToWrite * taskNum
-    var rowsWritten: Long = 0
-    while (rowsWritten < rowsToWrite) {
-      val insert = kuduTable.newInsert()
-      if (options.generatorType == SequentialGenerator) {
-        setRow(insert.getRow, currentRow)
-      } else if (options.generatorType == RandomGenerator) {
-        generator.randomizeRow(insert.getRow)
-      }
-      session.apply(insert)
+  def run(options: DistributedDataGeneratorOptions, ss: SparkSession): Unit = {
+    log.info(s"Running a DistributedDataGenerator with options: $options")
+    val sc = ss.sparkContext
+    val context = new KuduContext(options.masterAddresses, sc)
+    val metrics = GeneratorMetrics(sc)
+
+    // Generate the Inserts.
+    var rdd = sc
+      .parallelize(0 until options.numTasks, numSlices = options.numTasks)
+      .mapPartitions(
+        { taskNumIter =>
+          // We know there is only 1 task per partition because numSlices = 
options.numTasks above.
+          val taskNum = taskNumIter.next()
+          val generator = new DataGenerator.DataGeneratorBuilder()
+          // Add taskNum to the seed otherwise each task will try to generate 
the same rows.
+            .random(new java.util.Random(options.seed + taskNum))
+            .stringLength(options.stringLength)
+            .binaryLength(options.binaryLength)
+            .build()
+          val table = context.syncClient.openTable(options.tableName)
+          val schema = table.getSchema
+          val numRows = options.numRows / options.numTasks
+          val startRow: Long = numRows * taskNum
+          new GeneratedRowIterator(generator, options.generatorType, schema, 
startRow, numRows)
+        },
+        true
+      )
+
+    if (options.repartition) {
+      val table = context.syncClient.openTable(options.tableName)
+      val sparkSchema = SparkUtil.sparkSchema(table.getSchema)
+      rdd = context
+        .repartitionRows(rdd, options.tableName, sparkSchema, 
KuduWriteOptions(ignoreNull = true))
+    }
 
-      // Synchronously flush on potentially the last iteration of the
-      // loop, so we can check whether we need to retry any collisions.
-      if (rowsWritten + 1 == rowsToWrite) session.flush()
+    // Write the rows to Kudu.
+    // TODO: Use context.writeRows while still tracking inserts/collisions.
+    rdd.foreachPartition { rows =>
+      val kuduClient = context.syncClient
+      val table = kuduClient.openTable(options.tableName)
+      val kuduSchema = table.getSchema
+      val sparkSchema = SparkUtil.sparkSchema(kuduSchema)
+      val converter = new RowConverter(kuduSchema, sparkSchema, ignoreNull = 
true)
 
+      val session = kuduClient.newSession()
+      
session.setFlushMode(SessionConfiguration.FlushMode.AUTO_FLUSH_BACKGROUND)
+
+      var rowsWritten = 0
+      rows.foreach { row =>
+        val insert = table.newInsert()
+        val partialRow = converter.toPartialRow(row)
+        insert.setRow(partialRow)
+        session.apply(insert)
+        rowsWritten += 1
+      }
+      // Synchronously flush after the last record is written.
+      session.flush()
+
+      // Track the collisions.
+      var collisions = 0
       for (error <- session.getPendingErrors.getRowErrors) {
         if (error.getErrorStatus.isAlreadyPresent) {
           // Because we can't check for collisions every time, but instead
           // only when the rows are flushed, we subtract any rows that may
           // have failed from the counter.
           rowsWritten -= 1
-          metrics.collisions.add(1)
+          collisions += 1
         } else {
           throw new RuntimeException("Kudu write error: " + 
error.getErrorStatus.toString)
         }
       }
-      currentRow += 1
-      rowsWritten += 1
+      metrics.rowsWritten.add(rowsWritten)
+      metrics.collisions.add(collisions)
+      session.close()
     }
-    metrics.rowsWritten.add(rowsWritten)
-    session.close()
+
+    log.info(s"Rows written: ${metrics.rowsWritten.value}")
+    log.info(s"Collisions: ${metrics.collisions.value}")
+  }
+
+  /**
+   * Entry point for testing. SparkContext is a singleton,
+   * so tests must create and manage their own.
+   */
+  @InterfaceAudience.LimitedPrivate(Array("Test"))
+  def testMain(args: Array[String], ss: SparkSession): Unit = {
+    DistributedDataGeneratorOptions.parse(args) match {
+      case None => throw new IllegalArgumentException("Could not parse 
arguments")
+      case Some(config) => run(config, ss)
+    }
+  }
+
+  def main(args: Array[String]): Unit = {
+    val conf = new SparkConf().setAppName("DistributedDataGenerator")
+    val ss = SparkSession.builder().config(conf).getOrCreate()
+    testMain(args, ss)
+  }
+}
+
+private class GeneratedRowIterator(
+    generator: DataGenerator,
+    generatorType: String,
+    schema: Schema,
+    startRow: Long,
+    numRows: Long)
+    extends Iterator[Row] {
+
+  val sparkSchema = SparkUtil.sparkSchema(schema)
+  // ignoreNull values so unset/defaulted rows can be passed through.
+  val converter = new RowConverter(schema, sparkSchema, ignoreNull = true)
+
+  var currentRow: Long = startRow
+  var rowsGenerated: Long = 0
+
+  override def hasNext: Boolean = rowsGenerated < numRows
+
+  override def next(): Row = {
+    if (rowsGenerated >= numRows) {
+      throw new IllegalStateException("Already generated all of the rows.")
+    }
+
+    val partialRow = schema.newPartialRow()
+    if (generatorType == SequentialGenerator) {
+      setRow(partialRow, currentRow)
+    } else if (generatorType == RandomGenerator) {
+      generator.randomizeRow(partialRow)
+    } else {
+      throw new IllegalArgumentException(s"Generator type of $generatorType is 
unsupported")
+    }
+    currentRow += 1
+    rowsGenerated += 1
+    converter.toRow(partialRow)
   }
 
   /**
@@ -142,35 +229,6 @@ object DistributedDataGenerator {
       }
     }
   }
-
-  def run(options: DistributedDataGeneratorOptions, ss: SparkSession): Unit = {
-    log.info(s"Running a DistributedDataGenerator with options: $options")
-    val sc = ss.sparkContext
-    val context = new KuduContext(options.masterAddresses, sc)
-    val metrics = GeneratorMetrics(sc)
-    sc.parallelize(0 until options.numTasks, numSlices = options.numTasks)
-      .foreachPartition(taskNum => generateRows(context, options, 
taskNum.next(), metrics))
-    log.info(s"Rows written: ${metrics.rowsWritten.value}")
-    log.info(s"Collisions: ${metrics.collisions.value}")
-  }
-
-  /**
-   * Entry point for testing. SparkContext is a singleton,
-   * so tests must create and manage their own.
-   */
-  @InterfaceAudience.LimitedPrivate(Array("Test"))
-  def testMain(args: Array[String], ss: SparkSession): Unit = {
-    DistributedDataGeneratorOptions.parse(args) match {
-      case None => throw new IllegalArgumentException("Could not parse 
arguments")
-      case Some(config) => run(config, ss)
-    }
-  }
-
-  def main(args: Array[String]): Unit = {
-    val conf = new SparkConf().setAppName("DistributedDataGenerator")
-    val ss = SparkSession.builder().config(conf).getOrCreate()
-    testMain(args, ss)
-  }
 }
 
 @InterfaceAudience.Private
@@ -183,7 +241,8 @@ case class DistributedDataGeneratorOptions(
     numTasks: Int = DistributedDataGeneratorOptions.DefaultNumTasks,
     stringLength: Int = DistributedDataGeneratorOptions.DefaultStringLength,
     binaryLength: Int = DistributedDataGeneratorOptions.DefaultStringLength,
-    seed: Long = System.currentTimeMillis())
+    seed: Long = System.currentTimeMillis(),
+    repartition: Boolean = DistributedDataGeneratorOptions.DefaultRepartition)
 
 @InterfaceAudience.Private
 @InterfaceStability.Unstable
@@ -194,7 +253,8 @@ object DistributedDataGeneratorOptions {
   val DefaultBinaryLength: Int = 128
   val RandomGenerator: String = "random"
   val SequentialGenerator: String = "sequential"
-  val DefaultGeneratorType: String = RandomGenerator
+  val DefaultGeneratorType: String = SequentialGenerator
+  val DefaultRepartition: Boolean = false
 
   private val parser: OptionParser[DistributedDataGeneratorOptions] =
     new OptionParser[DistributedDataGeneratorOptions]("LoadRandomData") {
@@ -220,7 +280,8 @@ object DistributedDataGeneratorOptions {
 
       opt[Int]("num-tasks")
         .action((v, o) => o.copy(numTasks = v))
-        .text(s"The total number of Spark tasks to generate. Default: 
${DefaultNumTasks}")
+        .text(s"The total number of Spark tasks to use when generating data. " 
+
+          s"Default: ${DefaultNumTasks}")
         .optional()
 
       opt[Int]("string-length")
@@ -237,6 +298,11 @@ object DistributedDataGeneratorOptions {
         .action((v, o) => o.copy(seed = v))
         .text(s"The seed to use in the random data generator. " +
           s"Default: `System.currentTimeMillis()`")
+
+      opt[Boolean]("repartition")
+        .action((v, o) => o.copy(repartition = v))
+        .text(s"Repartition the data to ensure each spark task talks to a 
minimal " +
+          s"set of tablet servers.")
     }
 
   def parse(args: Seq[String]): Option[DistributedDataGeneratorOptions] = {
diff --git 
a/java/kudu-spark-tools/src/test/scala/org/apache/kudu/spark/tools/DistributedDataGeneratorTest.scala
 
b/java/kudu-spark-tools/src/test/scala/org/apache/kudu/spark/tools/DistributedDataGeneratorTest.scala
index e8c2e0d..37e2624 100644
--- 
a/java/kudu-spark-tools/src/test/scala/org/apache/kudu/spark/tools/DistributedDataGeneratorTest.scala
+++ 
b/java/kudu-spark-tools/src/test/scala/org/apache/kudu/spark/tools/DistributedDataGeneratorTest.scala
@@ -17,11 +17,14 @@
 package org.apache.kudu.spark.tools
 
 import org.apache.kudu.Type
+import org.apache.kudu.client.KuduPartitioner
 import org.apache.kudu.spark.kudu.KuduTestSuite
 import org.apache.kudu.test.RandomUtils
 import org.apache.kudu.util.DecimalUtil
 import org.apache.kudu.util.SchemaGenerator
 import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.SparkListener
+import org.apache.spark.scheduler.SparkListenerTaskEnd
 import org.apache.spark.sql.Row
 import org.junit.Test
 import org.junit.Assert.assertEquals
@@ -51,7 +54,9 @@ class DistributedDataGeneratorTest extends KuduTestSuite {
       randomTableName,
       harness.getMasterAddressesAsString)
     val rdd = runGeneratorTest(args)
-    assertEquals(numRows, rdd.collect.length)
+    val collisions = ss.sparkContext.longAccumulator("row_collisions").value
+    // Collisions could cause the number of row to be less than the number set.
+    assertEquals(numRows - collisions, rdd.collect.length)
   }
 
   @Test
@@ -67,6 +72,73 @@ class DistributedDataGeneratorTest extends KuduTestSuite {
     assertEquals(numRows, rdd.collect.length)
   }
 
+  @Test
+  def testRepartitionData() {
+    val numRows = 100
+    val args = Array(
+      s"--num-rows=$numRows",
+      "--num-tasks=10",
+      "--type=sequential",
+      "--repartition=true",
+      randomTableName,
+      harness.getMasterAddressesAsString)
+    val rdd = runGeneratorTest(args)
+    assertEquals(numRows, rdd.collect.length)
+  }
+
+  @Test
+  def testNumTasks() {
+    // Add a SparkListener to count the number of tasks that end.
+    var actualNumTasks = 0
+    val listener = new SparkListener {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+        actualNumTasks += 1
+      }
+    }
+    ss.sparkContext.addSparkListener(listener)
+
+    val numTasks = 8
+    val numRows = 100
+    val args = Array(
+      s"--num-rows=$numRows",
+      s"--num-tasks=$numTasks",
+      randomTableName,
+      harness.getMasterAddressesAsString)
+    runGeneratorTest(args)
+
+    assertEquals(numTasks, actualNumTasks)
+  }
+
+  @Test
+  def testNumTasksRepartition(): Unit = {
+    // Add a SparkListener to count the number of tasks that end.
+    var actualNumTasks = 0
+    val listener = new SparkListener {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+        actualNumTasks += 1
+      }
+    }
+    ss.sparkContext.addSparkListener(listener)
+
+    val numTasks = 8
+    val numRows = 100
+    val args = Array(
+      s"--num-rows=$numRows",
+      s"--num-tasks=$numTasks",
+      "--repartition=true",
+      randomTableName,
+      harness.getMasterAddressesAsString)
+    runGeneratorTest(args)
+
+    val table = kuduContext.syncClient.openTable(randomTableName)
+    val numPartitions = new 
KuduPartitioner.KuduPartitionerBuilder(table).build().numPartitions()
+
+    // We expect the number of tasks to be equal to numTasks + numPartitions 
because numTasks tasks
+    // are run to generate the data then we repartition the data to match the 
table partitioning
+    // and numPartitions tasks load the data.
+    assertEquals(numTasks + numPartitions, actualNumTasks)
+  }
+
   def runGeneratorTest(args: Array[String]): RDD[Row] = {
     val schema = generator.randomSchema()
     val options = generator.randomCreateTableOptions(schema)
diff --git 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
index 777e46f..28aec32 100644
--- 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
+++ 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
@@ -366,7 +366,7 @@ class KuduContext(val kuduMaster: String, sc: SparkContext, 
val socketReadTimeou
     log.info(s"completed $operation ops: duration histogram: 
$durationHistogram")
   }
 
-  private def repartitionRows(
+  private[spark] def repartitionRows(
       rdd: RDD[Row],
       tableName: String,
       schema: StructType,

Reply via email to