Repository: spark
Updated Branches:
  refs/heads/master 5828f41a5 -> 10f1f1965


[SPARK-21274][SQL] Implement EXCEPT ALL clause.

## What changes were proposed in this pull request?
Implements EXCEPT ALL clause through query rewrites using existing operators in 
Spark. In this PR, an internal UDTF (replicate_rows) is added to aid in 
preserving duplicate rows. Please refer to 
[Link](https://drive.google.com/open?id=1nyW0T0b_ajUduQoPgZLAsyHK8s3_dko3ulQuxaLpUXE)
 for the design.

**Note** This proposed UDTF is kept as a internal function that is purely used 
to aid with this particular rewrite to give us flexibility to change to a more 
generalized UDTF in future.

Input Query
``` SQL
SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2
```
Rewritten Query
```SQL
SELECT c1
    FROM (
     SELECT replicate_rows(sum_val, c1)
       FROM (
         SELECT c1, sum_val
           FROM (
             SELECT c1, sum(vcol) AS sum_val
               FROM (
                 SELECT 1L as vcol, c1 FROM ut1
                 UNION ALL
                 SELECT -1L as vcol, c1 FROM ut2
              ) AS union_all
            GROUP BY union_all.c1
          )
        WHERE sum_val > 0
       )
   )
```

## How was this patch tested?
Added test cases in SQLQueryTestSuite, DataFrameSuite and SetOperationSuite

Author: Dilip Biswal <dbis...@us.ibm.com>

Closes #21857 from dilipbiswal/dkb_except_all_final.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/10f1f196
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/10f1f196
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/10f1f196

Branch: refs/heads/master
Commit: 10f1f196595df66cb82d1fb9e27cc7ef0a176766
Parents: 5828f41
Author: Dilip Biswal <dbis...@us.ibm.com>
Authored: Fri Jul 27 13:47:33 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Fri Jul 27 13:47:33 2018 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 |  25 ++
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   5 +-
 .../sql/catalyst/analysis/TypeCoercion.scala    |  12 +-
 .../analysis/UnsupportedOperationChecker.scala  |   2 +-
 .../sql/catalyst/expressions/generators.scala   |  26 ++
 .../sql/catalyst/optimizer/Optimizer.scala      |  61 +++-
 .../optimizer/ReplaceExceptWithFilter.scala     |   2 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala  |   2 +-
 .../plans/logical/basicLogicalOperators.scala   |   7 +-
 .../catalyst/optimizer/SetOperationSuite.scala  |  24 +-
 .../sql/catalyst/parser/ErrorParserSuite.scala  |   3 -
 .../sql/catalyst/parser/PlanParserSuite.scala   |   1 -
 .../scala/org/apache/spark/sql/Dataset.scala    |  16 +
 .../spark/sql/execution/SparkStrategies.scala   |   6 +-
 .../resources/sql-tests/inputs/except-all.sql   | 146 +++++++++
 .../sql-tests/results/except-all.sql.out        | 319 +++++++++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala   |  70 +++-
 17 files changed, 708 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index c40aea9..b2e0a5b 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -293,6 +293,31 @@ class DataFrame(object):
         else:
             print(self._jdf.queryExecution().simpleString())
 
+    @since(2.4)
+    def exceptAll(self, other):
+        """Return a new :class:`DataFrame` containing rows in this 
:class:`DataFrame` but
+        not in another :class:`DataFrame` while preserving duplicates.
+
+        This is equivalent to `EXCEPT ALL` in SQL.
+
+        >>> df1 = spark.createDataFrame(
+        ...         [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b",  3), ("c", 
4)], ["C1", "C2"])
+        >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"])
+
+        >>> df1.exceptAll(df2).show()
+        +---+---+
+        | C1| C2|
+        +---+---+
+        |  a|  1|
+        |  a|  1|
+        |  a|  2|
+        |  c|  4|
+        +---+---+
+
+        Also as standard in SQL, this function resolves columns by position 
(not by name).
+        """
+        return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx)
+
     @since(1.3)
     def isLocal(self):
         """Returns ``True`` if the :func:`collect` and :func:`take` methods 
can be run locally

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d18509f..8abb1c7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -916,9 +916,8 @@ class Analyzer(
         j.copy(right = dedupRight(left, right))
       case i @ Intersect(left, right) if !i.duplicateResolved =>
         i.copy(right = dedupRight(left, right))
-      case i @ Except(left, right) if !i.duplicateResolved =>
-        i.copy(right = dedupRight(left, right))
-
+      case e @ Except(left, right, _) if !e.duplicateResolved =>
+        e.copy(right = dedupRight(left, right))
       // When resolve `SortOrder`s in Sort based on child, don't report errors 
as
       // we still have chance to resolve it based on its descendants
       case s @ Sort(ordering, global, child) if child.resolved && !s.resolved 
=>

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 6bdb639..f9edca5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -319,11 +319,17 @@ object TypeCoercion {
   object WidenSetOperationTypes extends Rule[LogicalPlan] {
 
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-      case s @ SetOperation(left, right) if s.childrenResolved &&
-          left.output.length == right.output.length && !s.resolved =>
+      case s @ Except(left, right, isAll) if s.childrenResolved &&
+        left.output.length == right.output.length && !s.resolved =>
         val newChildren: Seq[LogicalPlan] = 
buildNewChildrenWithWiderTypes(left :: right :: Nil)
         assert(newChildren.length == 2)
-        s.makeCopy(Array(newChildren.head, newChildren.last))
+        Except(newChildren.head, newChildren.last, isAll)
+
+      case s @ Intersect(left, right) if s.childrenResolved &&
+        left.output.length == right.output.length && !s.resolved =>
+        val newChildren: Seq[LogicalPlan] = 
buildNewChildrenWithWiderTypes(left :: right :: Nil)
+        assert(newChildren.length == 2)
+        Intersect(newChildren.head, newChildren.last)
 
       case s: Union if s.childrenResolved &&
           s.children.forall(_.output.length == s.children.head.output.length) 
&& !s.resolved =>

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index f68df5d..c9a3ee4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -306,7 +306,7 @@ object UnsupportedOperationChecker {
         case u: Union if u.children.map(_.isStreaming).distinct.size == 2 =>
           throwError("Union between streaming and batch DataFrames/Datasets is 
not supported")
 
-        case Except(left, right) if right.isStreaming =>
+        case Except(left, right, _) if right.isStreaming =>
           throwError("Except on a streaming DataFrame/Dataset on the right is 
not supported")
 
         case Intersect(left, right) if left.isStreaming && right.isStreaming =>

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index b6e0d36..d6e67b9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -224,6 +224,32 @@ case class Stack(children: Seq[Expression]) extends 
Generator {
 }
 
 /**
+ * Replicate the row N times. N is specified as the first argument to the 
function.
+ * This is an internal function solely used by optimizer to rewrite EXCEPT ALL 
AND
+ * INTERSECT ALL queries.
+ */
+case class ReplicateRows(children: Seq[Expression]) extends Generator with 
CodegenFallback {
+  private lazy val numColumns = children.length - 1 // remove the multiplier 
value from output.
+
+  override def elementSchema: StructType =
+    StructType(children.tail.zipWithIndex.map {
+      case (e, index) => StructField(s"col$index", e.dataType)
+    })
+
+  override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
+    val numRows = children.head.eval(input).asInstanceOf[Long]
+    val values = children.tail.map(_.eval(input)).toArray
+    Range.Long(0, numRows, 1).map { _ =>
+      val fields = new Array[Any](numColumns)
+      for (col <- 0 until numColumns) {
+        fields.update(col, values(col))
+      }
+      InternalRow(fields: _*)
+    }
+  }
+}
+
+/**
  * Wrapper around another generator to specify outer behavior. This is used to 
implement functions
  * such as explode_outer. This expression gets replaced during analysis.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 3c264eb..193f659 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -135,6 +135,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
     Batch("Subquery", Once,
       OptimizeSubqueries) ::
     Batch("Replace Operators", fixedPoint,
+      RewriteExcepAll,
       ReplaceIntersectWithSemiJoin,
       ReplaceExceptWithFilter,
       ReplaceExceptWithAntiJoin,
@@ -1422,7 +1423,7 @@ object ReplaceIntersectWithSemiJoin extends 
Rule[LogicalPlan] {
  */
 object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case Except(left, right) =>
+    case Except(left, right, false) =>
       assert(left.output.size == right.output.size)
       val joinCond = left.output.zip(right.output).map { case (l, r) => 
EqualNullSafe(l, r) }
       Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And)))
@@ -1430,6 +1431,64 @@ object ReplaceExceptWithAntiJoin extends 
Rule[LogicalPlan] {
 }
 
 /**
+ * Replaces logical [[Except]] operator using a combination of Union, Aggregate
+ * and Generate operator.
+ *
+ * Input Query :
+ * {{{
+ *    SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2
+ * }}}
+ *
+ * Rewritten Query:
+ * {{{
+ *   SELECT c1
+ *   FROM (
+ *     SELECT replicate_rows(sum_val, c1)
+ *       FROM (
+ *         SELECT c1, sum_val
+ *           FROM (
+ *             SELECT c1, sum(vcol) AS sum_val
+ *               FROM (
+ *                 SELECT 1L as vcol, c1 FROM ut1
+ *                 UNION ALL
+ *                 SELECT -1L as vcol, c1 FROM ut2
+ *              ) AS union_all
+ *            GROUP BY union_all.c1
+ *          )
+ *        WHERE sum_val > 0
+ *       )
+ *   )
+ * }}}
+ */
+
+object RewriteExcepAll extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case Except(left, right, true) =>
+      assert(left.output.size == right.output.size)
+
+      val newColumnLeft = Alias(Literal(1L), "vcol")()
+      val newColumnRight = Alias(Literal(-1L), "vcol")()
+      val modifiedLeftPlan = Project(Seq(newColumnLeft) ++ left.output, left)
+      val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, 
right)
+      val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan)
+      val aggSumCol =
+        Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), 
Complete, false), "sum")()
+      val aggOutputColumns = left.output ++ Seq(aggSumCol)
+      val aggregatePlan = Aggregate(left.output, aggOutputColumns, unionPlan)
+      val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute, 
Literal(0L)), aggregatePlan)
+      val genRowPlan = Generate(
+        ReplicateRows(Seq(aggSumCol.toAttribute) ++ left.output),
+        unrequiredChildIndex = Nil,
+        outer = false,
+        qualifier = None,
+        left.output,
+        filteredAggPlan
+      )
+      Project(left.output, genRowPlan)
+  }
+}
+
+/**
  * Removes literals from group expressions in [[Aggregate]], as they have no 
effect to the result
  * but only makes the grouping key bigger.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
index 45edf26..efd3944 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
@@ -46,7 +46,7 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
     }
 
     plan.transform {
-      case e @ Except(left, right) if isEligible(left, right) =>
+      case e @ Except(left, right, false) if isEligible(left, right) =>
         val newCondition = transformCondition(left, skipProject(right))
         newCondition.map { c =>
           Distinct(Filter(Not(c), left))

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 49f578a..8b3c068 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -537,7 +537,7 @@ class AstBuilder(conf: SQLConf) extends 
SqlBaseBaseVisitor[AnyRef] with Logging
       case SqlBaseParser.INTERSECT =>
         Intersect(left, right)
       case SqlBaseParser.EXCEPT if all =>
-        throw new ParseException("EXCEPT ALL is not supported.", ctx)
+        Except(left, right, isAll = true)
       case SqlBaseParser.EXCEPT =>
         Except(left, right)
       case SqlBaseParser.SETMINUS if all =>

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index ea5a9b8..498a13a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -183,8 +183,11 @@ case class Intersect(left: LogicalPlan, right: 
LogicalPlan) extends SetOperation
   }
 }
 
-case class Except(left: LogicalPlan, right: LogicalPlan) extends 
SetOperation(left, right) {
-
+case class Except(
+    left: LogicalPlan,
+    right: LogicalPlan,
+    isAll: Boolean = false) extends SetOperation(left, right) {
+  override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) 
"All" else "" )
   /** We don't use right.output because those rows get excluded from the set. 
*/
   override def output: Seq[Attribute] = left.output
 

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index aa88411..f002aa3 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, 
ReplicateRows}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
@@ -144,4 +144,26 @@ class SetOperationSuite extends PlanTest {
             Distinct(Union(query3 :: query4 :: Nil))).analyze
     comparePlans(distinctUnionCorrectAnswer2, optimized2)
   }
+
+  test("EXCEPT ALL rewrite") {
+    val input = Except(testRelation, testRelation2, isAll = true)
+    val rewrittenPlan = RewriteExcepAll(input)
+
+    val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c)
+      .union(testRelation2.select(Literal(-1L).as("vcol"), 'd, 'e, 'f))
+      .groupBy('a, 'b, 'c)('a, 'b, 'c, sum('vcol).as("sum"))
+      .where(GreaterThan('sum, Literal(0L))).analyze
+    val multiplerAttr = planFragment.output.last
+    val output = planFragment.output.dropRight(1)
+    val expectedPlan = Project(output,
+      Generate(
+        ReplicateRows(Seq(multiplerAttr) ++ output),
+        Nil,
+        false,
+        None,
+        output,
+        planFragment
+      ))
+    comparePlans(expectedPlan, rewrittenPlan)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
index f67697e..baaf018 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
@@ -58,8 +58,5 @@ class ErrorParserSuite extends SparkFunSuite {
     intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0,
       "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not 
supported",
       "^^^")
-    intercept("select * from r except all select * from t", 1, 0,
-      "EXCEPT ALL is not supported",
-      "^^^")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index fb51376..629e3c4 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -65,7 +65,6 @@ class PlanParserSuite extends AnalysisTest {
     assertEqual("select * from a union distinct select * from b", 
Distinct(a.union(b)))
     assertEqual("select * from a union all select * from b", a.union(b))
     assertEqual("select * from a except select * from b", a.except(b))
-    intercept("select * from a except all select * from b", "EXCEPT ALL is not 
supported.")
     assertEqual("select * from a except distinct select * from b", a.except(b))
     assertEqual("select * from a minus select * from b", a.except(b))
     intercept("select * from a minus all select * from b", "MINUS ALL is not 
supported.")

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b63235e..e6a3b0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1949,6 +1949,22 @@ class Dataset[T] private[sql](
   }
 
   /**
+   * Returns a new Dataset containing rows in this Dataset but not in another 
Dataset while
+   * preserving the duplicates.
+   * This is equivalent to `EXCEPT ALL` in SQL.
+   *
+   * @note Equality checking is performed directly on the encoded 
representation of the data
+   * and thus is not affected by a custom `equals` function defined on `T`. 
Also as standard in
+   * SQL, this function resolves columns by position (not by name).
+   *
+   * @group typedrel
+   * @since 2.4.0
+   */
+  def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator {
+    Except(planWithBarrier, other.planWithBarrier, isAll = true)
+  }
+
+  /**
    * Returns a new [[Dataset]] by sampling a fraction of rows (without 
replacement),
    * using a user-supplied seed.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 0c4ea85..3f5fd3d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -532,9 +532,13 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case logical.Intersect(left, right) =>
         throw new IllegalStateException(
           "logical intersect operator should have been replaced by semi-join 
in the optimizer")
-      case logical.Except(left, right) =>
+      case logical.Except(left, right, false) =>
         throw new IllegalStateException(
           "logical except operator should have been replaced by anti-join in 
the optimizer")
+      case logical.Except(left, right, true) =>
+        throw new IllegalStateException(
+          "logical except (all) operator should have been replaced by union, 
aggregate" +
+            "and generate operators in the optimizer")
 
       case logical.DeserializeToObject(deserializer, objAttr, child) =>
         execution.DeserializeToObjectExec(deserializer, objAttr, 
planLater(child)) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/core/src/test/resources/sql-tests/inputs/except-all.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql 
b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql
new file mode 100644
index 0000000..08b9a43
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql
@@ -0,0 +1,146 @@
+CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
+    (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1);
+CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES
+    (1), (2), (2), (3), (5), (5), (null) AS tab2(c1);
+CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES
+    (1, 2), 
+    (1, 2),
+    (1, 3),
+    (2, 3),
+    (2, 2)
+    AS tab3(k, v);
+CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES
+    (1, 2), 
+    (2, 3),
+    (2, 2),
+    (2, 2),
+    (2, 20)
+    AS tab4(k, v);
+
+-- Basic ExceptAll
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT * FROM tab2;
+
+-- ExceptAll same table in both branches
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT * FROM tab2 WHERE c1 IS NOT NULL;
+
+-- Empty left relation
+SELECT * FROM tab1 WHERE c1 > 5
+EXCEPT ALL
+SELECT * FROM tab2;
+
+-- Empty right relation
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT * FROM tab2 WHERE c1 > 6;
+
+-- Type Coerced ExceptAll
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT CAST(1 AS BIGINT);
+
+-- Error as types of two side are not compatible
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT array(1);
+
+-- Basic
+SELECT * FROM tab3
+EXCEPT ALL
+SELECT * FROM tab4;
+
+-- Basic
+SELECT * FROM tab4
+EXCEPT ALL
+SELECT * FROM tab3;
+
+-- ExceptAll + Intersect
+SELECT * FROM tab4
+EXCEPT ALL
+SELECT * FROM tab3
+INTERSECT DISTINCT
+SELECT * FROM tab4;
+
+-- ExceptAll + Except
+SELECT * FROM tab4
+EXCEPT ALL
+SELECT * FROM tab3
+EXCEPT DISTINCT
+SELECT * FROM tab4;
+
+-- Chain of set operations
+SELECT * FROM tab3
+EXCEPT ALL
+SELECT * FROM tab4
+UNION ALL
+SELECT * FROM tab3
+EXCEPT DISTINCT
+SELECT * FROM tab4;
+
+-- Mismatch on number of columns across both branches
+SELECT k FROM tab3
+EXCEPT ALL
+SELECT k, v FROM tab4;
+
+-- Chain of set operations
+SELECT * FROM tab3
+EXCEPT ALL
+SELECT * FROM tab4
+UNION
+SELECT * FROM tab3
+EXCEPT DISTINCT
+SELECT * FROM tab4;
+
+-- Chain of set operations
+SELECT * FROM tab3
+EXCEPT ALL
+SELECT * FROM tab4
+EXCEPT DISTINCT
+SELECT * FROM tab3
+EXCEPT DISTINCT
+SELECT * FROM tab4;
+
+-- Join under except all. Should produce empty resultset since both left and 
right sets 
+-- are same.
+SELECT * 
+FROM   (SELECT tab3.k, 
+               tab4.v 
+        FROM   tab3 
+               JOIN tab4 
+                 ON tab3.k = tab4.k)
+EXCEPT ALL 
+SELECT * 
+FROM   (SELECT tab3.k, 
+               tab4.v 
+        FROM   tab3 
+               JOIN tab4 
+                 ON tab3.k = tab4.k);
+
+-- Join under except all (2)
+SELECT * 
+FROM   (SELECT tab3.k, 
+               tab4.v 
+        FROM   tab3 
+               JOIN tab4 
+                 ON tab3.k = tab4.k) 
+EXCEPT ALL 
+SELECT * 
+FROM   (SELECT tab4.v AS k, 
+               tab3.k AS v 
+        FROM   tab3 
+               JOIN tab4 
+                 ON tab3.k = tab4.k);
+
+-- Group by under ExceptAll
+SELECT v FROM tab3 GROUP BY v
+EXCEPT ALL
+SELECT k FROM tab4 GROUP BY k;
+
+-- Clean-up 
+DROP VIEW IF EXISTS tab1;
+DROP VIEW IF EXISTS tab2;
+DROP VIEW IF EXISTS tab3;
+DROP VIEW IF EXISTS tab4;

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/core/src/test/resources/sql-tests/results/except-all.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out 
b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out
new file mode 100644
index 0000000..2a21c15
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out
@@ -0,0 +1,319 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 25
+
+
+-- !query 0
+CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
+    (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES
+    (1), (2), (2), (3), (5), (5), (null) AS tab2(c1)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES
+    (1, 2), 
+    (1, 2),
+    (1, 3),
+    (2, 3),
+    (2, 2)
+    AS tab3(k, v)
+-- !query 2 schema
+struct<>
+-- !query 2 output
+
+
+
+-- !query 3
+CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES
+    (1, 2), 
+    (2, 3),
+    (2, 2),
+    (2, 2),
+    (2, 20)
+    AS tab4(k, v)
+-- !query 3 schema
+struct<>
+-- !query 3 output
+
+
+
+-- !query 4
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT * FROM tab2
+-- !query 4 schema
+struct<c1:int>
+-- !query 4 output
+0
+2
+2
+NULL
+
+
+-- !query 5
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT * FROM tab2 WHERE c1 IS NOT NULL
+-- !query 5 schema
+struct<c1:int>
+-- !query 5 output
+0
+2
+2
+NULL
+NULL
+
+
+-- !query 6
+SELECT * FROM tab1 WHERE c1 > 5
+EXCEPT ALL
+SELECT * FROM tab2
+-- !query 6 schema
+struct<c1:int>
+-- !query 6 output
+
+
+
+-- !query 7
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT * FROM tab2 WHERE c1 > 6
+-- !query 7 schema
+struct<c1:int>
+-- !query 7 output
+0
+1
+2
+2
+2
+2
+3
+NULL
+NULL
+
+
+-- !query 8
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT CAST(1 AS BIGINT)
+-- !query 8 schema
+struct<c1:bigint>
+-- !query 8 output
+0
+2
+2
+2
+2
+3
+NULL
+NULL
+
+
+-- !query 9
+SELECT * FROM tab1
+EXCEPT ALL
+SELECT array(1)
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+ExceptAll can only be performed on tables with the compatible column types. 
array<int> <> int at the first column of the second table;
+
+
+-- !query 10
+SELECT * FROM tab3
+EXCEPT ALL
+SELECT * FROM tab4
+-- !query 10 schema
+struct<k:int,v:int>
+-- !query 10 output
+1      2
+1      3
+
+
+-- !query 11
+SELECT * FROM tab4
+EXCEPT ALL
+SELECT * FROM tab3
+-- !query 11 schema
+struct<k:int,v:int>
+-- !query 11 output
+2      2
+2      20
+
+
+-- !query 12
+SELECT * FROM tab4
+EXCEPT ALL
+SELECT * FROM tab3
+INTERSECT DISTINCT
+SELECT * FROM tab4
+-- !query 12 schema
+struct<k:int,v:int>
+-- !query 12 output
+2      2
+2      20
+
+
+-- !query 13
+SELECT * FROM tab4
+EXCEPT ALL
+SELECT * FROM tab3
+EXCEPT DISTINCT
+SELECT * FROM tab4
+-- !query 13 schema
+struct<k:int,v:int>
+-- !query 13 output
+
+
+
+-- !query 14
+SELECT * FROM tab3
+EXCEPT ALL
+SELECT * FROM tab4
+UNION ALL
+SELECT * FROM tab3
+EXCEPT DISTINCT
+SELECT * FROM tab4
+-- !query 14 schema
+struct<k:int,v:int>
+-- !query 14 output
+1      3
+
+
+-- !query 15
+SELECT k FROM tab3
+EXCEPT ALL
+SELECT k, v FROM tab4
+-- !query 15 schema
+struct<>
+-- !query 15 output
+org.apache.spark.sql.AnalysisException
+ExceptAll can only be performed on tables with the same number of columns, but 
the first table has 1 columns and the second table has 2 columns;
+
+
+-- !query 16
+SELECT * FROM tab3
+EXCEPT ALL
+SELECT * FROM tab4
+UNION
+SELECT * FROM tab3
+EXCEPT DISTINCT
+SELECT * FROM tab4
+-- !query 16 schema
+struct<k:int,v:int>
+-- !query 16 output
+1      3
+
+
+-- !query 17
+SELECT * FROM tab3
+EXCEPT ALL
+SELECT * FROM tab4
+EXCEPT DISTINCT
+SELECT * FROM tab3
+EXCEPT DISTINCT
+SELECT * FROM tab4
+-- !query 17 schema
+struct<k:int,v:int>
+-- !query 17 output
+
+
+
+-- !query 18
+SELECT * 
+FROM   (SELECT tab3.k, 
+               tab4.v 
+        FROM   tab3 
+               JOIN tab4 
+                 ON tab3.k = tab4.k)
+EXCEPT ALL 
+SELECT * 
+FROM   (SELECT tab3.k, 
+               tab4.v 
+        FROM   tab3 
+               JOIN tab4 
+                 ON tab3.k = tab4.k)
+-- !query 18 schema
+struct<k:int,v:int>
+-- !query 18 output
+
+
+
+-- !query 19
+SELECT * 
+FROM   (SELECT tab3.k, 
+               tab4.v 
+        FROM   tab3 
+               JOIN tab4 
+                 ON tab3.k = tab4.k) 
+EXCEPT ALL 
+SELECT * 
+FROM   (SELECT tab4.v AS k, 
+               tab3.k AS v 
+        FROM   tab3 
+               JOIN tab4 
+                 ON tab3.k = tab4.k)
+-- !query 19 schema
+struct<k:int,v:int>
+-- !query 19 output
+1      2
+1      2
+1      2
+2      20
+2      20
+2      3
+2      3
+
+
+-- !query 20
+SELECT v FROM tab3 GROUP BY v
+EXCEPT ALL
+SELECT k FROM tab4 GROUP BY k
+-- !query 20 schema
+struct<v:int>
+-- !query 20 output
+3
+
+
+-- !query 21
+DROP VIEW IF EXISTS tab1
+-- !query 21 schema
+struct<>
+-- !query 21 output
+
+
+
+-- !query 22
+DROP VIEW IF EXISTS tab2
+-- !query 22 schema
+struct<>
+-- !query 22 output
+
+
+
+-- !query 23
+DROP VIEW IF EXISTS tab3
+-- !query 23 schema
+struct<>
+-- !query 23 output
+
+
+
+-- !query 24
+DROP VIEW IF EXISTS tab4
+-- !query 24 schema
+struct<>
+-- !query 24 output
+

http://git-wip-us.apache.org/repos/asf/spark/blob/10f1f196/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 9cf8c47..af07359 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -36,7 +36,7 @@ import 
org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, 
SharedSQLContext}
-import org.apache.spark.sql.test.SQLTestData.TestData2
+import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -629,6 +629,74 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     assert(df4.schema.forall(!_.nullable))
   }
 
+  test("except all") {
+    checkAnswer(
+      lowerCaseData.exceptAll(upperCaseData),
+      Row(1, "a") ::
+      Row(2, "b") ::
+      Row(3, "c") ::
+      Row(4, "d") :: Nil)
+    checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil)
+    checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil)
+
+    // check null equality
+    checkAnswer(
+      nullInts.exceptAll(nullInts.filter("0 = 1")),
+      nullInts)
+    checkAnswer(
+      nullInts.exceptAll(nullInts),
+      Nil)
+
+    // check that duplicate values are preserved
+    checkAnswer(
+      allNulls.exceptAll(allNulls.filter("0 = 1")),
+      Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil)
+    checkAnswer(
+      allNulls.exceptAll(allNulls.limit(2)),
+      Row(null) :: Row(null) :: Nil)
+
+    // check that duplicates are retained.
+    val df = spark.sparkContext.parallelize(
+      NullStrings(1, "id1") ::
+      NullStrings(1, "id1") ::
+      NullStrings(2, "id1") ::
+      NullStrings(3, null) :: Nil).toDF("id", "value")
+
+    checkAnswer(
+      df.exceptAll(df.filter("0 = 1")),
+      Row(1, "id1") ::
+      Row(1, "id1") ::
+      Row(2, "id1") ::
+      Row(3, null) :: Nil)
+
+    // check if the empty set on the left side works
+    checkAnswer(
+      allNulls.filter("0 = 1").exceptAll(allNulls),
+      Nil)
+
+  }
+
+  test("exceptAll - nullability") {
+    val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF()
+    assert(nonNullableInts.schema.forall(!_.nullable))
+
+    val df1 = nonNullableInts.exceptAll(nullInts)
+    checkAnswer(df1, Row(11) :: Nil)
+    assert(df1.schema.forall(!_.nullable))
+
+    val df2 = nullInts.exceptAll(nonNullableInts)
+    checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil)
+    assert(df2.schema.forall(_.nullable))
+
+    val df3 = nullInts.exceptAll(nullInts)
+    checkAnswer(df3, Nil)
+    assert(df3.schema.forall(_.nullable))
+
+    val df4 = nonNullableInts.exceptAll(nonNullableInts)
+    checkAnswer(df4, Nil)
+    assert(df4.schema.forall(!_.nullable))
+  }
+
   test("intersect") {
     checkAnswer(
       lowerCaseData.intersect(lowerCaseData),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to