Repository: flink Updated Branches: refs/heads/master ab014ef94 -> 78f22aaec
http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala new file mode 100644 index 0000000..c4059d5 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import java.math.BigDecimal + +import org.apache.calcite.adapter.java.JavaTypeFactory +import org.apache.calcite.plan._ +import org.apache.calcite.plan.volcano.VolcanoPlanner +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem} +import org.apache.calcite.rel.core.TableScan +import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder} +import org.apache.calcite.sql.`type`.SqlTypeName._ +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory} +import org.apache.flink.table.expressions.{Expression, ExpressionParser} +import org.apache.flink.table.plan.util.RexProgramExpressionExtractor._ +import org.apache.flink.table.plan.schema.CompositeRelDataType +import org.apache.flink.table.utils.CommonTestData +import org.junit.Test +import org.junit.Assert._ + +import scala.collection.JavaConverters._ + +class RexProgramExpressionExtractorTest { + + private val typeFactory = new FlinkTypeFactory(RelDataTypeSystem.DEFAULT) + private val allFieldTypes = List(VARCHAR, DECIMAL, INTEGER, DOUBLE).map(typeFactory.createSqlType) + private val allFieldTypeInfos: Array[TypeInformation[_]] = + Array(BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.BIG_DEC_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO) + private val allFieldNames = List("name", "id", "amount", "price") + + @Test + def testExtractExpression(): Unit = { + val builder: RexBuilder = new RexBuilder(typeFactory) + val program = buildRexProgram( + allFieldNames, allFieldTypes, typeFactory, builder) + val firstExp = ExpressionParser.parseExpression("id > 6") + val secondExp = ExpressionParser.parseExpression("amount * price < 100") + val expected: Array[Expression] = Array(firstExp, secondExp) + val actual = extractPredicateExpressions( + program, + builder, + CommonTestData.getMockTableEnvironment.getFunctionCatalog) + + assertEquals(expected.length, actual.length) + // todo + } + + @Test + def testRewriteRexProgramWithCondition(): Unit = { + val originalRexProgram = buildRexProgram( + allFieldNames, allFieldTypes, typeFactory, new RexBuilder(typeFactory)) + val array = Array( + "$0", + "$1", + "$2", + "$3", + "*($t2, $t3)", + "100", + "<($t4, $t5)", + "6", + ">($t1, $t7)", + "AND($t6, $t8)") + assertTrue(extractExprStrList(originalRexProgram) sameElements array) + + val tEnv = CommonTestData.getMockTableEnvironment + val builder = FlinkRelBuilder.create(tEnv.getFrameworkConfig) + val tableScan = new MockTableScan(builder.getRexBuilder) + val newExpression = ExpressionParser.parseExpression("amount * price < 100") + val newRexProgram = rewriteRexProgram( + originalRexProgram, + tableScan, + Array(newExpression) + )(builder) + + val newArray = Array( + "$0", + "$1", + "$2", + "$3", + "*($t2, $t3)", + "100", + "<($t4, $t5)") + assertTrue(extractExprStrList(newRexProgram) sameElements newArray) + } + +// @Test +// def testVerifyExpressions(): Unit = { +// val strPart = "f1 < 4" +// val part = parseExpression(strPart) +// +// val shortFalseOrigin = parseExpression(s"f0 > 10 || $strPart") +// assertFalse(verifyExpressions(shortFalseOrigin, part)) +// +// val longFalseOrigin = parseExpression(s"(f0 > 10 || (($strPart) > POWER(f0, f1))) && 2") +// assertFalse(verifyExpressions(longFalseOrigin, part)) +// +// val shortOkayOrigin = parseExpression(s"f0 > 10 && ($strPart)") +// assertTrue(verifyExpressions(shortOkayOrigin, part)) +// +// val longOkayOrigin = parseExpression(s"f0 > 10 && (($strPart) > POWER(f0, f1))") +// assertTrue(verifyExpressions(longOkayOrigin, part)) +// +// val longOkayOrigin2 = parseExpression(s"(f0 > 10 || (2 > POWER(f0, f1))) && $strPart") +// assertTrue(verifyExpressions(longOkayOrigin2, part)) +// } + + private def buildRexProgram( + fieldNames: List[String], + fieldTypes: Seq[RelDataType], + typeFactory: JavaTypeFactory, + rexBuilder: RexBuilder): RexProgram = { + + val inputRowType = typeFactory.createStructType(fieldTypes.asJava, fieldNames.asJava) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + + val t0 = rexBuilder.makeInputRef(fieldTypes(2), 2) + val t1 = rexBuilder.makeInputRef(fieldTypes(1), 1) + val t2 = rexBuilder.makeInputRef(fieldTypes(3), 3) + // t3 = t0 * t2 + val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2)) + val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L)) + // project: amount, amount * price + builder.addProject(t0, "amount") + builder.addProject(t3, "total") + // t6 = t3 < t4 + val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4)) + // t7 = t1 > t5 + val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5)) + val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava)) + // condition: t6 and t7 + // (t0 * t2 < t4 && t1 > t5) + builder.addCondition(t8) + builder.getProgram + } + + /** + * extract all expression string list from input RexProgram expression lists + * + * @param rexProgram input RexProgram instance to analyze + * @return all expression string list of input RexProgram expression lists + */ + private def extractExprStrList(rexProgram: RexProgram) = + rexProgram.getExprList.asScala.map(_.toString).toArray + + class MockTableScan( + rexBuilder: RexBuilder) + extends TableScan( + RelOptCluster.create(new VolcanoPlanner(), rexBuilder), + RelTraitSet.createEmpty, + new MockRelOptTable) + + class MockRelOptTable + extends RelOptAbstractTable( + null, + "mockRelTable", + new CompositeRelDataType( + new RowTypeInfo(allFieldTypeInfos, allFieldNames.toArray), typeFactory)) +} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala new file mode 100644 index 0000000..cea9eee --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import java.math.BigDecimal + +import org.apache.calcite.adapter.java.JavaTypeFactory +import org.apache.calcite.jdbc.JavaTypeFactoryImpl +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem} +import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder} +import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR} +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.flink.table.plan.util.RexProgramProjectExtractor._ +import org.junit.Assert.{assertArrayEquals, assertTrue} +import org.junit.{Before, Test} + +import scala.collection.JavaConverters._ + +/** + * This class is responsible for testing RexProgramProjectExtractor. + */ +class RexProgramProjectExtractorTest { + private var typeFactory: JavaTypeFactory = _ + private var rexBuilder: RexBuilder = _ + private var allFieldTypes: Seq[RelDataType] = _ + private val allFieldNames = List("name", "id", "amount", "price") + + @Before + def setUp(): Unit = { + typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT) + rexBuilder = new RexBuilder(typeFactory) + allFieldTypes = List(VARCHAR, BIGINT, INTEGER, DOUBLE).map(typeFactory.createSqlType(_)) + } + + @Test + def testExtractRefInputFields(): Unit = { + val usedFields = extractRefInputFields(buildRexProgram()) + assertArrayEquals(usedFields, Array(2, 3, 1)) + } + + @Test + def testRewriteRexProgram(): Unit = { + val originRexProgram = buildRexProgram() + assertTrue(extractExprStrList(originRexProgram).sameElements(Array( + "$0", + "$1", + "$2", + "$3", + "*($t2, $t3)", + "100", + "<($t4, $t5)", + "6", + ">($t1, $t7)", + "AND($t6, $t8)"))) + // use amount, id, price fields to create a new RexProgram + val usedFields = Array(2, 3, 1) + val types = usedFields.map(allFieldTypes(_)).toList.asJava + val names = usedFields.map(allFieldNames(_)).toList.asJava + val inputRowType = typeFactory.createStructType(types, names) + val newRexProgram = rewriteRexProgram(originRexProgram, inputRowType, usedFields, rexBuilder) + assertTrue(extractExprStrList(newRexProgram).sameElements(Array( + "$0", + "$1", + "$2", + "*($t0, $t1)", + "100", + "<($t3, $t4)", + "6", + ">($t2, $t6)", + "AND($t5, $t7)"))) + } + + private def buildRexProgram(): RexProgram = { + val types = allFieldTypes.asJava + val names = allFieldNames.asJava + val inputRowType = typeFactory.createStructType(types, names) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + val t0 = rexBuilder.makeInputRef(types.get(2), 2) + val t1 = rexBuilder.makeInputRef(types.get(1), 1) + val t2 = rexBuilder.makeInputRef(types.get(3), 3) + val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2)) + val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L)) + // project: amount, amount * price + builder.addProject(t0, "amount") + builder.addProject(t3, "total") + // condition: amount * price < 100 and id > 6 + val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4)) + val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5)) + val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava)) + builder.addCondition(t8) + builder.getProgram + } + + /** + * extract all expression string list from input RexProgram expression lists + * + * @param rexProgram input RexProgram instance to analyze + * @return all expression string list of input RexProgram expression lists + */ + private def extractExprStrList(rexProgram: RexProgram) = { + rexProgram.getExprList.asScala.map(_.toString) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala index 6e4859b..a720f02 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala @@ -21,14 +21,21 @@ package org.apache.flink.table.utils import java.io.{File, FileOutputStream, OutputStreamWriter} import java.util -import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.java.typeutils.TypeExtractor +import org.apache.flink.table.sources.{BatchTableSource, CsvTableSource} +import org.apache.calcite.tools.RuleSet import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.common.typeutils.TypeSerializer -import org.apache.flink.api.java.typeutils.{PojoField, PojoTypeInfo, TypeExtractor} +import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} -import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo -import org.apache.flink.table.sources.{BatchTableSource, CsvTableSource, TableSource} -import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.table.api.{Table, TableConfig, TableEnvironment} +import org.apache.flink.table.expressions._ +import org.apache.flink.table.sinks.TableSink +import org.apache.flink.table.sources._ +import org.apache.flink.types.Row + +import scala.collection.JavaConverters._ object CommonTestData { @@ -98,4 +105,113 @@ object CommonTestData { this(null, null) } } + + def getMockTableEnvironment: TableEnvironment = new MockTableEnvironment + + def getFilterableTableSource = new TestFilterableTableSource +} + +class MockTableEnvironment extends TableEnvironment(new TableConfig) { + + override private[flink] def writeToSink[T](table: Table, sink: TableSink[T]): Unit = ??? + + override protected def checkValidTableName(name: String): Unit = ??? + + override def sql(query: String): Table = ??? + + override def registerTableSource(name: String, tableSource: TableSource[_]): Unit = ??? + + override protected def getBuiltInNormRuleSet: RuleSet = ??? + + override protected def getBuiltInOptRuleSet: RuleSet = ??? +} + +class TestFilterableTableSource + extends BatchTableSource[Row] + with StreamTableSource[Row] + with FilterableTableSource + with DefinedFieldNames { + + import org.apache.flink.table.api.Types._ + + val fieldNames = Array("name", "id", "amount", "price") + val fieldTypes = Array[TypeInformation[_]](STRING, LONG, INT, DOUBLE) + + private var filterLiteral: Literal = _ + private var filterPredicates: Array[Expression] = Array.empty + + /** Returns the data of the table as a [[DataSet]]. */ + override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = { + execEnv.fromCollection[Row](generateDynamicCollection(33).asJava, getReturnType) + } + + /** Returns the data of the table as a [[DataStream]]. */ + def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[Row] = { + execEnv.fromCollection[Row](generateDynamicCollection(33).asJava, getReturnType) + } + + private def generateDynamicCollection(num: Int): Seq[Row] = { + + if (filterLiteral == null) { + throw new RuntimeException("filter expression was not set") + } + + val filterValue = filterLiteral.value.asInstanceOf[Number].intValue() + + def shouldCreateRow(value: Int): Boolean = { + value > filterValue + } + + for { + cnt <- 0 until num + if shouldCreateRow(cnt) + } yield { + val row = new Row(fieldNames.length) + fieldNames.zipWithIndex.foreach { case (name, index) => + name match { + case "name" => + row.setField(index, s"Record_$cnt") + case "id" => + row.setField(index, cnt.toLong) + case "amount" => + row.setField(index, cnt.toInt) + case "price" => + row.setField(index, cnt.toDouble) + } + } + row + } + } + + /** Returns the [[TypeInformation]] for the return type. */ + override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames) + + /** Returns the names of the table fields. */ + override def getFieldNames: Array[String] = fieldNames + + /** Returns the indices of the table fields. */ + override def getFieldIndices: Array[Int] = fieldNames.indices.toArray + + override def getPredicate: Array[Expression] = filterPredicates + + /** Return an unsupported predicates expression. */ + override def setPredicate(predicates: Array[Expression]): Array[Expression] = { + predicates(0) match { + case gt: GreaterThan => + gt.left match { + case f: ResolvedFieldReference => + gt.right match { + case l: Literal => + if (f.name.equals("amount")) { + filterLiteral = l + filterPredicates = Array(predicates(0)) + Array(predicates(1)) + } else predicates + case _ => predicates + } + case _ => predicates + } + case _ => predicates + } + } }
