PHOENIX-2288: PDecimal precision and scale aren't carried through to Spark 
DataFrame (Navis Ryu)


Project: http://git-wip-us.apache.org/repos/asf/phoenix/repo
Commit: http://git-wip-us.apache.org/repos/asf/phoenix/commit/93437431
Tree: http://git-wip-us.apache.org/repos/asf/phoenix/tree/93437431
Diff: http://git-wip-us.apache.org/repos/asf/phoenix/diff/93437431

Branch: refs/heads/txn
Commit: 9343743157f1cc03b1b8b815289b6127a30d740f
Parents: 9bf9535
Author: Josh Mahonin <[email protected]>
Authored: Fri Nov 13 12:44:08 2015 -0500
Committer: Josh Mahonin <[email protected]>
Committed: Fri Nov 13 12:44:08 2015 -0500

----------------------------------------------------------------------
 .../org/apache/phoenix/util/ColumnInfo.java     | 89 ++++++++++++++++++--
 .../org/apache/phoenix/util/PhoenixRuntime.java | 16 ++--
 .../org/apache/phoenix/util/ColumnInfoTest.java | 27 ++++++
 phoenix-spark/src/it/resources/setup.sql        |  2 +-
 .../apache/phoenix/spark/PhoenixSparkIT.scala   |  8 +-
 .../org/apache/phoenix/spark/PhoenixRDD.scala   | 16 ++--
 6 files changed, 127 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/phoenix/blob/93437431/phoenix-core/src/main/java/org/apache/phoenix/util/ColumnInfo.java
----------------------------------------------------------------------
diff --git a/phoenix-core/src/main/java/org/apache/phoenix/util/ColumnInfo.java 
b/phoenix-core/src/main/java/org/apache/phoenix/util/ColumnInfo.java
index 3f94b92..0755ef7 100644
--- a/phoenix-core/src/main/java/org/apache/phoenix/util/ColumnInfo.java
+++ b/phoenix-core/src/main/java/org/apache/phoenix/util/ColumnInfo.java
@@ -10,10 +10,13 @@
 
 package org.apache.phoenix.util;
 
+import java.sql.Types;
 import java.util.List;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 
 import org.apache.phoenix.query.QueryConstants;
