Repository: spark
Updated Branches:
  refs/heads/master 38700ea40 -> 35a19f335


[SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on 
SQLContext

Instead of relying on `DataFrames` to verify our answers, we can just use 
simple arrays. This significantly simplifies the test logic for `LocalNode`s 
and reduces a lot of code duplicated from `SparkPlanTest`.

This also fixes an additional issue 
[SPARK-10624](https://issues.apache.org/jira/browse/SPARK-10624) where the 
output of `TakeOrderedAndProjectNode` is not actually ordered.

Author: Andrew Or <and...@databricks.com>

Closes #8764 from andrewor14/sql-local-tests-cleanup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35a19f33
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35a19f33
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35a19f33

Branch: refs/heads/master
Commit: 35a19f3357d2ec017cfefb90f1018403e9617de4
Parents: 38700ea
Author: Andrew Or <and...@databricks.com>
Authored: Tue Sep 15 17:24:32 2015 -0700
Committer: Andrew Or <and...@databricks.com>
Committed: Tue Sep 15 17:24:32 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/local/LocalNode.scala   |   8 +-
 .../spark/sql/execution/local/SampleNode.scala  |  16 +-
 .../local/TakeOrderedAndProjectNode.scala       |   2 +-
 .../spark/sql/execution/SparkPlanTest.scala     |   2 +-
 .../spark/sql/execution/local/DummyNode.scala   |  68 ++++
 .../sql/execution/local/ExpandNodeSuite.scala   |  54 ++--
 .../sql/execution/local/FilterNodeSuite.scala   |  34 +-
 .../sql/execution/local/HashJoinNodeSuite.scala | 141 ++++-----
 .../execution/local/IntersectNodeSuite.scala    |  24 +-
 .../sql/execution/local/LimitNodeSuite.scala    |  28 +-
 .../sql/execution/local/LocalNodeSuite.scala    |  73 +----
 .../sql/execution/local/LocalNodeTest.scala     | 165 +++-------
 .../local/NestedLoopJoinNodeSuite.scala         | 316 +++++++------------
 .../sql/execution/local/ProjectNodeSuite.scala  |  39 ++-
 .../sql/execution/local/SampleNodeSuite.scala   |  35 +-
 .../local/TakeOrderedAndProjectNodeSuite.scala  |  50 ++-
 .../sql/execution/local/UnionNodeSuite.scala    |  49 +--
 17 files changed, 468 insertions(+), 636 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index 569cff5..f96b62a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType
  * Before consuming the iterator, open function must be called.
  * After consuming the iterator, close function must be called.
  */
-abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with 
Logging {
+abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with 
Logging {
 
   protected val codegenEnabled: Boolean = conf.codegenEnabled
 
   protected val unsafeEnabled: Boolean = conf.unsafeEnabled
 
-  lazy val schema: StructType = StructType.fromAttributes(output)
-
   private[this] lazy val isTesting: Boolean = 
sys.props.contains("spark.testing")
 
-  def output: Seq[Attribute]
-
   /**
    * Called before open(). Prepare can be used to reserve memory needed. It 
must NOT consume
    * any input data.

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
index abf3df1..7937008 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
@@ -17,13 +17,12 @@
 
 package org.apache.spark.sql.execution.local
 
-import java.util.Random
-
 import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
 
+
 /**
  * Sample the dataset.
  *
@@ -51,18 +50,15 @@ case class SampleNode(
 
   override def open(): Unit = {
     child.open()
-    val (sampler, _seed) = if (withReplacement) {
-        val random = new Random(seed)
+    val sampler =
+      if (withReplacement) {
         // Disable gap sampling since the gap sampling method buffers two rows 
internally,
         // requiring us to copy the row, which is more expensive than the 
random number generator.
-        (new PoissonSampler[InternalRow](upperBound - lowerBound, 
useGapSamplingIfPossible = false),
-          // Use the seed for partition 0 like PartitionwiseSampledRDD to 
generate the same result
-          // of DataFrame
-          random.nextLong())
+        new PoissonSampler[InternalRow](upperBound - lowerBound, 
useGapSamplingIfPossible = false)
       } else {
-        (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
+        new BernoulliCellSampler[InternalRow](lowerBound, upperBound)
       }
-    sampler.setSeed(_seed)
+    sampler.setSeed(seed)
     iterator = sampler.sample(child.asIterator)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
index 53f1dcc..ae672fb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
@@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode(
     }
     // Close it eagerly since we don't need it.
     child.close()
-    iterator = queue.iterator
+    iterator = queue.toArray.sorted(ord).iterator
   }
 
   override def next(): Boolean = {

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index de45ae4..3d218f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -238,7 +238,7 @@ object SparkPlanTest {
       outputPlan transform {
         case plan: SparkPlan =>
           val inputMap = plan.children.flatMap(_.output).map(a => (a.name, 
a)).toMap
-          plan.transformExpressions {
+          plan transformExpressions {
             case UnresolvedAttribute(Seq(u)) =>
               inputMap.getOrElse(u,
                 sys.error(s"Invalid Test: Cannot resolve $u given input 
$inputMap"))

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala
new file mode 100644
index 0000000..efc3227
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala
@@ -0,0 +1,68 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+
+/**
+ * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]].
+ */
+private[local] case class DummyNode(
+    output: Seq[Attribute],
+    relation: LocalRelation,
+    conf: SQLConf)
+  extends LocalNode(conf) {
+
+  import DummyNode._
+
+  private var index: Int = CLOSED
+  private val input: Seq[InternalRow] = relation.data
+
+  def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new 
SQLConf) {
+    this(output, LocalRelation.fromProduct(output, data), conf)
+  }
+
+  def isOpen: Boolean = index != CLOSED
+
+  override def children: Seq[LocalNode] = Seq.empty
+
+  override def open(): Unit = {
+    index = -1
+  }
+
+  override def next(): Boolean = {
+    index += 1
+    index < input.size
+  }
+
+  override def fetch(): InternalRow = {
+    assert(index >= 0 && index < input.size)
+    input(index)
+  }
+
+  override def close(): Unit = {
+    index = CLOSED
+  }
+}
+
+private object DummyNode {
+  val CLOSED: Int = Int.MinValue
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala
index cfa7f3f..bbd94d8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala
@@ -17,35 +17,33 @@
 
 package org.apache.spark.sql.execution.local
 
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+
 class ExpandNodeSuite extends LocalNodeTest {
 
-  import testImplicits._
-
-  test("expand") {
-    val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", 
"value")
-    checkAnswer(
-      input,
-      node =>
-        ExpandNode(conf, Seq(
-          Seq(
-            input.col("key") + input.col("value"), input.col("key") - 
input.col("value")
-          ).map(_.expr),
-          Seq(
-            input.col("key") * input.col("value"), input.col("key") / 
input.col("value")
-          ).map(_.expr)
-        ), node.output, node),
-      Seq(
-        (2, 0),
-        (1, 1),
-        (4, 0),
-        (4, 1),
-        (6, 0),
-        (9, 1),
-        (8, 0),
-        (16, 1),
-        (10, 0),
-        (25, 1)
-      ).toDF().collect()
-    )
+  private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = {
+    val inputNode = new DummyNode(kvIntAttributes, inputData)
+    val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v))
+    val expandNode = new ExpandNode(conf, projections, inputNode.output, 
inputNode)
+    val resolvedNode = resolveExpressions(expandNode)
+    val expectedOutput = {
+      val firstHalf = inputData.map { case (k, v) => (k + v, k - v) }
+      val secondHalf = inputData.map { case (k, v) => (k * v, k / v) }
+      firstHalf ++ secondHalf
+    }
+    val actualOutput = resolvedNode.collect().map { case row =>
+      (row.getInt(0), row.getInt(1))
+    }
+    assert(actualOutput.toSet === expectedOutput.toSet)
+  }
+
+  test("empty") {
+    testExpand()
   }
+
+  test("basic") {
+    testExpand((1 to 100).map { i => (i, i * 1000) }.toArray)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
index a12670e..4eadce6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
@@ -17,25 +17,29 @@
 
 package org.apache.spark.sql.execution.local
 
-import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.catalyst.dsl.expressions._
 
-class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
 
-  test("basic") {
-    val condition = (testData.col("key") % 2) === 0
-    checkAnswer(
-      testData,
-      node => FilterNode(conf, condition.expr, node),
-      testData.filter(condition).collect()
-    )
+class FilterNodeSuite extends LocalNodeTest {
+
+  private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = {
+    val cond = 'k % 2 === 0
+    val inputNode = new DummyNode(kvIntAttributes, inputData)
+    val filterNode = new FilterNode(conf, cond, inputNode)
+    val resolvedNode = resolveExpressions(filterNode)
+    val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 }
+    val actualOutput = resolvedNode.collect().map { case row =>
+      (row.getInt(0), row.getInt(1))
+    }
+    assert(actualOutput === expectedOutput)
   }
 
   test("empty") {
-    val condition = (emptyTestData.col("key") % 2) === 0
-    checkAnswer(
-      emptyTestData,
-      node => FilterNode(conf, condition.expr, node),
-      emptyTestData.filter(condition).collect()
-    )
+    testFilter()
+  }
+
+  test("basic") {
+    testFilter((1 to 100).map { i => (i, i) }.toArray)
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
index 78d8913..5c1bdb0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
@@ -18,99 +18,80 @@
 package org.apache.spark.sql.execution.local
 
 import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.execution.joins
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
+
 
 class HashJoinNodeSuite extends LocalNodeTest {
 
-  import testImplicits._
+  // Test all combinations of the two dimensions: with/out unsafe and build 
sides
+  private val maybeUnsafeAndCodegen = Seq(false, true)
+  private val buildSides = Seq(BuildLeft, BuildRight)
+  maybeUnsafeAndCodegen.foreach { unsafeAndCodegen =>
+    buildSides.foreach { buildSide =>
+      testJoin(unsafeAndCodegen, buildSide)
+    }
+  }
 
-  def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
-    test(s"$suiteName: inner join with one match per row") {
-      withSQLConf(confPairs: _*) {
-        checkAnswer2(
-          upperCaseData,
-          lowerCaseData,
-          wrapForUnsafe(
-            (node1, node2) => HashJoinNode(
-              conf,
-              Seq(upperCaseData.col("N").expr),
-              Seq(lowerCaseData.col("n").expr),
-              joins.BuildLeft,
-              node1,
-              node2)
-          ),
-          upperCaseData.join(lowerCaseData, $"n" === $"N").collect()
-        )
+  /**
+   * Test inner hash join with varying degrees of matches.
+   */
+  private def testJoin(
+      unsafeAndCodegen: Boolean,
+      buildSide: BuildSide): Unit = {
+    val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe"
+    val testNamePrefix = s"$simpleOrUnsafe / $buildSide"
+    val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray
+    val conf = new SQLConf
+    conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen)
+    conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen)
+
+    // Actual test body
+    def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, 
String)]): Unit = {
+      val rightInputMap = rightInput.toMap
+      val leftNode = new DummyNode(joinNameAttributes, leftInput)
+      val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
+      val makeNode = (node1: LocalNode, node2: LocalNode) => {
+        resolveExpressions(new HashJoinNode(
+          conf, Seq('id1), Seq('id2), buildSide, node1, node2))
+      }
+      val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else 
makeNode
+      val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
+      val expectedOutput = leftInput
+        .filter { case (k, _) => rightInputMap.contains(k) }
+        .map { case (k, v) => (k, v, k, rightInputMap(k)) }
+      val actualOutput = hashJoinNode.collect().map { row =>
+        // (id, name, id, nickname)
+        (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
       }
+      assert(actualOutput === expectedOutput)
     }
 
-    test(s"$suiteName: inner join with multiple matches") {
-      withSQLConf(confPairs: _*) {
-        val x = testData2.where($"a" === 1).as("x")
-        val y = testData2.where($"a" === 1).as("y")
-        checkAnswer2(
-          x,
-          y,
-          wrapForUnsafe(
-            (node1, node2) => HashJoinNode(
-              conf,
-              Seq(x.col("a").expr),
-              Seq(y.col("a").expr),
-              joins.BuildLeft,
-              node1,
-              node2)
-          ),
-          x.join(y).where($"x.a" === $"y.a").collect()
-        )
-      }
+    test(s"$testNamePrefix: empty") {
+      runTest(Array.empty, Array.empty)
+      runTest(someData, Array.empty)
+      runTest(Array.empty, someData)
     }
 
-    test(s"$suiteName: inner join, no matches") {
-      withSQLConf(confPairs: _*) {
-        val x = testData2.where($"a" === 1).as("x")
-        val y = testData2.where($"a" === 2).as("y")
-        checkAnswer2(
-          x,
-          y,
-          wrapForUnsafe(
-            (node1, node2) => HashJoinNode(
-              conf,
-              Seq(x.col("a").expr),
-              Seq(y.col("a").expr),
-              joins.BuildLeft,
-              node1,
-              node2)
-          ),
-          Nil
-        )
-      }
+    test(s"$testNamePrefix: no matches") {
+      val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) 
}.toArray
+      runTest(someData, Array.empty)
+      runTest(Array.empty, someData)
+      runTest(someData, someIrrelevantData)
+      runTest(someIrrelevantData, someData)
     }
 
-    test(s"$suiteName: big inner join, 4 matches per row") {
-      withSQLConf(confPairs: _*) {
-        val bigData = 
testData.unionAll(testData).unionAll(testData).unionAll(testData)
-        val bigDataX = bigData.as("x")
-        val bigDataY = bigData.as("y")
+    test(s"$testNamePrefix: partial matches") {
+      val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray
+      runTest(someData, someOtherData)
+      runTest(someOtherData, someData)
+    }
 
-        checkAnswer2(
-          bigDataX,
-          bigDataY,
-          wrapForUnsafe(
-            (node1, node2) =>
-              HashJoinNode(
-                conf,
-                Seq(bigDataX.col("key").expr),
-                Seq(bigDataY.col("key").expr),
-                joins.BuildLeft,
-                node1,
-                node2)
-          ),
-          bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect())
-      }
+    test(s"$testNamePrefix: full matches") {
+      val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + 
v) }.toArray
+      runTest(someData, someSuperRelevantData)
+      runTest(someSuperRelevantData, someData)
     }
   }
 
-  joinSuite(
-    "general", SQLConf.CODEGEN_ENABLED.key -> "false", 
SQLConf.UNSAFE_ENABLED.key -> "false")
-  joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", 
SQLConf.UNSAFE_ENABLED.key -> "true")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
index 7deaa37..c0ad202 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
@@ -17,19 +17,21 @@
 
 package org.apache.spark.sql.execution.local
 
-class IntersectNodeSuite extends LocalNodeTest {
 
-  import testImplicits._
+class IntersectNodeSuite extends LocalNodeTest {
 
   test("basic") {
-    val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
-    val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, 
i.toString)).toDF("key", "value")
-
-    checkAnswer2(
-      input1,
-      input2,
-      (node1, node2) => IntersectNode(conf, node1, node2),
-      input1.intersect(input2).collect()
-    )
+    val n = 100
+    val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) 
}.toArray
+    val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) 
}.toArray
+    val leftNode = new DummyNode(kvIntAttributes, leftData)
+    val rightNode = new DummyNode(kvIntAttributes, rightData)
+    val intersectNode = new IntersectNode(conf, leftNode, rightNode)
+    val expectedOutput = leftData.intersect(rightData)
+    val actualOutput = intersectNode.collect().map { case row =>
+      (row.getInt(0), row.getInt(1))
+    }
+    assert(actualOutput === expectedOutput)
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
index 3b18390..fb79063 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
@@ -17,23 +17,25 @@
 
 package org.apache.spark.sql.execution.local
 
