[FLINK-7657] [table] Add time types FilterableTableSource push down This closes #4746.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/3b333b28 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/3b333b28 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/3b333b28 Branch: refs/heads/master Commit: 3b333b289cbbe43f722edef2c36c370ff4550128 Parents: ad8ef6d Author: Kent Murra <[email protected]> Authored: Wed Sep 27 13:48:55 2017 -0700 Committer: twalthr <[email protected]> Committed: Mon Nov 13 14:20:25 2017 +0100 ---------------------------------------------------------------------- .../flink/table/expressions/literals.scala | 34 +-- .../table/plan/util/RexProgramExtractor.scala | 33 ++- .../flink/table/api/TableSourceTest.scala | 65 +++++- .../table/plan/RexProgramExtractorTest.scala | 53 ++++- .../flink/table/plan/RexProgramTestBase.scala | 5 + .../runtime/batch/table/TableSourceITCase.scala | 117 +++++++++- .../stream/table/TableSourceITCase.scala | 2 +- .../flink/table/utils/TableTestBase.scala | 2 - .../table/utils/TestFilterableTableSource.scala | 226 +++++++++++++++++++ .../flink/table/utils/testTableSources.scala | 101 --------- 10 files changed, 511 insertions(+), 127 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala index eb9c4f5..d797cc4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala @@ -17,21 +17,22 @@ */ package org.apache.flink.table.expressions -import java.sql.{Date, Time, Timestamp} -import java.util.{Calendar, TimeZone} - import org.apache.calcite.avatica.util.TimeUnit import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.SqlIntervalQualifier import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.parser.SqlParserPos import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.{DateString, TimeString, TimestampString} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} +import java.sql.{Date, Time, Timestamp} +import java.util.{Calendar, TimeZone} + object Literal { - private[flink] val GMT = TimeZone.getTimeZone("GMT") + private[flink] val UTC = TimeZone.getTimeZone("UTC") private[flink] def apply(l: Any): Literal = l match { case i: Int => Literal(i, BasicTypeInfo.INT_TYPE_INFO) @@ -52,7 +53,7 @@ object Literal { } case class Literal(value: Any, resultType: TypeInformation[_]) extends LeafExpression { - override def toString = resultType match { + override def toString: String = resultType match { case _: BasicTypeInfo[_] => value.toString case [email protected] => value.toString + ".toDate" case [email protected] => value.toString + ".toTime" @@ -77,11 +78,14 @@ case class Literal(value: Any, resultType: TypeInformation[_]) extends LeafExpre // date/time case SqlTimeTypeInfo.DATE => - relBuilder.getRexBuilder.makeDateLiteral(dateToCalendar) + val datestr = DateString.fromCalendarFields(valueAsCalendar) + relBuilder.getRexBuilder.makeDateLiteral(datestr) case SqlTimeTypeInfo.TIME => - relBuilder.getRexBuilder.makeTimeLiteral(dateToCalendar, 0) + val timestr = TimeString.fromCalendarFields(valueAsCalendar) + relBuilder.getRexBuilder.makeTimeLiteral(timestr, 0) case SqlTimeTypeInfo.TIMESTAMP => - relBuilder.getRexBuilder.makeTimestampLiteral(dateToCalendar, 3) + val timestampstr = TimestampString.fromCalendarFields(valueAsCalendar) + relBuilder.getRexBuilder.makeTimestampLiteral(timestampstr, 3) case TimeIntervalTypeInfo.INTERVAL_MONTHS => val interval = java.math.BigDecimal.valueOf(value.asInstanceOf[Int]) @@ -103,12 +107,16 @@ case class Literal(value: Any, resultType: TypeInformation[_]) extends LeafExpre } } - private def dateToCalendar: Calendar = { + /** + * Convert a date value to a calendar. Calcite fromCalendarField functions use the Calendar.get + * methods, so the raw values of the individual fields are preserved when converted to the + * string formats. + * @return Get the Calendar value + */ + private def valueAsCalendar: Calendar = { val date = value.asInstanceOf[java.util.Date] - val cal = Calendar.getInstance(Literal.GMT) - val t = date.getTime - // according to Calcite's SqlFunctions.internalToXXX methods - cal.setTimeInMillis(t + TimeZone.getDefault.getOffset(t)) + val cal = Calendar.getInstance + cal.setTime(date) cal } } http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala index 53bf8e7..d11a43d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala @@ -22,12 +22,16 @@ import org.apache.calcite.plan.RelOptUtil import org.apache.calcite.rex._ import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.sql.{SqlFunction, SqlPostfixOperator} +import org.apache.calcite.util.{DateString, TimeString, TimestampString} +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo} import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.expressions.{And, Expression, Literal, Or, ResolvedFieldReference} import org.apache.flink.table.validate.FunctionCatalog import org.apache.flink.util.Preconditions +import java.sql.{Date, Time, Timestamp} + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.collection.mutable @@ -158,7 +162,33 @@ class RexNodeToExpressionConverter( } override def visitLiteral(literal: RexLiteral): Option[Expression] = { - Some(Literal(literal.getValue, FlinkTypeFactory.toTypeInfo(literal.getType))) + val literalType = FlinkTypeFactory.toTypeInfo(literal.getType) + + val literalValue = literalType match { + // Chrono use cases. + case [email protected] => + val rexValue = literal.getValueAs(classOf[DateString]) + Date.valueOf(rexValue.toString) + case [email protected] => + val rexValue = literal.getValueAs(classOf[TimeString]) + Time.valueOf(rexValue.toString(0)) + case [email protected] => + val rexValue = literal.getValueAs(classOf[TimestampString]) + Timestamp.valueOf(rexValue.toString(3)) + + case [email protected]_TYPE_INFO => + /* + Force integer conversion. RelDataType is INTEGER and SqlTypeName is DECIMAL, + meaning that it will assume that we are using a BigDecimal + and refuse to convert to Integer. + */ + val rexValue = literal.getValueAs(classOf[java.math.BigDecimal]) + rexValue.intValue() + + case _ => literal.getValue + } + + Some(Literal(literalValue, literalType)) } override def visitCall(call: RexCall): Option[Expression] = { @@ -209,7 +239,6 @@ class RexNodeToExpressionConverter( private def replace(str: String): String = { str.replaceAll("\\s|_", "") } - } /** http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala index 4b88bc3..dc84c19 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala @@ -18,13 +18,19 @@ package org.apache.flink.table.api +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation} import org.apache.flink.table.api.scala._ +import org.apache.flink.table.expressions.{BinaryComparison, Expression, Literal} import org.apache.flink.table.expressions.utils._ import org.apache.flink.table.runtime.utils.CommonTestData import org.apache.flink.table.sources.{CsvTableSource, TableSource} import org.apache.flink.table.utils.TableTestUtil._ import org.apache.flink.table.utils.{TableTestBase, TestFilterableTableSource} import org.junit.{Assert, Test} +import _root_.java.sql.{Date, Time, Timestamp} + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.types.Row class TableSourceTest extends TableTestBase { @@ -374,13 +380,69 @@ class TableSourceTest extends TableTestBase { Assert.assertEquals(source1, source2) } + @Test + def testTimeLiteralExpressionPushdown(): Unit = { + val (tableSource, tableName) = filterableTableSourceTimeTypes + val util = batchTestUtil() + val tableEnv = util.tableEnv + + tableEnv.registerTableSource(tableName, tableSource) + + val sqlQuery = + s""" + |SELECT id from $tableName + |WHERE + | tv > TIME '14:25:02' AND + | dv > DATE '2017-02-03' AND + | tsv > TIMESTAMP '2017-02-03 14:25:02.000' + """.stripMargin + + val result = tableEnv.sqlQuery(sqlQuery) + + val expectedFilter = + "'tv > 14:25:02.toTime && " + + "'dv > 2017-02-03.toDate && " + + "'tsv > 2017-02-03 14:25:02.0.toTimestamp" + val expected = unaryNode( + "DataSetCalc", + batchFilterableSourceTableNode( + tableName, + Array("id", "dv", "tv", "tsv"), + expectedFilter), + term("select", "id") + ) + util.verifyTable(result, expected) + } + // utils def filterableTableSource:(TableSource[_], String) = { - val tableSource = new TestFilterableTableSource + val tableSource = TestFilterableTableSource() (tableSource, "filterableTable") } + def filterableTableSourceTimeTypes:(TableSource[_], String) = { + val rowTypeInfo = new RowTypeInfo( + Array[TypeInformation[_]]( + BasicTypeInfo.INT_TYPE_INFO, + SqlTimeTypeInfo.DATE, + SqlTimeTypeInfo.TIME, + SqlTimeTypeInfo.TIMESTAMP + ), + Array("id", "dv", "tv", "tsv") + ) + + val row = new Row(4) + row.setField(0, 1) + row.setField(1, Date.valueOf("2017-01-23")) + row.setField(2, Time.valueOf("14:23:02")) + row.setField(3, Timestamp.valueOf("2017-01-24 12:45:01.234")) + + val tableSource = TestFilterableTableSource(rowTypeInfo, Seq(row), Set("dv", "tv", "tsv")) + (tableSource, "filterableTable") + } + + def csvTable: (CsvTableSource, String) = { val csvTable = CommonTestData.getCsvTableSource val tableName = "csvTable" @@ -414,4 +476,5 @@ class TableSourceTest extends TableTestBase { "StreamTableSourceScan(" + s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], source=[filter=[$exp]])" } + } http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala index c2a01c6..6ed9455 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala @@ -19,12 +19,15 @@ package org.apache.flink.table.plan import java.math.BigDecimal +import java.sql.{Date, Time, Timestamp} -import org.apache.calcite.plan.RelOptUtil +import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex._ import org.apache.calcite.sql.SqlPostfixOperator +import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, INTEGER, VARCHAR} import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.util.{DateString, TimeString, TimestampString} import org.apache.flink.table.expressions._ import org.apache.flink.table.plan.util.{RexNodeToExpressionConverter, RexProgramExtractor} import org.apache.flink.table.utils.InputTypeBuilder.inputOf @@ -199,6 +202,54 @@ class RexProgramExtractorTest extends RexProgramTestBase { } @Test + def testLiteralConversions(): Unit = { + val fieldNames = List("timestamp_col", "date_col", "time_col").asJava + val fieldTypes = makeTypes(SqlTypeName.TIMESTAMP, SqlTypeName.DATE, SqlTypeName.TIME) + + val inputRowType = typeFactory.createStructType(fieldTypes, fieldNames) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + + val timestampString = new TimestampString("2017-09-10 14:23:01.245") + val rexTimestamp = rexBuilder.makeTimestampLiteral(timestampString, 3) + val rexDate = rexBuilder.makeDateLiteral(new DateString("2017-09-12")) + val rexTime = rexBuilder.makeTimeLiteral(new TimeString("14:23:01"), 0) + + val allRexNodes = List(rexTimestamp, rexDate, rexTime) + + val condition = fieldTypes.asScala.zipWithIndex + .map((t: (RelDataType, Int)) => rexBuilder.makeInputRef(t._1, t._2)) + .zip(allRexNodes) + .map(t => rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t._1, t._2)) + .map(builder.addExpr) + .asJava + + builder.addCondition(builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, condition))) + + val (converted, _) = RexProgramExtractor.extractConjunctiveConditions( + builder.getProgram, + new RexBuilder(typeFactory), + functionCatalog) + + + val expected = Array[Expression]( + EqualTo( + UnresolvedFieldReference("timestamp_col"), + Literal(Timestamp.valueOf("2017-09-10 14:23:01.245")) + ), + EqualTo( + UnresolvedFieldReference("date_col"), + Literal(Date.valueOf("2017-09-12")) + ), + EqualTo( + UnresolvedFieldReference("time_col"), + Literal(Time.valueOf("14:23:01")) + ) + ) + + assertExpressionArrayEquals(expected, converted) + } + + @Test def testExtractArithmeticConditions(): Unit = { val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) val builder = new RexProgramBuilder(inputRowType, rexBuilder) http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala index b711604..728694f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala @@ -24,6 +24,7 @@ import org.apache.calcite.adapter.java.JavaTypeFactory import org.apache.calcite.jdbc.JavaTypeFactoryImpl import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem} import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder} +import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.fun.SqlStdOperatorTable @@ -76,4 +77,8 @@ abstract class RexProgramTestBase { builder.getProgram } + protected def makeTypes(fieldTypes: SqlTypeName*): java.util.List[RelDataType] = { + fieldTypes.toList.map(typeFactory.createSqlType).asJava + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/TableSourceITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/TableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/TableSourceITCase.scala index 2292e17..f0fe896 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/TableSourceITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/TableSourceITCase.scala @@ -19,17 +19,17 @@ package org.apache.flink.table.runtime.batch.table import java.lang.{Boolean => JBool, Integer => JInt, Long => JLong} +import java.sql.{Date, Time, Timestamp} import org.apache.calcite.runtime.SqlFunctions.{internalToTimestamp => toTimestamp} -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.DataSet -import org.apache.flink.api.java.{ExecutionEnvironment => JExecEnv} +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation} import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo} +import org.apache.flink.api.java.{DataSet, ExecutionEnvironment => JExecEnv} import org.apache.flink.api.scala.ExecutionEnvironment -import org.apache.flink.table.api.{TableEnvironment, TableException, TableSchema, Types} import org.apache.flink.table.api.scala._ -import org.apache.flink.table.runtime.utils.{CommonTestData, TableProgramsCollectionTestBase} +import org.apache.flink.table.api.{TableEnvironment, TableException, TableSchema, Types} import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.runtime.utils.{CommonTestData, TableProgramsCollectionTestBase} import org.apache.flink.table.sources.BatchTableSource import org.apache.flink.table.utils._ import org.apache.flink.test.util.TestBaseUtils @@ -101,7 +101,7 @@ class TableSourceITCase( val tableName = "MyTable" val env = ExecutionEnvironment.getExecutionEnvironment val tableEnv = TableEnvironment.getTableEnvironment(env, config) - tableEnv.registerTableSource(tableName, new TestFilterableTableSource) + tableEnv.registerTableSource(tableName, TestFilterableTableSource()) val results = tableEnv .scan(tableName) .where("amount > 4 && price < 9") @@ -250,6 +250,37 @@ class TableSourceITCase( "Mary,1970-01-01 00:00:00.0,40", "Bob,1970-01-01 00:00:00.0,20", "Liz,1970-01-01 00:00:02.0,40").mkString("\n") + } + + @Test + def testTableSourceWithFilterableDate(): Unit = { + val tableName = "MyTable" + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + + val rowTypeInfo = new RowTypeInfo( + Array[TypeInformation[_]](BasicTypeInfo.INT_TYPE_INFO, SqlTimeTypeInfo.DATE), + Array("id", "date_val")) + + val rows = Seq( + makeRow(23, Date.valueOf("2017-04-23")), + makeRow(24, Date.valueOf("2017-04-24")), + makeRow(25, Date.valueOf("2017-04-25")), + makeRow(26, Date.valueOf("2017-04-26")) + ) + + val query = + """ + |select id from MyTable + |where date_val >= DATE '2017-04-24' and date_val < DATE '2017-04-26' + """.stripMargin + val tableSource = TestFilterableTableSource(rowTypeInfo, rows, Set("date_val")) + tableEnv.registerTableSource(tableName, tableSource) + val results = tableEnv + .sqlQuery(query) + .collect() + + val expected = Seq(24, 25).mkString("\n") TestBaseUtils.compareResultAsText(results.asJava, expected) } @@ -473,6 +504,7 @@ class TableSourceITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test def testProjectOnlyProctime(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) @@ -504,6 +536,7 @@ class TableSourceITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test def testProjectOnlyRowtime(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) @@ -643,4 +676,76 @@ class TableSourceITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test + def testTableSourceWithFilterableTime(): Unit = { + val tableName = "MyTable" + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + + val rowTypeInfo = new RowTypeInfo( + Array[TypeInformation[_]](BasicTypeInfo.INT_TYPE_INFO, SqlTimeTypeInfo.TIME), + Array("id", "time_val")) + + val rows = Seq( + makeRow(1, Time.valueOf("7:23:19")), + makeRow(2, Time.valueOf("11:45:00")), + makeRow(3, Time.valueOf("11:45:01")), + makeRow(4, Time.valueOf("12:14:23")), + makeRow(5, Time.valueOf("13:33:12")) + ) + + val query = + """ + |select id from MyTable + |where time_val >= TIME '11:45:00' and time_val < TIME '12:14:23' + """.stripMargin + val tableSource = TestFilterableTableSource(rowTypeInfo, rows, Set("time_val")) + tableEnv.registerTableSource(tableName, tableSource) + val results = tableEnv + .sqlQuery(query) + .collect() + + val expected = Seq(2, 3).mkString("\n") + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testTableSourceWithFilterableTimestamp(): Unit = { + val tableName = "MyTable" + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + + val rowTypeInfo = new RowTypeInfo( + Array[TypeInformation[_]](BasicTypeInfo.INT_TYPE_INFO, SqlTimeTypeInfo.TIMESTAMP), + Array("id", "ts")) + + val rows = Seq( + makeRow(1, Timestamp.valueOf("2017-07-11 7:23:19")), + makeRow(2, Timestamp.valueOf("2017-07-12 11:45:00")), + makeRow(3, Timestamp.valueOf("2017-07-13 11:45:01")), + makeRow(4, Timestamp.valueOf("2017-07-14 12:14:23")), + makeRow(5, Timestamp.valueOf("2017-07-13 13:33:12")) + ) + + val query = + """ + |select id from MyTable + |where ts >= TIMESTAMP '2017-07-12 11:45:00' and ts < TIMESTAMP '2017-07-14 12:14:23' + """.stripMargin + val tableSource = TestFilterableTableSource(rowTypeInfo, rows, Set("ts")) + tableEnv.registerTableSource(tableName, tableSource) + val results = tableEnv + .sqlQuery(query) + .collect() + + val expected = Seq(2, 3, 5).mkString("\n") + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + private def makeRow(fields: Any*): Row = { + val row = new Row(fields.length) + val addField = (value: Any, pos: Int) => row.setField(pos, value) + fields.zipWithIndex.foreach(addField.tupled) + row + } } http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSourceITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSourceITCase.scala index a9e9632..77c1e08 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSourceITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSourceITCase.scala @@ -104,7 +104,7 @@ class TableSourceITCase extends StreamingMultipleProgramsTestBase { val tableName = "MyTable" val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) - tEnv.registerTableSource(tableName, new TestFilterableTableSource) + tEnv.registerTableSource(tableName, TestFilterableTableSource()) tEnv.scan(tableName) .where("amount > 4 && price < 9") .select("id, name") http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala index 4042f50..5f8f5d6 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala @@ -212,11 +212,9 @@ case class BatchTableTestUtil() extends TableTestUtil { def printSql(query: String): Unit = { printTable(tableEnv.sqlQuery(query)) } - } case class StreamTableTestUtil() extends TableTestUtil { - val javaEnv = mock(classOf[JStreamExecutionEnvironment]) when(javaEnv.getStreamTimeCharacteristic).thenReturn(TimeCharacteristic.EventTime) val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv) http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala new file mode 100644 index 0000000..ae2b1d6 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.utils + +import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.table.api.TableSchema +import org.apache.flink.table.api.Types._ +import org.apache.flink.table.expressions._ +import org.apache.flink.table.sources.{BatchTableSource, FilterableTableSource, StreamTableSource, TableSource} +import org.apache.flink.types.Row + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +object TestFilterableTableSource { + /** + * @return The default filterable table source. + */ + def apply(): TestFilterableTableSource = { + apply(defaultTypeInfo, defaultRows, defaultFilterableFields) + } + + /** + * A filterable data source with custom data. + * @param rowTypeInfo The type of the data. + * Its expected that both types and field names are provided + * @param rows The data as a sequence of rows. + * @param filterableFields The fields that are allowed to be filtered on. + * @return The table source. + */ + def apply(rowTypeInfo: RowTypeInfo, + rows: Seq[Row], + filterableFields: Set[String]): TestFilterableTableSource = { + new TestFilterableTableSource(rowTypeInfo, rows, filterableFields) + } + + private lazy val defaultFilterableFields = Set("amount") + + private lazy val defaultTypeInfo: RowTypeInfo = { + val fieldNames: Array[String] = Array("name", "id", "amount", "price") + val fieldTypes: Array[TypeInformation[_]] = Array(STRING, LONG, INT, DOUBLE) + new RowTypeInfo(fieldTypes, fieldNames) + } + + + private lazy val defaultRows: Seq[Row] = { + for { + cnt <- 0 until 33 + } yield { + Row.of( + s"Record_$cnt", + cnt.toLong.asInstanceOf[Object], + cnt.toInt.asInstanceOf[Object], + cnt.toDouble.asInstanceOf[Object]) + } + } +} + + +/** + * + * + * A data source that implements some very basic filtering in-memory in order to test + * expression push-down logic. + * + * @param rowTypeInfo The type info for the rows. + * @param data The data that filtering is applied to in order to get the final dataset. + * @param filterableFields The fields that are allowed to be filtered. + * @param filterPredicates The predicates that should be used to filter. + * @param filterPushedDown Whether predicates have been pushed down yet. + */ +class TestFilterableTableSource(rowTypeInfo: RowTypeInfo, + data: Seq[Row], + filterableFields: Set[String] = Set(), + filterPredicates: Seq[Expression] = Seq(), + val filterPushedDown: Boolean = false) + extends BatchTableSource[Row] + with StreamTableSource[Row] + with FilterableTableSource[Row] { + + val fieldNames: Array[String] = rowTypeInfo.getFieldNames + + val fieldTypes: Array[TypeInformation[_]] = rowTypeInfo.getFieldTypes + + // all comparing values for field "amount" + private val filterValues = new mutable.ArrayBuffer[Int] + + override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = { + execEnv.fromCollection[Row](applyPredicatesToRows(data).asJava, getReturnType) + } + + override def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[Row] = { + execEnv.fromCollection[Row](applyPredicatesToRows(data).asJava, getReturnType) + } + + override def explainSource(): String = { + if (filterPredicates.nonEmpty) { + s"filter=[${filterPredicates.reduce((l, r) => And(l, r)).toString}]" + } else { + "" + } + } + + override def getReturnType: TypeInformation[Row] = rowTypeInfo + + override def applyPredicate(predicates: JList[Expression]): TableSource[Row] = { + val predicatesToUse = new mutable.ListBuffer[Expression]() + val iterator = predicates.iterator() + while (iterator.hasNext) { + val expr = iterator.next() + if (shouldPushDown(expr)) { + predicatesToUse += expr + iterator.remove() + } + } + + new TestFilterableTableSource( + rowTypeInfo, + data, + filterableFields, + predicatesToUse, + filterPushedDown = true) + } + + override def isFilterPushedDown: Boolean = filterPushedDown + + private def applyPredicatesToRows(rows: Seq[Row]): Seq[Row] = { + rows.filter(shouldKeep) + } + + private def shouldPushDown(expr: Expression): Boolean = { + expr match { + case binExpr: BinaryComparison => shouldPushDown(binExpr) + case _ => false + } + } + + private def shouldPushDown(expr: BinaryComparison): Boolean = { + (expr.left, expr.right) match { + case (f: ResolvedFieldReference, v: Literal) => + filterableFields.contains(f.name) + case (v: Literal, f: ResolvedFieldReference) => + filterableFields.contains(f.name) + case (f1: ResolvedFieldReference, f2: ResolvedFieldReference) => + filterableFields.contains(f1.name) && filterableFields.contains(f2.name) + case (_, _) => false + } + } + + private def shouldKeep(row: Row): Boolean = { + filterPredicates.isEmpty || filterPredicates.forall { + case expr: BinaryComparison => binaryFilterApplies(expr, row) + case expr => throw new RuntimeException(expr + " not supported!") + } + } + + private def binaryFilterApplies(expr: BinaryComparison, row: Row): Boolean = { + val (lhsValue, rhsValue) = extractValues(expr, row) + + expr match { + case _: GreaterThan => + lhsValue.compareTo(rhsValue) > 0 + case LessThan(l: ResolvedFieldReference, r: Literal) => + lhsValue.compareTo(rhsValue) < 0 + case GreaterThanOrEqual(l: ResolvedFieldReference, r: Literal) => + lhsValue.compareTo(rhsValue) >= 0 + case LessThanOrEqual(l: ResolvedFieldReference, r: Literal) => + lhsValue.compareTo(rhsValue) <= 0 + case EqualTo(l: ResolvedFieldReference, r: Literal) => + lhsValue.compareTo(rhsValue) == 0 + case NotEqualTo(l: ResolvedFieldReference, r: Literal) => + lhsValue.compareTo(rhsValue) != 0 + } + } + + private def extractValues(expr: BinaryComparison, + row: Row): (Comparable[Any], Comparable[Any]) = { + (expr.left, expr.right) match { + case (l: ResolvedFieldReference, r: Literal) => + val idx = rowTypeInfo.getFieldIndex(l.name) + val lv = row.getField(idx).asInstanceOf[Comparable[Any]] + val rv = r.value.asInstanceOf[Comparable[Any]] + (lv, rv) + case (l: Literal, r: ResolvedFieldReference) => + val idx = rowTypeInfo.getFieldIndex(r.name) + val lv = l.value.asInstanceOf[Comparable[Any]] + val rv = row.getField(idx).asInstanceOf[Comparable[Any]] + (lv, rv) + case (l: Literal, r: Literal) => + val lv = l.value.asInstanceOf[Comparable[Any]] + val rv = r.value.asInstanceOf[Comparable[Any]] + (lv, rv) + case (l: ResolvedFieldReference, r: ResolvedFieldReference) => + val lidx = rowTypeInfo.getFieldIndex(l.name) + val ridx = rowTypeInfo.getFieldIndex(r.name) + val lv = row.getField(lidx).asInstanceOf[Comparable[Any]] + val rv = row.getField(ridx).asInstanceOf[Comparable[Any]] + (lv, rv) + case _ => throw new RuntimeException(expr + " not supported!") + } + } + + override def getTableSchema: TableSchema = new TableSchema(fieldNames, fieldTypes) +} http://git-wip-us.apache.org/repos/asf/flink/blob/3b333b28/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/testTableSources.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/testTableSources.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/testTableSources.scala index f11f0ca..c2eba32 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/testTableSources.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/testTableSources.scala @@ -147,107 +147,6 @@ class TestProjectableTableSource( } } -/** - * This source can only handle simple comparision with field "amount". - * Supports ">, <, >=, <=, =, <>" with an integer. - */ -class TestFilterableTableSource( - val recordNum: Int = 33) - extends BatchTableSource[Row] - with StreamTableSource[Row] - with FilterableTableSource[Row] { - - var filterPushedDown: Boolean = false - - val fieldNames: Array[String] = Array("name", "id", "amount", "price") - - val fieldTypes: Array[TypeInformation[_]] = Array(STRING, LONG, INT, DOUBLE) - - // all predicates with field "amount" - private var filterPredicates = new mutable.ArrayBuffer[Expression] - - // all comparing values for field "amount" - private val filterValues = new mutable.ArrayBuffer[Int] - - override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = { - execEnv.fromCollection[Row](generateDynamicCollection().asJava, getReturnType) - } - - override def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[Row] = { - execEnv.fromCollection[Row](generateDynamicCollection().asJava, getReturnType) - } - - override def explainSource(): String = { - if (filterPredicates.nonEmpty) { - s"filter=[${filterPredicates.reduce((l, r) => And(l, r)).toString}]" - } else { - "" - } - } - - override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames) - - override def applyPredicate(predicates: JList[Expression]): TableSource[Row] = { - val newSource = new TestFilterableTableSource(recordNum) - newSource.filterPushedDown = true - - val iterator = predicates.iterator() - while (iterator.hasNext) { - iterator.next() match { - case expr: BinaryComparison => - (expr.left, expr.right) match { - case (f: ResolvedFieldReference, v: Literal) if f.name.equals("amount") => - newSource.filterPredicates += expr - newSource.filterValues += v.value.asInstanceOf[Number].intValue() - iterator.remove() - case (_, _) => - } - case _ => - } - } - - newSource - } - - override def isFilterPushedDown: Boolean = filterPushedDown - - private def generateDynamicCollection(): Seq[Row] = { - Preconditions.checkArgument(filterPredicates.length == filterValues.length) - - for { - cnt <- 0 until recordNum - if shouldCreateRow(cnt) - } yield { - Row.of( - s"Record_$cnt", - cnt.toLong.asInstanceOf[Object], - cnt.toInt.asInstanceOf[Object], - cnt.toDouble.asInstanceOf[Object]) - } - } - - private def shouldCreateRow(value: Int): Boolean = { - filterPredicates.zip(filterValues).forall { - case (_: GreaterThan, v) => - value > v - case (_: LessThan, v) => - value < v - case (_: GreaterThanOrEqual, v) => - value >= v - case (_: LessThanOrEqual, v) => - value <= v - case (_: EqualTo, v) => - value == v - case (_: NotEqualTo, v) => - value != v - case (expr, _) => - throw new RuntimeException(expr + " not supported!") - } - } - - override def getTableSchema: TableSchema = new TableSchema(fieldNames, fieldTypes) -} - class TestNestedProjectableTableSource( tableSchema: TableSchema, returnType: TypeInformation[Row],
