This is an automated email from the ASF dual-hosted git repository. snuyanzin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 29e49482a82 [FLINK-33213][table] Flink SQL calculate SqlMonotonicity for Calc 29e49482a82 is described below commit 29e49482a82e8c1cd404e42c3aae0944188d956e Author: ParyshevSergey <67409218+paryshevser...@users.noreply.github.com> AuthorDate: Mon Nov 13 00:07:57 2023 +0700 [FLINK-33213][table] Flink SQL calculate SqlMonotonicity for Calc Co-authored-by: Sergey <gerald@MacBook-Air-Gerald.local> --- .../metadata/FlinkRelMdModifiedMonotonicity.scala | 76 ++++++++++++++- .../FlinkRelMdModifiedMonotonicityTest.scala | 102 ++++++++++++++++++++- .../runtime/stream/sql/AggregateITCase.scala | 49 ++++++++++ .../runtime/stream/table/AggregateITCase.scala | 57 ++++++++++++ 4 files changed, 281 insertions(+), 3 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala index 6c21d3a32c6..f25e2cebdd2 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala @@ -39,6 +39,7 @@ import org.apache.calcite.rel.metadata._ import org.apache.calcite.rex.{RexCall, RexCallBinding, RexInputRef, RexNode} import org.apache.calcite.sql.{SqlKind, SqlOperatorBinding} import org.apache.calcite.sql.fun.{SqlCountAggFunction, SqlMinMaxAggFunction, SqlSumAggFunction, SqlSumEmptyIsZeroAggFunction} +import org.apache.calcite.sql.fun.SqlStdOperatorTable.{AND, EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUAL, IN, IS_NOT_NULL, IS_NOT_TRUE, IS_NULL, IS_TRUE, LESS_THAN, LESS_THAN_OR_EQUAL, NOT, NOT_EQUALS, NOT_IN, OR, SEARCH} import org.apache.calcite.sql.validate.SqlMonotonicity import org.apache.calcite.sql.validate.SqlMonotonicity._ import org.apache.calcite.util.Util @@ -47,6 +48,7 @@ import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Time, Timestamp} import java.util.Collections +import scala.annotation.tailrec import scala.collection.JavaConversions._ /** @@ -88,8 +90,32 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon } def getRelModifiedMonotonicity(rel: Calc, mq: RelMetadataQuery): RelModifiedMonotonicity = { - val projects = rel.getProgram.getProjectList.map(rel.getProgram.expandLocalRef) - getProjectMonotonicity(projects, rel.getInput, mq) + val program = rel.getProgram + val projects = program.getProjectList.map(rel.getProgram.expandLocalRef) + val result = getProjectMonotonicity(projects, rel.getInput, mq) + + // check that `where` section exist + if (program.getCondition != null && result != null) { + val inputMonotonicity = FlinkRelMetadataQuery + .reuseOrCreate(mq) + .getRelModifiedMonotonicity(rel.getInput) + val inputProjects = program.getExprList.filter(expr => expr.isInstanceOf[RexInputRef]) + assert(inputMonotonicity.fieldMonotonicities.length == inputProjects.size) + val notConstantProjects = inputProjects.indices + .map( + index => + ( + inputProjects(index).asInstanceOf[RexInputRef], + inputMonotonicity.fieldMonotonicities(index))) + .filter { case (_, monotonicity) => monotonicity != CONSTANT } + .toArray + val condition = program.expandLocalRef(program.getCondition) + if (isNeedRetract(condition, notConstantProjects)) { + program.getProjectList.indices + .foreach(index => result.fieldMonotonicities(index) = NOT_MONOTONIC) + } + } + result } private def getProjectMonotonicity( @@ -638,6 +664,52 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon udf.getMonotonicity(binding) } + private def isNeedRetract( + rexNode: RexNode, + projects: Array[(RexInputRef, SqlMonotonicity)]): Boolean = { + rexNode match { + case inputRef: RexInputRef => + projects.exists { case (projectInput, _) => projectInput == inputRef } + + case rexCall: RexCall => + val operands = rexCall.getOperands.map(operand => removeAsAndCast(operand)) + rexCall.getOperator match { + case AND | OR => + val left = isNeedRetract(operands(0), projects) + val right = isNeedRetract(operands(1), projects) + left || right + + case GREATER_THAN | GREATER_THAN_OR_EQUAL => + projects + .find { case (inputRef, _) => operands.contains(inputRef) } + .exists { case (_, monotonicity) => monotonicity.unstrict() != INCREASING } + + case LESS_THAN | LESS_THAN_OR_EQUAL => + projects + .find { case (inputRef, _) => operands.contains(inputRef) } + .exists { case (_, monotonicity) => monotonicity.unstrict() != DECREASING } + + case SEARCH | IN | EQUALS | NOT_EQUALS | NOT_IN + if projects.exists(x => operands.contains(x._1)) => + true + + case NOT | IS_NOT_TRUE | IS_TRUE | IS_NOT_NULL | IS_NULL if operands.size() == 1 => + isNeedRetract(operands.head, projects) + + case _ => false + } + + case _ => false + } + } + + @tailrec + private def removeAsAndCast(rexNode: RexNode): RexNode = rexNode match { + case r: RexCall if r.getKind == SqlKind.AS || r.getKind == SqlKind.CAST => + removeAsAndCast(r.getOperands.get(0)) + case _ => rexNode + } + private def isValueGreaterThanZero[T](value: Comparable[T]): Int = { value match { case i: Integer => i.compareTo(0) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala index 9af23241c64..860dc52d325 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala @@ -21,16 +21,23 @@ import org.apache.flink.table.planner.plan.`trait`.RelModifiedMonotonicity import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalRank, FlinkLogicalTableAggregate} import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, RankType} -import org.apache.calcite.rel.`type`.RelDataTypeFieldImpl +import org.apache.calcite.rel.`type`.{RelDataTypeFieldImpl, RelRecordType} import org.apache.calcite.rel.RelCollations import org.apache.calcite.rel.core.JoinRelType +import org.apache.calcite.rel.hint.RelHint +import org.apache.calcite.rel.logical.LogicalCalc +import org.apache.calcite.rex.{RexNode, RexProgram} import org.apache.calcite.sql.fun.SqlStdOperatorTable._ +import org.apache.calcite.sql.validate.SqlMonotonicity import org.apache.calcite.sql.validate.SqlMonotonicity._ import org.apache.calcite.util.ImmutableBitSet import org.junit.Assert._ import org.junit.Test +import java.util + import scala.collection.JavaConversions._ +import scala.language.postfixOps class FlinkRelMdModifiedMonotonicityTest extends FlinkRelMdHandlerTestBase { @@ -41,6 +48,99 @@ class FlinkRelMdModifiedMonotonicityTest extends FlinkRelMdHandlerTestBase { mq.getRelModifiedMonotonicity(studentLogicalScan)) } + @Test + def testMonotonicityWithCondition(): Unit = { + // select id, age, count() from student group by id, age + val inputAgg = relBuilder + .scan("student") + .aggregate( + relBuilder.groupKey(relBuilder.field("id"), relBuilder.field("age")), + relBuilder.count().as("count")) + .build() + relBuilder.push(inputAgg) + + // project `age` field and corresponding output type + val projection = List(relBuilder.field("age")) + val ageFieldType = inputAgg.getRowType.getFieldList.filter(x => x.getName.equals("age")) + val outputType = new RelRecordType(ageFieldType) + + // select age from (select id, age, count() from student by id, age) where ... + // sub-query monotonicity is [CONSTANT, CONSTANT, INCREASING] + // some condition can broke monotonicity cause agg func + // like max/min depends on input monotonicity + def createCalc(condition: RexNode): LogicalCalc = { + val program = + RexProgram.create(inputAgg.getRowType, projection, condition, outputType, rexBuilder) + new LogicalCalc(cluster, logicalTraits, new util.ArrayList[RelHint](), inputAgg, program) + } + def assertMonotonicity(monotonicity: SqlMonotonicity, condition: RexNode): Unit = + assertEquals( + new RelModifiedMonotonicity(Array(monotonicity)), + mq.getRelModifiedMonotonicity(createCalc(condition)) + ) + + // where count > 1 and count < 3 + var condition = relBuilder + .and( + relBuilder.greaterThan(relBuilder.field("count"), relBuilder.literal(1)), + relBuilder.lessThan(relBuilder.field("count"), relBuilder.literal(3)) + ) + assertMonotonicity(NOT_MONOTONIC, condition) + + // where count > 1 + condition = relBuilder.greaterThan(relBuilder.field("count"), relBuilder.literal(1)) + assertMonotonicity(CONSTANT, condition) + + // where count < 3 + condition = relBuilder.lessThan(relBuilder.field("count"), relBuilder.literal(3)) + assertMonotonicity(NOT_MONOTONIC, condition) + + // where count > 1 or count < 3 + condition = relBuilder + .or( + relBuilder.greaterThan(relBuilder.field("count"), relBuilder.literal(1)), + relBuilder.lessThan(relBuilder.field("count"), relBuilder.literal(3)) + ) + // correct answer CONSTANT, but this condition must be + // destroyed by SimplifyFilterConditionRule and answer is null + assertMonotonicity(NOT_MONOTONIC, condition) + assertMonotonicity(CONSTANT, null) + + // where count in (1,2,3) + // where count not in (1,2,3) + condition = relBuilder + .in( + relBuilder.field("count"), + relBuilder.literal(1), + relBuilder.literal(2), + relBuilder.literal(3)) + assertMonotonicity(NOT_MONOTONIC, condition) + assertMonotonicity(NOT_MONOTONIC, relBuilder.not(condition)) + + // where count = 10 + condition = relBuilder + .equals(relBuilder.field("count"), relBuilder.literal(10)) + assertMonotonicity(NOT_MONOTONIC, condition) + + // where count <> 10 + condition = relBuilder + .notEquals(relBuilder.field("count"), relBuilder.literal(10)) + assertMonotonicity(NOT_MONOTONIC, condition) + + // where count > 5 or count < 2 + condition = relBuilder + .or( + relBuilder.greaterThan(relBuilder.field("count"), relBuilder.literal(5)), + relBuilder.lessThan(relBuilder.field("count"), relBuilder.literal(2)) + ) + assertMonotonicity(NOT_MONOTONIC, condition) + + // where age is not null + condition = relBuilder + .isNotNull(relBuilder.field("age")) + assertMonotonicity(CONSTANT, condition) + } + @Test def testGetRelMonotonicityOnProject(): Unit = { // test monotonicity pass on diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala index 4df20e1615d..ed1748cda3e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala @@ -89,6 +89,55 @@ class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: State assertEquals(expected, sink.getRetractResults) } + @Test + def testMaxAggRetractWithCondition(): Unit = { + val data = new mutable.MutableList[(Int, Int)] + data.+=((1, 10)) + data.+=((1, 10)) + data.+=((2, 5)) + data.+=((1, 10)) + + val t = failingDataSource(data).toTable(tEnv, 'id, 'price) + tEnv.createTemporaryView("T", t) + + val sql = + """ + |SELECT MAX(price) FROM( + | SELECT id, count(*) as c, price FROM T GROUP BY id, price) + |WHERE c > 0 and c < 3""".stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List("5") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + @Test + def testMinAggRetractWithCondition(): Unit = { + val data = new mutable.MutableList[(Int, Int)] + data.+=((1, 5)) + data.+=((2, 6)) + data.+=((1, 5)) + + val t = failingDataSource(data).toTable(tEnv, 'id, 'price) + tEnv.createTemporaryView("T", t) + + val sql = + """ + |SELECT MIN(price) FROM( + | SELECT id, count(*) as c, price FROM T GROUP BY id, price) + |WHERE c < 2""".stripMargin + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = List("6") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + @Test def testShufflePojo(): Unit = { val data = new mutable.MutableList[(Int, Int)] diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala index 6971ee26a8e..bd0db2495b3 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala @@ -63,6 +63,63 @@ class AggregateITCase(mode: StateBackendMode) extends StreamingWithStateTestBase assertEquals(expected.sorted, sink.getRetractResults.sorted) } + @Test + def testMaxAggRetractWithCondition(): Unit = { + val data = new mutable.MutableList[(Int, Int)] + data.+=((1, 10)) + data.+=((1, 10)) + data.+=((2, 5)) + data.+=((1, 10)) + + // select id, price, count() as c, from table group by id, price + val subQuery = failingDataSource(data) + .toTable(tEnv, 'id, 'price) + .groupBy('id, 'price) + .aggregate('id.count().as("c")) + .select('id, 'price, 'c) + + // select max(price) from subQuery where c > 0 and c < 3 + val topQuery = subQuery + .where(and('c.isGreater(0), 'c.isLess(3))) + .select('price) + .aggregate('price.max().as("max_price")) + .select('max_price) + + val sink = new TestingRetractSink() + topQuery.toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + assertEquals(List("5"), sink.getRetractResults.sorted) + } + + @Test + def testMinAggRetractWithCondition(): Unit = { + val data = new mutable.MutableList[(Int, Int)] + data.+=((1, 5)) + data.+=((2, 6)) + data.+=((1, 5)) + + // select id, price, count() as c, from table group by id, price + val subQuery = failingDataSource(data) + .toTable(tEnv, 'id, 'price) + .groupBy('id, 'price) + .aggregate('id.count().as("c")) + .select('id, 'price, 'c) + + // select min(price) from subQuery where c < 2 + val topQuery = subQuery + .where('c.isLess(2)) + .select('price) + .aggregate('price.min().as("min_price")) + .select('min_price) + + val sink = new TestingRetractSink() + topQuery.toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + assertEquals(List("6"), sink.getRetractResults.sorted) + } + @Test def testDistinctUDAGGMixedWithNonDistinctUsage(): Unit = { val testAgg = new WeightedAvg