Repository: spark
Updated Branches:
  refs/heads/master 921900fd0 -> 59a501359


[SPARK-11636][SQL] Support classes defined in the REPL with Encoders

Before this PR there were two things that would blow up if you called 
`df.as[MyClass]` if `MyClass` was defined in the REPL:
 - [x] Because `classForName` doesn't work on the munged names returned by 
`tpe.erasure.typeSymbol.asClass.fullName`
 - [x] Because we don't have anything to pass into the constructor for the 
`$outer` pointer.

Note that this PR is just adding the infrastructure for working with inner 
classes in encoder and is not yet sufficient to make them work in the REPL.  
Currently, the implementation show in 
https://github.com/marmbrus/spark/commit/95cec7d413b930b36420724fafd829bef8c732ab
 is causing a bug that breaks code gen due to some interaction between janino 
and the `ExecutorClassLoader`.  This will be addressed in a follow-up PR.

Author: Michael Armbrust <[email protected]>

Closes #9602 from marmbrus/dataset-replClasses.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/59a50135
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/59a50135
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/59a50135

Branch: refs/heads/master
Commit: 59a501359a267fbdb7689058693aa788703e54b1
Parents: 921900f
Author: Michael Armbrust <[email protected]>
Authored: Wed Nov 18 16:48:09 2015 -0800
Committer: Michael Armbrust <[email protected]>
Committed: Wed Nov 18 16:48:09 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 81 +++++++++++---------
 .../catalyst/encoders/ExpressionEncoder.scala   | 26 ++++++-
 .../sql/catalyst/encoders/OuterScopes.scala     | 42 ++++++++++
 .../sql/catalyst/encoders/ProductEncoder.scala  |  6 +-
 .../expressions/codegen/CodegenFallback.scala   |  2 +-
 .../codegen/GenerateMutableProjection.scala     |  4 +-
 .../codegen/GenerateProjection.scala            | 10 +--
 .../codegen/GenerateSafeProjection.scala        |  4 +-
 .../codegen/GenerateUnsafeProjection.scala      |  4 +-
 .../codegen/GenerateUnsafeRowJoiner.scala       |  6 +-
 .../sql/catalyst/expressions/literals.scala     |  6 ++
 .../sql/catalyst/expressions/objects.scala      | 42 ++++++++--
 .../encoders/ExpressionEncoderSuite.scala       |  7 +-
 .../catalyst/encoders/ProductEncoderSuite.scala |  4 +
 .../scala/org/apache/spark/sql/Dataset.scala    |  4 +-
 .../org/apache/spark/sql/GroupedDataset.scala   |  8 +-
 .../aggregate/TypedAggregateExpression.scala    | 19 ++---
 17 files changed, 193 insertions(+), 82 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 38828e5..59ccf35 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -35,17 +35,6 @@ object ScalaReflection extends ScalaReflection {
   // class loader of the current thread.
   override def mirror: universe.Mirror =
     universe.runtimeMirror(Thread.currentThread().getContextClassLoader)
-}
-
-/**
- * Support for generating catalyst schemas for scala objects.
- */
-trait ScalaReflection {
-  /** The universe we work in (runtime or macro) */
-  val universe: scala.reflect.api.Universe
-
-  /** The mirror used to access types in the universe */
-  def mirror: universe.Mirror
 
   import universe._
 
@@ -53,30 +42,6 @@ trait ScalaReflection {
   // Since the map values can be mutable, we explicitly import 
scala.collection.Map at here.
   import scala.collection.Map
 
-  case class Schema(dataType: DataType, nullable: Boolean)
-
-  /** Returns a Sequence of attributes for the given case class type. */
-  def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
-    case Schema(s: StructType, _) =>
-      s.toAttributes
-  }
-
-  /** Returns a catalyst DataType and its nullability for the given Scala Type 
using reflection. */
-  def schemaFor[T: TypeTag]: Schema =
-    ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) }
-
-  /**
-   * Return the Scala Type for `T` in the current classloader mirror.
-   *
-   * Use this method instead of the convenience method `universe.typeOf`, which
-   * assumes that all types can be found in the classloader that loaded 
scala-reflect classes.
-   * That's not necessarily the case when running using Eclipse launchers or 
even
-   * Sbt console or test (without `fork := true`).
-   *
-   * @see SPARK-5281
-   */
-  def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
-
   /**
    * Returns the Spark SQL DataType for a given scala type.  Where this is not 
an exact mapping
    * to a native type, an ObjectType is returned. Special handling is also 
used for Arrays including
@@ -114,7 +79,9 @@ trait ScalaReflection {
 
           }
           ObjectType(cls)
-        case other => ObjectType(Utils.classForName(className))
+        case other =>
+          val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
+          ObjectType(clazz)
       }
   }
 
@@ -640,6 +607,48 @@ trait ScalaReflection {
       }
     }
   }
+}
+
+/**
+ * Support for generating catalyst schemas for scala objects.  Note that 
unlike its companion
+ * object, this trait able to work in both the runtime and the compile time 
(macro) universe.
+ */
+trait ScalaReflection {
+  /** The universe we work in (runtime or macro) */
+  val universe: scala.reflect.api.Universe
+
+  /** The mirror used to access types in the universe */
+  def mirror: universe.Mirror
+
+  import universe._
+
+  // The Predef.Map is scala.collection.immutable.Map.
+  // Since the map values can be mutable, we explicitly import 
scala.collection.Map at here.
+  import scala.collection.Map
+
+  case class Schema(dataType: DataType, nullable: Boolean)
+
+  /** Returns a Sequence of attributes for the given case class type. */
+  def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
+    case Schema(s: StructType, _) =>
+      s.toAttributes
+  }
+
+  /** Returns a catalyst DataType and its nullability for the given Scala Type 
using reflection. */
+  def schemaFor[T: TypeTag]: Schema =
+    ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) }
+
+  /**
+   * Return the Scala Type for `T` in the current classloader mirror.
+   *
+   * Use this method instead of the convenience method `universe.typeOf`, which
+   * assumes that all types can be found in the classloader that loaded 
scala-reflect classes.
+   * That's not necessarily the case when running using Eclipse launchers or 
even
+   * Sbt console or test (without `fork := true`).
+   *
+   * @see SPARK-5281
+   */
+  def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
 
   /** Returns a catalyst DataType and its nullability for the given Scala Type 
using reflection. */
   def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index b977f27..456b595 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
