This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit c5ccbfef31af1db7eebd18f275a4f47dab65c855
Author: Jark Wu <imj...@gmail.com>
AuthorDate: Wed Jul 3 10:37:45 2019 +0800

    [FLINK-12978][table] Support LookupableTableSource for CsvTableSource
---
 .../apache/flink/table/sources/CsvTableSource.java | 123 ++++++++++++++++++++-
 .../runtime/batch/sql/TableSourceITCase.scala      |  31 +++++-
 .../runtime/stream/sql/TableSourceITCase.scala     |  30 +++++
 .../apache/flink/table/util/testTableSources.scala |  47 +++++++-
 4 files changed, 228 insertions(+), 3 deletions(-)

diff --git 
a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/sources/CsvTableSource.java
 
b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/sources/CsvTableSource.java
index 160bc9a..1cd49ae 100644
--- 
a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/sources/CsvTableSource.java
+++ 
b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/sources/CsvTableSource.java
@@ -24,15 +24,23 @@ import org.apache.flink.api.java.ExecutionEnvironment;
 import org.apache.flink.api.java.io.CsvInputFormat;
 import org.apache.flink.api.java.io.RowCsvInputFormat;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.core.fs.FileInputSplit;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.TableSchema;
+import org.apache.flink.table.functions.AsyncTableFunction;
+import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.functions.TableFunction;
 import org.apache.flink.types.Row;
 
 import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.stream.IntStream;
 
@@ -41,7 +49,8 @@ import java.util.stream.IntStream;
  * (logically) unlimited number of fields.
  */
 public class CsvTableSource
-       implements StreamTableSource<Row>, BatchTableSource<Row>, 
ProjectableTableSource<Row> {
+       implements StreamTableSource<Row>, BatchTableSource<Row>, 
LookupableTableSource<Row>,
+       ProjectableTableSource<Row> {
 
        private final CsvInputFormatConfig config;
 
@@ -178,6 +187,21 @@ public class CsvTableSource
        }
 
        @Override
+       public TableFunction<Row> getLookupFunction(String[] lookupKeys) {
+               return new CsvLookupFunction(config, lookupKeys);
+       }
+
+       @Override
+       public AsyncTableFunction<Row> getAsyncLookupFunction(String[] 
lookupKeys) {
+               throw new UnsupportedOperationException("CSV do not support 
async lookup");
+       }
+
+       @Override
+       public boolean isAsyncEnabled() {
+               return false;
+       }
+
+       @Override
        public String explainSource() {
                String[] fields = config.getSelectedFieldNames();
                return "CsvTableSource(read fields: " + String.join(", ", 
fields) + ")";
@@ -321,6 +345,103 @@ public class CsvTableSource
 
        }
 
+       // 
------------------------------------------------------------------------------------
+       // private utilities
+       // 
------------------------------------------------------------------------------------
+
+       /**
+        * LookupFunction to support lookup in CsvTableSource.
+        */
+       public static class CsvLookupFunction extends TableFunction<Row> {
+               private static final long serialVersionUID = 1L;
+
+               private final CsvInputFormatConfig config;
+
+               private final List<Integer> sourceKeys = new ArrayList<>();
+               private final List<Integer> targetKeys = new ArrayList<>();
+               private final Map<Object, List<Row>> dataMap = new HashMap<>();
+
+               CsvLookupFunction(CsvInputFormatConfig config, String[] 
lookupKeys) {
+                       this.config = config;
+
+                       List<String> fields = 
Arrays.asList(config.getSelectedFieldNames());
+                       for (int i = 0; i < lookupKeys.length; i++) {
+                               sourceKeys.add(i);
+                               int targetIdx = fields.indexOf(lookupKeys[i]);
+                               assert targetIdx != -1;
+                               targetKeys.add(targetIdx);
+                       }
+               }
+
+               @Override
+               public TypeInformation<Row> getResultType() {
+                       return new RowTypeInfo(config.getSelectedFieldTypes(), 
config.getSelectedFieldNames());
+               }
+
+               @Override
+               public void open(FunctionContext context) throws Exception {
+                       super.open(context);
+                       TypeInformation<Row> rowType = getResultType();
+
+                       RowCsvInputFormat inputFormat = 
config.createInputFormat();
+                       FileInputSplit[] inputSplits = 
inputFormat.createInputSplits(1);
+                       for (FileInputSplit split : inputSplits) {
+                               inputFormat.open(split);
+                               Row row = new Row(rowType.getArity());
+                               while (true) {
+                                       Row r = inputFormat.nextRecord(row);
+                                       if (r == null) {
+                                               break;
+                                       } else {
+                                               Object key = getTargetKey(r);
+                                               List<Row> rows = 
dataMap.computeIfAbsent(key, k -> new ArrayList<>());
+                                               rows.add(Row.copy(r));
+                                       }
+                               }
+                               inputFormat.close();
+                       }
+               }
+
+               public void eval(Object... values) {
+                       Object srcKey = getSourceKey(Row.of(values));
+                       if (dataMap.containsKey(srcKey)) {
+                               for (Row row1 : dataMap.get(srcKey)) {
+                                       collect(row1);
+                               }
+                       }
+               }
+
+               private Object getSourceKey(Row source) {
+                       return getKey(source, sourceKeys);
+               }
+
+               private Object getTargetKey(Row target) {
+                       return getKey(target, targetKeys);
+               }
+
+               private Object getKey(Row input, List<Integer> keys) {
+                       if (keys.size() == 1) {
+                               int keyIdx = keys.get(0);
+                               if (input.getField(keyIdx) != null) {
+                                       return input.getField(keyIdx);
+                               }
+                               return null;
+                       } else {
+                               Row key = new Row(keys.size());
+                               for (int i = 0; i < keys.size(); i++) {
+                                       int keyIdx = keys.get(i);
+                                       key.setField(i, input.getField(keyIdx));
+                               }
+                               return key;
+                       }
+               }
+
+               @Override
+               public void close() throws Exception {
+                       super.close();
+               }
+       }
+
        private static class CsvInputFormatConfig implements Serializable {
                private static final long serialVersionUID = 1L;
 
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/TableSourceITCase.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/TableSourceITCase.scala
index 4abf691..464dcd1 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/TableSourceITCase.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/TableSourceITCase.scala
@@ -26,7 +26,7 @@ import org.apache.flink.table.runtime.utils.{BatchTestBase, 
TestData}
 import org.apache.flink.table.types.TypeInfoDataTypeConverter
 import org.apache.flink.table.util.{TestFilterableTableSource, 
TestNestedProjectableTableSource, TestProjectableTableSource, TestTableSources}
 import org.apache.flink.types.Row
-import org.junit.{Before, Test}
+import org.junit.{Before, Ignore, Test}
 
 import java.lang.{Boolean => JBool, Integer => JInt, Long => JLong}
 
@@ -168,4 +168,33 @@ class TableSourceITCase extends BatchTestBase {
       )
     )
   }
+
+  @Ignore("[FLINK-13075] Project pushdown rule shouldn't require" +
+    " the TableSource return a modified schema in blink planner")
+  @Test
+  def testLookupJoinCsvTemporalTable(): Unit = {
+    val orders = TestTableSources.getOrdersCsvTableSource
+    val rates = TestTableSources.getRatesCsvTableSource
+    tEnv.registerTableSource("orders", orders)
+    tEnv.registerTableSource("rates", rates)
+
+    val sql =
+      """
+        |SELECT o.amount, o.currency, r.rate
+        |FROM (SELECT *, PROCTIME() as proc FROM orders) AS o
+        |JOIN rates FOR SYSTEM_TIME AS OF o.proc AS r
+        |ON o.currency = r.currency
+      """.stripMargin
+
+    checkResult(
+      sql,
+      Seq(
+        row(2, "Euro", 119),
+        row(1, "US Dollar", 102),
+        row(50, "Yen", 1),
+        row(3, "Euro", 119),
+        row(5, "US Dollar", 102)
+      )
+    )
+  }
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TableSourceITCase.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TableSourceITCase.scala
index 9f4c2c3..fa4922a 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TableSourceITCase.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TableSourceITCase.scala
@@ -336,4 +336,34 @@ class TableSourceITCase extends StreamingTestBase {
     assertEquals(expected.sorted, sink.getAppendResults.sorted)
   }
 
+  @Test
+  def testLookupJoinCsvTemporalTable(): Unit = {
+    val orders = TestTableSources.getOrdersCsvTableSource
+    val rates = TestTableSources.getRatesCsvTableSource
+    tEnv.registerTableSource("orders", orders)
+    tEnv.registerTableSource("rates", rates)
+
+    val sql =
+      """
+        |SELECT o.amount, o.currency, r.rate
+        |FROM (SELECT *, PROCTIME() as proc FROM orders) AS o
+        |JOIN rates FOR SYSTEM_TIME AS OF o.proc AS r
+        |ON o.currency = r.currency
+      """.stripMargin
+
+    val sink = new TestingAppendSink()
+    tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink)
+
+    env.execute()
+
+    val expected = Seq(
+      "2,Euro,119",
+      "1,US Dollar,102",
+      "50,Yen,1",
+      "3,Euro,119",
+      "5,US Dollar,102"
+    )
+    assertEquals(expected.sorted, sink.getAppendResults.sorted)
+  }
+
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/testTableSources.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/testTableSources.scala
index 776d2c4..54d14e4 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/testTableSources.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/testTableSources.scala
@@ -57,7 +57,10 @@ object TestTableSources {
       "Kelly#8#2.34#Williams"
     )
 
