Repository: spark
Updated Branches:
  refs/heads/master dd9ae7945 -> 75438422c


[SPARK-9369][SQL] Support IntervalType in UnsafeRow

Author: Wenchen Fan <[email protected]>

Closes #7688 from cloud-fan/interval and squashes the following commits:

5b36b17 [Wenchen Fan] fix codegen
a99ed50 [Wenchen Fan] address comment
9e6d319 [Wenchen Fan] Support IntervalType in UnsafeRow


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/75438422
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/75438422
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/75438422

Branch: refs/heads/master
Commit: 75438422c2cd90dca53f84879cddecfc2ee0e957
Parents: dd9ae79
Author: Wenchen Fan <[email protected]>
Authored: Mon Jul 27 11:28:22 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Mon Jul 27 11:28:22 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/UnsafeRow.java     | 23 +++++++++++++++-----
 .../catalyst/expressions/UnsafeRowWriters.java  | 19 +++++++++++++++-
 .../apache/spark/sql/catalyst/InternalRow.scala |  4 +++-
 .../catalyst/expressions/BoundAttribute.scala   |  1 +
 .../spark/sql/catalyst/expressions/Cast.scala   |  2 +-
 .../expressions/codegen/CodeGenerator.scala     |  7 +++---
 .../codegen/GenerateUnsafeProjection.scala      |  6 +++++
 .../expressions/ExpressionEvalHelper.scala      |  2 --
 8 files changed, 50 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/75438422/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 0fb33dd..fb084dd 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -29,6 +29,7 @@ import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.bitset.BitSetMethods;
 import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.unsafe.types.Interval;
 import org.apache.spark.unsafe.types.UTF8String;
 
 import static org.apache.spark.sql.types.DataTypes.*;