-import org.apache.spark.sql.test.SharedSQLContext
 
-class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
+class LimitNodeSuite extends LocalNodeTest {
 
-  test("basic") {
-    checkAnswer(
-      testData,
-      node => LimitNode(conf, 10, node),
-      testData.limit(10).collect()
-    )
+  private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int 
= 10): Unit = {
+    val inputNode = new DummyNode(kvIntAttributes, inputData)
+    val limitNode = new LimitNode(conf, limit, inputNode)
+    val expectedOutput = inputData.take(limit)
+    val actualOutput = limitNode.collect().map { case row =>
+      (row.getInt(0), row.getInt(1))
+    }
+    assert(actualOutput === expectedOutput)
   }
 
   test("empty") {
-    checkAnswer(
-      emptyTestData,
-      node => LimitNode(conf, 10, node),
-      emptyTestData.limit(10).collect()
-    )
+    testLimit()
   }
+
+  test("basic") {
+    testLimit((1 to 100).map { i => (i, i) }.toArray, 20)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
index b89fa46..0d1ed99 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
@@ -17,28 +17,24 @@
 
 package org.apache.spark.sql.execution.local
 
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.IntegerType
 
-class LocalNodeSuite extends SparkFunSuite {
-  private val data = (1 to 100).toArray
+class LocalNodeSuite extends LocalNodeTest {
+  private val data = (1 to 100).map { i => (i, i) }.toArray
 
   test("basic open, next, fetch, close") {
-    val node = new DummyLocalNode(data)
+    val node = new DummyNode(kvIntAttributes, data)
     assert(!node.isOpen)
     node.open()
     assert(node.isOpen)
-    data.foreach { i =>
+    data.foreach { case (k, v) =>
       assert(node.next())
       // fetch should be idempotent
       val fetched = node.fetch()
       assert(node.fetch() === fetched)
       assert(node.fetch() === fetched)
-      assert(node.fetch().numFields === 1)
-      assert(node.fetch().getInt(0) === i)
+      assert(node.fetch().numFields === 2)
+      assert(node.fetch().getInt(0) === k)
+      assert(node.fetch().getInt(1) === v)
     }
     assert(!node.next())
     node.close()
@@ -46,16 +42,17 @@ class LocalNodeSuite extends SparkFunSuite {
   }
 
   test("asIterator") {
-    val node = new DummyLocalNode(data)
+    val node = new DummyNode(kvIntAttributes, data)
     val iter = node.asIterator
     node.open()
-    data.foreach { i =>
+    data.foreach { case (k, v) =>
       // hasNext should be idempotent
       assert(iter.hasNext)
       assert(iter.hasNext)
       val item = iter.next()
-      assert(item.numFields === 1)
-      assert(item.getInt(0) === i)
+      assert(item.numFields === 2)
+      assert(item.getInt(0) === k)
+      assert(item.getInt(1) === v)
     }
     intercept[NoSuchElementException] {
       iter.next()
@@ -64,53 +61,13 @@ class LocalNodeSuite extends SparkFunSuite {
   }
 
   test("collect") {
-    val node = new DummyLocalNode(data)
+    val node = new DummyNode(kvIntAttributes, data)
     node.open()
     val collected = node.collect()
     assert(collected.size === data.size)
-    assert(collected.forall(_.size === 1))
-    assert(collected.map(_.getInt(0)) === data)
+    assert(collected.forall(_.size === 2))
+    assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === 
data)
     node.close()
   }
 
 }
-
-/**
- * A dummy [[LocalNode]] that just returns one row per integer in the input.
- */
-private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends 
LocalNode(conf) {
-  private var index = Int.MinValue
-
-  def this(input: Array[Int]) {
-    this(new SQLConf, input)
-  }
-
-  def isOpen: Boolean = {
-    index != Int.MinValue
-  }
-
-  override def output: Seq[Attribute] = {
-    Seq(AttributeReference("something", IntegerType)())
-  }
-
-  override def children: Seq[LocalNode] = Seq.empty
-
-  override def open(): Unit = {
-    index = -1
-  }
-
-  override def next(): Boolean = {
-    index += 1
-    index < input.size
-  }
-
-  override def fetch(): InternalRow = {
-    assert(index >= 0 && index < input.size)
-    val values = Array(input(index).asInstanceOf[Any])
-    new GenericInternalRow(values)
-  }
-
-  override def close(): Unit = {
-    index = Int.MinValue
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
index 86dd280..098050b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
@@ -17,147 +17,54 @@
 
 package org.apache.spark.sql.execution.local
 
-import scala.util.control.NonFatal
-
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{DataFrame, Row, SQLConf}
-import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.types.{IntegerType, StringType}
 
-class LocalNodeTest extends SparkFunSuite with SharedSQLContext {
 
-  def conf: SQLConf = sqlContext.conf
+class LocalNodeTest extends SparkFunSuite {
 
-  protected def wrapForUnsafe(
-      f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => 
LocalNode = {
-    if (conf.unsafeEnabled) {
-      (left: LocalNode, right: LocalNode) => {
-        val _left = ConvertToUnsafeNode(conf, left)
-        val _right = ConvertToUnsafeNode(conf, right)
-        val r = f(_left, _right)
-        ConvertToSafeNode(conf, r)
-      }
-    } else {
-      f
-    }
-  }
-
-  /**
-   * Runs the LocalNode and makes sure the answer matches the expected result.
-   * @param input the input data to be used.
-   * @param nodeFunction a function which accepts the input LocalNode and uses 
it to instantiate
-   *                     the local physical operator that's being tested.
-   * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
-   * @param sortAnswers if true, the answers will be sorted by their toString 
representations prior
-   *                    to being compared.
-   */
-  protected def checkAnswer(
-      input: DataFrame,
-      nodeFunction: LocalNode => LocalNode,
-      expectedAnswer: Seq[Row],
-      sortAnswers: Boolean = true): Unit = {
-    doCheckAnswer(
-      input :: Nil,
-      nodes => nodeFunction(nodes.head),
-      expectedAnswer,
-      sortAnswers)
-  }
-
-  /**
-   * Runs the LocalNode and makes sure the answer matches the expected result.
-   * @param left the left input data to be used.
-   * @param right the right input data to be used.
-   * @param nodeFunction a function which accepts the input LocalNode and uses 
it to instantiate
-   *                     the local physical operator that's being tested.
-   * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
-   * @param sortAnswers if true, the answers will be sorted by their toString 
representations prior
-   *                    to being compared.
-   */
-  protected def checkAnswer2(
-      left: DataFrame,
-      right: DataFrame,
-      nodeFunction: (LocalNode, LocalNode) => LocalNode,
-      expectedAnswer: Seq[Row],
-      sortAnswers: Boolean = true): Unit = {
-    doCheckAnswer(
-      left :: right :: Nil,
-      nodes => nodeFunction(nodes(0), nodes(1)),
-      expectedAnswer,
-      sortAnswers)
-  }
+  protected val conf: SQLConf = new SQLConf
+  protected val kvIntAttributes = Seq(
+    AttributeReference("k", IntegerType)(),
+    AttributeReference("v", IntegerType)())
+  protected val joinNameAttributes = Seq(
+    AttributeReference("id1", IntegerType)(),
+    AttributeReference("name", StringType)())
+  protected val joinNicknameAttributes = Seq(
+    AttributeReference("id2", IntegerType)(),
+    AttributeReference("nickname", StringType)())
 
   /**
-   * Runs the `LocalNode`s and makes sure the answer matches the expected 
result.
-   * @param input the input data to be used.
-   * @param nodeFunction a function which accepts a sequence of input 
`LocalNode`s and uses them to
-   *                     instantiate the local physical operator that's being 
tested.
-   * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
-   * @param sortAnswers if true, the answers will be sorted by their toString 
representations prior
-   *                    to being compared.
+   * Wrap a function processing two [[LocalNode]]s such that:
+   *   (1) all input rows are automatically converted to unsafe rows
+   *   (2) all output rows are automatically converted back to safe rows
    */
-  protected def doCheckAnswer(
-    input: Seq[DataFrame],
-    nodeFunction: Seq[LocalNode] => LocalNode,
-    expectedAnswer: Seq[Row],
-    sortAnswers: Boolean = true): Unit = {
-    LocalNodeTest.checkAnswer(
-      input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, 
sortAnswers) match {
-      case Some(errorMessage) => fail(errorMessage)
-      case None =>
+  protected def wrapForUnsafe(
+      f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => 
LocalNode = {
+    (left: LocalNode, right: LocalNode) => {
+      val _left = ConvertToUnsafeNode(conf, left)
+      val _right = ConvertToUnsafeNode(conf, right)
+      val r = f(_left, _right)
+      ConvertToSafeNode(conf, r)
     }
   }
 
-  protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = {
-    new SeqScanNode(
-      conf,
-      df.queryExecution.sparkPlan.output,
-      df.queryExecution.toRdd.map(_.copy()).collect())
-  }
-
-}
-
-/**
- * Helper methods for writing tests of individual local physical operators.
- */
-object LocalNodeTest {
-
   /**
-   * Runs the `LocalNode`s and makes sure the answer matches the expected 
result.
-   * @param input the input data to be used.
-   * @param nodeFunction a function which accepts the input `LocalNode`s and 
uses them to
-   *                     instantiate the local physical operator that's being 
tested.
-   * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
-   * @param sortAnswers if true, the answers will be sorted by their toString 
representations prior
-   *                    to being compared.
+   * Recursively resolve all expressions in a [[LocalNode]] using the node's 
attributes.
    */
-  def checkAnswer(
-    input: Seq[SeqScanNode],
-    nodeFunction: Seq[LocalNode] => LocalNode,
-    expectedAnswer: Seq[Row],
-    sortAnswers: Boolean): Option[String] = {
-
-    val outputNode = nodeFunction(input)
-
-    val outputResult: Seq[Row] = try {
-      outputNode.collect()
-    } catch {
-      case NonFatal(e) =>
-        val errorMessage =
-          s"""
-              | Exception thrown while executing local plan:
-              | $outputNode
-              | == Exception ==
-              | $e
-              | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
-          """.stripMargin
-        return Some(errorMessage)
-    }
-
-    SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map 
{ errorMessage =>
-      s"""
-          | Results do not match for local plan:
-          | $outputNode
-          | $errorMessage
-       """.stripMargin
+  protected def resolveExpressions(outputNode: LocalNode): LocalNode = {
+    outputNode transform {
+      case node: LocalNode =>
+        val inputMap = node.output.map { a => (a.name, a) }.toMap
+        node transformExpressions {
+          case UnresolvedAttribute(Seq(u)) =>
+            inputMap.getOrElse(u,
+              sys.error(s"Invalid Test: Cannot resolve $u given input 
$inputMap"))
+        }
     }
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
index b1ef26b..40299d9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
@@ -18,222 +18,128 @@
 package org.apache.spark.sql.execution.local
 
 import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, 
RightOuter}
 import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
 
+
 class NestedLoopJoinNodeSuite extends LocalNodeTest {
 
-  import testImplicits._
-
-  private def joinSuite(
-      suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): 
Unit = {
-    test(s"$suiteName: left outer join") {
-      withSQLConf(confPairs: _*) {
-        checkAnswer2(
-          upperCaseData,
-          lowerCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              LeftOuter,
-              Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr))
-          ),
-          upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect())
-
-        checkAnswer2(
-          upperCaseData,
-          lowerCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              LeftOuter,
-              Some(
-                (upperCaseData.col("N") === lowerCaseData.col("n") &&
-                  lowerCaseData.col("n") > 1).expr))
-          ),
-          upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, 
"left").collect())
-
-        checkAnswer2(
-          upperCaseData,
-          lowerCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              LeftOuter,
-              Some(
-                (upperCaseData.col("N") === lowerCaseData.col("n") &&
-                  upperCaseData.col("N") > 1).expr))
-          ),
-          upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, 
"left").collect())
-
-        checkAnswer2(
-          upperCaseData,
-          lowerCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              LeftOuter,
-              Some(
-                (upperCaseData.col("N") === lowerCaseData.col("n") &&
-                  lowerCaseData.col("l") > upperCaseData.col("L")).expr))
-          ),
-          upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", 
"left").collect())
+  // Test all combinations of the three dimensions: with/out unsafe, build 
sides, and join types
+  private val maybeUnsafeAndCodegen = Seq(false, true)
+  private val buildSides = Seq(BuildLeft, BuildRight)
+  private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter)
+  maybeUnsafeAndCodegen.foreach { unsafeAndCodegen =>
+    buildSides.foreach { buildSide =>
+      joinTypes.foreach { joinType =>
+        testJoin(unsafeAndCodegen, buildSide, joinType)
       }
     }
+  }
 
-    test(s"$suiteName: right outer join") {
-      withSQLConf(confPairs: _*) {
-        checkAnswer2(
-          lowerCaseData,
-          upperCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              RightOuter,
-              Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
-          ),
-          lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect())
-
-        checkAnswer2(
-          lowerCaseData,
-          upperCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              RightOuter,
-              Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
-                lowerCaseData.col("n") > 1).expr))
-          ),
-          lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, 
"right").collect())
-
-        checkAnswer2(
-          lowerCaseData,
-          upperCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              RightOuter,
-              Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
-                upperCaseData.col("N") > 1).expr))
-          ),
-          lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, 
"right").collect())
-
-        checkAnswer2(
-          lowerCaseData,
-          upperCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              RightOuter,
-              Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
-                lowerCaseData.col("l") > upperCaseData.col("L")).expr))
-          ),
-          lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", 
"right").collect())
+  /**
+   * Test outer nested loop joins with varying degrees of matches.
+   */
+  private def testJoin(
+      unsafeAndCodegen: Boolean,
+      buildSide: BuildSide,
+      joinType: JoinType): Unit = {
+    val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe"
+    val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType"
+    val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray
+    val conf = new SQLConf
+    conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen)
+    conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen)
+
+    // Actual test body
+    def runTest(
+        joinType: JoinType,
+        leftInput: Array[(Int, String)],
+        rightInput: Array[(Int, String)]): Unit = {
+      val leftNode = new DummyNode(joinNameAttributes, leftInput)
+      val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
+      val cond = 'id1 === 'id2
+      val makeNode = (node1: LocalNode, node2: LocalNode) => {
+        resolveExpressions(
+          new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, 
Some(cond)))
       }
