Hi,
I've recently found some behaviour that I found buggy when working with
phoenix-spark and arrays.
Take a look at those unit tests:
test("Can save arrays from custom dataframes back to phoenix") {
val dataSet = List(Row(2L, Array("String1", "String2", "String3")))
val sqlContext = new SQLContext(sc)
val schema = StructType(
Seq(StructField("ID", LongType, nullable = false),
StructField("VCARRAY", ArrayType(StringType))))
val rowRDD = sc.parallelize(dataSet)
// Apply the schema to the RDD.
val df = sqlContext.createDataFrame(rowRDD, schema)
df.write
.format("org.apache.phoenix.spark")
.options(Map("table" -> "ARRAY_TEST_TABLE", "zkUrl" -> quorumAddress))
.mode(SaveMode.Overwrite)
.save()
}
test("Can save arrays of AnyVal type back to phoenix") {
val dataSet = List((2L, Array(1, 2, 3), Array(1L, 2L, 3L)))
sc
.parallelize(dataSet)
.saveToPhoenix(
"ARRAY_ANYVAL_TEST_TABLE",
Seq("ID", "INTARRAY", "BIGINTARRAY"),
zkUrl = Some(quorumAddress)
)
// Load the results back
val stmt = conn.createStatement()
val rs = stmt.executeQuery("SELECT INTARRAY, BIGINTARRAY FROM
ARRAY_ANYVAL_TEST_TABLE WHERE ID = 2")
rs.next()
val intArray = rs.getArray(1).getArray().asInstanceOf[Array[Int]]
val longArray = rs.getArray(2).getArray().asInstanceOf[Array[Long]]
// Verify the arrays are equal
intArray shouldEqual dataSet(0)._2
longArray shouldEqual dataSet(0)._3
}
Both fail with some ClassCastExceptions.
In attached patch I've proposed a solution. The tricky part is with
Array[Byte] as this would be same for both VARBINARY and TINYINT[].
Let me know If I should create an issue for this, and if my solution
satisfies you.
Regards
Dawid Wysakowicz
From 5d24874cd0b2d15618843ada221634fa2a371d35 Mon Sep 17 00:00:00 2001
From: dawid <[email protected]>
Date: Mon, 30 Nov 2015 18:54:40 +0100
Subject: [PATCH] Phoenix-spark arrays
---
phoenix-spark/src/it/resources/setup.sql | 2 +
.../org/apache/phoenix/spark/PhoenixSparkIT.scala | 48 +++++++++++++++++++++-
.../phoenix/spark/PhoenixRecordWritable.scala | 27 ++++++++----
3 files changed, 67 insertions(+), 10 deletions(-)
diff --git a/phoenix-spark/src/it/resources/setup.sql b/phoenix-spark/src/it/resources/setup.sql
index 154a996..e97148c 100644
--- a/phoenix-spark/src/it/resources/setup.sql
+++ b/phoenix-spark/src/it/resources/setup.sql
@@ -30,6 +30,8 @@ UPSERT INTO "table3" ("id", "col1") VALUES (1, 'foo')
UPSERT INTO "table3" ("id", "col1") VALUES (2, 'bar')
CREATE TABLE ARRAY_TEST_TABLE (ID BIGINT NOT NULL PRIMARY KEY, VCARRAY VARCHAR[])
UPSERT INTO ARRAY_TEST_TABLE (ID, VCARRAY) VALUES (1, ARRAY['String1', 'String2', 'String3'])
+CREATE TABLE ARRAY_ANYVAL_TEST_TABLE (ID BIGINT NOT NULL PRIMARY KEY, INTARRAY INTEGER[], BIGINTARRAY BIGINT[])
+UPSERT INTO ARRAY_ANYVAL_TEST_TABLE (ID, INTARRAY, BIGINTARRAY) VALUES (1, ARRAY[1, 2, 3], ARRAY[1, 2, 3])
CREATE TABLE DATE_PREDICATE_TEST_TABLE (ID BIGINT NOT NULL, TIMESERIES_KEY TIMESTAMP NOT NULL CONSTRAINT pk PRIMARY KEY (ID, TIMESERIES_KEY))
UPSERT INTO DATE_PREDICATE_TEST_TABLE (ID, TIMESERIES_KEY) VALUES (1, CAST(CURRENT_TIME() AS TIMESTAMP))
CREATE TABLE OUTPUT_TEST_TABLE (id BIGINT NOT NULL PRIMARY KEY, col1 VARCHAR, col2 INTEGER, col3 DATE)
diff --git a/phoenix-spark/src/it/scala/org/apache/phoenix/spark/PhoenixSparkIT.scala b/phoenix-spark/src/it/scala/org/apache/phoenix/spark/PhoenixSparkIT.scala
index e1c9df4..86769f6 100644
--- a/phoenix-spark/src/it/scala/org/apache/phoenix/spark/PhoenixSparkIT.scala
+++ b/phoenix-spark/src/it/scala/org/apache/phoenix/spark/PhoenixSparkIT.scala
@@ -23,8 +23,8 @@ import org.apache.phoenix.query.BaseTest
import org.apache.phoenix.schema.{TableNotFoundException, ColumnNotFoundException}
import org.apache.phoenix.schema.types.PVarchar
import org.apache.phoenix.util.{SchemaUtil, ColumnInfo}
-import org.apache.spark.sql.{SaveMode, execution, SQLContext}
-import org.apache.spark.sql.types.{LongType, DataType, StringType, StructField}
+import org.apache.spark.sql.{Row, SaveMode, execution, SQLContext}
+import org.apache.spark.sql.types._
import org.apache.spark.{SparkConf, SparkContext}
import org.joda.time.DateTime
import org.scalatest._
@@ -448,4 +448,48 @@ class PhoenixSparkIT extends FunSuite with Matchers with BeforeAndAfterAll {
count shouldEqual 1L
}
+
+ test("Can save arrays from custom dataframes back to phoenix") {
+ val dataSet = List(Row(2L, Array("String1", "String2", "String3")))
+
+ val sqlContext = new SQLContext(sc)
+
+ val schema = StructType(
+ Seq(StructField("ID", LongType, nullable = false),
+ StructField("VCARRAY", ArrayType(StringType))))
+
+ val rowRDD = sc.parallelize(dataSet)
+
+ // Apply the schema to the RDD.
+ val df = sqlContext.createDataFrame(rowRDD, schema)
+
+ df.write
+ .format("org.apache.phoenix.spark")
+ .options(Map("table" -> "ARRAY_TEST_TABLE", "zkUrl" -> quorumAddress))
+ .mode(SaveMode.Overwrite)
+ .save()
+ }
+
+ test("Can save arrays of AnyVal type back to phoenix") {
+ val dataSet = List((2L, Array(1, 2, 3), Array(1L, 2L, 3L)))
+
+ sc
+ .parallelize(dataSet)
+ .saveToPhoenix(
+ "ARRAY_ANYVAL_TEST_TABLE",
+ Seq("ID", "INTARRAY", "BIGINTARRAY"),
+ zkUrl = Some(quorumAddress)
+ )
+
+ // Load the results back
+ val stmt = conn.createStatement()
+ val rs = stmt.executeQuery("SELECT INTARRAY, BIGINTARRAY FROM ARRAY_ANYVAL_TEST_TABLE WHERE ID = 2")
+ rs.next()
+ val intArray = rs.getArray(1).getArray().asInstanceOf[Array[Int]]
+ val longArray = rs.getArray(2).getArray().asInstanceOf[Array[Long]]
+
+ // Verify the arrays are equal
+ intArray shouldEqual dataSet(0)._2
+ longArray shouldEqual dataSet(0)._3
+ }
}
\ No newline at end of file
diff --git a/phoenix-spark/src/main/scala/org/apache/phoenix/spark/PhoenixRecordWritable.scala b/phoenix-spark/src/main/scala/org/apache/phoenix/spark/PhoenixRecordWritable.scala
index f11f9cc..c91d105 100644
--- a/phoenix-spark/src/main/scala/org/apache/phoenix/spark/PhoenixRecordWritable.scala
+++ b/phoenix-spark/src/main/scala/org/apache/phoenix/spark/PhoenixRecordWritable.scala
@@ -19,6 +19,7 @@ import org.apache.hadoop.mapreduce.lib.db.DBWritable
import org.apache.phoenix.schema.types.{PDataType, PDate, PhoenixArray}
import org.apache.phoenix.util.ColumnInfo
import org.joda.time.DateTime
+import scala.collection.mutable.ArrayBuffer
import scala.collection.{immutable, mutable}
import scala.collection.JavaConversions._
@@ -53,15 +54,25 @@ class PhoenixRecordWritable(columnMetaDataList: List[ColumnInfo]) extends DBWrit
}
// Save as array or object
+ def setArrayInStatement(obj: Array[AnyRef]): Unit = {
+ // Create a java.sql.Array, need to lookup the base sql type name
+ val sqlArray = statement.getConnection.createArrayOf(
+ PDataType.arrayBaseType(finalType).getSqlTypeName,
+ obj
+ )
+ statement.setArray(i + 1, sqlArray)
+ }
+
finalObj match {
- case obj: Array[AnyRef] => {
- // Create a java.sql.Array, need to lookup the base sql type name
- val sqlArray = statement.getConnection.createArrayOf(
- PDataType.arrayBaseType(finalType).getSqlTypeName,
- obj
- )
- statement.setArray(i + 1, sqlArray)
- }
+ case obj: Array[AnyRef] => setArrayInStatement(obj)
+ case obj: ArrayBuffer[AnyRef] => setArrayInStatement(obj.toArray)
+ case obj: ArrayBuffer[AnyVal] => setArrayInStatement(obj.map(_.asInstanceOf[AnyRef]).toArray)
+ case obj: Array[Int] => setArrayInStatement(obj.map(_.asInstanceOf[AnyRef]))
+ case obj: Array[Long] => setArrayInStatement(obj.map(_.asInstanceOf[AnyRef]))
+ case obj: Array[Char] => setArrayInStatement(obj.map(_.asInstanceOf[AnyRef]))
+ case obj: Array[Short] => setArrayInStatement(obj.map(_.asInstanceOf[AnyRef]))
+ case obj: Array[Float] => setArrayInStatement(obj.map(_.asInstanceOf[AnyRef]))
+ case obj: Array[Double] => setArrayInStatement(obj.map(_.asInstanceOf[AnyRef]))
case _ => statement.setObject(i + 1, finalObj)
}
} else {
--
1.9.1