KUDU-1493: Implement SchemaRelationProvider Implement SchemaRelationProvider org.apache.kudu.spark.kudu.DefaultSource to allow for specifying a user schema on read and thus allow for DataFrames of different column orderings
Change-Id: I8b6073256b61a174f898be222058277be976273c Reviewed-on: http://gerrit.cloudera.org:8080/5167 Tested-by: Kudu Jenkins Reviewed-by: Dan Burkert <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/kudu/repo Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/69f2619d Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/69f2619d Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/69f2619d Branch: refs/heads/master Commit: 69f2619d059e6aeb45f947b8367069b12f3b1bb9 Parents: 0bd10cb Author: Drew Manlove <[email protected]> Authored: Mon Nov 21 14:14:06 2016 -0700 Committer: Dan Burkert <[email protected]> Committed: Thu Feb 23 01:36:37 2017 +0000 ---------------------------------------------------------------------- .../apache/kudu/spark/kudu/DefaultSource.scala | 71 +++++++++++++------- .../kudu/spark/kudu/DefaultSourceTest.scala | 42 ++++++++++-- 2 files changed, 85 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/kudu/blob/69f2619d/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala index af1d573..5a6c178 100644 --- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala +++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala @@ -24,7 +24,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} -import org.apache.kudu.Type +import org.apache.kudu.{Type, ColumnSchema} + import org.apache.kudu.annotations.InterfaceStability import org.apache.kudu.client.KuduPredicate.ComparisonOp import org.apache.kudu.client._ @@ -38,7 +39,8 @@ import org.apache.kudu.client._ * operations through [[org.apache.spark.sql.DataFrameReader.format]]. */ @InterfaceStability.Unstable -class DefaultSource extends RelationProvider with CreatableRelationProvider { +class DefaultSource extends RelationProvider with CreatableRelationProvider + with SchemaRelationProvider { val TABLE_KEY = "kudu.table" val KUDU_MASTER = "kudu.master" @@ -58,18 +60,9 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider { throw new IllegalArgumentException( s"Kudu table name must be specified in create options using key '$TABLE_KEY'")) val kuduMaster = parameters.getOrElse(KUDU_MASTER, "localhost") + val operationType = getOperationType(parameters.getOrElse(OPERATION, "upsert")) - val opParam = parameters.getOrElse(OPERATION, "upsert") - val operationType = opParam.toLowerCase match { - case "insert" => Insert - case "insert-ignore" => InsertIgnore - case "upsert" => Upsert - case "update" => Update - case "delete" => Delete - case _ => throw new IllegalArgumentException(s"Unsupported operation type '$opParam'") - } - - new KuduRelation(tableName, kuduMaster, operationType)(sqlContext) + new KuduRelation(tableName, kuduMaster, operationType, None)(sqlContext) } /** @@ -93,6 +86,28 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider { kuduRelation } + + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], + schema: StructType): BaseRelation = { + val tableName = parameters.getOrElse(TABLE_KEY, + throw new IllegalArgumentException(s"Kudu table name must be specified in create options " + + s"using key '$TABLE_KEY'")) + val kuduMaster = parameters.getOrElse(KUDU_MASTER, "localhost") + val operationType = getOperationType(parameters.getOrElse(OPERATION, "upsert")) + + new KuduRelation(tableName, kuduMaster, operationType, Some(schema))(sqlContext) + } + + private def getOperationType(opParam: String): OperationType = { + opParam.toLowerCase match { + case "insert" => Insert + case "insert-ignore" => InsertIgnore + case "upsert" => Upsert + case "update" => Update + case "delete" => Delete + case _ => throw new IllegalArgumentException(s"Unsupported operation type '$opParam'") + } + } } /** @@ -101,16 +116,18 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider { * @param tableName Kudu table that we plan to read from * @param masterAddrs Kudu master addresses * @param operationType The default operation type to perform when writing to the relation + * @param userSchema A schema used to select columns for the relation * @param sqlContext SparkSQL context */ @InterfaceStability.Unstable class KuduRelation(private val tableName: String, private val masterAddrs: String, - private val operationType: OperationType)( - val sqlContext: SQLContext) -extends BaseRelation -with PrunedFilteredScan -with InsertableRelation { + private val operationType: OperationType, + private val userSchema: Option[StructType])( + val sqlContext: SQLContext) + extends BaseRelation + with PrunedFilteredScan + with InsertableRelation { import KuduRelation._ @@ -127,13 +144,19 @@ with InsertableRelation { * @return schema generated from the Kudu table's schema */ override def schema: StructType = { - val fields: Array[StructField] = - table.getSchema.getColumns.asScala.map { columnSchema => - val sparkType = kuduTypeToSparkType(columnSchema.getType) - StructField(columnSchema.getName, sparkType, columnSchema.isNullable) - }.toArray + userSchema match { + case Some(x) => + StructType(x.fields.map(uf => table.getSchema.getColumn(uf.name)) + .map(kuduColumnToSparkField)) + case None => + StructType(table.getSchema.getColumns.asScala.map(kuduColumnToSparkField).toArray) + } + } - new StructType(fields) + def kuduColumnToSparkField: (ColumnSchema) => StructField = { + columnSchema => + val sparkType = kuduTypeToSparkType(columnSchema.getType) + new StructField(columnSchema.getName, sparkType, columnSchema.isNullable) } /** http://git-wip-us.apache.org/repos/asf/kudu/blob/69f2619d/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala index 3dbd088..9790e81 100644 --- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala +++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala @@ -25,19 +25,21 @@ import scala.collection.immutable.IndexedSeq import scala.util.control.NonFatal import com.google.common.collect.ImmutableList -import org.apache.spark.sql.SQLContext + import org.apache.spark.sql.functions._ +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.{DataTypes, StructField, StructType}; import org.junit.Assert._ import org.junit.runner.RunWith +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.junit.JUnitRunner -import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.kudu.ColumnSchema.ColumnSchemaBuilder import org.apache.kudu.{Schema, Type} -import org.apache.kudu.client.CreateTableOptions; +import org.apache.kudu.client.CreateTableOptions @RunWith(classOf[JUnitRunner]) -class DefaultSourceTest extends FunSuite with TestContext with BeforeAndAfter { +class DefaultSourceTest extends FunSuite with TestContext with BeforeAndAfter with Matchers { test("timestamp conversion") { val epoch = new Timestamp(0) @@ -450,4 +452,36 @@ class DefaultSourceTest extends FunSuite with TestContext with BeforeAndAfter { assertTrue(kuduContext.tableExists(newTable)) assert(checkDf.count == 10) } + + test("create relation with schema") { + + // user-supplied schema that is compatible with actual schema, but with the key at the end + val userSchema: StructType = StructType(List( + StructField("c4_long", DataTypes.LongType), + StructField("key", DataTypes.IntegerType) + )) + + val dfDefaultSchema = sqlContext.read.options(kuduOptions).kudu + assertEquals(11, dfDefaultSchema.schema.fields.length) + + val dfWithUserSchema = sqlContext.read.options(kuduOptions).schema(userSchema).kudu + assertEquals(2, dfWithUserSchema.schema.fields.length) + + dfWithUserSchema.limit(10).collect() + assertTrue(dfWithUserSchema.columns.deep == Array("c4_long", "key").deep) + } + + test("create relation with invalid schema") { + + // user-supplied schema that is NOT compatible with actual schema + val userSchema: StructType = StructType(List( + StructField("foo", DataTypes.LongType), + StructField("bar", DataTypes.IntegerType) + )) + + intercept[IllegalArgumentException] { + sqlContext.read.options(kuduOptions).schema(userSchema).kudu + }.getMessage should include ("Unknown column: foo") + + } }
