This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit c5ccbfef31af1db7eebd18f275a4f47dab65c855 Author: Jark Wu <imj...@gmail.com> AuthorDate: Wed Jul 3 10:37:45 2019 +0800 [FLINK-12978][table] Support LookupableTableSource for CsvTableSource --- .../apache/flink/table/sources/CsvTableSource.java | 123 ++++++++++++++++++++- .../runtime/batch/sql/TableSourceITCase.scala | 31 +++++- .../runtime/stream/sql/TableSourceITCase.scala | 30 +++++ .../apache/flink/table/util/testTableSources.scala | 47 +++++++- 4 files changed, 228 insertions(+), 3 deletions(-) diff --git a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/sources/CsvTableSource.java b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/sources/CsvTableSource.java index 160bc9a..1cd49ae 100644 --- a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/sources/CsvTableSource.java +++ b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/sources/CsvTableSource.java @@ -24,15 +24,23 @@ import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.io.CsvInputFormat; import org.apache.flink.api.java.io.RowCsvInputFormat; import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.core.fs.FileInputSplit; import org.apache.flink.core.fs.Path; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.functions.AsyncTableFunction; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.TableFunction; import org.apache.flink.types.Row; import java.io.Serializable; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.stream.IntStream; @@ -41,7 +49,8 @@ import java.util.stream.IntStream; * (logically) unlimited number of fields. */ public class CsvTableSource - implements StreamTableSource<Row>, BatchTableSource<Row>, ProjectableTableSource<Row> { + implements StreamTableSource<Row>, BatchTableSource<Row>, LookupableTableSource<Row>, + ProjectableTableSource<Row> { private final CsvInputFormatConfig config; @@ -178,6 +187,21 @@ public class CsvTableSource } @Override + public TableFunction<Row> getLookupFunction(String[] lookupKeys) { + return new CsvLookupFunction(config, lookupKeys); + } + + @Override + public AsyncTableFunction<Row> getAsyncLookupFunction(String[] lookupKeys) { + throw new UnsupportedOperationException("CSV do not support async lookup"); + } + + @Override + public boolean isAsyncEnabled() { + return false; + } + + @Override public String explainSource() { String[] fields = config.getSelectedFieldNames(); return "CsvTableSource(read fields: " + String.join(", ", fields) + ")"; @@ -321,6 +345,103 @@ public class CsvTableSource } + // ------------------------------------------------------------------------------------ + // private utilities + // ------------------------------------------------------------------------------------ + + /** + * LookupFunction to support lookup in CsvTableSource. + */ + public static class CsvLookupFunction extends TableFunction<Row> { + private static final long serialVersionUID = 1L; + + private final CsvInputFormatConfig config; + + private final List<Integer> sourceKeys = new ArrayList<>(); + private final List<Integer> targetKeys = new ArrayList<>(); + private final Map<Object, List<Row>> dataMap = new HashMap<>(); + + CsvLookupFunction(CsvInputFormatConfig config, String[] lookupKeys) { + this.config = config; + + List<String> fields = Arrays.asList(config.getSelectedFieldNames()); + for (int i = 0; i < lookupKeys.length; i++) { + sourceKeys.add(i); + int targetIdx = fields.indexOf(lookupKeys[i]); + assert targetIdx != -1; + targetKeys.add(targetIdx); + } + } + + @Override + public TypeInformation<Row> getResultType() { + return new RowTypeInfo(config.getSelectedFieldTypes(), config.getSelectedFieldNames()); + } + + @Override + public void open(FunctionContext context) throws Exception { + super.open(context); + TypeInformation<Row> rowType = getResultType(); + + RowCsvInputFormat inputFormat = config.createInputFormat(); + FileInputSplit[] inputSplits = inputFormat.createInputSplits(1); + for (FileInputSplit split : inputSplits) { + inputFormat.open(split); + Row row = new Row(rowType.getArity()); + while (true) { + Row r = inputFormat.nextRecord(row); + if (r == null) { + break; + } else { + Object key = getTargetKey(r); + List<Row> rows = dataMap.computeIfAbsent(key, k -> new ArrayList<>()); + rows.add(Row.copy(r)); + } + } + inputFormat.close(); + } + } + + public void eval(Object... values) { + Object srcKey = getSourceKey(Row.of(values)); + if (dataMap.containsKey(srcKey)) { + for (Row row1 : dataMap.get(srcKey)) { + collect(row1); + } + } + } + + private Object getSourceKey(Row source) { + return getKey(source, sourceKeys); + } + + private Object getTargetKey(Row target) { + return getKey(target, targetKeys); + } + + private Object getKey(Row input, List<Integer> keys) { + if (keys.size() == 1) { + int keyIdx = keys.get(0); + if (input.getField(keyIdx) != null) { + return input.getField(keyIdx); + } + return null; + } else { + Row key = new Row(keys.size()); + for (int i = 0; i < keys.size(); i++) { + int keyIdx = keys.get(i); + key.setField(i, input.getField(keyIdx)); + } + return key; + } + } + + @Override + public void close() throws Exception { + super.close(); + } + } + private static class CsvInputFormatConfig implements Serializable { private static final long serialVersionUID = 1L; diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/TableSourceITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/TableSourceITCase.scala index 4abf691..464dcd1 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/TableSourceITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/TableSourceITCase.scala @@ -26,7 +26,7 @@ import org.apache.flink.table.runtime.utils.{BatchTestBase, TestData} import org.apache.flink.table.types.TypeInfoDataTypeConverter import org.apache.flink.table.util.{TestFilterableTableSource, TestNestedProjectableTableSource, TestProjectableTableSource, TestTableSources} import org.apache.flink.types.Row -import org.junit.{Before, Test} +import org.junit.{Before, Ignore, Test} import java.lang.{Boolean => JBool, Integer => JInt, Long => JLong} @@ -168,4 +168,33 @@ class TableSourceITCase extends BatchTestBase { ) ) } + + @Ignore("[FLINK-13075] Project pushdown rule shouldn't require" + + " the TableSource return a modified schema in blink planner") + @Test + def testLookupJoinCsvTemporalTable(): Unit = { + val orders = TestTableSources.getOrdersCsvTableSource + val rates = TestTableSources.getRatesCsvTableSource + tEnv.registerTableSource("orders", orders) + tEnv.registerTableSource("rates", rates) + + val sql = + """ + |SELECT o.amount, o.currency, r.rate + |FROM (SELECT *, PROCTIME() as proc FROM orders) AS o + |JOIN rates FOR SYSTEM_TIME AS OF o.proc AS r + |ON o.currency = r.currency + """.stripMargin + + checkResult( + sql, + Seq( + row(2, "Euro", 119), + row(1, "US Dollar", 102), + row(50, "Yen", 1), + row(3, "Euro", 119), + row(5, "US Dollar", 102) + ) + ) + } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TableSourceITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TableSourceITCase.scala index 9f4c2c3..fa4922a 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TableSourceITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TableSourceITCase.scala @@ -336,4 +336,34 @@ class TableSourceITCase extends StreamingTestBase { assertEquals(expected.sorted, sink.getAppendResults.sorted) } + @Test + def testLookupJoinCsvTemporalTable(): Unit = { + val orders = TestTableSources.getOrdersCsvTableSource + val rates = TestTableSources.getRatesCsvTableSource + tEnv.registerTableSource("orders", orders) + tEnv.registerTableSource("rates", rates) + + val sql = + """ + |SELECT o.amount, o.currency, r.rate + |FROM (SELECT *, PROCTIME() as proc FROM orders) AS o + |JOIN rates FOR SYSTEM_TIME AS OF o.proc AS r + |ON o.currency = r.currency + """.stripMargin + + val sink = new TestingAppendSink() + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + + env.execute() + + val expected = Seq( + "2,Euro,119", + "1,US Dollar,102", + "50,Yen,1", + "3,Euro,119", + "5,US Dollar,102" + ) + assertEquals(expected.sorted, sink.getAppendResults.sorted) + } + } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/testTableSources.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/testTableSources.scala index 776d2c4..54d14e4 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/testTableSources.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/testTableSources.scala @@ -57,7 +57,10 @@ object TestTableSources { "Kelly#8#2.34#Williams" ) - val tempFilePath = writeToTempFile(csvRecords.mkString("$"), "csv-test", "tmp") + val tempFilePath = writeToTempFile( + csvRecords.mkString("$"), + "csv-test", + "tmp") CsvTableSource.builder() .path(tempFilePath) .field("first", Types.STRING) @@ -71,6 +74,48 @@ object TestTableSources { .build() } + def getOrdersCsvTableSource: CsvTableSource = { + val csvRecords = Seq( + "2,Euro,2", + "1,US Dollar,3", + "50,Yen,4", + "3,Euro,5", + "5,US Dollar,6" + ) + val tempFilePath = writeToTempFile( + csvRecords.mkString("$"), + "csv-order-test", + "tmp") + CsvTableSource.builder() + .path(tempFilePath) + .field("amount", Types.LONG) + .field("currency", Types.STRING) + .field("ts",Types.LONG) + .fieldDelimiter(",") + .lineDelimiter("$") + .build() + } + + def getRatesCsvTableSource: CsvTableSource = { + val csvRecords = Seq( + "US Dollar,102", + "Yen,1", + "Euro,119", + "RMB,702" + ) + val tempFilePath = writeToTempFile( + csvRecords.mkString("$"), + "csv-rate-test", + "tmp") + CsvTableSource.builder() + .path(tempFilePath) + .field("currency", Types.STRING) + .field("rate", Types.LONG) + .fieldDelimiter(",") + .lineDelimiter("$") + .build() + } + private def writeToTempFile( contents: String, filePrefix: String,