[FLINK-7227] [table] Fix push-down of disjunctive predicates with more than two 
terms.

This closes #4608.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ab55a436
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ab55a436
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ab55a436

Branch: refs/heads/release-1.3
Commit: ab55a436f79ee315ee5ea072f30095e313c9464f
Parents: 50f6c75
Author: Usman Younas <usmanyou...@usmans-mbp.fritz.box>
Authored: Mon Aug 28 13:44:02 2017 +0000
Committer: Fabian Hueske <fhue...@apache.org>
Committed: Tue Sep 5 15:57:31 2017 +0200

----------------------------------------------------------------------
 .../table/plan/util/RexProgramExtractor.scala   |  7 +-
 .../plan/util/RexProgramExtractorTest.scala     | 69 ++++++++++++++++++--
 .../table/utils/TestFilterableTableSource.scala |  1 +
 3 files changed, 69 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/ab55a436/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala
index ba8713d..d484f1d 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala
@@ -20,10 +20,11 @@ package org.apache.flink.table.plan.util
 
 import org.apache.calcite.plan.RelOptUtil
 import org.apache.calcite.rex._
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
 import org.apache.calcite.sql.{SqlFunction, SqlPostfixOperator}
 import org.apache.flink.table.api.TableException
 import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.expressions.{Expression, Literal, 
ResolvedFieldReference}
+import org.apache.flink.table.expressions.{And, Expression, Literal, Or, 
ResolvedFieldReference}
 import org.apache.flink.table.validate.FunctionCatalog
 import org.apache.flink.util.Preconditions
 
@@ -167,6 +168,10 @@ class RexNodeToExpressionConverter(
       None
     } else {
         call.getOperator match {
+          case SqlStdOperatorTable.OR =>
+            Option(operands.reduceLeft(Or))
+          case SqlStdOperatorTable.AND =>
+            Option(operands.reduceLeft(And))
           case function: SqlFunction =>
             lookupFunction(replace(function.getName), operands)
           case postfix: SqlPostfixOperator =>

http://git-wip-us.apache.org/repos/asf/flink/blob/ab55a436/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
index 5d5eece..7e1fa12 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.util
 
 import java.math.BigDecimal
 
-import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.rex._
 import org.apache.calcite.sql.SqlPostfixOperator
 import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, INTEGER, VARCHAR}
 import org.apache.calcite.sql.fun.SqlStdOperatorTable
@@ -32,6 +32,7 @@ import org.junit.Assert.{assertArrayEquals, assertEquals, 
assertThat}
 import org.junit.Test
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 class RexProgramExtractorTest extends RexProgramTestBase {
 
@@ -103,6 +104,8 @@ class RexProgramExtractorTest extends RexProgramTestBase {
     val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
     // 100
     val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+    // 200
+    val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(200L))
 
     // a = amount < 100
     val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, 
t0, t3))
@@ -112,15 +115,17 @@ class RexProgramExtractorTest extends RexProgramTestBase {
     val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, 
t2, t3))
     // d = amount <= id
     val d = 
builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, 
t1))
+    // e = price == 200
+    val e = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, 
t2, t4))
 
     // a AND b
     val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, 
List(a, b).asJava))
-    // (a AND b) or c
-    val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, 
List(and, c).asJava))
-    // not d
+    // (a AND b) OR c OR e
+    val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, 
List(and, c, e).asJava))
+    // NOT d
     val not = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT, 
List(d).asJava))
 
-    // (a AND b) OR c) AND (NOT d)
+    // (a AND b) OR c OR e) AND (NOT d)
     builder.addCondition(builder.addExpr(
       rexBuilder.makeCall(SqlStdOperatorTable.AND, List(or, not).asJava)))
 
@@ -133,14 +138,64 @@ class RexProgramExtractorTest extends RexProgramTestBase {
         functionCatalog)
 
     val expected: Array[Expression] = Array(
-      ExpressionParser.parseExpression("amount < 100 || price == 100"),
-      ExpressionParser.parseExpression("id > 100 || price == 100"),
+      ExpressionParser.parseExpression("amount < 100 || price == 100 || price 
=== 200"),
+      ExpressionParser.parseExpression("id > 100 || price == 100 || price === 
200"),
       ExpressionParser.parseExpression("!(amount <= id)"))
     assertExpressionArrayEquals(expected, convertedExpressions)
     assertEquals(0, unconvertedRexNodes.length)
   }
 
   @Test
+  def testExtractANDExpressions(): Unit = {
+    val inputRowType = typeFactory.createStructType(allFieldTypes, 
allFieldNames)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    // amount
+    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+    // id
+    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+    // price
+    val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
+    // 100
+    val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+    // a = amount < 100
+    val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, 
t0, t3))
+    // b = id > 100
+    val b = 
builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t3))
+    // c = price == 100
+    val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, 
t2, t3))
+    // d = amount <= id
+    val d = 
builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, 
t1))
+
+    // a AND b AND c AND d
+    val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, 
List(a, b, c, d).asJava))
+
+    builder.addCondition(builder.addExpr(and))
+
+    val program = builder.getProgram
+
+    val expanded = program.expandLocalRef(program.getCondition)
+
+    var convertedExpressions = new mutable.ArrayBuffer[Expression]
+    val unconvertedRexNodes = new mutable.ArrayBuffer[RexNode]
+    val inputNames = program.getInputRowType.getFieldNames.asScala.toArray
+    val converter = new RexNodeToExpressionConverter(inputNames, 
functionCatalog)
+
+    expanded.accept(converter) match {
+      case Some(expression) =>
+        convertedExpressions += expression
+      case None => unconvertedRexNodes += expanded
+    }
+
+    val expected: Array[Expression] = Array(
+      ExpressionParser.parseExpression("amount < 100 && id > 100 && price === 
100 && amount <= id"))
+
+    assertExpressionArrayEquals(expected, convertedExpressions.toArray)
+    assertEquals(0, unconvertedRexNodes.length)
+  }
+
+  @Test
   def testExtractArithmeticConditions(): Unit = {
     val inputRowType = typeFactory.createStructType(allFieldTypes, 
allFieldNames)
     val builder = new RexProgramBuilder(inputRowType, rexBuilder)

http://git-wip-us.apache.org/repos/asf/flink/blob/ab55a436/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala
index dcf2acd..fb99864 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala
@@ -89,6 +89,7 @@ class TestFilterableTableSource(
               iterator.remove()
             case (_, _) =>
           }
+        case _ =>
       }
     }
 

Reply via email to