KUDU-1493: Implement SchemaRelationProvider

Implement SchemaRelationProvider org.apache.kudu.spark.kudu.DefaultSource to 
allow
for specifying a user schema on read and thus allow for DataFrames of different
column orderings

Change-Id: I8b6073256b61a174f898be222058277be976273c
Reviewed-on: http://gerrit.cloudera.org:8080/5167
Tested-by: Kudu Jenkins
Reviewed-by: Dan Burkert <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/kudu/repo
Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/69f2619d
Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/69f2619d
Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/69f2619d

Branch: refs/heads/master
Commit: 69f2619d059e6aeb45f947b8367069b12f3b1bb9
Parents: 0bd10cb
Author: Drew Manlove <[email protected]>
Authored: Mon Nov 21 14:14:06 2016 -0700
Committer: Dan Burkert <[email protected]>
Committed: Thu Feb 23 01:36:37 2017 +0000

----------------------------------------------------------------------
 .../apache/kudu/spark/kudu/DefaultSource.scala  | 71 +++++++++++++-------
 .../kudu/spark/kudu/DefaultSourceTest.scala     | 42 ++++++++++--
 2 files changed, 85 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kudu/blob/69f2619d/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
----------------------------------------------------------------------
diff --git 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
index af1d573..5a6c178 100644
--- 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
+++ 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
@@ -24,7 +24,8 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
-import org.apache.kudu.Type
+import org.apache.kudu.{Type, ColumnSchema}
+
 import org.apache.kudu.annotations.InterfaceStability
 import org.apache.kudu.client.KuduPredicate.ComparisonOp
 import org.apache.kudu.client._
@@ -38,7 +39,8 @@ import org.apache.kudu.client._
   * operations through [[org.apache.spark.sql.DataFrameReader.format]].
   */
 @InterfaceStability.Unstable