-import org.apache.phoenix.schema.types.PDataType;
+import org.apache.phoenix.schema.types.*;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
@@ -28,8 +31,31 @@ public class ColumnInfo {
 
     private final String columnName;
     private final int sqlType;
-
+  
+    private final Integer precision;
+    private final Integer scale;
+    
+    public static ColumnInfo create(String columnName, int sqlType, Integer 
maxLength, Integer scale) {
+        if(scale != null) {
+            assert(maxLength != null); // If we have a scale, we should always 
have a maxLength
+            scale = Math.min(maxLength, scale);
+            return new ColumnInfo(columnName, sqlType, maxLength, scale);
+        }
+        if (maxLength != null) {
+            return new ColumnInfo(columnName, sqlType, maxLength);
+        }
+        return new ColumnInfo(columnName, sqlType);
+    }
+    
     public ColumnInfo(String columnName, int sqlType) {
+        this(columnName, sqlType, null);
+    }
+    
+    public ColumnInfo(String columnName, int sqlType, Integer maxLength) {
+        this(columnName, sqlType, maxLength, null);
+    }
+
+    public ColumnInfo(String columnName, int sqlType, Integer precision, 
Integer scale) {
         Preconditions.checkNotNull(columnName, "columnName cannot be null");
         Preconditions.checkArgument(!columnName.isEmpty(), "columnName cannot 
be empty");
         if(!columnName.startsWith(SchemaUtil.ESCAPE_CHARACTER)) {
@@ -37,6 +63,8 @@ public class ColumnInfo {
         }
         this.columnName = columnName;
         this.sqlType = sqlType;
+        this.precision = precision;
+        this.scale = scale;
     }
 
     public String getColumnName() {
@@ -64,9 +92,14 @@ public class ColumnInfo {
         return unescapedColumnName.substring(index+1).trim();
     }
 
+    // Return the proper SQL type string, taking into account possible array, 
length and scale parameters
+    public String toTypeString() {
+        return PhoenixRuntime.getSqlTypeName(getPDataType(), getMaxLength(), 
getScale());
+    }
+
     @Override
     public String toString() {
-        return getPDataType().getSqlTypeName() + STR_SEPARATOR + columnName ;
+        return toTypeString() + STR_SEPARATOR + columnName ;
     }
 
     @Override
@@ -77,6 +110,8 @@ public class ColumnInfo {
         ColumnInfo that = (ColumnInfo) o;
 
         if (sqlType != that.sqlType) return false;
+        if (precision != that.precision) return false;
+        if (scale != that.scale) return false;
         if (!columnName.equals(that.columnName)) return false;
 
         return true;
@@ -85,7 +120,7 @@ public class ColumnInfo {
     @Override
     public int hashCode() {
         int result = columnName.hashCode();
-        result = 31 * result + sqlType;
+        result = 31 * result + (precision << 2) + (scale << 1) + sqlType;
         return result;
     }
 
@@ -100,15 +135,51 @@ public class ColumnInfo {
      */
     public static ColumnInfo fromString(String stringRepresentation) {
         List<String> components =
-                Lists.newArrayList(stringRepresentation.split(":",2));
-        
+                Lists.newArrayList(stringRepresentation.split(":", 2));
+
         if (components.size() != 2) {
             throw new IllegalArgumentException("Unparseable string: " + 
stringRepresentation);
         }
 
-        return new ColumnInfo(
-                components.get(1),
-                PDataType.fromSqlTypeName(components.get(0)).getSqlType());
+        String[] typeParts = components.get(0).split(" ");
+        String columnName = components.get(1);
+
+        Integer maxLength = null;
+        Integer scale = null;
+        if (typeParts[0].contains("(")) {
+            Matcher matcher = 
Pattern.compile("([^\\(]+)\\((\\d+)(?:,(\\d+))?\\)").matcher(typeParts[0]);
+            if (!matcher.matches() || matcher.groupCount() > 3) {
+                throw new IllegalArgumentException("Unparseable type string: " 
+ typeParts[0]);
+            }
+            maxLength = Integer.valueOf(matcher.group(2));
+            if (matcher.group(3) != null) {
+                scale = Integer.valueOf(matcher.group(3));
+            }
+            // Drop the (N) or (N,N) from the original type
+            typeParts[0] = matcher.group(1);
+        }
+
+        // Create the PDataType from the sql type name, including the second 
'ARRAY' part if present
+        PDataType dataType;
+        if(typeParts.length < 2) {
+            dataType = PDataType.fromSqlTypeName(typeParts[0]);
+        }
+        else {
+            dataType = PDataType.fromSqlTypeName(typeParts[0] + " " + 
typeParts[1]);
+        }
+                
+        return ColumnInfo.create(columnName, dataType.getSqlType(), maxLength, 
scale);
+    }
+    
+    public Integer getMaxLength() {
+        return precision;
     }
 
+    public Integer getPrecision() {
+        return precision;
+    }
+    
+    public Integer getScale() {
+        return scale;
+    }
 }

http://git-wip-us.apache.org/repos/asf/phoenix/blob/93437431/phoenix-core/src/main/java/org/apache/phoenix/util/PhoenixRuntime.java
----------------------------------------------------------------------
diff --git 
a/phoenix-core/src/main/java/org/apache/phoenix/util/PhoenixRuntime.java 
b/phoenix-core/src/main/java/org/apache/phoenix/util/PhoenixRuntime.java
index 2e6142a..d4a45f6 100644
--- a/phoenix-core/src/main/java/org/apache/phoenix/util/PhoenixRuntime.java
+++ b/phoenix-core/src/main/java/org/apache/phoenix/util/PhoenixRuntime.java
@@ -389,8 +389,7 @@ public class PhoenixRuntime {
                int offset = (table.getBucketNum() == null ? 0 : 1);
                for (int i = offset; i < table.getColumns().size(); i++) {
                   PColumn pColumn = table.getColumns().get(i);
-               int sqlType = pColumn.getDataType().getSqlType();
-               columnInfoList.add(new ColumnInfo(pColumn.toString(), 
sqlType)); 
+               columnInfoList.add(PhoenixRuntime.getColumnInfo(pColumn)); 
             }
         } else {
             // Leave "null" as indication to skip b/c it doesn't exist
@@ -459,19 +458,18 @@ public class PhoenixRuntime {
         return getColumnInfo(pColumn);
     }
 
-   /**
+    /**
      * Constructs a column info for the supplied pColumn
      * @param pColumn
      * @return columnInfo
      * @throws SQLException if the parameter is null.
      */
     public static ColumnInfo getColumnInfo(PColumn pColumn) throws 
SQLException {
-        if (pColumn==null) {
+        if (pColumn == null) {
             throw new SQLException("pColumn must not be null.");
         }
-        int sqlType = pColumn.getDataType().getSqlType();
-        ColumnInfo columnInfo = new ColumnInfo(pColumn.toString(),sqlType);
-        return columnInfo;
+        return ColumnInfo.create(pColumn.toString(), 
pColumn.getDataType().getSqlType(),
+                pColumn.getMaxLength(), pColumn.getScale());
     }
 
    /**
@@ -784,6 +782,10 @@ public class PhoenixRuntime {
         PDataType dataType = pCol.getDataType();
         Integer maxLength = pCol.getMaxLength();
         Integer scale = pCol.getScale();
+        return getSqlTypeName(dataType, maxLength, scale);
+    }
+
+    public static String getSqlTypeName(PDataType dataType, Integer maxLength, 
Integer scale) {
         return dataType.isArrayType() ? getArraySqlTypeName(maxLength, scale, 
dataType) : appendMaxLengthAndScale(maxLength, scale, 
dataType.getSqlTypeName());
     }
     

http://git-wip-us.apache.org/repos/asf/phoenix/blob/93437431/phoenix-core/src/test/java/org/apache/phoenix/util/ColumnInfoTest.java
----------------------------------------------------------------------
diff --git 
a/phoenix-core/src/test/java/org/apache/phoenix/util/ColumnInfoTest.java 
b/phoenix-core/src/test/java/org/apache/phoenix/util/ColumnInfoTest.java
index 7f460cd..3bc26f2 100644
--- a/phoenix-core/src/test/java/org/apache/phoenix/util/ColumnInfoTest.java
+++ b/phoenix-core/src/test/java/org/apache/phoenix/util/ColumnInfoTest.java
@@ -24,6 +24,7 @@ import java.sql.SQLException;
 import java.sql.Types;
 
 import org.apache.phoenix.exception.SQLExceptionCode;
+import org.apache.phoenix.schema.types.*;
 import org.junit.Test;
 
 public class ColumnInfoTest {
@@ -55,4 +56,30 @@ public class ColumnInfoTest {
         ColumnInfo columnInfo = new ColumnInfo(":myColumn", Types.INTEGER);
         assertEquals(columnInfo, ColumnInfo.fromString(columnInfo.toString()));
     }
+    
+    @Test
+    public void testOptionalDescriptionType() {
+        testType(new ColumnInfo("a.myColumn", Types.CHAR), 
"CHAR:\"a\".\"myColumn\"");
+        testType(new ColumnInfo("a.myColumn", Types.CHAR, 100), 
"CHAR(100):\"a\".\"myColumn\"");
+        testType(new ColumnInfo("a.myColumn", Types.VARCHAR), 
"VARCHAR:\"a\".\"myColumn\"");
+        testType(new ColumnInfo("a.myColumn", Types.VARCHAR, 100), 
"VARCHAR(100):\"a\".\"myColumn\"");
+        testType(new ColumnInfo("a.myColumn", Types.DECIMAL), 
"DECIMAL:\"a\".\"myColumn\"");
+        testType(new ColumnInfo("a.myColumn", Types.DECIMAL, 100, 10), 
"DECIMAL(100,10):\"a\".\"myColumn\"");
+        testType(new ColumnInfo("a.myColumn", Types.BINARY, 5), 
"BINARY(5):\"a\".\"myColumn\"");
+
+        // Array types
+        testType(new ColumnInfo("a.myColumn", 
PCharArray.INSTANCE.getSqlType(), 3), "CHAR(3) ARRAY:\"a\".\"myColumn\"");
+        testType(new ColumnInfo("a.myColumn", 
PDecimalArray.INSTANCE.getSqlType(), 10, 2), "DECIMAL(10,2) 
ARRAY:\"a\".\"myColumn\"");
+        testType(new ColumnInfo("a.myColumn", 
PVarcharArray.INSTANCE.getSqlType(), 4), "VARCHAR(4) ARRAY:\"a\".\"myColumn\"");
+    }
+
+    private void testType(ColumnInfo columnInfo, String expected) {
+        assertEquals(expected, columnInfo.toString());
+        ColumnInfo reverted = ColumnInfo.fromString(columnInfo.toString());
+        assertEquals(reverted.getColumnName(), columnInfo.getColumnName());
+        assertEquals(reverted.getDisplayName(), columnInfo.getDisplayName());
+        assertEquals(reverted.getSqlType(), columnInfo.getSqlType());
+        assertEquals(reverted.getMaxLength(), columnInfo.getMaxLength());
+        assertEquals(reverted.getScale(), columnInfo.getScale());
+    }
 }

http://git-wip-us.apache.org/repos/asf/phoenix/blob/93437431/phoenix-spark/src/it/resources/setup.sql
----------------------------------------------------------------------
diff --git a/phoenix-spark/src/it/resources/setup.sql 
b/phoenix-spark/src/it/resources/setup.sql
index db46a92..d6dbe20 100644
--- a/phoenix-spark/src/it/resources/setup.sql
+++ b/phoenix-spark/src/it/resources/setup.sql
@@ -35,5 +35,5 @@ UPSERT INTO DATE_PREDICATE_TEST_TABLE (ID, TIMESERIES_KEY) 
VALUES (1, CAST(CURRE
 CREATE TABLE OUTPUT_TEST_TABLE (id BIGINT NOT NULL PRIMARY KEY, col1 VARCHAR, 
col2 INTEGER, col3 DATE)
 CREATE TABLE CUSTOM_ENTITY."z02"(id BIGINT NOT NULL PRIMARY KEY)
 UPSERT INTO CUSTOM_ENTITY."z02" (id) VALUES(1)
-CREATE TABLE TEST_DECIMAL (ID BIGINT NOT NULL PRIMARY KEY, COL1 DECIMAL)
+CREATE TABLE TEST_DECIMAL (ID BIGINT NOT NULL PRIMARY KEY, COL1 DECIMAL(9, 6))
 UPSERT INTO TEST_DECIMAL VALUES (1, 123.456789)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/phoenix/blob/93437431/phoenix-spark/src/it/scala/org/apache/phoenix/spark/PhoenixSparkIT.scala
----------------------------------------------------------------------
diff --git 
a/phoenix-spark/src/it/scala/org/apache/phoenix/spark/PhoenixSparkIT.scala 
b/phoenix-spark/src/it/scala/org/apache/phoenix/spark/PhoenixSparkIT.scala
index 7f97cc7..31104ba 100644
--- a/phoenix-spark/src/it/scala/org/apache/phoenix/spark/PhoenixSparkIT.scala
+++ b/phoenix-spark/src/it/scala/org/apache/phoenix/spark/PhoenixSparkIT.scala
@@ -511,11 +511,9 @@ class PhoenixSparkIT extends FunSuite with Matchers with 
BeforeAndAfterAll {
     res9.count() shouldEqual 2
   }
 
-
-  // We can load the type, but it defaults to Spark's default (precision 38, 
scale 10)
-  ignore("Can load decimal types with accurate precision and scale 
(PHOENIX-2288)") {
+  test("Can load decimal types with accurate precision and scale 
(PHOENIX-2288)") {
     val sqlContext = new SQLContext(sc)
     val df = sqlContext.load("org.apache.phoenix.spark", Map("table" -> 
"TEST_DECIMAL", "zkUrl" -> quorumAddress))
-    assert(df.select("COL1").first().getDecimal(0) == BigDecimal("123.456789"))
+    assert(df.select("COL1").first().getDecimal(0) == 
BigDecimal("123.456789").bigDecimal)
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/phoenix/blob/93437431/phoenix-spark/src/main/scala/org/apache/phoenix/spark/PhoenixRDD.scala
----------------------------------------------------------------------
diff --git 
a/phoenix-spark/src/main/scala/org/apache/phoenix/spark/PhoenixRDD.scala 
b/phoenix-spark/src/main/scala/org/apache/phoenix/spark/PhoenixRDD.scala
index e2d96cb..ac60ceb 100644
--- a/phoenix-spark/src/main/scala/org/apache/phoenix/spark/PhoenixRDD.scala
+++ b/phoenix-spark/src/main/scala/org/apache/phoenix/spark/PhoenixRDD.scala
@@ -13,15 +13,13 @@
  */
 package org.apache.phoenix.spark
 
-import java.text.DecimalFormat
-
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.hbase.{HBaseConfiguration, HConstants}
 import org.apache.hadoop.io.NullWritable
 import org.apache.phoenix.mapreduce.PhoenixInputFormat
 import org.apache.phoenix.mapreduce.util.PhoenixConfigurationUtil
 import org.apache.phoenix.schema.types._
-import org.apache.phoenix.util.{PhoenixRuntime, ColumnInfo}
+import org.apache.phoenix.util.ColumnInfo
 import org.apache.spark._
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
@@ -124,22 +122,22 @@ class PhoenixRDD(sc: SparkContext, table: String, 
columns: Seq[String],
 
   def phoenixSchemaToCatalystSchema(columnList: Seq[ColumnInfo]) = {
     columnList.map(ci => {
-      val structType = phoenixTypeToCatalystType(ci.getPDataType)
+      val structType = phoenixTypeToCatalystType(ci)
       StructField(ci.getDisplayName, structType)
     })
   }
 
 
   // Lookup table for Phoenix types to Spark catalyst types
-  def phoenixTypeToCatalystType(phoenixType: PDataType[_]): DataType = 
phoenixType match {
+  def phoenixTypeToCatalystType(columnInfo: ColumnInfo): DataType = 
columnInfo.getPDataType match {
     case t if t.isInstanceOf[PVarchar] || t.isInstanceOf[PChar] => StringType
     case t if t.isInstanceOf[PLong] || t.isInstanceOf[PUnsignedLong] => 
LongType
     case t if t.isInstanceOf[PInteger] || t.isInstanceOf[PUnsignedInt] => 
IntegerType
     case t if t.isInstanceOf[PFloat] || t.isInstanceOf[PUnsignedFloat] => 
FloatType
     case t if t.isInstanceOf[PDouble] || t.isInstanceOf[PUnsignedDouble] => 
DoubleType
-    // TODO: support custom precision / scale.
     // Use Spark system default precision for now (explicit to work with < 1.5)
-    case t if t.isInstanceOf[PDecimal] => DecimalType(38, 18)
+    case t if t.isInstanceOf[PDecimal] => 
+      if (columnInfo.getPrecision < 0) DecimalType(38, 18) else 
DecimalType(columnInfo.getPrecision, columnInfo.getScale)
     case t if t.isInstanceOf[PTimestamp] || t.isInstanceOf[PUnsignedTimestamp] 
=> TimestampType
     case t if t.isInstanceOf[PTime] || t.isInstanceOf[PUnsignedTime] => 
TimestampType
     case t if t.isInstanceOf[PDate] || t.isInstanceOf[PUnsignedDate] => 
TimestampType
@@ -154,8 +152,8 @@ class PhoenixRDD(sc: SparkContext, table: String, columns: 
Seq[String],
     case t if t.isInstanceOf[PTinyintArray] || 
t.isInstanceOf[PUnsignedTinyintArray] => ArrayType(ByteType, containsNull = 
true)
     case t if t.isInstanceOf[PFloatArray] || 
t.isInstanceOf[PUnsignedFloatArray] => ArrayType(FloatType, containsNull = true)
     case t if t.isInstanceOf[PDoubleArray] || 
t.isInstanceOf[PUnsignedDoubleArray] => ArrayType(DoubleType, containsNull = 
true)
-    // TODO: support custom precision / scale
-    case t if t.isInstanceOf[PDecimalArray] => { ArrayType(DecimalType(38, 
18), containsNull = true) }
+    case t if t.isInstanceOf[PDecimalArray] => ArrayType(
+      if (columnInfo.getPrecision < 0) DecimalType(38, 18) else 
DecimalType(columnInfo.getPrecision, columnInfo.getScale), containsNull = true)
     case t if t.isInstanceOf[PTimestampArray] || 
t.isInstanceOf[PUnsignedTimestampArray] => ArrayType(TimestampType, 
containsNull = true)
     case t if t.isInstanceOf[PDateArray] || t.isInstanceOf[PUnsignedDateArray] 
=> ArrayType(TimestampType, containsNull = true)
     case t if t.isInstanceOf[PTimeArray] || t.isInstanceOf[PUnsignedTimeArray] 
=> ArrayType(TimestampType, containsNull = true)

Reply via email to