Repository: spark
Updated Branches:
  refs/heads/master 46bb2b512 -> 1b08c4393


[SPARK-23584][SQL] NewInstance should support interpreted execution

## What changes were proposed in this pull request?
This pr supported interpreted mode for `NewInstance`.

## How was this patch tested?
Added tests in `ObjectExpressionsSuite`.

Author: Takeshi Yamamuro <[email protected]>

Closes #20778 from maropu/SPARK-23584.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1b08c439
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1b08c439
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1b08c439

Branch: refs/heads/master
Commit: 1b08c4393cf48e21fea9914d130d8d3bf544061d
Parents: 46bb2b5
Author: Takeshi Yamamuro <[email protected]>
Authored: Thu Apr 19 14:38:26 2018 +0200
Committer: Herman van Hovell <[email protected]>
Committed: Thu Apr 19 14:38:26 2018 +0200

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 13 +++++++
 .../catalyst/expressions/objects/objects.scala  | 28 +++++++++++++--
 .../expressions/ObjectExpressionsSuite.scala    | 36 ++++++++++++++++++++
 3 files changed, 75 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1b08c439/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index e4274aa..818cc2f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.sql.catalyst
 
+import java.lang.reflect.Constructor
+
+import org.apache.commons.lang3.reflect.ConstructorUtils
+
 import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, 
UnresolvedAttribute, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
@@ -782,6 +786,15 @@ object ScalaReflection extends ScalaReflection {
   }
 
   /**
+   * Finds an accessible constructor with compatible parameters. This is a 
more flexible search
+   * than the exact matching algorithm in `Class.getConstructor`. The first 
assignment-compatible
+   * matching constructor is returned. Otherwise, it returns `None`.
+   */
+  def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): 
Option[Constructor[_]] = {
+    Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: 
_*))
+  }
+
+  /**
    * Whether the fields of the given type is defined entirely by its 
constructor parameters.
    */
   def definedByConstructorParams(tpe: Type): Boolean = 
cleanUpReflectionObjects {

http://git-wip-us.apache.org/repos/asf/spark/blob/1b08c439/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 72b202b..1645bd7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -449,8 +449,32 @@ case class NewInstance(
     childrenResolved && !needOuterPointer
   }
 
-  override def eval(input: InternalRow): Any =
-    throw new UnsupportedOperationException("Only code-generated evaluation is 
supported.")
+  @transient private lazy val constructor: (Seq[AnyRef]) => Any = {
+    val paramTypes = ScalaReflection.expressionJavaClasses(arguments)
+    val getConstructor = (paramClazz: Seq[Class[_]]) => {
+      ScalaReflection.findConstructor(cls, paramClazz).getOrElse {
+        sys.error(s"Couldn't find a valid constructor on $cls")
+      }
+    }
+    outerPointer.map { p =>
+      val outerObj = p()
+      val d = outerObj.getClass +: paramTypes
+      val c = getConstructor(outerObj.getClass +: paramTypes)
+      (args: Seq[AnyRef]) => {
+        c.newInstance(outerObj +: args: _*)
+      }
+    }.getOrElse {
+      val c = getConstructor(paramTypes)
+      (args: Seq[AnyRef]) => {
+        c.newInstance(args: _*)
+      }
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val argValues = arguments.map(_.eval(input))
+    constructor(argValues.map(_.asInstanceOf[AnyRef]))
+  }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val javaType = CodeGenerator.javaType(dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/1b08c439/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index b0188b0..bf805f4 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -47,6 +47,20 @@ class InvokeTargetSubClass extends InvokeTargetClass {
   override def binOp(e1: Int, e2: Double): Double = e1 - e2
 }
 
+// Tests for NewInstance
+class Outer extends Serializable {
+  class Inner(val value: Int) {
+    override def hashCode(): Int = super.hashCode()
+    override def equals(other: Any): Boolean = {
+      if (other.isInstanceOf[Inner]) {
+        value == other.asInstanceOf[Inner].value
+      } else {
+        false
+      }
+    }
+  }
+}
+
 class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("SPARK-16622: The returned value of the called method in Invoke can be 
null") {
@@ -383,6 +397,27 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     }
   }
 
+  test("SPARK-23584 NewInstance should support interpreted execution") {
+    // Normal case test
+    val newInst1 = NewInstance(
+      cls = classOf[GenericArrayData],
+      arguments = Literal.fromObject(List(1, 2, 3)) :: Nil,
+      propagateNull = false,
+      dataType = ArrayType(IntegerType),
+      outerPointer = None)
+    checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3)))
+
+    // Inner class case test
+    val outerObj = new Outer()
+    val newInst2 = NewInstance(
+      cls = classOf[outerObj.Inner],
+      arguments = Literal(1) :: Nil,
+      propagateNull = false,
+      dataType = ObjectType(classOf[outerObj.Inner]),
+      outerPointer = Some(() => outerObj))
+    checkObjectExprEvaluation(newInst2, new outerObj.Inner(1))
+  }
+
   test("LambdaVariable should support interpreted execution") {
     def genSchema(dt: DataType): Seq[StructType] = {
       Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil),
@@ -421,6 +456,7 @@ class TestBean extends Serializable {
   private var x: Int = 0
 
   def setX(i: Int): Unit = x = i
+
   def setNonPrimitive(i: AnyRef): Unit =
     assert(i != null, "this setter should not be called with null.")
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to