This is an automated email from the ASF dual-hosted git repository.

snuyanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 29e49482a82 [FLINK-33213][table] Flink SQL calculate SqlMonotonicity 
for Calc
29e49482a82 is described below

commit 29e49482a82e8c1cd404e42c3aae0944188d956e
Author: ParyshevSergey <67409218+paryshevser...@users.noreply.github.com>
AuthorDate: Mon Nov 13 00:07:57 2023 +0700

    [FLINK-33213][table] Flink SQL calculate SqlMonotonicity for Calc
    
    Co-authored-by: Sergey <gerald@MacBook-Air-Gerald.local>
---
 .../metadata/FlinkRelMdModifiedMonotonicity.scala  |  76 ++++++++++++++-
 .../FlinkRelMdModifiedMonotonicityTest.scala       | 102 ++++++++++++++++++++-
 .../runtime/stream/sql/AggregateITCase.scala       |  49 ++++++++++
 .../runtime/stream/table/AggregateITCase.scala     |  57 ++++++++++++
 4 files changed, 281 insertions(+), 3 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
index 6c21d3a32c6..f25e2cebdd2 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
@@ -39,6 +39,7 @@ import org.apache.calcite.rel.metadata._
 import org.apache.calcite.rex.{RexCall, RexCallBinding, RexInputRef, RexNode}
 import org.apache.calcite.sql.{SqlKind, SqlOperatorBinding}
 import org.apache.calcite.sql.fun.{SqlCountAggFunction, SqlMinMaxAggFunction, 
SqlSumAggFunction, SqlSumEmptyIsZeroAggFunction}
+import org.apache.calcite.sql.fun.SqlStdOperatorTable.{AND, EQUALS, 
GREATER_THAN, GREATER_THAN_OR_EQUAL, IN, IS_NOT_NULL, IS_NOT_TRUE, IS_NULL, 
IS_TRUE, LESS_THAN, LESS_THAN_OR_EQUAL, NOT, NOT_EQUALS, NOT_IN, OR, SEARCH}
 import org.apache.calcite.sql.validate.SqlMonotonicity
 import org.apache.calcite.sql.validate.SqlMonotonicity._
 import org.apache.calcite.util.Util
@@ -47,6 +48,7 @@ import java.math.{BigDecimal => JBigDecimal}
 import java.sql.{Date, Time, Timestamp}
 import java.util.Collections
 