@@ -90,7 +91,8 @@ public final class UnsafeRow extends MutableRow {
     final Set<DataType> _readableFieldTypes = new HashSet<>(
       Arrays.asList(new DataType[]{
         StringType,
-        BinaryType
+        BinaryType,
+        IntervalType
       }));
     _readableFieldTypes.addAll(settableFieldTypes);
     readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
@@ -333,11 +335,6 @@ public final class UnsafeRow extends MutableRow {
   }
 
   @Override
-  public String getString(int ordinal) {
-    return getUTF8String(ordinal).toString();
-  }
-
-  @Override
   public byte[] getBinary(int ordinal) {
     if (isNullAt(ordinal)) {
       return null;
@@ -359,6 +356,20 @@ public final class UnsafeRow extends MutableRow {
   }
 
   @Override
+  public Interval getInterval(int ordinal) {
+    if (isNullAt(ordinal)) {
+      return null;
+    } else {
+      final long offsetAndSize = getLong(ordinal);
+      final int offset = (int) (offsetAndSize >> 32);
+      final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, 
baseOffset + offset);
+      final long microseconds =
+        PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8);
+      return new Interval(months, microseconds);
+    }
+  }
+
+  @Override
   public UnsafeRow getStruct(int ordinal, int numFields) {
     if (isNullAt(ordinal)) {
       return null;

http://git-wip-us.apache.org/repos/asf/spark/blob/75438422/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index 87521d1..0ba31d3 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.types.ByteArray;
+import org.apache.spark.unsafe.types.Interval;
 import org.apache.spark.unsafe.types.UTF8String;
 
 /**
@@ -54,7 +55,7 @@ public class UnsafeRowWriters {
     }
   }
 
-  /** Writer for bianry (byte array) type. */
+  /** Writer for binary (byte array) type. */
   public static class BinaryWriter {
 
     public static int getSize(byte[] input) {
@@ -80,4 +81,20 @@ public class UnsafeRowWriters {
     }
   }
 
+  /** Writer for interval type. */
+  public static class IntervalWriter {
+
+    public static int write(UnsafeRow target, int ordinal, int cursor, 
Interval input) {
+      final long offset = target.getBaseOffset() + cursor;
+
+      // Write the months and microseconds fields of Interval to the variable 
length portion.
+      PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 
input.months);
+      PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 
input.microseconds);
+
+      // Set the fixed length portion.
+      target.setLong(ordinal, ((long) cursor) << 32);
+      return 16;
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/75438422/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index ad39772..9a11de3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{Interval, UTF8String}
 
 /**
  * An abstract class for row used internal in Spark SQL, which only contain 
the columns as
@@ -60,6 +60,8 @@ abstract class InternalRow extends Serializable {
 
   def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, 
DecimalType.SYSTEM_DEFAULT)
 
+  def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, 
IntervalType)
+
   // This is only use for test and will throw a null pointer exception if the 
position is null.
   def getString(ordinal: Int): String = getUTF8String(ordinal).toString
 

http://git-wip-us.apache.org/repos/asf/spark/blob/75438422/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 6b5c450..41a877f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, 
nullable: Boolean)
         case DoubleType => input.getDouble(ordinal)
         case StringType => input.getUTF8String(ordinal)
         case BinaryType => input.getBinary(ordinal)
+        case IntervalType => input.getInterval(ordinal)
         case t: StructType => input.getStruct(ordinal, t.size)
         case dataType => input.get(ordinal, dataType)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/75438422/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index e208262..bd8b017 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -630,7 +630,7 @@ case class Cast(child: Expression, dataType: DataType)
   private[this] def castToIntervalCode(from: DataType): CastFunction = from 
match {
     case StringType =>
       (c, evPrim, evNull) =>
-        s"$evPrim = 
org.apache.spark.unsafe.types.Interval.fromString($c.toString());"
+        s"$evPrim = Interval.fromString($c.toString());"
   }
 
   private[this] def decimalToTimestampCode(d: String): String =

http://git-wip-us.apache.org/repos/asf/spark/blob/75438422/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 2a1e288..2f02c90 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -79,7 +79,6 @@ class CodeGenContext {
     mutableStates += ((javaType, variableName, initCode))
   }
 
-  final val intervalType: String = classOf[Interval].getName
   final val JAVA_BOOLEAN = "boolean"
   final val JAVA_BYTE = "byte"
   final val JAVA_SHORT = "short"
@@ -109,6 +108,7 @@ class CodeGenContext {
       case _ if isPrimitiveType(jt) => 
s"$row.get${primitiveTypeName(jt)}($ordinal)"
       case StringType => s"$row.getUTF8String($ordinal)"
       case BinaryType => s"$row.getBinary($ordinal)"
+      case IntervalType => s"$row.getInterval($ordinal)"
       case t: StructType => s"$row.getStruct($ordinal, ${t.size})"
       case _ => s"($jt)$row.get($ordinal)"
     }
@@ -150,7 +150,7 @@ class CodeGenContext {
     case dt: DecimalType => "Decimal"
     case BinaryType => "byte[]"
     case StringType => "UTF8String"
-    case IntervalType => intervalType
+    case IntervalType => "Interval"
     case _: StructType => "InternalRow"
     case _: ArrayType => s"scala.collection.Seq"
     case _: MapType => s"scala.collection.Map"
@@ -292,7 +292,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: 
AnyRef] extends Loggin
       classOf[InternalRow].getName,
       classOf[UnsafeRow].getName,
       classOf[UTF8String].getName,
-      classOf[Decimal].getName
+      classOf[Decimal].getName,
+      classOf[Interval].getName
     ))
     evaluator.setExtendedClass(classOf[GeneratedClass])
     try {

http://git-wip-us.apache.org/repos/asf/spark/blob/75438422/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index afd0d9c..9d21619 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -33,10 +33,12 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
   private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
   private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
+  private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
 
   /** Returns true iff we support this data type. */
   def canSupport(dataType: DataType): Boolean = dataType match {
     case t: AtomicType if !t.isInstanceOf[DecimalType] => true
+    case _: IntervalType => true
     case NullType => true
     case _ => false
   }
@@ -68,6 +70,8 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
           s" + (${exprs(i).isNull} ? 0 : 
$StringWriter.getSize(${exprs(i).primitive}))"
         case BinaryType =>
           s" + (${exprs(i).isNull} ? 0 : 
$BinaryWriter.getSize(${exprs(i).primitive}))"
+        case IntervalType =>
+          s" + (${exprs(i).isNull} ? 0 : 16)"
         case _ => ""
       }
     }.mkString("")
@@ -80,6 +84,8 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
           s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, 
${exprs(i).primitive})"
         case BinaryType =>
           s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, 
${exprs(i).primitive})"
+        case IntervalType =>
+          s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, 
${exprs(i).primitive})"
         case NullType => ""
         case _ =>
           throw new UnsupportedOperationException(s"Not supported DataType: 
${e.dataType}")

http://git-wip-us.apache.org/repos/asf/spark/blob/75438422/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 8b0f90c..ab0cdc8 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -78,8 +78,6 @@ trait ExpressionEvalHelper {
       generator
     } catch {
       case e: Throwable =>
-        val ctx = new CodeGenContext
-        val evaluated = expression.gen(ctx)
         fail(
           s"""
             |Code generation of $expression failed:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to