+import java.util.concurrent.ConcurrentMap
+
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.{typeTag, TypeTag}
 
 import org.apache.spark.util.Utils
-import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.{AnalysisException, Encoder}
 import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, 
UnresolvedExtractValue, UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.catalyst.expressions._
@@ -211,7 +213,9 @@ case class ExpressionEncoder[T](
    * Returns a new copy of this encoder, where the expressions used by 
`fromRow` are resolved to the
    * given schema.
    */
-  def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+  def resolve(
+      schema: Seq[Attribute],
+      outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
     val positionToAttribute = AttributeMap.toIndex(schema)
     val unbound = fromRowExpression transform {
       case b: BoundReference => positionToAttribute(b.ordinal)
@@ -219,7 +223,23 @@ case class ExpressionEncoder[T](
 
     val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
     val analyzedPlan = SimpleAnalyzer.execute(plan)
-    copy(fromRowExpression = analyzedPlan.expressions.head.children.head)
+
+    // In order to construct instances of inner classes (for example those 
declared in a REPL cell),
+    // we need an instance of the outer scope.  This rule substitues those 
outer objects into
+    // expressions that are missing them by looking up the name in the 
SQLContexts `outerScopes`
+    // registry.
+    copy(fromRowExpression = analyzedPlan.expressions.head.children.head 
transform {
+      case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
+        val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
+        if (outer == null) {
+          throw new AnalysisException(
+            s"Unable to generate an encoder for inner class `${n.cls.getName}` 
without access " +
+            s"to the scope that this class was defined in. " + "" +
+             "Try moving this class out of its parent class.")
+        }
+
+        n.copy(outerPointer = Some(Literal.fromObject(outer)))
+    })
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
new file mode 100644
index 0000000..a753b18
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import java.util.concurrent.ConcurrentMap
+
+import com.google.common.collect.MapMaker
+
+object OuterScopes {
+  @transient
+  lazy val outerScopes: ConcurrentMap[String, AnyRef] =
+    new MapMaker().weakValues().makeMap()
+
+  /**
+   * Adds a new outer scope to this context that can be used when 
instantiating an `inner class`
+   * during deserialialization. Inner classes are created when a case class is 
defined in the
+   * Spark REPL and registering the outer scope that this class was defined in 
allows us to create
+   * new instances on the spark executors.  In normal use, users should not 
need to call this
+   * function.
+   *
+   * Warning: this function operates on the assumption that there is only ever 
one instance of any
+   * given wrapper class.
+   */
+  def addOuterScope(outer: AnyRef): Unit = {
+    outerScopes.putIfAbsent(outer.getClass.getName, outer)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index 55c4ee1..2914c6e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -31,6 +31,7 @@ import scala.reflect.ClassTag
 
 object ProductEncoder {
   import ScalaReflection.universe._
+  import ScalaReflection.mirror
   import ScalaReflection.localTypeOf
   import ScalaReflection.dataTypeFor
   import ScalaReflection.Schema
@@ -420,8 +421,7 @@ object ProductEncoder {
           }
         }
 
-        val className: String = t.erasure.typeSymbol.asClass.fullName
-        val cls = Utils.classForName(className)
+        val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
 
         val arguments = params.head.zipWithIndex.map { case (p, i) =>
           val fieldName = p.name.toString
@@ -429,7 +429,7 @@ object ProductEncoder {
           val dataType = schemaFor(fieldType).dataType
 
           // For tuples, we based grab the inner fields by ordinal instead of 
name.
-          if (className startsWith "scala.Tuple") {
+          if (cls.getName startsWith "scala.Tuple") {
             constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
           } else {
             constructorFor(fieldType, Some(addToPath(fieldName)))

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index d51a8de..a31574c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -34,7 +34,7 @@ trait CodegenFallback extends Expression {
     val objectTerm = ctx.freshName("obj")
     s"""
       /* expression: ${this} */
-      Object $objectTerm = expressions[${ctx.references.size - 
1}].eval(${ctx.INPUT_ROW});
+      java.lang.Object $objectTerm = expressions[${ctx.references.size - 
1}].eval(${ctx.INPUT_ROW});
       boolean ${ev.isNull} = $objectTerm == null;
       ${ctx.javaType(this.dataType)} ${ev.value} = 
${ctx.defaultValue(this.dataType)};
       if (!${ev.isNull}) {

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 4b66069..40189f0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -82,7 +82,7 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], () => Mu
     val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
 
     val code = s"""
-      public Object generate($exprType[] expr) {
+      public java.lang.Object generate($exprType[] expr) {
         return new SpecificMutableProjection(expr);
       }
 
@@ -109,7 +109,7 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], () => Mu
           return (InternalRow) mutableRow;
         }
 
-        public Object apply(Object _i) {
+        public java.lang.Object apply(java.lang.Object _i) {
           InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
           $allProjections
           // copy all the results into MutableRow

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index c0d313b..f229f20 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -167,7 +167,7 @@ object GenerateProjection extends 
CodeGenerator[Seq[Expression], Projection] {
         ${initMutableStates(ctx)}
       }
 
-      public Object apply(Object r) {
+      public java.lang.Object apply(java.lang.Object r) {
         // GenerateProjection does not work with UnsafeRows.
         assert(!(r instanceof ${classOf[UnsafeRow].getName}));
         return new SpecificRow((InternalRow) r);
@@ -186,14 +186,14 @@ object GenerateProjection extends 
CodeGenerator[Seq[Expression], Projection] {
         public void setNullAt(int i) { nullBits[i] = true; }
         public boolean isNullAt(int i) { return nullBits[i]; }
 
-        public Object genericGet(int i) {
+        public java.lang.Object genericGet(int i) {
           if (isNullAt(i)) return null;
           switch (i) {
           $getCases
           }
           return null;
         }
-        public void update(int i, Object value) {
+        public void update(int i, java.lang.Object value) {
           if (value == null) {
             setNullAt(i);
             return;
@@ -212,7 +212,7 @@ object GenerateProjection extends 
CodeGenerator[Seq[Expression], Projection] {
           return result;
         }
 
-        public boolean equals(Object other) {
+        public boolean equals(java.lang.Object other) {
           if (other instanceof SpecificRow) {
             SpecificRow row = (SpecificRow) other;
             $columnChecks
@@ -222,7 +222,7 @@ object GenerateProjection extends 
CodeGenerator[Seq[Expression], Projection] {
         }
 
         public InternalRow copy() {
-          Object[] arr = new Object[${expressions.length}];
+          java.lang.Object[] arr = new java.lang.Object[${expressions.length}];
           ${copyColumns}
           return new ${classOf[GenericInternalRow].getName}(arr);
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index f0ed864..b7926bd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -148,7 +148,7 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
     }
     val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes)
     val code = s"""
-      public Object generate($exprType[] expr) {
+      public java.lang.Object generate($exprType[] expr) {
         return new SpecificSafeProjection(expr);
       }
 
@@ -165,7 +165,7 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
           ${initMutableStates(ctx)}
         }
 
-        public Object apply(Object _i) {
+        public java.lang.Object apply(java.lang.Object _i) {
           InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
           $allExpressions
           return mutableRow;

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 4c17d02..7b6c937 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -324,7 +324,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)
 
     val code = s"""
-      public Object generate($exprType[] exprs) {
+      public java.lang.Object generate($exprType[] exprs) {
         return new SpecificUnsafeProjection(exprs);
       }
 
@@ -342,7 +342,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         }
 
         // Scala.Function1 need this
-        public Object apply(Object row) {
+        public java.lang.Object apply(java.lang.Object row) {
           return apply((InternalRow) row);
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index da91ff2..da602d9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -159,7 +159,7 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
 
     // ------------------------ Finally, put everything together  
--------------------------- //
     val code = s"""
-       |public Object generate($exprType[] exprs) {
+       |public java.lang.Object generate($exprType[] exprs) {
        |  return new SpecificUnsafeRowJoiner();
        |}
        |
@@ -176,9 +176,9 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
        |      buf = new byte[sizeInBytes];
        |    }
        |
-       |    final Object obj1 = row1.getBaseObject();
+       |    final java.lang.Object obj1 = row1.getBaseObject();
        |    final long offset1 = row1.getBaseOffset();
-       |    final Object obj2 = row2.getBaseObject();
+       |    final java.lang.Object obj2 = row2.getBaseObject();
        |    final long offset2 = row2.getBaseOffset();
        |
        |    $copyBitset

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 455fa24..e34fd49 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -48,6 +48,12 @@ object Literal {
       throw new RuntimeException("Unsupported literal type " + v.getClass + " 
" + v)
   }
 
+  /**
+   * Constructs a [[Literal]] of [[ObjectType]], for example when you need to 
pass an object
+   * into code generation.
+   */
+  def fromObject(obj: AnyRef): Literal = new Literal(obj, 
ObjectType(obj.getClass))
+
   def create(v: Any, dataType: DataType): Literal = {
     Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index acf0da2..f865a94 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkConf
 import org.apache.spark.serializer._
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.encoders.ProductEncoder
 import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
 import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.catalyst.InternalRow
@@ -178,6 +179,15 @@ case class Invoke(
   }
 }
 
+object NewInstance {
+  def apply(
+      cls: Class[_],
+      arguments: Seq[Expression],
+      propagateNull: Boolean = false,
+      dataType: DataType): NewInstance =
+    new NewInstance(cls, arguments, propagateNull, dataType, None)
+}
+
 /**
  * Constructs a new instance of the given class, using the result of 
evaluating the specified
  * expressions as arguments.
@@ -189,12 +199,15 @@ case class Invoke(
  * @param dataType The type of object being constructed, as a Spark SQL 
datatype.  This allows you
  *                 to manually specify the type when the object in question is 
a valid internal
  *                 representation (i.e. ArrayData) instead of an object.
+ * @param outerPointer If the object being constructed is an inner class the 
outerPointer must
+ *                     for the containing class must be specified.
  */
 case class NewInstance(
     cls: Class[_],
     arguments: Seq[Expression],
-    propagateNull: Boolean = true,
-    dataType: DataType) extends Expression {
+    propagateNull: Boolean,
+    dataType: DataType,
+    outerPointer: Option[Literal]) extends Expression {
   private val className = cls.getName
 
   override def nullable: Boolean = propagateNull
@@ -209,30 +222,43 @@ case class NewInstance(
     val argGen = arguments.map(_.gen(ctx))
     val argString = argGen.map(_.value).mkString(", ")
 
+    val outer = outerPointer.map(_.gen(ctx))
+
+    val setup =
+      s"""
+         ${argGen.map(_.code).mkString("\n")}
+         ${outer.map(_.code.mkString("")).getOrElse("")}
+       """.stripMargin
+
+    val constructorCall = outer.map { gen =>
+      s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
+    }.getOrElse {
+      s"new $className($argString)"
+    }
+
     if (propagateNull) {
       val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
         s"${ev.isNull} = ${ev.value} == null;"
       } else {
         ""
       }
-
       val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
+
       s"""
-        ${argGen.map(_.code).mkString("\n")}
+        $setup
 
         boolean ${ev.isNull} = true;
         $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
-
         if ($argsNonNull) {
-          ${ev.value} = new $className($argString);
+          ${ev.value} = $constructorCall;
           ${ev.isNull} = false;
         }
        """
     } else {
       s"""
-        ${argGen.map(_.code).mkString("\n")}
+        $setup
 
-        $javaType ${ev.value} = new $className($argString);
+        $javaType ${ev.value} = $constructorCall;
         final boolean ${ev.isNull} = ${ev.value} == null;
       """
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 9fe64b4..cde0364 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -18,6 +18,9 @@
 package org.apache.spark.sql.catalyst.encoders
 
 import java.util.Arrays
+import java.util.concurrent.ConcurrentMap
+
+import com.google.common.collect.MapMaker
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
@@ -25,6 +28,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.types.ArrayType
 
 abstract class ExpressionEncoderSuite extends SparkFunSuite {
+  val outers: ConcurrentMap[String, AnyRef] = new 
MapMaker().weakValues().makeMap()
+
   protected def encodeDecodeTest[T](
       input: T,
       encoder: ExpressionEncoder[T],
@@ -32,7 +37,7 @@ abstract class ExpressionEncoderSuite extends SparkFunSuite {
     test(s"encode/decode for $testName: $input") {
       val row = encoder.toRow(input)
       val schema = encoder.schema.toAttributes
-      val boundEncoder = encoder.resolve(schema).bind(schema)
+      val boundEncoder = encoder.resolve(schema, outers).bind(schema)
       val convertedBack = try boundEncoder.fromRow(row) catch {
         case e: Exception =>
           fail(

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
index bc539d6..1798514 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -53,6 +53,10 @@ case class RepeatedData(
 case class SpecificCollection(l: List[Int])
 
 class ProductEncoderSuite extends ExpressionEncoderSuite {
+  outers.put(getClass.getName, this)
+
+  case class InnerClass(i: Int)
+  productTest(InnerClass(1))
 
   productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b644f6a..bdcdc5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -74,7 +74,7 @@ class Dataset[T] private[sql](
 
   /** The encoder for this [[Dataset]] that has been resolved to its output 
schema. */
   private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
-    unresolvedTEncoder.resolve(queryExecution.analyzed.output)
+    unresolvedTEncoder.resolve(queryExecution.analyzed.output, 
OuterScopes.outerScopes)
 
   private implicit def classTag = resolvedTEncoder.clsTag
 
@@ -375,7 +375,7 @@ class Dataset[T] private[sql](
       sqlContext,
       Project(
         c1.withInputType(
-          resolvedTEncoder,
+          resolvedTEncoder.bind(queryExecution.analyzed.output),
           queryExecution.analyzed.output).named :: Nil,
         logicalPlan))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 3f84e22..7e5acbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, 
encoderFor}
+import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, 
encoderFor, OuterScopes}
 import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, 
Attribute}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
@@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql](
   private implicit val unresolvedKEncoder = encoderFor(kEncoder)
   private implicit val unresolvedTEncoder = encoderFor(tEncoder)
 
-  private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes)
-  private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes)
+  private val resolvedKEncoder =
+    unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes)
+  private val resolvedTEncoder =
+    unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes)
 
   private def logicalPlan = queryExecution.analyzed
   private def sqlContext = queryExecution.sqlContext

http://git-wip-us.apache.org/repos/asf/spark/blob/59a50135/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 3f27758..6ce41aa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -52,8 +52,8 @@ object TypedAggregateExpression {
  */
 case class TypedAggregateExpression(
     aggregator: Aggregator[Any, Any, Any],
-    aEncoder: Option[ExpressionEncoder[Any]],
-    bEncoder: ExpressionEncoder[Any],
+    aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
+    bEncoder: ExpressionEncoder[Any], // Should be bound.
     cEncoder: ExpressionEncoder[Any],
     children: Seq[Attribute],
     mutableAggBufferOffset: Int,
@@ -92,9 +92,6 @@ case class TypedAggregateExpression(
   // We let the dataset do the binding for us.
   lazy val boundA = aEncoder.get
 
-  val bAttributes = bEncoder.schema.toAttributes
-  lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
-
   private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
     // todo: need a more neat way to assign the value.
     var i = 0
@@ -114,24 +111,24 @@ case class TypedAggregateExpression(
 
   override def update(buffer: MutableRow, input: InternalRow): Unit = {
     val inputA = boundA.fromRow(input)
-    val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+    val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
     val merged = aggregator.reduce(currentB, inputA)
-    val returned = boundB.toRow(merged)
+    val returned = bEncoder.toRow(merged)
 
     updateBuffer(buffer, returned)
   }
 
   override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
-    val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1)
-    val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2)
+    val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1)
+    val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2)
     val merged = aggregator.merge(b1, b2)
-    val returned = boundB.toRow(merged)
+    val returned = bEncoder.toRow(merged)
 
     updateBuffer(buffer1, returned)
   }
 
   override def eval(buffer: InternalRow): Any = {
-    val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+    val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
     val result = cEncoder.toRow(aggregator.finish(b))
     dataType match {
       case _: StructType => result


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to