This is an automated email from the ASF dual-hosted git repository.

granthenke pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kudu.git


The following commit(s) were added to refs/heads/master by this push:
     new d9be1f6  KUDU-2672: [spark] Optionally repartition to match Kudu 
partitions
d9be1f6 is described below

commit d9be1f6623c068524e9bd65a89e25146d9b70dd5
Author: Grant Henke <[email protected]>
AuthorDate: Fri Feb 8 17:12:00 2019 -0600

    KUDU-2672: [spark] Optionally repartition to match Kudu partitions
    
    Adds a write option to repartition the data to match
    the target Kudu partitions. Additionally provides the
    option to sort while repartitioning.
    
    Repartitioning ensures that one task/client is only
    writing to a single tablet. This improves throughput
    by improving batching especially for tables with a large
    number of partitions.
    
    Additionally sorting before writing to Kudu reduces the
    amount of compactions needed and can improve
    sustained throughput.
    
    Change-Id: I8763615997bccc08901235841149fc3bacb321e7
    Reviewed-on: http://gerrit.cloudera.org:8080/12484
    Tested-by: Kudu Jenkins
    Reviewed-by: Adar Dembo <[email protected]>
---
 .../org/apache/kudu/spark/kudu/DefaultSource.scala | 11 ++-
 .../org/apache/kudu/spark/kudu/KuduContext.scala   | 98 +++++++++++++++++++---
 .../apache/kudu/spark/kudu/KuduWriteOptions.scala  | 10 ++-
 .../org/apache/kudu/spark/kudu/RowConverter.scala  |  7 +-
 .../apache/kudu/spark/kudu/DefaultSourceTest.scala | 80 +++++++++++++++++-
 5 files changed, 184 insertions(+), 22 deletions(-)

diff --git 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
index 139184b..fddfe40 100644
--- 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
+++ 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
@@ -61,6 +61,8 @@ class DefaultSource
   val SCAN_LOCALITY = "kudu.scanLocality"
   val IGNORE_NULL = "kudu.ignoreNull"
   val IGNORE_DUPLICATE_ROW_ERRORS = "kudu.ignoreDuplicateRowErrors"
+  val REPARTITION = "kudu.repartition"
+  val REPARTITION_SORT = "kudu.repartition.sort"
   val SCAN_REQUEST_TIMEOUT_MS = "kudu.scanRequestTimeoutMs"
   val SOCKET_READ_TIMEOUT_MS = "kudu.socketReadTimeoutMs"
   val BATCH_SIZE = "kudu.batchSize"
@@ -193,10 +195,11 @@ class DefaultSource
     Try(parameters(OPERATION) == "insert-ignore").getOrElse(false)
     val ignoreNull =
       parameters.get(IGNORE_NULL).map(_.toBoolean).getOrElse(defaultIgnoreNull)