+import scala.annotation.tailrec
 import scala.collection.JavaConversions._
 
 /**
@@ -88,8 +90,32 @@ class FlinkRelMdModifiedMonotonicity private extends 
MetadataHandler[ModifiedMon
   }
 
   def getRelModifiedMonotonicity(rel: Calc, mq: RelMetadataQuery): 
RelModifiedMonotonicity = {
-    val projects = 
rel.getProgram.getProjectList.map(rel.getProgram.expandLocalRef)
-    getProjectMonotonicity(projects, rel.getInput, mq)
+    val program = rel.getProgram
+    val projects = program.getProjectList.map(rel.getProgram.expandLocalRef)
+    val result = getProjectMonotonicity(projects, rel.getInput, mq)
+
+    // check that `where` section exist
+    if (program.getCondition != null && result != null) {
+      val inputMonotonicity = FlinkRelMetadataQuery
+        .reuseOrCreate(mq)
+        .getRelModifiedMonotonicity(rel.getInput)
+      val inputProjects = program.getExprList.filter(expr => 
expr.isInstanceOf[RexInputRef])
+      assert(inputMonotonicity.fieldMonotonicities.length == 
inputProjects.size)
+      val notConstantProjects = inputProjects.indices
+        .map(
+          index =>
+            (
+              inputProjects(index).asInstanceOf[RexInputRef],
+              inputMonotonicity.fieldMonotonicities(index)))
+        .filter { case (_, monotonicity) => monotonicity != CONSTANT }
+        .toArray
+      val condition = program.expandLocalRef(program.getCondition)
+      if (isNeedRetract(condition, notConstantProjects)) {
+        program.getProjectList.indices
+          .foreach(index => result.fieldMonotonicities(index) = NOT_MONOTONIC)
+      }
+    }
+    result
   }
 
   private def getProjectMonotonicity(
@@ -638,6 +664,52 @@ class FlinkRelMdModifiedMonotonicity private extends 
MetadataHandler[ModifiedMon
     udf.getMonotonicity(binding)
   }
 
+  private def isNeedRetract(
+      rexNode: RexNode,
+      projects: Array[(RexInputRef, SqlMonotonicity)]): Boolean = {
+    rexNode match {
+      case inputRef: RexInputRef =>
+        projects.exists { case (projectInput, _) => projectInput == inputRef }
+
+      case rexCall: RexCall =>
+        val operands = rexCall.getOperands.map(operand => 
removeAsAndCast(operand))
+        rexCall.getOperator match {
+          case AND | OR =>
+            val left = isNeedRetract(operands(0), projects)
+            val right = isNeedRetract(operands(1), projects)
+            left || right
+
+          case GREATER_THAN | GREATER_THAN_OR_EQUAL =>
+            projects
+              .find { case (inputRef, _) => operands.contains(inputRef) }
+              .exists { case (_, monotonicity) => monotonicity.unstrict() != 
INCREASING }
+
+          case LESS_THAN | LESS_THAN_OR_EQUAL =>
+            projects
+              .find { case (inputRef, _) => operands.contains(inputRef) }
+              .exists { case (_, monotonicity) => monotonicity.unstrict() != 
DECREASING }
+
+          case SEARCH | IN | EQUALS | NOT_EQUALS | NOT_IN
+              if projects.exists(x => operands.contains(x._1)) =>
+            true
+
+          case NOT | IS_NOT_TRUE | IS_TRUE | IS_NOT_NULL | IS_NULL if 
operands.size() == 1 =>
+            isNeedRetract(operands.head, projects)
+
+          case _ => false
+        }
+
+      case _ => false
+    }
+  }
+
+  @tailrec
+  private def removeAsAndCast(rexNode: RexNode): RexNode = rexNode match {
+    case r: RexCall if r.getKind == SqlKind.AS || r.getKind == SqlKind.CAST =>
+      removeAsAndCast(r.getOperands.get(0))
+    case _ => rexNode
+  }
+
   private def isValueGreaterThanZero[T](value: Comparable[T]): Int = {
     value match {
       case i: Integer => i.compareTo(0)
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala
index 9af23241c64..860dc52d325 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala
@@ -21,16 +21,23 @@ import 
org.apache.flink.table.planner.plan.`trait`.RelModifiedMonotonicity
 import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalRank, 
FlinkLogicalTableAggregate}
 import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, 
RankType}
 
-import org.apache.calcite.rel.`type`.RelDataTypeFieldImpl
+import org.apache.calcite.rel.`type`.{RelDataTypeFieldImpl, RelRecordType}
 import org.apache.calcite.rel.RelCollations
 import org.apache.calcite.rel.core.JoinRelType
+import org.apache.calcite.rel.hint.RelHint
+import org.apache.calcite.rel.logical.LogicalCalc
+import org.apache.calcite.rex.{RexNode, RexProgram}
 import org.apache.calcite.sql.fun.SqlStdOperatorTable._
+import org.apache.calcite.sql.validate.SqlMonotonicity
 import org.apache.calcite.sql.validate.SqlMonotonicity._
 import org.apache.calcite.util.ImmutableBitSet
 import org.junit.Assert._
 import org.junit.Test
 
+import java.util
+
 import scala.collection.JavaConversions._
+import scala.language.postfixOps
 
 class FlinkRelMdModifiedMonotonicityTest extends FlinkRelMdHandlerTestBase {
 
@@ -41,6 +48,99 @@ class FlinkRelMdModifiedMonotonicityTest extends 
FlinkRelMdHandlerTestBase {
       mq.getRelModifiedMonotonicity(studentLogicalScan))
   }
 
+  @Test
+  def testMonotonicityWithCondition(): Unit = {
+    // select id, age, count() from student group by id, age
+    val inputAgg = relBuilder
+      .scan("student")
+      .aggregate(
+        relBuilder.groupKey(relBuilder.field("id"), relBuilder.field("age")),
+        relBuilder.count().as("count"))
+      .build()
+    relBuilder.push(inputAgg)
+
+    // project `age` field and corresponding output type
+    val projection = List(relBuilder.field("age"))
+    val ageFieldType = inputAgg.getRowType.getFieldList.filter(x => 
x.getName.equals("age"))
+    val outputType = new RelRecordType(ageFieldType)
+
+    // select age from (select id, age, count() from student by id, age) where 
...
+    // sub-query monotonicity is [CONSTANT, CONSTANT, INCREASING]
+    // some condition can broke monotonicity cause agg func
+    // like max/min depends on input monotonicity
+    def createCalc(condition: RexNode): LogicalCalc = {
+      val program =
+        RexProgram.create(inputAgg.getRowType, projection, condition, 
outputType, rexBuilder)
+      new LogicalCalc(cluster, logicalTraits, new util.ArrayList[RelHint](), 
inputAgg, program)
+    }
+    def assertMonotonicity(monotonicity: SqlMonotonicity, condition: RexNode): 
Unit =
+      assertEquals(
+        new RelModifiedMonotonicity(Array(monotonicity)),
+        mq.getRelModifiedMonotonicity(createCalc(condition))
+      )
+
+    // where count > 1 and count < 3
+    var condition = relBuilder
+      .and(
+        relBuilder.greaterThan(relBuilder.field("count"), 
relBuilder.literal(1)),
+        relBuilder.lessThan(relBuilder.field("count"), relBuilder.literal(3))
+      )
+    assertMonotonicity(NOT_MONOTONIC, condition)
+
+    // where count > 1
+    condition = relBuilder.greaterThan(relBuilder.field("count"), 
relBuilder.literal(1))
+    assertMonotonicity(CONSTANT, condition)
+
+    // where count < 3
+    condition = relBuilder.lessThan(relBuilder.field("count"), 
relBuilder.literal(3))
+    assertMonotonicity(NOT_MONOTONIC, condition)
+
+    // where count > 1 or count < 3
+    condition = relBuilder
+      .or(
+        relBuilder.greaterThan(relBuilder.field("count"), 
relBuilder.literal(1)),
+        relBuilder.lessThan(relBuilder.field("count"), relBuilder.literal(3))
+      )
+    // correct answer CONSTANT, but this condition must be
+    // destroyed by SimplifyFilterConditionRule and answer is null
+    assertMonotonicity(NOT_MONOTONIC, condition)
+    assertMonotonicity(CONSTANT, null)
+
+    // where count in (1,2,3)
+    // where count not in (1,2,3)
+    condition = relBuilder
+      .in(
+        relBuilder.field("count"),
+        relBuilder.literal(1),
+        relBuilder.literal(2),
+        relBuilder.literal(3))
+    assertMonotonicity(NOT_MONOTONIC, condition)
+    assertMonotonicity(NOT_MONOTONIC, relBuilder.not(condition))
+
+    // where count = 10
+    condition = relBuilder
+      .equals(relBuilder.field("count"), relBuilder.literal(10))
+    assertMonotonicity(NOT_MONOTONIC, condition)
+
+    // where count <> 10
+    condition = relBuilder
+      .notEquals(relBuilder.field("count"), relBuilder.literal(10))
+    assertMonotonicity(NOT_MONOTONIC, condition)
+
+    // where count > 5 or count < 2
+    condition = relBuilder
+      .or(
+        relBuilder.greaterThan(relBuilder.field("count"), 
relBuilder.literal(5)),
+        relBuilder.lessThan(relBuilder.field("count"), relBuilder.literal(2))
+      )
+    assertMonotonicity(NOT_MONOTONIC, condition)
+
+    // where age is not null
+    condition = relBuilder
+      .isNotNull(relBuilder.field("age"))
+    assertMonotonicity(CONSTANT, condition)
+  }
+
   @Test
   def testGetRelMonotonicityOnProject(): Unit = {
     // test monotonicity pass on
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
index 4df20e1615d..ed1748cda3e 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
@@ -89,6 +89,55 @@ class AggregateITCase(aggMode: AggMode, miniBatch: 
MiniBatchMode, backend: State
     assertEquals(expected, sink.getRetractResults)
   }
 
+  @Test
+  def testMaxAggRetractWithCondition(): Unit = {
+    val data = new mutable.MutableList[(Int, Int)]
+    data.+=((1, 10))
+    data.+=((1, 10))
+    data.+=((2, 5))
+    data.+=((1, 10))
+
+    val t = failingDataSource(data).toTable(tEnv, 'id, 'price)
+    tEnv.createTemporaryView("T", t)
+
+    val sql =
+      """
+        |SELECT MAX(price) FROM(
+        |   SELECT id, count(*) as c, price FROM T GROUP BY id, price)
+        |WHERE c > 0 and c < 3""".stripMargin
+
+    val sink = new TestingRetractSink
+    tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1)
+    env.execute()
+
+    val expected = List("5")
+    assertEquals(expected.sorted, sink.getRetractResults.sorted)
+  }
+
+  @Test
+  def testMinAggRetractWithCondition(): Unit = {
+    val data = new mutable.MutableList[(Int, Int)]
+    data.+=((1, 5))
+    data.+=((2, 6))
+    data.+=((1, 5))
+
+    val t = failingDataSource(data).toTable(tEnv, 'id, 'price)
+    tEnv.createTemporaryView("T", t)
+
+    val sql =
+      """
+        |SELECT MIN(price) FROM(
+        |   SELECT id, count(*) as c, price FROM T GROUP BY id, price)
+        |WHERE c < 2""".stripMargin
+
+    val sink = new TestingRetractSink
+    tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1)
+    env.execute()
+
+    val expected = List("6")
+    assertEquals(expected.sorted, sink.getRetractResults.sorted)
+  }
+
   @Test
   def testShufflePojo(): Unit = {
     val data = new mutable.MutableList[(Int, Int)]
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
index 6971ee26a8e..bd0db2495b3 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
@@ -63,6 +63,63 @@ class AggregateITCase(mode: StateBackendMode) extends 
StreamingWithStateTestBase
     assertEquals(expected.sorted, sink.getRetractResults.sorted)
   }
 
+  @Test
+  def testMaxAggRetractWithCondition(): Unit = {
+    val data = new mutable.MutableList[(Int, Int)]
+    data.+=((1, 10))
+    data.+=((1, 10))
+    data.+=((2, 5))
+    data.+=((1, 10))
+
+    // select id, price, count() as c, from table group by id, price
+    val subQuery = failingDataSource(data)
+      .toTable(tEnv, 'id, 'price)
+      .groupBy('id, 'price)
+      .aggregate('id.count().as("c"))
+      .select('id, 'price, 'c)
+
+    // select max(price) from subQuery where c > 0 and c < 3
+    val topQuery = subQuery
+      .where(and('c.isGreater(0), 'c.isLess(3)))
+      .select('price)
+      .aggregate('price.max().as("max_price"))
+      .select('max_price)
+
+    val sink = new TestingRetractSink()
+    topQuery.toRetractStream[Row].addSink(sink).setParallelism(1)
+    env.execute()
+
+    assertEquals(List("5"), sink.getRetractResults.sorted)
+  }
+
+  @Test
+  def testMinAggRetractWithCondition(): Unit = {
+    val data = new mutable.MutableList[(Int, Int)]
+    data.+=((1, 5))
+    data.+=((2, 6))
+    data.+=((1, 5))
+
+    // select id, price, count() as c, from table group by id, price
+    val subQuery = failingDataSource(data)
+      .toTable(tEnv, 'id, 'price)
+      .groupBy('id, 'price)
+      .aggregate('id.count().as("c"))
+      .select('id, 'price, 'c)
+
+    // select min(price) from subQuery where c < 2
+    val topQuery = subQuery
+      .where('c.isLess(2))
+      .select('price)
+      .aggregate('price.min().as("min_price"))
+      .select('min_price)
+
+    val sink = new TestingRetractSink()
+    topQuery.toRetractStream[Row].addSink(sink).setParallelism(1)
+    env.execute()
+
+    assertEquals(List("6"), sink.getRetractResults.sorted)
+  }
+
   @Test
   def testDistinctUDAGGMixedWithNonDistinctUsage(): Unit = {
     val testAgg = new WeightedAvg

Reply via email to