+      val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else 
makeNode
+      val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
+      val expectedOutput = generateExpectedOutput(leftInput, rightInput, 
joinType)
+      val actualOutput = hashJoinNode.collect().map { row =>
+        // (id, name, id, nickname)
+        (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
+      }
+      assert(actualOutput.toSet === expectedOutput.toSet)
     }
 
-    test(s"$suiteName: full outer join") {
-      withSQLConf(confPairs: _*) {
-        checkAnswer2(
-          lowerCaseData,
-          upperCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              FullOuter,
-              Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
-          ),
-          lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect())
-
-        checkAnswer2(
-          lowerCaseData,
-          upperCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              FullOuter,
-              Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
-                lowerCaseData.col("n") > 1).expr))
-          ),
-          lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, 
"full").collect())
-
-        checkAnswer2(
-          lowerCaseData,
-          upperCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              FullOuter,
-              Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
-                upperCaseData.col("N") > 1).expr))
-          ),
-          lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, 
"full").collect())
-
-        checkAnswer2(
-          lowerCaseData,
-          upperCaseData,
-          wrapForUnsafe(
-            (node1, node2) => NestedLoopJoinNode(
-              conf,
-              node1,
-              node2,
-              buildSide,
-              FullOuter,
-              Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
-                lowerCaseData.col("l") > upperCaseData.col("L")).expr))
-          ),
-          lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", 
"full").collect())
-      }
+    test(s"$testNamePrefix: empty") {
+      runTest(joinType, Array.empty, Array.empty)
+    }
+
+    test(s"$testNamePrefix: no matches") {
+      val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) 
}.toArray
+      runTest(joinType, someData, Array.empty)
+      runTest(joinType, Array.empty, someData)
+      runTest(joinType, someData, someIrrelevantData)
+      runTest(joinType, someIrrelevantData, someData)
+    }
+
+    test(s"$testNamePrefix: partial matches") {
+      val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray
+      runTest(joinType, someData, someOtherData)
+      runTest(joinType, someOtherData, someData)
+    }
+
+    test(s"$testNamePrefix: full matches") {
+      val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + 
v) }
+      runTest(joinType, someData, someSuperRelevantData)
+      runTest(joinType, someSuperRelevantData, someData)
+    }
+  }
+
+  /**
+   * Helper method to generate the expected output of a test based on the join 
type.
+   */
+  private def generateExpectedOutput(
+      leftInput: Array[(Int, String)],
+      rightInput: Array[(Int, String)],
+      joinType: JoinType): Array[(Int, String, Int, String)] = {
+    joinType match {
+      case LeftOuter =>
+        val rightInputMap = rightInput.toMap
+        leftInput.map { case (k, v) =>
+          val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
+          val rightValue = rightInputMap.getOrElse(k, null)
+          (k, v, rightKey, rightValue)
+        }
+
+      case RightOuter =>
+        val leftInputMap = leftInput.toMap
+        rightInput.map { case (k, v) =>
+          val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
+          val leftValue = leftInputMap.getOrElse(k, null)
+          (leftKey, leftValue, k, v)
+        }
+
+      case FullOuter =>
+        val leftInputMap = leftInput.toMap
+        val rightInputMap = rightInput.toMap
+        val leftOutput = leftInput.map { case (k, v) =>
+          val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
+          val rightValue = rightInputMap.getOrElse(k, null)
+          (k, v, rightKey, rightValue)
+        }
+        val rightOutput = rightInput.map { case (k, v) =>
+          val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
+          val leftValue = leftInputMap.getOrElse(k, null)
+          (leftKey, leftValue, k, v)
+        }
+        (leftOutput ++ rightOutput).distinct
+
+      case other =>
+        throw new IllegalArgumentException(s"Join type $other is not 
applicable")
     }
   }
 
-  joinSuite(
-    "general-build-left",
-    BuildLeft,
-    SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> 
"false")
-  joinSuite(
-    "general-build-right",
-    BuildRight,
-    SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> 
"false")
-  joinSuite(
-    "tungsten-build-left",
-    BuildLeft,
-    SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> 
"true")
-  joinSuite(
-    "tungsten-build-right",
-    BuildRight,
-    SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> 
"true")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
index 38e0a23..02ecb23 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
@@ -17,28 +17,33 @@
 
 package org.apache.spark.sql.execution.local
 
-import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
NamedExpression}
+import org.apache.spark.sql.types.{IntegerType, StringType}
 