-
-    Try(parameters.getOrElse(IGNORE_NULL, "false").toBoolean).getOrElse(false)
-
-    KuduWriteOptions(ignoreDuplicateRowErrors, ignoreNull)
+    val repartition =
+      
parameters.get(REPARTITION).map(_.toBoolean).getOrElse(defaultRepartition)
+    val repartitionSort =
+      
parameters.get(REPARTITION_SORT).map(_.toBoolean).getOrElse(defaultRepartitionSort)
+    KuduWriteOptions(ignoreDuplicateRowErrors, ignoreNull, repartition, 
repartitionSort)
   }
 
   private def getMasterAddrs(parameters: Map[String, String]): String = {
diff --git 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
index 4fa59c1..777e46f 100644
--- 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
+++ 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
@@ -19,6 +19,7 @@ package org.apache.kudu.spark.kudu
 
 import java.security.AccessController
 import java.security.PrivilegedAction
+
 import javax.security.auth.Subject
 import javax.security.auth.login.AppConfigurationEntry
 import javax.security.auth.login.Configuration
@@ -26,25 +27,22 @@ import javax.security.auth.login.LoginContext
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
-
 import org.apache.hadoop.util.ShutdownHookManager
+import org.apache.spark.Partitioner
 import org.apache.spark.SparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.DataType
-import org.apache.spark.sql.types.DataTypes
-import org.apache.spark.sql.types.DecimalType
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.AccumulatorV2
+import org.apache.spark.util.CollectionAccumulator
 import org.apache.spark.util.LongAccumulator
 import org.apache.yetus.audience.InterfaceAudience
 import org.apache.yetus.audience.InterfaceStability
 import org.slf4j.Logger
 import org.slf4j.LoggerFactory
-
 import org.apache.kudu.client.SessionConfiguration.FlushMode
 import org.apache.kudu.client._
 import org.apache.kudu.spark.kudu.SparkUtil._
@@ -66,6 +64,12 @@ class KuduContext(val kuduMaster: String, sc: SparkContext, 
val socketReadTimeou
 
   def this(kuduMaster: String, sc: SparkContext) = this(kuduMaster, sc, None)
 
+  // An accumulator that collects all the rows written to Kudu for testing 
only.
+  // Enabled by setting captureRows = true.
+  private[kudu] var captureRows = false
+  private[kudu] var rowsAccumulator: CollectionAccumulator[Row] =
+    sc.collectionAccumulator[Row]("kudu.rows")
+
   /**
    * A collection of accumulator metrics describing the usage of a KuduContext.
    */
@@ -329,7 +333,21 @@ class KuduContext(val kuduMaster: String, sc: 
SparkContext, val socketReadTimeou
     val schema = data.schema
     // Get the client's last propagated timestamp on the driver.
     val lastPropagatedTimestamp = syncClient.getLastPropagatedTimestamp
-    data.queryExecution.toRdd.foreachPartition(iterator => {
+
+    // Convert to an RDD and map the InternalRows to Rows.
+    // This avoids any corruption as reported in SPARK-26880.
+    var rdd = data.queryExecution.toRdd.mapPartitions { rows =>
+      val table = syncClient.openTable(tableName)
+      val converter = new RowConverter(table.getSchema, schema, 
writeOptions.ignoreNull)
+      rows.map(converter.toRow)
+    }
+
+    if (writeOptions.repartition) {
+      rdd = repartitionRows(rdd, tableName, schema, writeOptions)
+    }
+
+    // Write the rows for each Spark partition.
+    rdd.foreachPartition(iterator => {
       val pendingErrors = writePartitionRows(
         iterator,
         schema,
@@ -348,8 +366,55 @@ class KuduContext(val kuduMaster: String, sc: 
SparkContext, val socketReadTimeou
     log.info(s"completed $operation ops: duration histogram: 
$durationHistogram")
   }
 
+  private def repartitionRows(
+      rdd: RDD[Row],
+      tableName: String,
+      schema: StructType,
+      writeOptions: KuduWriteOptions): RDD[Row] = {
+    val partitionCount = getPartitionCount(tableName)
+    val sparkPartitioner = new Partitioner {
+      override def numPartitions: Int = partitionCount
+      override def getPartition(key: Any): Int = {
+        key.asInstanceOf[(Int, Row)]._1
+      }
+    }
+
+    // Key the rows by the Kudu partition index using the KuduPartitioner and 
the
+    // table's primary key. This allows us to re-partition and sort the 
columns.
+    val keyedRdd = rdd.mapPartitions { rows =>
+      val table = syncClient.openTable(tableName)
+      val converter = new RowConverter(table.getSchema, schema, 
writeOptions.ignoreNull)
+      val partitioner = new 
KuduPartitioner.KuduPartitionerBuilder(table).build()
+      rows.map { row =>
+        val partialRow = converter.toPartialRow(row)
+        val partitionIndex = partitioner.partitionRow(partialRow)
+        ((partitionIndex, partialRow.encodePrimaryKey()), row)
+      }
+    }
+
+    // Define an implicit Ordering trait for the encoded primary key
+    // to enable rdd sorting functions below.
+    implicit val byteArrayOrdering: Ordering[Array[Byte]] = new 
Ordering[Array[Byte]] {
+      def compare(x: Array[Byte], y: Array[Byte]): Int = {
+        TypeUtils.compareBinary(x, y)
+      }
+    }
+
+    // Partition the rows by the Kudu partition index to ensure the Spark 
partitions
+    // match the Kudu partitions. This will make the number of Spark tasks 
match the number
+    // of Kudu partitions. Optionally sort while repartitioning.
+    // TODO: At some point we may want to support more or less tasks while 
still partitioning.
+    val shuffledRDD = if (writeOptions.repartitionSort) {
+      keyedRdd.repartitionAndSortWithinPartitions(sparkPartitioner)
+    } else {
+      keyedRdd.partitionBy(sparkPartitioner)
+    }
+    // Drop the partitioning key.
+    shuffledRDD.map { case (_, row) => row }
+  }
+
   private def writePartitionRows(
-      rows: Iterator[InternalRow],
+      rows: Iterator[Row],
       schema: StructType,
       tableName: String,
       opType: OperationType,
@@ -367,8 +432,11 @@ class KuduContext(val kuduMaster: String, sc: 
SparkContext, val socketReadTimeou
     log.info(s"applying operations of type '${opType.toString}' to table 
'$tableName'")
     val startTime = System.currentTimeMillis()
     try {
-      for (internalRow <- rows) {
-        val partialRow = rowConverter.toPartialRow(internalRow)
+      for (row <- rows) {
+        if (captureRows) {
+          rowsAccumulator.add(row)
+        }
+        val partialRow = rowConverter.toPartialRow(row)
         val operation = opType.operation(table)
         operation.setRow(partialRow)
         session.apply(operation)
@@ -386,6 +454,12 @@ class KuduContext(val kuduMaster: String, sc: 
SparkContext, val socketReadTimeou
     }
     session.getPendingErrors
   }
+
+  private def getPartitionCount(tableName: String): Int = {
+    val table = syncClient.openTable(tableName)
+    val partitioner = new KuduPartitioner.KuduPartitionerBuilder(table).build()
+    partitioner.numPartitions()
+  }
 }
 
 private object KuduContext {
diff --git 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduWriteOptions.scala
 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduWriteOptions.scala
index 22a6886..24d4fad 100644
--- 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduWriteOptions.scala
+++ 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduWriteOptions.scala
@@ -28,14 +28,22 @@ import org.apache.kudu.spark.kudu.KuduWriteOptions._
  * @param ignoreDuplicateRowErrors when inserting, ignore any new rows that
  *                                 have a primary key conflict with existing 
rows
  * @param ignoreNull update only non-Null columns if set true
+ * @param repartition if set to true, the data will be repartitioned to match 
the
+ *                   partitioning of the target Kudu table
+ * @param repartitionSort if set to true, the data will also be sorted while 
being
+ *                   repartitioned. This is only used if repartition is true.
  */
 @InterfaceAudience.Public
 @InterfaceStability.Evolving
 case class KuduWriteOptions(
     ignoreDuplicateRowErrors: Boolean = defaultIgnoreDuplicateRowErrors,
-    ignoreNull: Boolean = defaultIgnoreNull)
+    ignoreNull: Boolean = defaultIgnoreNull,
+    repartition: Boolean = defaultRepartition,
+    repartitionSort: Boolean = defaultRepartitionSort)
 
 object KuduWriteOptions {
   val defaultIgnoreDuplicateRowErrors: Boolean = false
   val defaultIgnoreNull: Boolean = false
+  val defaultRepartition: Boolean = false
+  val defaultRepartitionSort: Boolean = true
 }
diff --git 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/RowConverter.scala 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/RowConverter.scala
index f23767e..e9f16e4 100644
--- 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/RowConverter.scala
+++ 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/RowConverter.scala
@@ -40,11 +40,10 @@ class RowConverter(kuduSchema: Schema, schema: StructType, 
ignoreNull: Boolean)
   })
 
   /**
-   * Converts a Spark internal row to a Kudu PartialRow.
+   * Converts a Spark internalRow to a Spark Row.
    */
-  def toPartialRow(internalRow: InternalRow): PartialRow = {
-    val row = typeConverter(internalRow).asInstanceOf[Row]
-    toPartialRow(row)
+  def toRow(internalRow: InternalRow): Row = {
+    typeConverter(internalRow).asInstanceOf[Row]
   }
 
   /**
diff --git 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
index cbc315e..6afd32c 100644
--- 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
+++ 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
@@ -20,6 +20,7 @@ import scala.collection.JavaConverters._
 import scala.collection.immutable.IndexedSeq
 import scala.util.control.NonFatal
 import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.DataTypes
@@ -31,12 +32,16 @@ import org.apache.kudu.ColumnSchema.ColumnSchemaBuilder
 import org.apache.kudu.client.CreateTableOptions
 import org.apache.kudu.Schema
 import org.apache.kudu.Type
+import org.apache.kudu.test.RandomUtils
+import org.apache.spark.scheduler.SparkListener
+import org.apache.spark.scheduler.SparkListenerTaskEnd
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.junit.Before
 import org.junit.Test
 
-class DefaultSourceTest extends KuduTestSuite with Matchers {
+import scala.util.Random
 
+class DefaultSourceTest extends KuduTestSuite with Matchers {
   val rowCount = 10
   var sqlContext: SQLContext = _
   var rows: IndexedSeq[(Int, Int, String, Long)] = _
@@ -403,6 +408,79 @@ class DefaultSourceTest extends KuduTestSuite with 
Matchers {
   }
 
   @Test
+  def testRepartition(): Unit = {
+    runRepartitionTest(false)
+  }
+
+  @Test
+  def testRepartitionAndSort(): Unit = {
+    runRepartitionTest(true)
+  }
+
+  def runRepartitionTest(repartitionSort: Boolean): Unit = {
+    // Create a simple table with 2 range partitions split on the value 100.
+    val tableName = "testRepartition"
+    val splitValue = 100
+    val split = simpleSchema.newPartialRow()
+    split.addInt("key", splitValue)
+    val options = new CreateTableOptions()
+    options.setRangePartitionColumns(List("key").asJava)
+    options.addSplitRow(split)
+    val table = kuduClient.createTable(tableName, simpleSchema, options)
+
+    // Add a SparkListener to count the number of tasks that end.
+    var actualNumTasks = 0
+    val listener = new SparkListener {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+        actualNumTasks += 1
+      }
+    }
+    ss.sparkContext.addSparkListener(listener)
+
+    val random = Random.javaRandomToRandom(RandomUtils.getRandom)
+    val data = random.shuffle(
+      Seq(
+        Row.fromSeq(Seq(0, "0")),
+        Row.fromSeq(Seq(25, "25")),
+        Row.fromSeq(Seq(50, "50")),
+        Row.fromSeq(Seq(75, "75")),
+        Row.fromSeq(Seq(99, "99")),
+        Row.fromSeq(Seq(100, "100")),
+        Row.fromSeq(Seq(101, "101")),
+        Row.fromSeq(Seq(125, "125")),
+        Row.fromSeq(Seq(150, "150")),
+        Row.fromSeq(Seq(175, "175")),
+        Row.fromSeq(Seq(199, "199"))
+      ))
+    val dataRDD = ss.sparkContext.parallelize(data, numSlices = 2)
+    val schema = SparkUtil.sparkSchema(table.getSchema)
+    val dataDF = ss.sqlContext.createDataFrame(dataRDD, schema)
+
+    // Capture the rows so we can validate the insert order.
+    kuduContext.captureRows = true
+
+    kuduContext.insertRows(
+      dataDF,
+      tableName,
+      new KuduWriteOptions(repartition = true, repartitionSort = 
repartitionSort))
+    // 2 tasks from the parallelize call, and 2 from the repartitioning.
+    assertEquals(4, actualNumTasks)
+    val rows = kuduContext.rowsAccumulator.value.asScala
+    assertEquals(data.size, rows.size)
+    assertEquals(data.map(_.getInt(0)).sorted, rows.map(_.getInt(0)).sorted)
+
+    // If repartitionSort is true, verify the rows were sorted while 
repartitioning.
+    if (repartitionSort) {
+      def isSorted(rows: Seq[Int]): Boolean = {
+        rows.sliding(2).forall(p => (p.size == 1) || p.head < p.tail.head)
+      }
+      val (bottomRows, topRows) = rows.map(_.getInt(0)).partition(_ < 
splitValue)
+      assertTrue(isSorted(bottomRows))
+      assertTrue(isSorted(topRows))
+    }
+  }
+
+  @Test
   def testDeleteRows() {
     val df = sqlContext.read.options(kuduOptions).format("kudu").load
     val deleteDF = df.filter("key = 0").select("key")

Reply via email to