-    val tempFilePath = writeToTempFile(csvRecords.mkString("$"), "csv-test", 
"tmp")
+    val tempFilePath = writeToTempFile(
+      csvRecords.mkString("$"),
+      "csv-test",
+      "tmp")
     CsvTableSource.builder()
       .path(tempFilePath)
       .field("first", Types.STRING)
@@ -71,6 +74,48 @@ object TestTableSources {
       .build()
   }
 
+  def getOrdersCsvTableSource: CsvTableSource = {
+    val csvRecords = Seq(
+      "2,Euro,2",
+      "1,US Dollar,3",
+      "50,Yen,4",
+      "3,Euro,5",
+      "5,US Dollar,6"
+    )
+    val tempFilePath = writeToTempFile(
+      csvRecords.mkString("$"),
+      "csv-order-test",
+      "tmp")
+    CsvTableSource.builder()
+      .path(tempFilePath)
+      .field("amount", Types.LONG)
+      .field("currency", Types.STRING)
+      .field("ts",Types.LONG)
+      .fieldDelimiter(",")
+      .lineDelimiter("$")
+      .build()
+  }
+
+  def getRatesCsvTableSource: CsvTableSource = {
+    val csvRecords = Seq(
+      "US Dollar,102",
+      "Yen,1",
+      "Euro,119",
+      "RMB,702"
+    )
+    val tempFilePath = writeToTempFile(
+      csvRecords.mkString("$"),
+      "csv-rate-test",
+      "tmp")
+    CsvTableSource.builder()
+      .path(tempFilePath)
+      .field("currency", Types.STRING)
+      .field("rate", Types.LONG)
+      .fieldDelimiter(",")
+      .lineDelimiter("$")
+      .build()
+  }
+
   private def writeToTempFile(
       contents: String,
       filePrefix: String,

Reply via email to