This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new bc76a93239e [FLINK-28420][table] Support partial caching in sync and async lookup runner bc76a93239e is described below commit bc76a93239e159fba7277da96336c3796887e6c5 Author: Qingsheng Ren <renqs...@gmail.com> AuthorDate: Mon Aug 8 10:33:12 2022 +0800 [FLINK-28420][table] Support partial caching in sync and async lookup runner This closes #20480 --- .../flink/table/functions/AsyncLookupFunction.java | 5 +- .../flink/table/functions/LookupFunction.java | 3 + .../nodes/exec/common/CommonExecLookupJoin.java | 9 ++ .../table/planner/plan/utils/LookupJoinUtil.java | 26 ++- .../planner/codegen/LookupJoinCodeGenerator.scala | 12 +- .../factories/TestValuesRuntimeFunctions.java | 138 +++++++++++----- .../planner/factories/TestValuesTableFactory.java | 71 ++++++--- .../runtime/batch/sql/join/LookupJoinITCase.scala | 110 ++++++++++++- .../runtime/stream/sql/AsyncLookupJoinITCase.scala | 149 ++++++++++++++++-- .../runtime/stream/sql/LookupJoinITCase.scala | 111 ++++++++++++- .../table/lookup/CachingAsyncLookupFunction.java | 133 ++++++++++++++++ .../table/lookup/CachingLookupFunction.java | 167 ++++++++++++++++++++ .../functions/table/lookup/LookupCacheManager.java | 174 +++++++++++++++++++++ .../table/CachingAsyncLookupFunctionTest.java | 129 +++++++++++++++ .../functions/table/CachingLookupFunctionTest.java | 103 ++++++++++++ 15 files changed, 1255 insertions(+), 85 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AsyncLookupFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AsyncLookupFunction.java index 5e58b646d3b..3127a22a71e 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AsyncLookupFunction.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AsyncLookupFunction.java @@ -38,6 +38,9 @@ public abstract class AsyncLookupFunction extends AsyncTableFunction<RowData> { /** * Asynchronously lookup rows matching the lookup keys. * + * <p>Please note that the returning collection of RowData shouldn't be reused across + * invocations. + * * @param keyRow - A {@link RowData} that wraps lookup keys. * @return A collection of all matching rows in the lookup table. */ @@ -47,7 +50,7 @@ public abstract class AsyncLookupFunction extends AsyncTableFunction<RowData> { public final void eval(CompletableFuture<Collection<RowData>> future, Object... keys) { GenericRowData keyRow = GenericRowData.of(keys); asyncLookup(keyRow) - .whenCompleteAsync( + .whenComplete( (result, exception) -> { if (exception != null) { future.completeExceptionally( diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/LookupFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/LookupFunction.java index bef18923589..26ef0e66197 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/LookupFunction.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/LookupFunction.java @@ -37,6 +37,9 @@ public abstract class LookupFunction extends TableFunction<RowData> { /** * Synchronously lookup rows matching the lookup keys. * + * <p>Please note that the returning collection of RowData shouldn't be reused across + * invocations. + * * @param keyRow - A {@link RowData} that wraps lookup keys. * @return A collection of all matching rows in the lookup table. */ diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java index fe25460c247..14914151d3e 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java @@ -333,6 +333,15 @@ public abstract class CommonExecLookupJoin extends ExecNodeBase<RowData> throw new UnsupportedOperationException("to be supported"); } + private LogicalType getLookupKeyLogicalType( + LookupJoinUtil.LookupKey lookupKey, RowType inputRowType) { + if (lookupKey instanceof LookupJoinUtil.FieldRefLookupKey) { + return inputRowType.getTypeAt(((LookupJoinUtil.FieldRefLookupKey) lookupKey).index); + } else { + return ((LookupJoinUtil.ConstantLookupKey) lookupKey).sourceType; + } + } + protected void validateLookupKeyType( final Map<Integer, LookupJoinUtil.LookupKey> lookupKeys, final RowType inputRowType, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/LookupJoinUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/LookupJoinUtil.java index 8b5966eba8e..2a397d7641c 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/LookupJoinUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/LookupJoinUtil.java @@ -23,10 +23,16 @@ import org.apache.flink.table.api.TableException; import org.apache.flink.table.connector.source.AsyncTableFunctionProvider; import org.apache.flink.table.connector.source.LookupTableSource; import org.apache.flink.table.connector.source.TableFunctionProvider; +import org.apache.flink.table.connector.source.lookup.AsyncLookupFunctionProvider; +import org.apache.flink.table.connector.source.lookup.LookupFunctionProvider; +import org.apache.flink.table.connector.source.lookup.PartialCachingAsyncLookupProvider; +import org.apache.flink.table.connector.source.lookup.PartialCachingLookupProvider; import org.apache.flink.table.functions.UserDefinedFunction; import org.apache.flink.table.planner.plan.schema.LegacyTableSourceTable; import org.apache.flink.table.planner.plan.schema.TableSourceTable; import org.apache.flink.table.runtime.connector.source.LookupRuntimeProviderContext; +import org.apache.flink.table.runtime.functions.table.lookup.CachingAsyncLookupFunction; +import org.apache.flink.table.runtime.functions.table.lookup.CachingLookupFunction; import org.apache.flink.table.sources.LookupableTableSource; import org.apache.flink.table.types.logical.LogicalType; @@ -182,7 +188,25 @@ public final class LookupJoinUtil { + "found in TableSourceTable: %s, please check the code to ensure a proper TableFunctionProvider is specified.", temporalTable.getQualifiedName())); } - if (provider instanceof TableFunctionProvider) { + if (provider instanceof LookupFunctionProvider) { + if (provider instanceof PartialCachingLookupProvider) { + PartialCachingLookupProvider partialCachingLookupProvider = + (PartialCachingLookupProvider) provider; + return new CachingLookupFunction( + partialCachingLookupProvider.getCache(), + partialCachingLookupProvider.createLookupFunction()); + } + return ((LookupFunctionProvider) provider).createLookupFunction(); + } else if (provider instanceof AsyncLookupFunctionProvider) { + if (provider instanceof PartialCachingAsyncLookupProvider) { + PartialCachingAsyncLookupProvider partialCachingLookupProvider = + (PartialCachingAsyncLookupProvider) provider; + return new CachingAsyncLookupFunction( + partialCachingLookupProvider.getCache(), + partialCachingLookupProvider.createAsyncLookupFunction()); + } + return ((AsyncLookupFunctionProvider) provider).createAsyncLookupFunction(); + } else if (provider instanceof TableFunctionProvider) { return ((TableFunctionProvider<?>) provider).createTableFunction(); } else if (provider instanceof AsyncTableFunctionProvider) { return ((AsyncTableFunctionProvider<?>) provider).createAsyncTableFunction(); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala index 89c08cb4110..2a6e2c25fbe 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala @@ -25,7 +25,7 @@ import org.apache.flink.table.catalog.DataTypeFactory import org.apache.flink.table.connector.source.{LookupTableSource, ScanTableSource} import org.apache.flink.table.data.{GenericRowData, RowData} import org.apache.flink.table.data.utils.JoinedRowData -import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction, UserDefinedFunction, UserDefinedFunctionHelper} +import org.apache.flink.table.functions.{AsyncLookupFunction, AsyncTableFunction, LookupFunction, TableFunction, UserDefinedFunction, UserDefinedFunctionHelper} import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.GenerateUtils._ @@ -248,7 +248,15 @@ object LookupJoinCodeGenerator { val defaultArgDataTypes = callContext.getArgumentDataTypes.asScala val defaultOutputDataType = callContext.getOutputDataType.get() - val outputClass = toScala(extractSimpleGeneric(baseClass, udf.getClass, 0)) + val outputClass = + if ( + udf.getClass.getSuperclass == classOf[LookupFunction] + || udf.getClass.getSuperclass == classOf[AsyncLookupFunction] + ) { + Some(classOf[RowData]) + } else { + toScala(extractSimpleGeneric(baseClass, udf.getClass, 0)) + } val (argDataTypes, outputDataType) = outputClass match { case Some(c) if c == classOf[Row] => (defaultArgDataTypes, defaultOutputDataType) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesRuntimeFunctions.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesRuntimeFunctions.java index 284d8602984..b915529ef62 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesRuntimeFunctions.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesRuntimeFunctions.java @@ -37,11 +37,13 @@ import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.table.connector.sink.DynamicTableSink.DataStructureConverter; +import org.apache.flink.table.connector.source.LookupTableSource; +import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.TimestampData; -import org.apache.flink.table.functions.AsyncTableFunction; +import org.apache.flink.table.functions.AsyncLookupFunction; import org.apache.flink.table.functions.FunctionContext; -import org.apache.flink.table.functions.TableFunction; +import org.apache.flink.table.functions.LookupFunction; import org.apache.flink.test.util.SuccessException; import org.apache.flink.types.Row; import org.apache.flink.types.RowKind; @@ -65,6 +67,7 @@ import java.util.concurrent.Executors; import static org.apache.flink.table.planner.factories.TestValuesTableFactory.RESOURCE_COUNTER; import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; import static org.assertj.core.api.Assertions.assertThat; /** Runtime function implementations for {@link TestValuesTableFactory}. */ @@ -550,57 +553,95 @@ final class TestValuesRuntimeFunctions { * A lookup function which find matched rows with the given fields. NOTE: We have to declare it * as public because it will be used in code generation. */ - public static class TestValuesLookupFunction extends TableFunction<Row> { + public static class TestValuesLookupFunction extends LookupFunction { private static final long serialVersionUID = 1L; - private final Map<Row, List<Row>> data; + private final List<Row> data; + private final int[] lookupIndices; + private final LookupTableSource.DataStructureConverter converter; + private transient Map<RowData, List<RowData>> indexedData; private transient boolean isOpenCalled = false; - protected TestValuesLookupFunction(Map<Row, List<Row>> data) { + protected TestValuesLookupFunction( + List<Row> data, + int[] lookupIndices, + LookupTableSource.DataStructureConverter converter) { this.data = data; + this.lookupIndices = lookupIndices; + this.converter = converter; } @Override public void open(FunctionContext context) throws Exception { RESOURCE_COUNTER.incrementAndGet(); isOpenCalled = true; + indexDataByKey(); } - public void eval(Object... inputs) { + @Override + public Collection<RowData> lookup(RowData keyRow) throws IOException { checkArgument(isOpenCalled, "open() is not called."); - Row key = Row.of(inputs); - if (Arrays.asList(inputs).contains(null)) { - throw new IllegalArgumentException( + for (int i = 0; i < keyRow.getArity(); i++) { + checkNotNull( + ((GenericRowData) keyRow).getField(i), String.format( "Lookup key %s contains null value, which should not happen.", - key)); - } - List<Row> list = data.get(key); - if (list != null) { - list.forEach(this::collect); + keyRow)); } + return indexedData.get(keyRow); } @Override public void close() throws Exception { RESOURCE_COUNTER.decrementAndGet(); } + + private void indexDataByKey() { + indexedData = new HashMap<>(); + data.forEach( + record -> { + GenericRowData rowData = (GenericRowData) converter.toInternal(record); + checkNotNull( + rowData, "Cannot convert record to internal GenericRowData type"); + RowData key = + GenericRowData.of( + Arrays.stream(lookupIndices) + .mapToObj(rowData::getField) + .toArray()); + List<RowData> list = indexedData.get(key); + if (list != null) { + list.add(rowData); + } else { + list = new ArrayList<>(); + list.add(rowData); + indexedData.put(key, list); + } + }); + } } /** * An async lookup function which find matched rows with the given fields. NOTE: We have to * declare it as public because it will be used in code generation. */ - public static class AsyncTestValueLookupFunction extends AsyncTableFunction<Row> { + public static class AsyncTestValueLookupFunction extends AsyncLookupFunction { private static final long serialVersionUID = 1L; - private final Map<Row, List<Row>> mapping; + private final List<Row> data; + private final int[] lookupIndices; + private final LookupTableSource.DataStructureConverter converter; private final Random random; private transient boolean isOpenCalled = false; private transient ExecutorService executor; + private transient Map<RowData, List<RowData>> indexedData; - protected AsyncTestValueLookupFunction(Map<Row, List<Row>> mapping) { - this.mapping = mapping; + protected AsyncTestValueLookupFunction( + List<Row> data, + int[] lookupIndices, + LookupTableSource.DataStructureConverter converter) { + this.data = data; + this.lookupIndices = lookupIndices; + this.converter = converter; this.random = new Random(); } @@ -610,33 +651,29 @@ final class TestValuesRuntimeFunctions { isOpenCalled = true; // generate unordered result for async lookup executor = Executors.newFixedThreadPool(2); + indexDataByKey(); } - public void eval(CompletableFuture<Collection<Row>> resultFuture, Object... inputs) { + @Override + public CompletableFuture<Collection<RowData>> asyncLookup(RowData keyRow) { checkArgument(isOpenCalled, "open() is not called."); - final Row key = Row.of(inputs); - if (Arrays.asList(inputs).contains(null)) { - throw new IllegalArgumentException( + for (int i = 0; i < keyRow.getArity(); i++) { + checkNotNull( + ((GenericRowData) keyRow).getField(i), String.format( "Lookup key %s contains null value, which should not happen.", - key)); + keyRow)); } - CompletableFuture.supplyAsync( - () -> { - try { - Thread.sleep(random.nextInt(5)); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - List<Row> list = mapping.get(key); - if (list == null) { - return Collections.<Row>emptyList(); - } else { - return list; - } - }, - executor) - .thenAccept(resultFuture::complete); + return CompletableFuture.supplyAsync( + () -> { + try { + Thread.sleep(random.nextInt(5)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return indexedData.get(keyRow); + }, + executor); } @Override @@ -646,5 +683,28 @@ final class TestValuesRuntimeFunctions { executor.shutdown(); } } + + private void indexDataByKey() { + indexedData = new HashMap<>(); + data.forEach( + record -> { + GenericRowData rowData = (GenericRowData) converter.toInternal(record); + checkNotNull( + rowData, "Cannot convert record to internal GenericRowData type"); + RowData key = + GenericRowData.of( + Arrays.stream(lookupIndices) + .mapToObj(rowData::getField) + .toArray()); + List<RowData> list = indexedData.get(key); + if (list != null) { + list.add(rowData); + } else { + list = new ArrayList<>(); + list.add(rowData); + indexedData.put(key, list); + } + }); + } } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java index b3f545fe718..44d7c537be2 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java @@ -68,6 +68,13 @@ import org.apache.flink.table.connector.source.abilities.SupportsProjectionPushD import org.apache.flink.table.connector.source.abilities.SupportsReadingMetadata; import org.apache.flink.table.connector.source.abilities.SupportsSourceWatermark; import org.apache.flink.table.connector.source.abilities.SupportsWatermarkPushDown; +import org.apache.flink.table.connector.source.lookup.AsyncLookupFunctionProvider; +import org.apache.flink.table.connector.source.lookup.LookupFunctionProvider; +import org.apache.flink.table.connector.source.lookup.LookupOptions; +import org.apache.flink.table.connector.source.lookup.PartialCachingAsyncLookupProvider; +import org.apache.flink.table.connector.source.lookup.PartialCachingLookupProvider; +import org.apache.flink.table.connector.source.lookup.cache.DefaultLookupCache; +import org.apache.flink.table.connector.source.lookup.cache.LookupCache; import org.apache.flink.table.data.RowData; import org.apache.flink.table.expressions.AggregateExpression; import org.apache.flink.table.expressions.FieldReferenceExpression; @@ -128,6 +135,11 @@ import java.util.stream.Collectors; import scala.collection.Seq; +import static org.apache.flink.table.connector.source.lookup.LookupOptions.CACHE_TYPE; +import static org.apache.flink.table.connector.source.lookup.LookupOptions.PARTIAL_CACHE_CACHE_MISSING_KEY; +import static org.apache.flink.table.connector.source.lookup.LookupOptions.PARTIAL_CACHE_EXPIRE_AFTER_ACCESS; +import static org.apache.flink.table.connector.source.lookup.LookupOptions.PARTIAL_CACHE_EXPIRE_AFTER_WRITE; +import static org.apache.flink.table.connector.source.lookup.LookupOptions.PARTIAL_CACHE_MAX_ROWS; import static org.apache.flink.util.Preconditions.checkArgument; import static org.assertj.core.api.Assertions.assertThat; @@ -394,6 +406,10 @@ public final class TestValuesTableFactory boolean failingSource = helper.getOptions().get(FAILING_SOURCE); int numElementToSkip = helper.getOptions().get(SOURCE_NUM_ELEMENT_TO_SKIP); boolean internalData = helper.getOptions().get(INTERNAL_DATA); + DefaultLookupCache cache = null; + if (helper.getOptions().get(CACHE_TYPE).equals(LookupOptions.LookupCacheType.PARTIAL)) { + cache = DefaultLookupCache.fromConfig(helper.getOptions()); + } Optional<List<String>> filterableFields = helper.getOptions().getOptional(FILTERABLE_FIELDS); @@ -508,7 +524,8 @@ public final class TestValuesTableFactory Long.MAX_VALUE, partitions, readableMetadata, - null); + null, + cache); } } else { try { @@ -609,7 +626,12 @@ public final class TestValuesTableFactory ENABLE_WATERMARK_PUSH_DOWN, SINK_DROP_LATE_EVENT, SOURCE_NUM_ELEMENT_TO_SKIP, - INTERNAL_DATA)); + INTERNAL_DATA, + CACHE_TYPE, + PARTIAL_CACHE_EXPIRE_AFTER_ACCESS, + PARTIAL_CACHE_EXPIRE_AFTER_WRITE, + PARTIAL_CACHE_CACHE_MISSING_KEY, + PARTIAL_CACHE_MAX_ROWS)); } private static int validateAndExtractRowtimeIndex( @@ -1451,6 +1473,7 @@ public final class TestValuesTableFactory implements LookupTableSource, SupportsDynamicFiltering { private final @Nullable String lookupFunctionClass; + private final @Nullable LookupCache cache; private final boolean isAsync; private TestValuesScanLookupTableSource( @@ -1471,7 +1494,8 @@ public final class TestValuesTableFactory long limit, List<Map<String, String>> allPartitions, Map<String, DataType> readableMetadata, - @Nullable int[] projectedMetadataFields) { + @Nullable int[] projectedMetadataFields, + @Nullable LookupCache cache) { super( producedDataType, changelogMode, @@ -1491,6 +1515,7 @@ public final class TestValuesTableFactory projectedMetadataFields); this.lookupFunctionClass = lookupFunctionClass; this.isAsync = isAsync; + this.cache = cache; } @SuppressWarnings({"unchecked", "rawtypes"}) @@ -1513,7 +1538,6 @@ public final class TestValuesTableFactory } int[] lookupIndices = Arrays.stream(context.getKeys()).mapToInt(k -> k[0]).toArray(); - Map<Row, List<Row>> mapping = new HashMap<>(); Collection<Row> rows; if (allPartitions.equals(Collections.EMPTY_LIST)) { rows = data.getOrDefault(Collections.EMPTY_MAP, Collections.EMPTY_LIST); @@ -1531,27 +1555,25 @@ public final class TestValuesTableFactory data = data.subList(numElementToSkip, data.size()); } } - - data.forEach( - record -> { - Row key = - Row.of( - Arrays.stream(lookupIndices) - .mapToObj(record::getField) - .toArray()); - List<Row> list = mapping.get(key); - if (list != null) { - list.add(record); - } else { - list = new ArrayList<>(); - list.add(record); - mapping.put(key, list); - } - }); + DataStructureConverter converter = + context.createDataStructureConverter(producedDataType); if (isAsync) { - return AsyncTableFunctionProvider.of(new AsyncTestValueLookupFunction(mapping)); + AsyncTestValueLookupFunction asyncLookupFunction = + new AsyncTestValueLookupFunction(data, lookupIndices, converter); + if (cache == null) { + return AsyncLookupFunctionProvider.of(asyncLookupFunction); + } else { + + return PartialCachingAsyncLookupProvider.of(asyncLookupFunction, cache); + } } else { - return TableFunctionProvider.of(new TestValuesLookupFunction(mapping)); + TestValuesLookupFunction lookupFunction = + new TestValuesLookupFunction(data, lookupIndices, converter); + if (cache == null) { + return LookupFunctionProvider.of(lookupFunction); + } else { + return PartialCachingLookupProvider.of(lookupFunction, cache); + } } } @@ -1575,7 +1597,8 @@ public final class TestValuesTableFactory limit, allPartitions, readableMetadata, - projectedMetadataFields); + projectedMetadataFields, + cache); } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/LookupJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/LookupJoinITCase.scala index df9f37b9e0a..aaf22ab7150 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/LookupJoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/LookupJoinITCase.scala @@ -18,10 +18,16 @@ package org.apache.flink.table.planner.runtime.batch.sql.join import org.apache.flink.table.api.{TableSchema, Types} +import org.apache.flink.table.connector.source.lookup.LookupOptions +import org.apache.flink.table.data.GenericRowData +import org.apache.flink.table.data.binary.BinaryStringData import org.apache.flink.table.planner.factories.TestValuesTableFactory import org.apache.flink.table.planner.runtime.utils.{BatchTestBase, InMemoryLookupableTableSource} +import org.apache.flink.table.runtime.functions.table.lookup.LookupCacheManager import org.apache.flink.types.Row +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.IterableAssert.assertThatIterable import org.junit.{After, Assume, Before, Test} import org.junit.Assert.assertEquals import org.junit.runner.RunWith @@ -33,7 +39,8 @@ import java.util import scala.collection.JavaConversions._ @RunWith(classOf[Parameterized]) -class LookupJoinITCase(legacyTableSource: Boolean, isAsyncMode: Boolean) extends BatchTestBase { +class LookupJoinITCase(legacyTableSource: Boolean, isAsyncMode: Boolean, enableCache: Boolean) + extends BatchTestBase { val data = List( rowOf(1L, 12L, "Julian"), @@ -96,12 +103,21 @@ class LookupJoinITCase(legacyTableSource: Boolean, isAsyncMode: Boolean) extends isBounded = true) } else { val dataId = TestValuesTableFactory.registerData(data) + val cacheOptions = + if (enableCache) + s""" + | '${LookupOptions.CACHE_TYPE.key()}' = '${LookupOptions.LookupCacheType.PARTIAL}', + | '${LookupOptions.PARTIAL_CACHE_MAX_ROWS.key()}' = '${Long.MaxValue}', + |""".stripMargin + else "" + tEnv.executeSql(s""" |CREATE TABLE $tableName ( | `age` INT, | `id` BIGINT, | `name` STRING |) WITH ( + | $cacheOptions | 'connector' = 'values', | 'data-id' = '$dataId', | 'async' = '$isAsyncMode', @@ -114,6 +130,13 @@ class LookupJoinITCase(legacyTableSource: Boolean, isAsyncMode: Boolean) extends private def createLookupTableWithComputedColumn(tableName: String, data: List[Row]): Unit = { if (!legacyTableSource) { val dataId = TestValuesTableFactory.registerData(data) + val cacheOptions = + if (enableCache) + s""" + | '${LookupOptions.CACHE_TYPE.key()}' = '${LookupOptions.LookupCacheType.PARTIAL}', + | '${LookupOptions.PARTIAL_CACHE_MAX_ROWS.key()}' = '${Long.MaxValue}', + |""".stripMargin + else "" tEnv.executeSql(s""" |CREATE TABLE $tableName ( | `age` INT, @@ -121,6 +144,7 @@ class LookupJoinITCase(legacyTableSource: Boolean, isAsyncMode: Boolean) extends | `name` STRING, | `nominal_age` as age + 1 |) WITH ( + | $cacheOptions | 'connector' = 'values', | 'data-id' = '$dataId', | 'async' = '$isAsyncMode', @@ -304,17 +328,91 @@ class LookupJoinITCase(legacyTableSource: Boolean, isAsyncMode: Boolean) extends BatchTestBase.row(3, 15, "Fabian", "Fabian", 33, 34)) checkResult(sql, expected) } + + @Test + def testLookupCacheSharingAcrossSubtasks(): Unit = { + if (!enableCache) { + return + } + // Keep the cache for later validation + LookupCacheManager.keepCacheOnRelease(true) + try { + // Use datagen source here to support parallel running + val sourceDdl = + s""" + |CREATE TABLE datagen_source ( + | id BIGINT, + | proc AS PROCTIME() + |) WITH ( + | 'connector' = 'datagen', + | 'fields.id.kind' = 'sequence', + | 'fields.id.start' = '1', + | 'fields.id.end' = '6', + | 'number-of-rows' = '6' + |) + |""".stripMargin + tEnv.executeSql(sourceDdl) + val sql = + """ + |SELECT T.id, D.name, D.age FROM datagen_source as T + |LEFT JOIN userTable FOR SYSTEM_TIME AS OF T.proc AS D + |ON T.id = D.id + |""".stripMargin + executeQuery(parseQuery(sql)) + + // Validate that only one cache is registered + val managedCaches = LookupCacheManager.getInstance().getManagedCaches + assertThat(managedCaches.size()).isEqualTo(1) + + // Validate 6 entries are cached + val cache = managedCaches.get(managedCaches.keySet().iterator().next()).getCache + assertThat(cache.size()).isEqualTo(6) + + // Validate contents of cached entries + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(1L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(11), jl(1L), BinaryStringData.fromString("Julian"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(2L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(22), jl(2L), BinaryStringData.fromString("Jark"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(3L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(33), jl(3L), BinaryStringData.fromString("Fabian"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(4L)))).isEmpty() + } finally { + LookupCacheManager.getInstance().checkAllReleased() + LookupCacheManager.getInstance().clear() + LookupCacheManager.keepCacheOnRelease(false) + } + } + + def ji(i: Int): java.lang.Integer = { + new java.lang.Integer(i) + } + + def jl(l: Long): java.lang.Long = { + new java.lang.Long(l) + } } object LookupJoinITCase { - @Parameterized.Parameters(name = "LegacyTableSource={0}, isAsyncMode = {1}") + val LEGACY_TABLE_SOURCE: JBoolean = JBoolean.TRUE; + val DYNAMIC_TABLE_SOURCE: JBoolean = JBoolean.FALSE; + val ASYNC_MODE: JBoolean = JBoolean.TRUE; + val SYNC_MODE: JBoolean = JBoolean.FALSE; + val ENABLE_CACHE: JBoolean = JBoolean.TRUE; + val DISABLE_CACHE: JBoolean = JBoolean.FALSE; + + @Parameterized.Parameters(name = "LegacyTableSource={0}, isAsyncMode = {1}, enableCache = {2}") def parameters(): util.Collection[Array[java.lang.Object]] = { Seq[Array[AnyRef]]( - Array(JBoolean.TRUE, JBoolean.TRUE), - Array(JBoolean.TRUE, JBoolean.FALSE), - Array(JBoolean.FALSE, JBoolean.TRUE), - Array(JBoolean.FALSE, JBoolean.FALSE) + Array(LEGACY_TABLE_SOURCE, ASYNC_MODE, DISABLE_CACHE), + Array(LEGACY_TABLE_SOURCE, SYNC_MODE, DISABLE_CACHE), + Array(DYNAMIC_TABLE_SOURCE, ASYNC_MODE, DISABLE_CACHE), + Array(DYNAMIC_TABLE_SOURCE, SYNC_MODE, DISABLE_CACHE), + Array(DYNAMIC_TABLE_SOURCE, ASYNC_MODE, ENABLE_CACHE), + Array(DYNAMIC_TABLE_SOURCE, SYNC_MODE, ENABLE_CACHE) ) } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AsyncLookupJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AsyncLookupJoinITCase.scala index 94858311ed6..5121972aa7b 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AsyncLookupJoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AsyncLookupJoinITCase.scala @@ -22,13 +22,19 @@ import org.apache.flink.table.api.{TableSchema, Types} import org.apache.flink.table.api.bridge.scala._ import org.apache.flink.table.api.config.ExecutionConfigOptions import org.apache.flink.table.api.config.ExecutionConfigOptions.AsyncOutputMode +import org.apache.flink.table.connector.source.lookup.LookupOptions +import org.apache.flink.table.data.GenericRowData +import org.apache.flink.table.data.binary.BinaryStringData import org.apache.flink.table.planner.factories.TestValuesTableFactory import org.apache.flink.table.planner.runtime.utils.{InMemoryLookupableTableSource, StreamingWithStateTestBase, TestingAppendSink, TestingRetractSink} import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode} import org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils._ +import org.apache.flink.table.runtime.functions.table.lookup.LookupCacheManager import org.apache.flink.types.Row import org.apache.flink.util.ExceptionUtils +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.IterableAssert.assertThatIterable import org.junit.{After, Before, Test} import org.junit.Assert.{assertEquals, assertTrue, fail} import org.junit.runner.RunWith @@ -44,7 +50,8 @@ class AsyncLookupJoinITCase( legacyTableSource: Boolean, backend: StateBackendMode, objectReuse: Boolean, - asyncOutputMode: AsyncOutputMode) + asyncOutputMode: AsyncOutputMode, + enableCache: Boolean) extends StreamingWithStateTestBase(backend) { val data = List( @@ -97,12 +104,20 @@ class AsyncLookupJoinITCase( tableName) } else { val dataId = TestValuesTableFactory.registerData(data) + val cacheOptions = + if (enableCache) + s""" + | '${LookupOptions.CACHE_TYPE.key()}' = '${LookupOptions.LookupCacheType.PARTIAL}', + | '${LookupOptions.PARTIAL_CACHE_MAX_ROWS.key()}' = '${Long.MaxValue}', + |""".stripMargin + else "" tEnv.executeSql(s""" |CREATE TABLE $tableName ( | `age` INT, | `id` BIGINT, | `name` STRING |) WITH ( + | $cacheOptions | 'connector' = 'values', | 'data-id' = '$dataId', | 'async' = 'true' @@ -305,19 +320,135 @@ class AsyncLookupJoinITCase( fail("NumberFormatException is expected here!") } + @Test + def testLookupCacheSharingAcrossSubtasks(): Unit = { + if (!enableCache) { + return + } + // Keep the cache for later validation + LookupCacheManager.keepCacheOnRelease(true) + try { + // Use datagen source here to support parallel running + val sourceDdl = + s""" + |CREATE TABLE T ( + | id BIGINT, + | proc AS PROCTIME() + |) WITH ( + | 'connector' = 'datagen', + | 'fields.id.kind' = 'sequence', + | 'fields.id.start' = '1', + | 'fields.id.end' = '6' + |) + |""".stripMargin + tEnv.executeSql(sourceDdl) + val sql = + """ + |SELECT T.id, D.name, D.age FROM T + |LEFT JOIN user_table FOR SYSTEM_TIME AS OF T.proc AS D + |ON T.id = D.id + |""".stripMargin + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + // Validate that only one cache is registered + val managedCaches = LookupCacheManager.getInstance().getManagedCaches + assertThat(managedCaches.size()).isEqualTo(1) + + // Validate 6 entries are cached + val cache = managedCaches.get(managedCaches.keySet().iterator().next()).getCache + assertThat(cache.size()).isEqualTo(6) + + // Validate contents of cached entries + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(1L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(11), jl(1L), BinaryStringData.fromString("Julian"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(2L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(22), jl(2L), BinaryStringData.fromString("Jark"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(3L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(33), jl(3L), BinaryStringData.fromString("Fabian"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(4L)))).isEmpty() + } finally { + LookupCacheManager.getInstance().checkAllReleased() + LookupCacheManager.getInstance().clear() + LookupCacheManager.keepCacheOnRelease(false) + } + } + + def ji(i: Int): java.lang.Integer = { + new java.lang.Integer(i) + } + + def jl(l: Long): java.lang.Long = { + new java.lang.Long(l) + } + } object AsyncLookupJoinITCase { - @Parameterized.Parameters( - name = "LegacyTableSource={0}, StateBackend={1}, ObjectReuse={2}, AsyncOutputMode={3}") + + val LEGACY_TABLE_SOURCE: JBoolean = JBoolean.TRUE; + val DYNAMIC_TABLE_SOURCE: JBoolean = JBoolean.FALSE; + val ENABLE_OBJECT_REUSE: JBoolean = JBoolean.TRUE; + val DISABLE_OBJECT_REUSE: JBoolean = JBoolean.FALSE; + val ENABLE_CACHE: JBoolean = JBoolean.TRUE; + val DISABLE_CACHE: JBoolean = JBoolean.FALSE; + + @Parameterized.Parameters(name = + "LegacyTableSource={0}, StateBackend={1}, ObjectReuse={2}, AsyncOutputMode={3}, EnableCache={4}") def parameters(): JCollection[Array[Object]] = { Seq[Array[AnyRef]]( - Array(JBoolean.TRUE, HEAP_BACKEND, JBoolean.TRUE, AsyncOutputMode.ALLOW_UNORDERED), - Array(JBoolean.TRUE, ROCKSDB_BACKEND, JBoolean.FALSE, AsyncOutputMode.ORDERED), - Array(JBoolean.FALSE, HEAP_BACKEND, JBoolean.FALSE, AsyncOutputMode.ORDERED), - Array(JBoolean.FALSE, HEAP_BACKEND, JBoolean.TRUE, AsyncOutputMode.ORDERED), - Array(JBoolean.FALSE, ROCKSDB_BACKEND, JBoolean.FALSE, AsyncOutputMode.ALLOW_UNORDERED), - Array(JBoolean.FALSE, ROCKSDB_BACKEND, JBoolean.TRUE, AsyncOutputMode.ALLOW_UNORDERED) + Array( + LEGACY_TABLE_SOURCE, + HEAP_BACKEND, + ENABLE_OBJECT_REUSE, + AsyncOutputMode.ALLOW_UNORDERED, + DISABLE_CACHE), + Array( + LEGACY_TABLE_SOURCE, + ROCKSDB_BACKEND, + DISABLE_OBJECT_REUSE, + AsyncOutputMode.ORDERED, + DISABLE_CACHE), + Array( + DYNAMIC_TABLE_SOURCE, + HEAP_BACKEND, + DISABLE_OBJECT_REUSE, + AsyncOutputMode.ORDERED, + DISABLE_CACHE), + Array( + DYNAMIC_TABLE_SOURCE, + HEAP_BACKEND, + ENABLE_OBJECT_REUSE, + AsyncOutputMode.ORDERED, + DISABLE_CACHE), + Array( + DYNAMIC_TABLE_SOURCE, + ROCKSDB_BACKEND, + DISABLE_OBJECT_REUSE, + AsyncOutputMode.ALLOW_UNORDERED, + DISABLE_CACHE), + Array( + DYNAMIC_TABLE_SOURCE, + ROCKSDB_BACKEND, + ENABLE_OBJECT_REUSE, + AsyncOutputMode.ALLOW_UNORDERED, + DISABLE_CACHE), + Array( + DYNAMIC_TABLE_SOURCE, + HEAP_BACKEND, + DISABLE_OBJECT_REUSE, + AsyncOutputMode.ORDERED, + ENABLE_CACHE), + Array( + DYNAMIC_TABLE_SOURCE, + HEAP_BACKEND, + ENABLE_OBJECT_REUSE, + AsyncOutputMode.ALLOW_UNORDERED, + ENABLE_CACHE) ) } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/LookupJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/LookupJoinITCase.scala index 32e1e556509..e7253da141d 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/LookupJoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/LookupJoinITCase.scala @@ -20,11 +20,17 @@ package org.apache.flink.table.planner.runtime.stream.sql import org.apache.flink.api.scala._ import org.apache.flink.table.api._ import org.apache.flink.table.api.bridge.scala._ +import org.apache.flink.table.connector.source.lookup.LookupOptions +import org.apache.flink.table.data.GenericRowData +import org.apache.flink.table.data.binary.BinaryStringData import org.apache.flink.table.planner.factories.TestValuesTableFactory import org.apache.flink.table.planner.runtime.utils.{InMemoryLookupableTableSource, StreamingTestBase, TestingAppendSink} import org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils.TestAddWithOpen +import org.apache.flink.table.runtime.functions.table.lookup.LookupCacheManager import org.apache.flink.types.Row +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.IterableAssert.assertThatIterable import org.junit.{After, Before, Test} import org.junit.Assert.{assertEquals, assertTrue} import org.junit.runner.RunWith @@ -37,7 +43,7 @@ import java.util.{Collection => JCollection} import scala.collection.JavaConversions._ @RunWith(classOf[Parameterized]) -class LookupJoinITCase(legacyTableSource: Boolean) extends StreamingTestBase { +class LookupJoinITCase(legacyTableSource: Boolean, enableCache: Boolean) extends StreamingTestBase { val data = List( rowOf(1L, 12, "Julian"), @@ -100,12 +106,21 @@ class LookupJoinITCase(legacyTableSource: Boolean) extends StreamingTestBase { tableName) } else { val dataId = TestValuesTableFactory.registerData(data) + val cacheOptions = + if (enableCache) + s""" + | '${LookupOptions.CACHE_TYPE.key()}' = '${LookupOptions.LookupCacheType.PARTIAL}', + | '${LookupOptions.PARTIAL_CACHE_MAX_ROWS.key()}' = '${Long.MaxValue}', + |""".stripMargin + else "" + tEnv.executeSql(s""" |CREATE TABLE $tableName ( | `age` INT, | `id` BIGINT, | `name` STRING |) WITH ( + | $cacheOptions | 'connector' = 'values', | 'data-id' = '$dataId' |) @@ -116,6 +131,13 @@ class LookupJoinITCase(legacyTableSource: Boolean) extends StreamingTestBase { private def createLookupTableWithComputedColumn(tableName: String, data: List[Row]): Unit = { if (!legacyTableSource) { val dataId = TestValuesTableFactory.registerData(data) + val cacheOptions = + if (enableCache) + s""" + | '${LookupOptions.CACHE_TYPE.key()}' = '${LookupOptions.LookupCacheType.PARTIAL}', + | '${LookupOptions.PARTIAL_CACHE_MAX_ROWS.key()}' = '${Long.MaxValue}', + |""".stripMargin + else "" tEnv.executeSql(s""" |CREATE TABLE $tableName ( | `age` INT, @@ -123,6 +145,7 @@ class LookupJoinITCase(legacyTableSource: Boolean) extends StreamingTestBase { | `name` STRING, | `nominal_age` as age + 1 |) WITH ( + | $cacheOptions | 'connector' = 'values', | 'data-id' = '$dataId' |) @@ -529,11 +552,93 @@ class LookupJoinITCase(legacyTableSource: Boolean) extends StreamingTestBase { env.execute() assertEquals(Seq(), sink.getAppendResults) } + + @Test + def testLookupCacheSharingAcrossSubtasks(): Unit = { + if (!enableCache) { + return + } + // Keep the cache for later validation + LookupCacheManager.keepCacheOnRelease(true) + try { + // Use datagen source here to support parallel running + val sourceDdl = + s""" + |CREATE TABLE T ( + | id BIGINT, + | proc AS PROCTIME() + |) WITH ( + | 'connector' = 'datagen', + | 'fields.id.kind' = 'sequence', + | 'fields.id.start' = '1', + | 'fields.id.end' = '6' + |) + |""".stripMargin + tEnv.executeSql(sourceDdl) + val sql = + """ + |SELECT T.id, D.name, D.age FROM T + |LEFT JOIN user_table FOR SYSTEM_TIME AS OF T.proc AS D + |ON T.id = D.id + |""".stripMargin + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + // Validate that only one cache is registered + val managedCaches = LookupCacheManager.getInstance().getManagedCaches + assertThat(managedCaches.size()).isEqualTo(1) + + // Validate 6 entries are cached + val cache = managedCaches.get(managedCaches.keySet().iterator().next()).getCache + assertThat(cache.size()).isEqualTo(6) + + // Validate contents of cached entries + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(1L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(11), jl(1L), BinaryStringData.fromString("Julian"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(2L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(22), jl(2L), BinaryStringData.fromString("Jark"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(3L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(33), jl(3L), BinaryStringData.fromString("Fabian"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(4L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(11), jl(4L), BinaryStringData.fromString("Hello world"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(5L)))) + .containsExactlyInAnyOrder( + GenericRowData.of(ji(11), jl(5L), BinaryStringData.fromString("Hello world"))) + assertThatIterable(cache.getIfPresent(GenericRowData.of(jl(6L)))) + .isEmpty() + } finally { + LookupCacheManager.getInstance().checkAllReleased() + LookupCacheManager.getInstance().clear() + LookupCacheManager.keepCacheOnRelease(false) + } + } + + def ji(i: Int): java.lang.Integer = { + new java.lang.Integer(i) + } + + def jl(l: Long): java.lang.Long = { + new java.lang.Long(l) + } } object LookupJoinITCase { - @Parameterized.Parameters(name = "LegacyTableSource={0}") + + val LEGACY_TABLE_SOURCE: JBoolean = JBoolean.TRUE; + val DYNAMIC_TABLE_SOURCE: JBoolean = JBoolean.FALSE; + val ENABLE_CACHE: JBoolean = JBoolean.TRUE; + val DISABLE_CACHE: JBoolean = JBoolean.FALSE; + + @Parameterized.Parameters(name = "LegacyTableSource={0}, EnableCache={1}") def parameters(): JCollection[Array[Object]] = { - Seq[Array[AnyRef]](Array(JBoolean.TRUE), Array(JBoolean.FALSE)) + Seq[Array[AnyRef]]( + Array(LEGACY_TABLE_SOURCE, DISABLE_CACHE), + Array(DYNAMIC_TABLE_SOURCE, ENABLE_CACHE), + Array(DYNAMIC_TABLE_SOURCE, DISABLE_CACHE)) } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/table/lookup/CachingAsyncLookupFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/table/lookup/CachingAsyncLookupFunction.java new file mode 100644 index 00000000000..c087ef5ef9f --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/table/lookup/CachingAsyncLookupFunction.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.table.lookup; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.metrics.Counter; +import org.apache.flink.metrics.ThreadSafeSimpleCounter; +import org.apache.flink.metrics.groups.CacheMetricGroup; +import org.apache.flink.runtime.metrics.groups.InternalCacheMetricGroup; +import org.apache.flink.table.connector.source.lookup.cache.LookupCache; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.AsyncLookupFunction; +import org.apache.flink.table.functions.FunctionContext; + +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; + +/** + * A wrapper function around user-provided async lookup function with a cache layer. + * + * <p>This function will check the cache on lookup request and return entries directly on cache hit, + * otherwise the function will invoke the actual lookup function, and store the entry into the cache + * after lookup for later use. + */ +public class CachingAsyncLookupFunction extends AsyncLookupFunction { + private static final long serialVersionUID = 1L; + + // Constants + public static final String LOOKUP_CACHE_METRIC_GROUP_NAME = "cache"; + private static final long UNINITIALIZED = -1; + + // The actual user-provided lookup function + private final AsyncLookupFunction delegate; + + private LookupCache cache; + private transient String cacheIdentifier; + + // Cache metrics + private transient CacheMetricGroup cacheMetricGroup; + private transient Counter loadCounter; + private transient Counter numLoadFailuresCounter; + private volatile long latestLoadTime = UNINITIALIZED; + + public CachingAsyncLookupFunction(LookupCache cache, AsyncLookupFunction delegate) { + this.cache = cache; + this.delegate = delegate; + } + + @Override + public void open(FunctionContext context) throws Exception { + // Get the shared cache from manager + cacheIdentifier = functionIdentifier(); + cache = LookupCacheManager.getInstance().registerCacheIfAbsent(cacheIdentifier, cache); + + // Register metrics + cacheMetricGroup = + new InternalCacheMetricGroup( + context.getMetricGroup(), LOOKUP_CACHE_METRIC_GROUP_NAME); + loadCounter = new ThreadSafeSimpleCounter(); + cacheMetricGroup.loadCounter(loadCounter); + numLoadFailuresCounter = new ThreadSafeSimpleCounter(); + cacheMetricGroup.numLoadFailuresCounter(numLoadFailuresCounter); + + cache.open(cacheMetricGroup); + delegate.open(context); + } + + @Override + public CompletableFuture<Collection<RowData>> asyncLookup(RowData keyRow) { + Collection<RowData> cachedValues = cache.getIfPresent(keyRow); + if (cachedValues != null) { + return CompletableFuture.completedFuture(cachedValues); + } else { + return delegate.asyncLookup(keyRow) + .whenComplete( + (lookupValues, throwable) -> { + if (throwable != null) { + // TODO: Should implement retry on failure logic as proposed in + // FLIP-234 + numLoadFailuresCounter.inc(); + throw new RuntimeException( + String.format("Failed to lookup key '%s'", keyRow), + throwable); + } + loadCounter.inc(); + updateLatestLoadTime(); + Collection<RowData> cachingValues = lookupValues; + if (lookupValues == null || lookupValues.isEmpty()) { + cachingValues = Collections.emptyList(); + } + cache.put(keyRow, cachingValues); + }); + } + } + + @Override + public void close() throws Exception { + delegate.close(); + if (cacheIdentifier != null) { + LookupCacheManager.getInstance().unregisterCache(cacheIdentifier); + } + } + + @VisibleForTesting + public LookupCache getCache() { + return cache; + } + + // --------------------------------- Helper functions ---------------------------- + private void updateLatestLoadTime() { + if (latestLoadTime == UNINITIALIZED) { + cacheMetricGroup.latestLoadTimeGauge(() -> latestLoadTime); + } + latestLoadTime = System.currentTimeMillis(); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/table/lookup/CachingLookupFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/table/lookup/CachingLookupFunction.java new file mode 100644 index 00000000000..d3bbf3b6d7c --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/table/lookup/CachingLookupFunction.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.table.lookup; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.metrics.Counter; +import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.metrics.groups.CacheMetricGroup; +import org.apache.flink.runtime.metrics.MetricNames; +import org.apache.flink.runtime.metrics.groups.InternalCacheMetricGroup; +import org.apache.flink.table.connector.source.lookup.cache.LookupCache; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.LookupFunction; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A wrapper function around user-provided lookup function with a cache layer. + * + * <p>This function will check the cache on lookup request and return entries directly on cache hit, + * otherwise the function will invoke the actual lookup function, and store the entry into the cache + * after lookup for later use. + */ +@Internal +public class CachingLookupFunction extends LookupFunction { + private static final long serialVersionUID = 1L; + + // Constants + public static final String LOOKUP_CACHE_METRIC_GROUP_NAME = "cache"; + private static final long UNINITIALIZED = -1; + + // The actual user-provided lookup function + private final LookupFunction delegate; + + private LookupCache cache; + private transient String cacheIdentifier; + + // Cache metrics + private transient CacheMetricGroup cacheMetricGroup; + private transient Counter loadCounter; + private transient Counter numLoadFailuresCounter; + private volatile long latestLoadTime = UNINITIALIZED; + + /** + * Create a {@link CachingLookupFunction}. + * + * <p>Please note that the cache may not be the final instance serving in this function. The + * actual cache instance will be retrieved from the {@link LookupCacheManager} during {@link + * #open}. + */ + public CachingLookupFunction(LookupCache cache, LookupFunction delegate) { + this.cache = cache; + this.delegate = delegate; + } + + /** + * Open the {@link CachingLookupFunction}. + * + * <p>In order to reduce the memory usage of the cache, {@link LookupCacheManager} is used to + * provide a shared cache instance across subtasks of this function. Here we use {@link + * #functionIdentifier()} as the id of the cache, which is generated by MD5 of serialized bytes + * of this function. As different subtasks of the function will generate the same MD5, this + * could promise that they will be served with the same cache instance. + * + * @see #functionIdentifier() + */ + @Override + public void open(FunctionContext context) throws Exception { + // Get the shared cache from manager + cacheIdentifier = functionIdentifier(); + cache = LookupCacheManager.getInstance().registerCacheIfAbsent(cacheIdentifier, cache); + + // Register metrics + cacheMetricGroup = + new InternalCacheMetricGroup( + context.getMetricGroup(), LOOKUP_CACHE_METRIC_GROUP_NAME); + loadCounter = new SimpleCounter(); + cacheMetricGroup.loadCounter(loadCounter); + numLoadFailuresCounter = new SimpleCounter(); + cacheMetricGroup.numLoadFailuresCounter(numLoadFailuresCounter); + + // Initialize cache and the delegating function + cache.open(cacheMetricGroup); + delegate.open(context); + } + + @Override + public Collection<RowData> lookup(RowData keyRow) throws IOException { + Collection<RowData> cachedValues = cache.getIfPresent(keyRow); + if (cachedValues != null) { + // Cache hit + return cachedValues; + } else { + // Cache miss + Collection<RowData> lookupValues = lookupByDelegate(keyRow); + // Here we use keyRow as the cache key directly, as keyRow always contains the copy of + // key fields from left table, no matter if object reuse is enabled. + if (lookupValues == null || lookupValues.isEmpty()) { + cache.put(keyRow, Collections.emptyList()); + } else { + cache.put(keyRow, lookupValues); + } + return lookupValues; + } + } + + @Override + public void close() throws Exception { + delegate.close(); + if (cacheIdentifier != null) { + LookupCacheManager.getInstance().unregisterCache(cacheIdentifier); + } + } + + @VisibleForTesting + public LookupCache getCache() { + return cache; + } + + // -------------------------------- Helper functions ------------------------------ + private Collection<RowData> lookupByDelegate(RowData keyRow) throws IOException { + try { + Collection<RowData> lookupValues = delegate.lookup(keyRow); + loadCounter.inc(); + updateLatestLoadTime(); + return lookupValues; + } catch (Exception e) { + // TODO: Should implement retry on failure logic as proposed in FLIP-234 + numLoadFailuresCounter.inc(); + throw new IOException(String.format("Failed to lookup with key '%s'", keyRow), e); + } + } + + private void updateLatestLoadTime() { + checkNotNull( + cacheMetricGroup, + "Could not register metric '%s' as cache metric group is not initialized", + MetricNames.LATEST_LOAD_TIME); + // Lazily register the metric + if (latestLoadTime == UNINITIALIZED) { + cacheMetricGroup.latestLoadTimeGauge(() -> latestLoadTime); + } + latestLoadTime = System.currentTimeMillis(); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/table/lookup/LookupCacheManager.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/table/lookup/LookupCacheManager.java new file mode 100644 index 00000000000..610033e1975 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/table/lookup/LookupCacheManager.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.table.lookup; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.table.connector.source.lookup.cache.LookupCache; +import org.apache.flink.util.RefCounted; + +import javax.annotation.concurrent.NotThreadSafe; + +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Managing shared caches across different subtasks. + * + * <p>In order to reduce the memory usage of cache, different subtasks of the same lookup join + * runner will share the same cache instance. Caches are managed by the identifier of the lookup + * table for which it is serving. + */ +@Internal +public class LookupCacheManager { + private static LookupCacheManager instance; + private static boolean keepCacheOnRelease = false; + private final Map<String, RefCountedCache> managedCaches = new HashMap<>(); + + /** Default constructor is not allowed to use. */ + private LookupCacheManager() {} + + /** Get the shared instance of {@link LookupCacheManager}. */ + public static synchronized LookupCacheManager getInstance() { + if (instance == null) { + instance = new LookupCacheManager(); + } + return instance; + } + + /** + * Register a cache instance with identifier to the manager. + * + * <p>If the cache with the given identifier is already registered in the manager, this method + * will return the registered one, otherwise this method will register the given cache into the + * manager then return. + * + * @param cacheIdentifier identifier of the cache + * @param cache instance of cache trying to register + * @return instance of the shared cache + */ + public synchronized LookupCache registerCacheIfAbsent( + String cacheIdentifier, LookupCache cache) { + checkNotNull(cache, "Could not register null cache in the manager"); + RefCountedCache refCountedCache = + managedCaches.computeIfAbsent( + cacheIdentifier, identifier -> new RefCountedCache(cache)); + refCountedCache.retain(); + return refCountedCache.cache; + } + + /** + * Release the cache with the given identifier from the manager. + * + * <p>The manager will track a reference count of managed caches, and will close the cache if + * the reference count reaches 0. + */ + public synchronized void unregisterCache(String cacheIdentifier) { + RefCountedCache refCountedCache = + checkNotNull( + managedCaches.get(cacheIdentifier), + "Cache identifier '%s' is not registered", + cacheIdentifier); + if (refCountedCache.release()) { + managedCaches.remove(cacheIdentifier); + } + } + + /** + * A wrapper class of {@link LookupCache} which also tracks the reference count of it. + * + * <p>This class is exposed as public for testing purpose and not thread safe. Concurrent + * accesses should be guarded by synchronized methods provided by {@link LookupCacheManager}. + */ + @NotThreadSafe + @VisibleForTesting + public static class RefCountedCache implements RefCounted { + private final LookupCache cache; + private int refCount; + + public RefCountedCache(LookupCache cache) { + this.cache = cache; + this.refCount = 0; + } + + @Override + public void retain() { + refCount++; + } + + @Override + public boolean release() { + checkState(refCount > 0, "Could not release a cache with refCount = 0"); + if (--refCount == 0 && !keepCacheOnRelease) { + closeCache(); + return true; + } + return false; + } + + public LookupCache getCache() { + return cache; + } + + private void closeCache() { + try { + cache.close(); + } catch (Exception e) { + throw new RuntimeException("Failed to close the cache", e); + } + } + } + + // ---------------------------- For testing purpose ------------------------------ + public static void keepCacheOnRelease(boolean toKeep) { + keepCacheOnRelease = toKeep; + } + + public void checkAllReleased() { + if (managedCaches.isEmpty()) { + return; + } + String leakedCaches = + managedCaches.entrySet().stream() + .filter(entry -> entry.getValue().refCount != 0) + .map( + entry -> + String.format( + "#Reference: %d with ID: %s", + entry.getValue().refCount, entry.getKey())) + .collect(Collectors.joining("\n")); + if (!leakedCaches.isEmpty()) { + throw new IllegalStateException( + "Cache leak detected. Unreleased caches: \n" + leakedCaches); + } + } + + public void clear() { + managedCaches.forEach((identifier, cache) -> cache.closeCache()); + managedCaches.clear(); + } + + public Map<String, RefCountedCache> getManagedCaches() { + return managedCaches; + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/functions/table/CachingAsyncLookupFunctionTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/functions/table/CachingAsyncLookupFunctionTest.java new file mode 100644 index 00000000000..85f3b858ec4 --- /dev/null +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/functions/table/CachingAsyncLookupFunctionTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.table; + +import org.apache.flink.streaming.util.MockStreamingRuntimeContext; +import org.apache.flink.table.connector.source.lookup.cache.DefaultLookupCache; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.AsyncLookupFunction; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.runtime.functions.table.lookup.CachingAsyncLookupFunction; +import org.apache.flink.util.concurrent.FutureUtils; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Unit test for {@link CachingAsyncLookupFunction}. */ +class CachingAsyncLookupFunctionTest { + + private static final RowData KEY_1 = GenericRowData.of(1); + private static final Collection<RowData> VALUE_1 = + Collections.singletonList(GenericRowData.of(1, "Alice", 18L)); + private static final RowData KEY_2 = GenericRowData.of(2); + private static final Collection<RowData> VALUE_2 = + Arrays.asList(GenericRowData.of(2, "Bob", 20L), GenericRowData.of(2, "Charlie", 22L)); + private static final RowData NON_EXIST_KEY = GenericRowData.of(3); + + @Test + void testCaching() throws Exception { + TestingAsyncLookupFunction delegate = new TestingAsyncLookupFunction(); + CachingAsyncLookupFunction function = createCachingFunction(delegate); + + // All cache miss + FutureUtils.completeAll( + Arrays.asList( + function.asyncLookup(KEY_1), + function.asyncLookup(KEY_2), + function.asyncLookup(NON_EXIST_KEY))) + .get(); + + // All cache hit + FutureUtils.completeAll( + Arrays.asList( + function.asyncLookup(KEY_1), + function.asyncLookup(KEY_2), + function.asyncLookup(NON_EXIST_KEY))) + .get(); + + assertThat(delegate.getLookupCount()).hasValue(3); + assertThat(function.getCache().getIfPresent(KEY_1)) + .containsExactlyInAnyOrderElementsOf(VALUE_1); + assertThat(function.getCache().getIfPresent(KEY_2)) + .containsExactlyInAnyOrderElementsOf(VALUE_2); + assertThat(function.getCache().getIfPresent(NON_EXIST_KEY)).isEmpty(); + } + + private CachingAsyncLookupFunction createCachingFunction(AsyncLookupFunction delegate) + throws Exception { + CachingAsyncLookupFunction function = + new CachingAsyncLookupFunction( + DefaultLookupCache.newBuilder().maximumSize(Long.MAX_VALUE).build(), + delegate); + function.open(new FunctionContext(new MockStreamingRuntimeContext(false, 1, 0))); + return function; + } + + private static final class TestingAsyncLookupFunction extends AsyncLookupFunction { + private final transient ConcurrentMap<RowData, Collection<RowData>> data = + new ConcurrentHashMap<>(); + private transient AtomicInteger lookupCount; + private transient ExecutorService executor; + + @Override + public void open(FunctionContext context) throws Exception { + data.put(KEY_1, VALUE_1); + data.put(KEY_2, VALUE_2); + lookupCount = new AtomicInteger(0); + executor = Executors.newFixedThreadPool(3); + } + + @Override + public CompletableFuture<Collection<RowData>> asyncLookup(RowData keyRow) { + return CompletableFuture.supplyAsync( + () -> { + try { + Thread.sleep(ThreadLocalRandom.current().nextInt(0, 10)); + Collection<RowData> values = data.get(keyRow); + lookupCount.incrementAndGet(); + return values; + } catch (Exception e) { + throw new RuntimeException("Failed to lookup value", e); + } + }, + executor); + } + + public AtomicInteger getLookupCount() { + return lookupCount; + } + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/functions/table/CachingLookupFunctionTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/functions/table/CachingLookupFunctionTest.java new file mode 100644 index 00000000000..b3755bf379f --- /dev/null +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/functions/table/CachingLookupFunctionTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.table; + +import org.apache.flink.streaming.util.MockStreamingRuntimeContext; +import org.apache.flink.table.connector.source.lookup.cache.DefaultLookupCache; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.LookupFunction; +import org.apache.flink.table.runtime.functions.table.lookup.CachingLookupFunction; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Unit test for {@link CachingLookupFunction}. */ +class CachingLookupFunctionTest { + private static final RowData KEY_1 = GenericRowData.of(1); + private static final Collection<RowData> VALUE_1 = + Collections.singletonList(GenericRowData.of(1, "Alice", 18L)); + private static final RowData KEY_2 = GenericRowData.of(2); + private static final Collection<RowData> VALUE_2 = + Arrays.asList(GenericRowData.of(2, "Bob", 20L), GenericRowData.of(2, "Charlie", 22L)); + private static final RowData NON_EXIST_KEY = GenericRowData.of(3); + + @Test + void testCaching() throws Exception { + TestingLookupFunction delegate = new TestingLookupFunction(); + CachingLookupFunction function = createCachingFunction(delegate); + + // All cache miss + function.lookup(KEY_1); + function.lookup(KEY_2); + function.lookup(NON_EXIST_KEY); + + // All cache hit + function.lookup(KEY_1); + function.lookup(KEY_2); + function.lookup(NON_EXIST_KEY); + + assertThat(delegate.getLookupCount()).isEqualTo(3); + assertThat(function.getCache().getIfPresent(KEY_1)) + .containsExactlyInAnyOrderElementsOf(VALUE_1); + assertThat(function.getCache().getIfPresent(KEY_2)) + .containsExactlyInAnyOrderElementsOf(VALUE_2); + assertThat(function.getCache().getIfPresent(NON_EXIST_KEY)).isEmpty(); + } + + private CachingLookupFunction createCachingFunction(LookupFunction delegate) throws Exception { + CachingLookupFunction function = + new CachingLookupFunction( + DefaultLookupCache.newBuilder().maximumSize(Long.MAX_VALUE).build(), + delegate); + function.open(new FunctionContext(new MockStreamingRuntimeContext(false, 1, 0))); + return function; + } + + private static final class TestingLookupFunction extends LookupFunction { + private static final long serialVersionUID = 1L; + + private final transient Map<RowData, Collection<RowData>> data = new HashMap<>(); + private int lookupCount = 0; + + @Override + public void open(FunctionContext context) { + data.put(KEY_1, VALUE_1); + data.put(KEY_2, VALUE_2); + } + + @Override + public Collection<RowData> lookup(RowData keyRow) { + lookupCount++; + return data.get(keyRow); + } + + public int getLookupCount() { + return lookupCount; + } + } +}