http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala index 429cccb..570bdff 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala @@ -22,17 +22,16 @@ package org.apache.flink.table.sources * Adds support for projection push-down to a [[TableSource]]. * A [[TableSource]] extending this interface is able to project the fields of the return table. * - * @tparam T The return type of the [[ProjectableTableSource]]. + * @tparam T The return type of the [[TableSource]]. */ trait ProjectableTableSource[T] { /** - * Creates a copy of the [[ProjectableTableSource]] that projects its output on the specified - * fields. + * Creates a copy of the [[TableSource]] that projects its output on the specified fields. * * @param fields The indexes of the fields to return. - * @return A copy of the [[ProjectableTableSource]] that projects its output. + * @return A copy of the [[TableSource]] that projects its output. */ - def projectFields(fields: Array[Int]): ProjectableTableSource[T] + def projectFields(fields: Array[Int]): TableSource[T] }
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala index fe205f1..c41582e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala @@ -38,4 +38,6 @@ trait TableSource[T] { /** Returns the [[TypeInformation]] for the return type of the [[TableSource]]. */ def getReturnType: TypeInformation[T] + /** Describes the table source */ + def explainSource(): String = "" } http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 2c08d8d..fcfcf43 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -86,6 +86,7 @@ class FunctionCatalog { .getOrElse(throw ValidationException(s"Undefined scalar function: $name")) .asInstanceOf[ScalarSqlFunction] ScalarFunctionCall(scalarSqlFunction.getScalarFunction, children) + // user-defined table function call case tf if classOf[TableFunction[_]].isAssignableFrom(tf) => val tableSqlFunction = sqlFunctions @@ -105,7 +106,7 @@ class FunctionCatalog { case Success(expr) => expr case Failure(e) => throw new ValidationException(e.getMessage) } - case Failure(e) => + case Failure(_) => val childrenClass = Seq.fill(children.length)(classOf[Expression]) // try to find a constructor matching the exact number of children Try(funcClass.getDeclaredConstructor(childrenClass: _*)) match { @@ -114,7 +115,7 @@ class FunctionCatalog { case Success(expr) => expr case Failure(exception) => throw ValidationException(exception.getMessage) } - case Failure(exception) => + case Failure(_) => throw ValidationException(s"Invalid number of arguments for function $funcClass") } } http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala index 058eca7..97d4d59 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala @@ -22,8 +22,9 @@ import org.apache.flink.table.api.Types import org.apache.flink.table.api.scala._ import org.apache.flink.table.sources.{CsvTableSource, TableSource} import org.apache.flink.table.utils.TableTestUtil._ +import org.apache.flink.table.expressions.utils._ +import org.apache.flink.table.utils.{CommonTestData, TableTestBase, TestFilterableTableSource} import org.junit.{Assert, Test} -import org.apache.flink.table.utils.{CommonTestData, TableTestBase} class TableSourceTest extends TableTestBase { @@ -46,7 +47,7 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - projectableSourceBatchTableNode(tableName, projectedFields), + batchSourceTableNode(tableName, projectedFields), term("select", "UPPER(last) AS _c0", "FLOOR(id) AS _c1", "*(score, 2) AS _c2") ) @@ -64,7 +65,7 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - projectableSourceBatchTableNode(tableName, projectedFields), + batchSourceTableNode(tableName, projectedFields), term("select", "last", "FLOOR(id) AS EXPR$1", "*(score, 2) AS EXPR$2") ) @@ -83,12 +84,37 @@ class TableSourceTest extends TableTestBase { .scan(tableName) .select('id, 'score, 'first) - val expected = projectableSourceBatchTableNode(tableName, noCalcFields) + val expected = batchSourceTableNode(tableName, noCalcFields) util.verifyTable(result, expected) } @Test - def testBatchFilterableSourceScanPlanTableApi(): Unit = { + def testBatchFilterableWithoutPushDown(): Unit = { + val (tableSource, tableName) = filterableTableSource + val util = batchTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + + val result = tEnv + .scan(tableName) + .select('price, 'id, 'amount) + .where("price * 2 < 32") + + val expected = unaryNode( + "DataSetCalc", + batchSourceTableNode( + tableName, + Array("name", "id", "amount", "price")), + term("select", "price", "id", "amount"), + term("where", "<(*(price, 2), 32)") + ) + + util.verifyTable(result, expected) + } + + @Test + def testBatchFilterablePartialPushDown(): Unit = { val (tableSource, tableName) = filterableTableSource val util = batchTestUtil() val tEnv = util.tEnv @@ -97,18 +123,94 @@ class TableSourceTest extends TableTestBase { val result = tEnv .scan(tableName) - .select('price, 'id, 'amount) .where("amount > 2 && price * 2 < 32") + .select('price, 'name.lowerCase(), 'amount) val expected = unaryNode( "DataSetCalc", - filterableSourceBatchTableNode( + batchFilterableSourceTableNode( tableName, Array("name", "id", "amount", "price"), - ">(amount, 2)"), - term("select", "price", "id", "amount"), + "'amount > 2"), + term("select", "price", "LOWER(name) AS _c1", "amount"), term("where", "<(*(price, 2), 32)") ) + util.verifyTable(result, expected) + } + + @Test + def testBatchFilterableFullyPushedDown(): Unit = { + val (tableSource, tableName) = filterableTableSource + val util = batchTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + + val result = tEnv + .scan(tableName) + .select('price, 'id, 'amount) + .where("amount > 2 && amount < 32") + + val expected = unaryNode( + "DataSetCalc", + batchFilterableSourceTableNode( + tableName, + Array("name", "id", "amount", "price"), + "'amount > 2 && 'amount < 32"), + term("select", "price", "id", "amount") + ) + util.verifyTable(result, expected) + } + + @Test + def testBatchFilterableWithUnconvertedExpression(): Unit = { + val (tableSource, tableName) = filterableTableSource + val util = batchTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + + val result = tEnv + .scan(tableName) + .select('price, 'id, 'amount) + .where("amount > 2 && (amount < 32 || amount.cast(LONG) > 10)") // cast can not be converted + + val expected = unaryNode( + "DataSetCalc", + batchFilterableSourceTableNode( + tableName, + Array("name", "id", "amount", "price"), + "'amount > 2"), + term("select", "price", "id", "amount"), + term("where", "OR(<(amount, 32), >(CAST(amount), 10))") + ) + util.verifyTable(result, expected) + } + + @Test + def testBatchFilterableWithUDF(): Unit = { + val (tableSource, tableName) = filterableTableSource + val util = batchTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + val func = Func0 + tEnv.registerFunction("func0", func) + + val result = tEnv + .scan(tableName) + .select('price, 'id, 'amount) + .where("amount > 2 && func0(amount) < 32") + + val expected = unaryNode( + "DataSetCalc", + batchFilterableSourceTableNode( + tableName, + Array("name", "id", "amount", "price"), + "'amount > 2"), + term("select", "price", "id", "amount"), + term("where", s"<(${func.functionIdentifier}(amount), 32)") + ) util.verifyTable(result, expected) } @@ -129,7 +231,7 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - projectableSourceStreamTableNode(tableName, projectedFields), + streamSourceTableNode(tableName, projectedFields), term("select", "last", "FLOOR(id) AS _c1", "*(score, 2) AS _c2") ) @@ -147,7 +249,7 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - projectableSourceStreamTableNode(tableName, projectedFields), + streamSourceTableNode(tableName, projectedFields), term("select", "last", "FLOOR(id) AS EXPR$1", "*(score, 2) AS EXPR$2") ) @@ -166,7 +268,7 @@ class TableSourceTest extends TableTestBase { .scan(tableName) .select('id, 'score, 'first) - val expected = projectableSourceStreamTableNode(tableName, noCalcFields) + val expected = streamSourceTableNode(tableName, noCalcFields) util.verifyTable(result, expected) } @@ -185,10 +287,10 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - filterableSourceStreamTableNode( + streamFilterableSourceTableNode( tableName, Array("name", "id", "amount", "price"), - ">(amount, 2)"), + "'amount > 2"), term("select", "price", "id", "amount"), term("where", "<(*(price, 2), 32)") ) @@ -254,7 +356,7 @@ class TableSourceTest extends TableTestBase { // utils def filterableTableSource:(TableSource[_], String) = { - val tableSource = CommonTestData.getFilterableTableSource + val tableSource = new TestFilterableTableSource (tableSource, "filterableTable") } @@ -264,37 +366,27 @@ class TableSourceTest extends TableTestBase { (csvTable, tableName) } - def projectableSourceBatchTableNode( - sourceName: String, - fields: Array[String]): String = { - - "BatchTableSourceScan(" + - s"table=[[$sourceName]], fields=[${fields.mkString(", ")}])" + def batchSourceTableNode(sourceName: String, fields: Array[String]): String = { + s"BatchTableSourceScan(table=[[$sourceName]], fields=[${fields.mkString(", ")}])" } - def projectableSourceStreamTableNode( - sourceName: String, - fields: Array[String]): String = { - - "StreamTableSourceScan(" + - s"table=[[$sourceName]], fields=[${fields.mkString(", ")}])" + def streamSourceTableNode(sourceName: String, fields: Array[String] ): String = { + s"StreamTableSourceScan(table=[[$sourceName]], fields=[${fields.mkString(", ")}])" } - def filterableSourceBatchTableNode( - sourceName: String, - fields: Array[String], - exp: String): String = { - + def batchFilterableSourceTableNode( + sourceName: String, + fields: Array[String], + exp: String): String = { "BatchTableSourceScan(" + - s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], filter=[$exp])" + s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], source=[filter=[$exp]])" } - def filterableSourceStreamTableNode( - sourceName: String, - fields: Array[String], - exp: String): String = { - + def streamFilterableSourceTableNode( + sourceName: String, + fields: Array[String], + exp: String): String = { "StreamTableSourceScan(" + - s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], filter=[$exp])" + s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], source=[filter=[$exp]])" } } http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala index ca7cd8a..7e349cf 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala @@ -23,7 +23,7 @@ import org.apache.flink.table.api.scala.batch.utils.TableProgramsCollectionTestB import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode import org.apache.flink.table.api.scala._ import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.utils.CommonTestData +import org.apache.flink.table.utils.{CommonTestData, TestFilterableTableSource} import org.apache.flink.test.util.TestBaseUtils import org.junit.Test import org.junit.runner.RunWith @@ -107,7 +107,7 @@ class TableSourceITCase( val tableName = "MyTable" val env = ExecutionEnvironment.getExecutionEnvironment val tableEnv = TableEnvironment.getTableEnvironment(env, config) - tableEnv.registerTableSource(tableName, CommonTestData.getFilterableTableSource) + tableEnv.registerTableSource(tableName, new TestFilterableTableSource) val results = tableEnv .scan(tableName) .where("amount > 4 && price < 9") http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala index 973c2f3..66711cb 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala @@ -24,7 +24,7 @@ import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.utils.CommonTestData +import org.apache.flink.table.utils.{CommonTestData, TestFilterableTableSource} import org.apache.flink.types.Row import org.junit.Assert._ import org.junit.Test @@ -90,7 +90,7 @@ class TableSourceITCase extends StreamingMultipleProgramsTestBase { val tableName = "MyTable" val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) - tEnv.registerTableSource(tableName, CommonTestData.getFilterableTableSource) + tEnv.registerTableSource(tableName, new TestFilterableTableSource) tEnv.scan(tableName) .where("amount > 4 && price < 9") .select("id, name") http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala index 30da5ba..d8de554 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala @@ -199,7 +199,7 @@ abstract class ExpressionTestBase { // extract RexNode val calcProgram = dataSetCalc .asInstanceOf[DataSetCalc] - .calcProgram + .getProgram val expanded = calcProgram.expandLocalRef(calcProgram.getProjectList.get(0)) testExprs += ((expanded, expected)) @@ -222,7 +222,7 @@ abstract class ExpressionTestBase { // extract RexNode val calcProgram = dataSetCalc .asInstanceOf[DataSetCalc] - .calcProgram + .getProgram val expanded = calcProgram.expandLocalRef(calcProgram.getProjectList.get(0)) testExprs += ((expanded, expected)) http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala deleted file mode 100644 index c4059d5..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.util - -import java.math.BigDecimal - -import org.apache.calcite.adapter.java.JavaTypeFactory -import org.apache.calcite.plan._ -import org.apache.calcite.plan.volcano.VolcanoPlanner -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem} -import org.apache.calcite.rel.core.TableScan -import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder} -import org.apache.calcite.sql.`type`.SqlTypeName._ -import org.apache.calcite.sql.fun.SqlStdOperatorTable -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory} -import org.apache.flink.table.expressions.{Expression, ExpressionParser} -import org.apache.flink.table.plan.util.RexProgramExpressionExtractor._ -import org.apache.flink.table.plan.schema.CompositeRelDataType -import org.apache.flink.table.utils.CommonTestData -import org.junit.Test -import org.junit.Assert._ - -import scala.collection.JavaConverters._ - -class RexProgramExpressionExtractorTest { - - private val typeFactory = new FlinkTypeFactory(RelDataTypeSystem.DEFAULT) - private val allFieldTypes = List(VARCHAR, DECIMAL, INTEGER, DOUBLE).map(typeFactory.createSqlType) - private val allFieldTypeInfos: Array[TypeInformation[_]] = - Array(BasicTypeInfo.STRING_TYPE_INFO, - BasicTypeInfo.BIG_DEC_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.DOUBLE_TYPE_INFO) - private val allFieldNames = List("name", "id", "amount", "price") - - @Test - def testExtractExpression(): Unit = { - val builder: RexBuilder = new RexBuilder(typeFactory) - val program = buildRexProgram( - allFieldNames, allFieldTypes, typeFactory, builder) - val firstExp = ExpressionParser.parseExpression("id > 6") - val secondExp = ExpressionParser.parseExpression("amount * price < 100") - val expected: Array[Expression] = Array(firstExp, secondExp) - val actual = extractPredicateExpressions( - program, - builder, - CommonTestData.getMockTableEnvironment.getFunctionCatalog) - - assertEquals(expected.length, actual.length) - // todo - } - - @Test - def testRewriteRexProgramWithCondition(): Unit = { - val originalRexProgram = buildRexProgram( - allFieldNames, allFieldTypes, typeFactory, new RexBuilder(typeFactory)) - val array = Array( - "$0", - "$1", - "$2", - "$3", - "*($t2, $t3)", - "100", - "<($t4, $t5)", - "6", - ">($t1, $t7)", - "AND($t6, $t8)") - assertTrue(extractExprStrList(originalRexProgram) sameElements array) - - val tEnv = CommonTestData.getMockTableEnvironment - val builder = FlinkRelBuilder.create(tEnv.getFrameworkConfig) - val tableScan = new MockTableScan(builder.getRexBuilder) - val newExpression = ExpressionParser.parseExpression("amount * price < 100") - val newRexProgram = rewriteRexProgram( - originalRexProgram, - tableScan, - Array(newExpression) - )(builder) - - val newArray = Array( - "$0", - "$1", - "$2", - "$3", - "*($t2, $t3)", - "100", - "<($t4, $t5)") - assertTrue(extractExprStrList(newRexProgram) sameElements newArray) - } - -// @Test -// def testVerifyExpressions(): Unit = { -// val strPart = "f1 < 4" -// val part = parseExpression(strPart) -// -// val shortFalseOrigin = parseExpression(s"f0 > 10 || $strPart") -// assertFalse(verifyExpressions(shortFalseOrigin, part)) -// -// val longFalseOrigin = parseExpression(s"(f0 > 10 || (($strPart) > POWER(f0, f1))) && 2") -// assertFalse(verifyExpressions(longFalseOrigin, part)) -// -// val shortOkayOrigin = parseExpression(s"f0 > 10 && ($strPart)") -// assertTrue(verifyExpressions(shortOkayOrigin, part)) -// -// val longOkayOrigin = parseExpression(s"f0 > 10 && (($strPart) > POWER(f0, f1))") -// assertTrue(verifyExpressions(longOkayOrigin, part)) -// -// val longOkayOrigin2 = parseExpression(s"(f0 > 10 || (2 > POWER(f0, f1))) && $strPart") -// assertTrue(verifyExpressions(longOkayOrigin2, part)) -// } - - private def buildRexProgram( - fieldNames: List[String], - fieldTypes: Seq[RelDataType], - typeFactory: JavaTypeFactory, - rexBuilder: RexBuilder): RexProgram = { - - val inputRowType = typeFactory.createStructType(fieldTypes.asJava, fieldNames.asJava) - val builder = new RexProgramBuilder(inputRowType, rexBuilder) - - val t0 = rexBuilder.makeInputRef(fieldTypes(2), 2) - val t1 = rexBuilder.makeInputRef(fieldTypes(1), 1) - val t2 = rexBuilder.makeInputRef(fieldTypes(3), 3) - // t3 = t0 * t2 - val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2)) - val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) - val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L)) - // project: amount, amount * price - builder.addProject(t0, "amount") - builder.addProject(t3, "total") - // t6 = t3 < t4 - val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4)) - // t7 = t1 > t5 - val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5)) - val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava)) - // condition: t6 and t7 - // (t0 * t2 < t4 && t1 > t5) - builder.addCondition(t8) - builder.getProgram - } - - /** - * extract all expression string list from input RexProgram expression lists - * - * @param rexProgram input RexProgram instance to analyze - * @return all expression string list of input RexProgram expression lists - */ - private def extractExprStrList(rexProgram: RexProgram) = - rexProgram.getExprList.asScala.map(_.toString).toArray - - class MockTableScan( - rexBuilder: RexBuilder) - extends TableScan( - RelOptCluster.create(new VolcanoPlanner(), rexBuilder), - RelTraitSet.createEmpty, - new MockRelOptTable) - - class MockRelOptTable - extends RelOptAbstractTable( - null, - "mockRelTable", - new CompositeRelDataType( - new RowTypeInfo(allFieldTypeInfos, allFieldNames.toArray), typeFactory)) -} http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala new file mode 100644 index 0000000..b0a5fcf --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import java.math.BigDecimal + +import org.apache.calcite.rex.{RexBuilder, RexProgramBuilder} +import org.apache.calcite.sql.SqlPostfixOperator +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.flink.table.expressions.{Expression, ExpressionParser} +import org.apache.flink.table.validate.FunctionCatalog +import org.junit.Assert.{assertArrayEquals, assertEquals} +import org.junit.Test + +import scala.collection.JavaConverters._ + +class RexProgramExtractorTest extends RexProgramTestBase { + + private val functionCatalog: FunctionCatalog = FunctionCatalog.withBuiltIns + + @Test + def testExtractRefInputFields(): Unit = { + val usedFields = RexProgramExtractor.extractRefInputFields(buildSimpleRexProgram()) + assertArrayEquals(usedFields, Array(2, 3, 1)) + } + + @Test + def testExtractSimpleCondition(): Unit = { + val builder: RexBuilder = new RexBuilder(typeFactory) + val program = buildSimpleRexProgram() + + val firstExp = ExpressionParser.parseExpression("id > 6") + val secondExp = ExpressionParser.parseExpression("amount * price < 100") + val expected: Array[Expression] = Array(firstExp, secondExp) + + val (convertedExpressions, unconvertedRexNodes) = + RexProgramExtractor.extractConjunctiveConditions( + program, + builder, + functionCatalog) + + assertExpressionArrayEquals(expected, convertedExpressions) + assertEquals(0, unconvertedRexNodes.length) + } + + @Test + def testExtractSingleCondition(): Unit = { + val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + + // amount + val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2) + // id + val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1) + + // a = amount >= id + val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1)) + builder.addCondition(a) + + val program = builder.getProgram + val relBuilder: RexBuilder = new RexBuilder(typeFactory) + val (convertedExpressions, unconvertedRexNodes) = + RexProgramExtractor.extractConjunctiveConditions( + program, + relBuilder, + functionCatalog) + + val expected: Array[Expression] = Array(ExpressionParser.parseExpression("amount >= id")) + assertExpressionArrayEquals(expected, convertedExpressions) + assertEquals(0, unconvertedRexNodes.length) + } + + // ((a AND b) OR c) AND (NOT d) => (a OR c) AND (b OR c) AND (NOT d) + @Test + def testExtractCnfCondition(): Unit = { + val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + + // amount + val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2) + // id + val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1) + // price + val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3) + // 100 + val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + + // a = amount < 100 + val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t3)) + // b = id > 100 + val b = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t3)) + // c = price == 100 + val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t2, t3)) + // d = amount <= id + val d = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)) + + // a AND b + val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(a, b).asJava)) + // (a AND b) or c + val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, List(and, c).asJava)) + // not d + val not = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT, List(d).asJava)) + + // (a AND b) OR c) AND (NOT d) + builder.addCondition(builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.AND, List(or, not).asJava))) + + val program = builder.getProgram + val relBuilder: RexBuilder = new RexBuilder(typeFactory) + val (convertedExpressions, unconvertedRexNodes) = + RexProgramExtractor.extractConjunctiveConditions( + program, + relBuilder, + functionCatalog) + + val expected: Array[Expression] = Array( + ExpressionParser.parseExpression("amount < 100 || price == 100"), + ExpressionParser.parseExpression("id > 100 || price == 100"), + ExpressionParser.parseExpression("!(amount <= id)")) + assertExpressionArrayEquals(expected, convertedExpressions) + assertEquals(0, unconvertedRexNodes.length) + } + + @Test + def testExtractArithmeticConditions(): Unit = { + val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + + // amount + val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2) + // id + val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1) + // 100 + val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + + val condition = List( + // amount < id + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t1)), + // amount <= id + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)), + // amount <> id + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS, t0, t1)), + // amount == id + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t0, t1)), + // amount >= id + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1)), + // amount > id + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t0, t1)), + // amount + id == 100 + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, + rexBuilder.makeCall(SqlStdOperatorTable.PLUS, t0, t1), t2)), + // amount - id == 100 + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, + rexBuilder.makeCall(SqlStdOperatorTable.MINUS, t0, t1), t2)), + // amount * id == 100 + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, + rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t1), t2)), + // amount / id == 100 + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, + rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, t0, t1), t2)), + // -amount == 100 + builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, + rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, t0), t2)) + ).asJava + + builder.addCondition(builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, condition))) + val program = builder.getProgram + val relBuilder: RexBuilder = new RexBuilder(typeFactory) + val (convertedExpressions, unconvertedRexNodes) = + RexProgramExtractor.extractConjunctiveConditions( + program, + relBuilder, + functionCatalog) + + val expected: Array[Expression] = Array( + ExpressionParser.parseExpression("amount < id"), + ExpressionParser.parseExpression("amount <= id"), + ExpressionParser.parseExpression("amount <> id"), + ExpressionParser.parseExpression("amount == id"), + ExpressionParser.parseExpression("amount >= id"), + ExpressionParser.parseExpression("amount > id"), + ExpressionParser.parseExpression("amount + id == 100"), + ExpressionParser.parseExpression("amount - id == 100"), + ExpressionParser.parseExpression("amount * id == 100"), + ExpressionParser.parseExpression("amount / id == 100"), + ExpressionParser.parseExpression("-amount == 100") + ) + assertExpressionArrayEquals(expected, convertedExpressions) + assertEquals(0, unconvertedRexNodes.length) + } + + @Test + def testExtractPostfixConditions(): Unit = { + testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NULL, "('flag).isNull") + // IS_NOT_NULL will be eliminated since flag is not nullable + // testExtractSinglePostfixCondition(SqlStdOperatorTable.IS_NOT_NULL, "('flag).isNotNull") + testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_TRUE, "('flag).isTrue") + testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_TRUE, "('flag).isNotTrue") + testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_FALSE, "('flag).isFalse") + testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_FALSE, "('flag).isNotFalse") + } + + @Test + def testExtractConditionWithFunctionCalls(): Unit = { + val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + + // amount + val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2) + // id + val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1) + // 100 + val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + + // sum(amount) > 100 + val condition1 = builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, + rexBuilder.makeCall(SqlStdOperatorTable.SUM, t0), t2)) + + // min(id) == 100 + val condition2 = builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, + rexBuilder.makeCall(SqlStdOperatorTable.MIN, t1), t2)) + + builder.addCondition(builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2))) + + val program = builder.getProgram + val relBuilder: RexBuilder = new RexBuilder(typeFactory) + val (convertedExpressions, unconvertedRexNodes) = + RexProgramExtractor.extractConjunctiveConditions( + program, + relBuilder, + functionCatalog) + + val expected: Array[Expression] = Array( + ExpressionParser.parseExpression("sum(amount) > 100"), + ExpressionParser.parseExpression("min(id) == 100") + ) + assertExpressionArrayEquals(expected, convertedExpressions) + assertEquals(0, unconvertedRexNodes.length) + } + + @Test + def testExtractWithUnsupportedConditions(): Unit = { + val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + + // amount + val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2) + // id + val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1) + // 100 + val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + + // unsupported now: amount.cast(BigInteger) + val cast = builder.addExpr(rexBuilder.makeCast(allFieldTypes.get(1), t0)) + + // unsupported now: amount.cast(BigInteger) > 100 + val condition1 = builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, cast, t2)) + + // amount <= id + val condition2 = builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)) + + // contains unsupported condition: (amount.cast(BigInteger) > 100 OR amount <= id) + val condition3 = builder.addExpr( + rexBuilder.makeCall(SqlStdOperatorTable.OR, condition1, condition2)) + + // only condition2 can be translated + builder.addCondition( + rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2, condition3)) + + val program = builder.getProgram + val relBuilder: RexBuilder = new RexBuilder(typeFactory) + val (convertedExpressions, unconvertedRexNodes) = + RexProgramExtractor.extractConjunctiveConditions( + program, + relBuilder, + functionCatalog) + + val expected: Array[Expression] = Array( + ExpressionParser.parseExpression("amount <= id") + ) + assertExpressionArrayEquals(expected, convertedExpressions) + assertEquals(2, unconvertedRexNodes.length) + assertEquals(">(CAST($2):BIGINT NOT NULL, 100)", unconvertedRexNodes(0).toString) + assertEquals("OR(>(CAST($2):BIGINT NOT NULL, 100), <=($2, $1))", + unconvertedRexNodes(1).toString) + } + + private def testExtractSinglePostfixCondition( + fieldIndex: Integer, + op: SqlPostfixOperator, + expr: String) : Unit = { + + val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + rexBuilder = new RexBuilder(typeFactory) + + // flag + val t0 = rexBuilder.makeInputRef(allFieldTypes.get(fieldIndex), fieldIndex) + builder.addCondition(builder.addExpr(rexBuilder.makeCall(op, t0))) + + val program = builder.getProgram(false) + val relBuilder: RexBuilder = new RexBuilder(typeFactory) + val (convertedExpressions, unconvertedRexNodes) = + RexProgramExtractor.extractConjunctiveConditions( + program, + relBuilder, + functionCatalog) + + assertEquals(1, convertedExpressions.length) + assertEquals(expr, convertedExpressions.head.toString) + assertEquals(0, unconvertedRexNodes.length) + } + + private def assertExpressionArrayEquals( + expected: Array[Expression], + actual: Array[Expression]) = { + val sortedExpected = expected.sortBy(e => e.toString) + val sortedActual = actual.sortBy(e => e.toString) + + assertEquals(sortedExpected.length, sortedActual.length) + sortedExpected.zip(sortedActual).foreach { + case (l, r) => assertEquals(l.toString, r.toString) + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala deleted file mode 100644 index cea9eee..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.util - -import java.math.BigDecimal - -import org.apache.calcite.adapter.java.JavaTypeFactory -import org.apache.calcite.jdbc.JavaTypeFactoryImpl -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem} -import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder} -import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR} -import org.apache.calcite.sql.fun.SqlStdOperatorTable -import org.apache.flink.table.plan.util.RexProgramProjectExtractor._ -import org.junit.Assert.{assertArrayEquals, assertTrue} -import org.junit.{Before, Test} - -import scala.collection.JavaConverters._ - -/** - * This class is responsible for testing RexProgramProjectExtractor. - */ -class RexProgramProjectExtractorTest { - private var typeFactory: JavaTypeFactory = _ - private var rexBuilder: RexBuilder = _ - private var allFieldTypes: Seq[RelDataType] = _ - private val allFieldNames = List("name", "id", "amount", "price") - - @Before - def setUp(): Unit = { - typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT) - rexBuilder = new RexBuilder(typeFactory) - allFieldTypes = List(VARCHAR, BIGINT, INTEGER, DOUBLE).map(typeFactory.createSqlType(_)) - } - - @Test - def testExtractRefInputFields(): Unit = { - val usedFields = extractRefInputFields(buildRexProgram()) - assertArrayEquals(usedFields, Array(2, 3, 1)) - } - - @Test - def testRewriteRexProgram(): Unit = { - val originRexProgram = buildRexProgram() - assertTrue(extractExprStrList(originRexProgram).sameElements(Array( - "$0", - "$1", - "$2", - "$3", - "*($t2, $t3)", - "100", - "<($t4, $t5)", - "6", - ">($t1, $t7)", - "AND($t6, $t8)"))) - // use amount, id, price fields to create a new RexProgram - val usedFields = Array(2, 3, 1) - val types = usedFields.map(allFieldTypes(_)).toList.asJava - val names = usedFields.map(allFieldNames(_)).toList.asJava - val inputRowType = typeFactory.createStructType(types, names) - val newRexProgram = rewriteRexProgram(originRexProgram, inputRowType, usedFields, rexBuilder) - assertTrue(extractExprStrList(newRexProgram).sameElements(Array( - "$0", - "$1", - "$2", - "*($t0, $t1)", - "100", - "<($t3, $t4)", - "6", - ">($t2, $t6)", - "AND($t5, $t7)"))) - } - - private def buildRexProgram(): RexProgram = { - val types = allFieldTypes.asJava - val names = allFieldNames.asJava - val inputRowType = typeFactory.createStructType(types, names) - val builder = new RexProgramBuilder(inputRowType, rexBuilder) - val t0 = rexBuilder.makeInputRef(types.get(2), 2) - val t1 = rexBuilder.makeInputRef(types.get(1), 1) - val t2 = rexBuilder.makeInputRef(types.get(3), 3) - val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2)) - val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) - val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L)) - // project: amount, amount * price - builder.addProject(t0, "amount") - builder.addProject(t3, "total") - // condition: amount * price < 100 and id > 6 - val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4)) - val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5)) - val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava)) - builder.addCondition(t8) - builder.getProgram - } - - /** - * extract all expression string list from input RexProgram expression lists - * - * @param rexProgram input RexProgram instance to analyze - * @return all expression string list of input RexProgram expression lists - */ - private def extractExprStrList(rexProgram: RexProgram) = { - rexProgram.getExprList.asScala.map(_.toString) - } - -} http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala new file mode 100644 index 0000000..899eed2 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import org.junit.Assert.assertTrue +import org.junit.Test + +import scala.collection.JavaConverters._ + +class RexProgramRewriterTest extends RexProgramTestBase { + + @Test + def testRewriteRexProgram(): Unit = { + val rexProgram = buildSimpleRexProgram() + assertTrue(extractExprStrList(rexProgram) == wrapRefArray(Array( + "$0", + "$1", + "$2", + "$3", + "$4", + "*($t2, $t3)", + "100", + "<($t5, $t6)", + "6", + ">($t1, $t8)", + "AND($t7, $t9)"))) + + // use amount, id, price fields to create a new RexProgram + val usedFields = Array(2, 3, 1) + val types = usedFields.map(allFieldTypes.get).toList.asJava + val names = usedFields.map(allFieldNames.get).toList.asJava + val inputRowType = typeFactory.createStructType(types, names) + val newRexProgram = RexProgramRewriter.rewriteWithFieldProjection( + rexProgram, inputRowType, rexBuilder, usedFields) + assertTrue(extractExprStrList(newRexProgram) == wrapRefArray(Array( + "$0", + "$1", + "$2", + "*($t0, $t1)", + "100", + "<($t3, $t4)", + "6", + ">($t2, $t6)", + "AND($t5, $t7)"))) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala new file mode 100644 index 0000000..6ef3d82 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import java.math.BigDecimal +import java.util + +import org.apache.calcite.adapter.java.JavaTypeFactory +import org.apache.calcite.jdbc.JavaTypeFactoryImpl +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem} +import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder} +import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR, BOOLEAN} +import org.apache.calcite.sql.fun.SqlStdOperatorTable + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +abstract class RexProgramTestBase { + + val typeFactory: JavaTypeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT) + + val allFieldNames: util.List[String] = List("name", "id", "amount", "price", "flag").asJava + + val allFieldTypes: util.List[RelDataType] = + List(VARCHAR, BIGINT, INTEGER, DOUBLE, BOOLEAN).map(typeFactory.createSqlType).asJava + + var rexBuilder: RexBuilder = new RexBuilder(typeFactory) + + /** + * extract all expression string list from input RexProgram expression lists + * + * @param rexProgram input RexProgram instance to analyze + * @return all expression string list of input RexProgram expression lists + */ + protected def extractExprStrList(rexProgram: RexProgram): mutable.Buffer[String] = { + rexProgram.getExprList.asScala.map(_.toString) + } + + // select amount, amount * price as total where amount * price < 100 and id > 6 + protected def buildSimpleRexProgram(): RexProgram = { + val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + + val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2) + val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1) + val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3) + val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2)) + val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L)) + + // project: amount, amount * price as total + builder.addProject(t0, "amount") + builder.addProject(t3, "total") + + // condition: amount * price < 100 and id > 6 + val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4)) + val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5)) + val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava)) + builder.addCondition(t8) + + builder.getProgram + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala index a720f02..2364f23 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala @@ -21,21 +21,11 @@ package org.apache.flink.table.utils import java.io.{File, FileOutputStream, OutputStreamWriter} import java.util -import org.apache.flink.api.java.typeutils.TypeExtractor -import org.apache.flink.table.sources.{BatchTableSource, CsvTableSource} -import org.apache.calcite.tools.RuleSet import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.typeutils.TypeExtractor import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} -import org.apache.flink.streaming.api.datastream.DataStream -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment -import org.apache.flink.table.api.{Table, TableConfig, TableEnvironment} -import org.apache.flink.table.expressions._ -import org.apache.flink.table.sinks.TableSink -import org.apache.flink.table.sources._ -import org.apache.flink.types.Row - -import scala.collection.JavaConverters._ +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.sources.{BatchTableSource, CsvTableSource} object CommonTestData { @@ -108,110 +98,4 @@ object CommonTestData { def getMockTableEnvironment: TableEnvironment = new MockTableEnvironment - def getFilterableTableSource = new TestFilterableTableSource -} - -class MockTableEnvironment extends TableEnvironment(new TableConfig) { - - override private[flink] def writeToSink[T](table: Table, sink: TableSink[T]): Unit = ??? - - override protected def checkValidTableName(name: String): Unit = ??? - - override def sql(query: String): Table = ??? - - override def registerTableSource(name: String, tableSource: TableSource[_]): Unit = ??? - - override protected def getBuiltInNormRuleSet: RuleSet = ??? - - override protected def getBuiltInOptRuleSet: RuleSet = ??? -} - -class TestFilterableTableSource - extends BatchTableSource[Row] - with StreamTableSource[Row] - with FilterableTableSource - with DefinedFieldNames { - - import org.apache.flink.table.api.Types._ - - val fieldNames = Array("name", "id", "amount", "price") - val fieldTypes = Array[TypeInformation[_]](STRING, LONG, INT, DOUBLE) - - private var filterLiteral: Literal = _ - private var filterPredicates: Array[Expression] = Array.empty - - /** Returns the data of the table as a [[DataSet]]. */ - override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = { - execEnv.fromCollection[Row](generateDynamicCollection(33).asJava, getReturnType) - } - - /** Returns the data of the table as a [[DataStream]]. */ - def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[Row] = { - execEnv.fromCollection[Row](generateDynamicCollection(33).asJava, getReturnType) - } - - private def generateDynamicCollection(num: Int): Seq[Row] = { - - if (filterLiteral == null) { - throw new RuntimeException("filter expression was not set") - } - - val filterValue = filterLiteral.value.asInstanceOf[Number].intValue() - - def shouldCreateRow(value: Int): Boolean = { - value > filterValue - } - - for { - cnt <- 0 until num - if shouldCreateRow(cnt) - } yield { - val row = new Row(fieldNames.length) - fieldNames.zipWithIndex.foreach { case (name, index) => - name match { - case "name" => - row.setField(index, s"Record_$cnt") - case "id" => - row.setField(index, cnt.toLong) - case "amount" => - row.setField(index, cnt.toInt) - case "price" => - row.setField(index, cnt.toDouble) - } - } - row - } - } - - /** Returns the [[TypeInformation]] for the return type. */ - override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames) - - /** Returns the names of the table fields. */ - override def getFieldNames: Array[String] = fieldNames - - /** Returns the indices of the table fields. */ - override def getFieldIndices: Array[Int] = fieldNames.indices.toArray - - override def getPredicate: Array[Expression] = filterPredicates - - /** Return an unsupported predicates expression. */ - override def setPredicate(predicates: Array[Expression]): Array[Expression] = { - predicates(0) match { - case gt: GreaterThan => - gt.left match { - case f: ResolvedFieldReference => - gt.right match { - case l: Literal => - if (f.name.equals("amount")) { - filterLiteral = l - filterPredicates = Array(predicates(0)) - Array(predicates(1)) - } else predicates - case _ => predicates - } - case _ => predicates - } - case _ => predicates - } - } } http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/MockTableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/MockTableEnvironment.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/MockTableEnvironment.scala new file mode 100644 index 0000000..6a86ace --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/MockTableEnvironment.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.utils + +import org.apache.calcite.tools.RuleSet +import org.apache.flink.table.api.{Table, TableConfig, TableEnvironment} +import org.apache.flink.table.sinks.TableSink +import org.apache.flink.table.sources.TableSource + +class MockTableEnvironment extends TableEnvironment(new TableConfig) { + + override private[flink] def writeToSink[T](table: Table, sink: TableSink[T]): Unit = ??? + + override protected def checkValidTableName(name: String): Unit = ??? + + override def sql(query: String): Table = ??? + + override def registerTableSource(name: String, tableSource: TableSource[_]): Unit = ??? + + override protected def getBuiltInNormRuleSet: RuleSet = ??? + + override protected def getBuiltInOptRuleSet: RuleSet = ??? +} http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala new file mode 100644 index 0000000..dcf2acd --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.utils + +import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.table.api.Types._ +import org.apache.flink.table.expressions._ +import org.apache.flink.table.sources.{BatchTableSource, FilterableTableSource, StreamTableSource, TableSource} +import org.apache.flink.types.Row +import org.apache.flink.util.Preconditions + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** + * This source can only handle simple comparision with field "amount". + * Supports ">, <, >=, <=, =, <>" with an integer. + */ +class TestFilterableTableSource( + val recordNum: Int = 33) + extends BatchTableSource[Row] + with StreamTableSource[Row] + with FilterableTableSource[Row] { + + var filterPushedDown: Boolean = false + + val fieldNames: Array[String] = Array("name", "id", "amount", "price") + + val fieldTypes: Array[TypeInformation[_]] = Array(STRING, LONG, INT, DOUBLE) + + // all predicates with field "amount" + private var filterPredicates = new mutable.ArrayBuffer[Expression] + + // all comparing values for field "amount" + private val filterValues = new mutable.ArrayBuffer[Int] + + override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = { + execEnv.fromCollection[Row](generateDynamicCollection().asJava, getReturnType) + } + + override def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[Row] = { + execEnv.fromCollection[Row](generateDynamicCollection().asJava, getReturnType) + } + + override def explainSource(): String = { + if (filterPredicates.nonEmpty) { + s"filter=[${filterPredicates.reduce((l, r) => And(l, r)).toString}]" + } else { + "" + } + } + + override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames) + + override def applyPredicate(predicates: JList[Expression]): TableSource[Row] = { + val newSource = new TestFilterableTableSource(recordNum) + newSource.filterPushedDown = true + + val iterator = predicates.iterator() + while (iterator.hasNext) { + iterator.next() match { + case expr: BinaryComparison => + (expr.left, expr.right) match { + case (f: ResolvedFieldReference, v: Literal) if f.name.equals("amount") => + newSource.filterPredicates += expr + newSource.filterValues += v.value.asInstanceOf[Number].intValue() + iterator.remove() + case (_, _) => + } + } + } + + newSource + } + + override def isFilterPushedDown: Boolean = filterPushedDown + + private def generateDynamicCollection(): Seq[Row] = { + Preconditions.checkArgument(filterPredicates.length == filterValues.length) + + for { + cnt <- 0 until recordNum + if shouldCreateRow(cnt) + } yield { + Row.of( + s"Record_$cnt", + cnt.toLong.asInstanceOf[Object], + cnt.toInt.asInstanceOf[Object], + cnt.toDouble.asInstanceOf[Object]) + } + } + + private def shouldCreateRow(value: Int): Boolean = { + filterPredicates.zip(filterValues).forall { + case (_: GreaterThan, v) => + value > v + case (_: LessThan, v) => + value < v + case (_: GreaterThanOrEqual, v) => + value >= v + case (_: LessThanOrEqual, v) => + value <= v + case (_: EqualTo, v) => + value == v + case (_: NotEqualTo, v) => + value != v + case (expr, _) => + throw new RuntimeException(expr + " not supported!") + } + } +} +