-class DefaultSource extends RelationProvider with CreatableRelationProvider {
+class DefaultSource extends RelationProvider with CreatableRelationProvider
+  with SchemaRelationProvider {
 
   val TABLE_KEY = "kudu.table"
   val KUDU_MASTER = "kudu.master"
@@ -58,18 +60,9 @@ class DefaultSource extends RelationProvider with 
CreatableRelationProvider {
       throw new IllegalArgumentException(
         s"Kudu table name must be specified in create options using key 
'$TABLE_KEY'"))
     val kuduMaster = parameters.getOrElse(KUDU_MASTER, "localhost")
+    val operationType = getOperationType(parameters.getOrElse(OPERATION, 
"upsert"))
 
-    val opParam = parameters.getOrElse(OPERATION, "upsert")
-    val operationType = opParam.toLowerCase match {
-      case "insert" => Insert
-      case "insert-ignore" => InsertIgnore
-      case "upsert" => Upsert
-      case "update" => Update
-      case "delete" => Delete
-      case _ => throw new IllegalArgumentException(s"Unsupported operation 
type '$opParam'")
-    }
-
-    new KuduRelation(tableName, kuduMaster, operationType)(sqlContext)
+    new KuduRelation(tableName, kuduMaster, operationType, None)(sqlContext)
   }
 
   /**
@@ -93,6 +86,28 @@ class DefaultSource extends RelationProvider with 
CreatableRelationProvider {
 
     kuduRelation
   }
+
+  override def createRelation(sqlContext: SQLContext, parameters: Map[String, 
String],
+                              schema: StructType): BaseRelation = {
+    val tableName = parameters.getOrElse(TABLE_KEY,
+      throw new IllegalArgumentException(s"Kudu table name must be specified 
in create options " +
+        s"using key '$TABLE_KEY'"))
+    val kuduMaster = parameters.getOrElse(KUDU_MASTER, "localhost")
+    val operationType = getOperationType(parameters.getOrElse(OPERATION, 
"upsert"))
+
+    new KuduRelation(tableName, kuduMaster, operationType, 
Some(schema))(sqlContext)
+  }
+
+  private def getOperationType(opParam: String): OperationType = {
+    opParam.toLowerCase match {
+      case "insert" => Insert
+      case "insert-ignore" => InsertIgnore
+      case "upsert" => Upsert
+      case "update" => Update
+      case "delete" => Delete
+      case _ => throw new IllegalArgumentException(s"Unsupported operation 
type '$opParam'")
+    }
+  }
 }
 
 /**
@@ -101,16 +116,18 @@ class DefaultSource extends RelationProvider with 
CreatableRelationProvider {
   * @param tableName Kudu table that we plan to read from
   * @param masterAddrs Kudu master addresses
   * @param operationType The default operation type to perform when writing to 
the relation
+  * @param userSchema A schema used to select columns for the relation
   * @param sqlContext SparkSQL context
   */
 @InterfaceStability.Unstable
 class KuduRelation(private val tableName: String,
                    private val masterAddrs: String,
-                   private val operationType: OperationType)(
-                   val sqlContext: SQLContext)
-extends BaseRelation
-with PrunedFilteredScan
-with InsertableRelation {
+                   private val operationType: OperationType,
+                   private val userSchema: Option[StructType])(
+                    val sqlContext: SQLContext)
+  extends BaseRelation
+    with PrunedFilteredScan
+    with InsertableRelation {
 
   import KuduRelation._
 
@@ -127,13 +144,19 @@ with InsertableRelation {
     * @return schema generated from the Kudu table's schema
     */
   override def schema: StructType = {
-    val fields: Array[StructField] =
-      table.getSchema.getColumns.asScala.map { columnSchema =>
-        val sparkType = kuduTypeToSparkType(columnSchema.getType)
-        StructField(columnSchema.getName, sparkType, columnSchema.isNullable)
-      }.toArray
+    userSchema match {
+      case Some(x) =>
+        StructType(x.fields.map(uf => table.getSchema.getColumn(uf.name))
+          .map(kuduColumnToSparkField))
+      case None =>
+        
StructType(table.getSchema.getColumns.asScala.map(kuduColumnToSparkField).toArray)
+    }
+  }
 
-    new StructType(fields)
+  def kuduColumnToSparkField: (ColumnSchema) => StructField = {
+    columnSchema =>
+      val sparkType = kuduTypeToSparkType(columnSchema.getType)
+      new StructField(columnSchema.getName, sparkType, columnSchema.isNullable)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/kudu/blob/69f2619d/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
----------------------------------------------------------------------
diff --git 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
index 3dbd088..9790e81 100644
--- 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
+++ 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
@@ -25,19 +25,21 @@ import scala.collection.immutable.IndexedSeq
 import scala.util.control.NonFatal
 
 import com.google.common.collect.ImmutableList
-import org.apache.spark.sql.SQLContext
+
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType};
 import org.junit.Assert._
 import org.junit.runner.RunWith
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
 import org.scalatest.junit.JUnitRunner
-import org.scalatest.{BeforeAndAfter, FunSuite}
 
 import org.apache.kudu.ColumnSchema.ColumnSchemaBuilder
 import org.apache.kudu.{Schema, Type}
-import org.apache.kudu.client.CreateTableOptions;
+import org.apache.kudu.client.CreateTableOptions
 
 @RunWith(classOf[JUnitRunner])
-class DefaultSourceTest extends FunSuite with TestContext with BeforeAndAfter {
+class DefaultSourceTest extends FunSuite with TestContext with BeforeAndAfter 
with Matchers {
 
   test("timestamp conversion") {
     val epoch = new Timestamp(0)
@@ -450,4 +452,36 @@ class DefaultSourceTest extends FunSuite with TestContext 
with BeforeAndAfter {
     assertTrue(kuduContext.tableExists(newTable))
     assert(checkDf.count == 10)
   }
+
+  test("create relation with schema") {
+
+    // user-supplied schema that is compatible with actual schema, but with 
the key at the end
+    val userSchema: StructType = StructType(List(
+      StructField("c4_long", DataTypes.LongType),
+      StructField("key", DataTypes.IntegerType)
+    ))
+
+    val dfDefaultSchema = sqlContext.read.options(kuduOptions).kudu
+    assertEquals(11, dfDefaultSchema.schema.fields.length)
+
+    val dfWithUserSchema = 
sqlContext.read.options(kuduOptions).schema(userSchema).kudu
+    assertEquals(2, dfWithUserSchema.schema.fields.length)
+
+    dfWithUserSchema.limit(10).collect()
+    assertTrue(dfWithUserSchema.columns.deep == Array("c4_long", "key").deep)
+  }
+
+  test("create relation with invalid schema") {
+
+    // user-supplied schema that is NOT compatible with actual schema
+    val userSchema: StructType = StructType(List(
+      StructField("foo", DataTypes.LongType),
+      StructField("bar", DataTypes.IntegerType)
+    ))
+
+    intercept[IllegalArgumentException] {
+      sqlContext.read.options(kuduOptions).schema(userSchema).kudu
+    }.getMessage should include ("Unknown column: foo")
+
+  }
 }

Reply via email to