-class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext {
 
-  test("basic") {
-    val output = testData.queryExecution.sparkPlan.output
-    val columns = Seq(output(1), output(0))
-    checkAnswer(
-      testData,
-      node => ProjectNode(conf, columns, node),
-      testData.select("value", "key").collect()
-    )
+class ProjectNodeSuite extends LocalNodeTest {
+  private val pieAttributes = Seq(
+    AttributeReference("id", IntegerType)(),
+    AttributeReference("age", IntegerType)(),
+    AttributeReference("name", StringType)())
+
+  private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): 
Unit = {
+    val inputNode = new DummyNode(pieAttributes, inputData)
+    val columns = Seq[NamedExpression](inputNode.output(0), 
inputNode.output(2))
+    val projectNode = new ProjectNode(conf, columns, inputNode)
+    val expectedOutput = inputData.map { case (id, age, name) => (id, name) }
+    val actualOutput = projectNode.collect().map { case row =>
+      (row.getInt(0), row.getString(1))
+    }
+    assert(actualOutput === expectedOutput)
   }
 
   test("empty") {
-    val output = emptyTestData.queryExecution.sparkPlan.output
-    val columns = Seq(output(1), output(0))
-    checkAnswer(
-      emptyTestData,
-      node => ProjectNode(conf, columns, node),
-      emptyTestData.select("value", "key").collect()
-    )
+    testProject()
+  }
+
+  test("basic") {
+    testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
index 87a7da4..a3e83bb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
@@ -17,21 +17,32 @@
 
 package org.apache.spark.sql.execution.local
 
-class SampleNodeSuite extends LocalNodeTest {
+import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
+
 
-  import testImplicits._
+class SampleNodeSuite extends LocalNodeTest {
 
   private def testSample(withReplacement: Boolean): Unit = {
-    test(s"withReplacement: $withReplacement") {
-      val seed = 0L
-      val input = sqlContext.sparkContext.
-        parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 
1 partition
-        toDF("key", "value")
-      checkAnswer(
-        input,
-        node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
-        input.sample(withReplacement, 0.3, seed).collect()
-      )
+    val seed = 0L
+    val lowerb = 0.0
+    val upperb = 0.3
+    val maybeOut = if (withReplacement) "" else "out"
+    test(s"with$maybeOut replacement") {
+      val inputData = (1 to 1000).map { i => (i, i) }.toArray
+      val inputNode = new DummyNode(kvIntAttributes, inputData)
+      val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, 
seed, inputNode)
+      val sampler =
+        if (withReplacement) {
+          new PoissonSampler[(Int, Int)](upperb - lowerb, 
useGapSamplingIfPossible = false)
+        } else {
+          new BernoulliCellSampler[(Int, Int)](lowerb, upperb)
+        }
+      sampler.setSeed(seed)
+      val expectedOutput = sampler.sample(inputData.iterator).toArray
+      val actualOutput = sampleNode.collect().map { case row =>
+        (row.getInt(0), row.getInt(1))
+      }
+      assert(actualOutput === expectedOutput)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
index ff28b24..42ebc7b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
@@ -17,38 +17,34 @@
 
 package org.apache.spark.sql.execution.local
 
-import org.apache.spark.sql.Column
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, 
SortOrder}
+import scala.util.Random
 
-class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.SortOrder
 
-  import testImplicits._
 
-  private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
-    val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
-      col.expr match {
-        case expr: SortOrder =>
-          expr
-        case expr: Expression =>
-          SortOrder(expr, Ascending)
-      }
-    }
-    sortOrder
-  }
+class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
 
-  private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
-    val testCaseName = if (desc) "desc" else "asc"
-    test(testCaseName) {
-      val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
-      val sortColumn = if (desc) input.col("key").desc else input.col("key")
-      checkAnswer(
-        input,
-        node => TakeOrderedAndProjectNode(conf, 5, 
columnToSortOrder(sortColumn), None, node),
-        input.sort(sortColumn).limit(5).collect()
-      )
+  private def testTakeOrderedAndProject(desc: Boolean): Unit = {
+    val limit = 10
+    val ascOrDesc = if (desc) "desc" else "asc"
+    test(ascOrDesc) {
+      val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) 
}.toArray
+      val inputNode = new DummyNode(kvIntAttributes, inputData)
+      val firstColumn = inputNode.output(0)
+      val sortDirection = if (desc) Descending else Ascending
+      val sortOrder = SortOrder(firstColumn, sortDirection)
+      val takeOrderAndProjectNode = new TakeOrderedAndProjectNode(
+        conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode)
+      val expectedOutput = inputData
+        .map { case (k, _) => k }
+        .sortBy { k => k * (if (desc) -1 else 1) }
+        .take(limit)
+      val actualOutput = takeOrderAndProjectNode.collect().map { row => 
row.getInt(0) }
+      assert(actualOutput === expectedOutput)
     }
   }
 
-  testTakeOrderedAndProjectNode(desc = false)
-  testTakeOrderedAndProjectNode(desc = true)
+  testTakeOrderedAndProject(desc = false)
+  testTakeOrderedAndProject(desc = true)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35a19f33/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
index eedd732..666b023 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
@@ -17,36 +17,39 @@
 
 package org.apache.spark.sql.execution.local
 
-import org.apache.spark.sql.test.SharedSQLContext
 
-class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
+class UnionNodeSuite extends LocalNodeTest {
 
-  test("basic") {
-    checkAnswer2(
-      testData,
-      testData,
-      (node1, node2) => UnionNode(conf, Seq(node1, node2)),
-      testData.unionAll(testData).collect()
-    )
+  private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = {
+    val inputNodes = inputData.map { data =>
+      new DummyNode(kvIntAttributes, data)
+    }
+    val unionNode = new UnionNode(conf, inputNodes)
+    val expectedOutput = inputData.flatten
+    val actualOutput = unionNode.collect().map { case row =>
+      (row.getInt(0), row.getInt(1))
+    }
+    assert(actualOutput === expectedOutput)
   }
 
   test("empty") {
-    checkAnswer2(
-      emptyTestData,
-      emptyTestData,
-      (node1, node2) => UnionNode(conf, Seq(node1, node2)),
-      emptyTestData.unionAll(emptyTestData).collect()
-    )
+    testUnion(Seq(Array.empty))
+    testUnion(Seq(Array.empty, Array.empty))
+  }
+
+  test("self") {
+    val data = (1 to 100).map { i => (i, i) }.toArray
+    testUnion(Seq(data))
+    testUnion(Seq(data, data))
+    testUnion(Seq(data, data, data))
   }
 
-  test("complicated union") {
-    val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, 
emptyTestData,
-      emptyTestData, emptyTestData, testData, emptyTestData)
-    doCheckAnswer(
-      dfs,
-      nodes => UnionNode(conf, nodes),
-      dfs.reduce(_.unionAll(_)).collect()
-    )
+  test("basic") {
+    val zero = Array.empty[(Int, Int)]
+    val one = (1 to 100).map { i => (i, i) }.toArray
+    val two = (50 to 150).map { i => (i, i) }.toArray
+    val three = (800 to 900).map { i => (i, i) }.toArray
+    testUnion(Seq(zero, one, two, three))
   }
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to