This is an automated email from the ASF dual-hosted git repository.

wanglijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 59850a91f96 [FLINK-32491][table-runtime] Introduce 
RuntimeFilterCodeGenerator
59850a91f96 is described below

commit 59850a91f96ddb5c58d8554b76e46bc6245dda7b
Author: Lijie Wang <[email protected]>
AuthorDate: Tue Jul 18 00:18:18 2023 +0800

    [FLINK-32491][table-runtime] Introduce RuntimeFilterCodeGenerator
    
    This closes #23007
---
 .../runtimefilter/RuntimeFilterCodeGenerator.scala | 115 +++++++++++++++
 .../RuntimeFilterCodeGeneratorTest.java            | 154 +++++++++++++++++++++
 .../LocalRuntimeFilterBuilderOperator.java         |   2 +-
 .../runtimefilter/util/RuntimeFilterUtils.java     |  10 +-
 .../GlobalRuntimeFilterBuilderOperatorTest.java    |   4 +-
 .../LocalRuntimeFilterBuilderOperatorTest.java     |   6 +-
 6 files changed, 283 insertions(+), 8 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/runtimefilter/RuntimeFilterCodeGenerator.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/runtimefilter/RuntimeFilterCodeGenerator.scala
new file mode 100644
index 00000000000..bcf6d77ee60
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/runtimefilter/RuntimeFilterCodeGenerator.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.codegen.runtimefilter
+
+import org.apache.flink.runtime.operators.util.BloomFilter
+import org.apache.flink.table.data.RowData
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, 
OperatorCodeGenerator, ProjectionCodeGenerator}
+import org.apache.flink.table.planner.codegen.CodeGenUtils.{className, 
newName, DEFAULT_INPUT1_TERM, DEFAULT_INPUT2_TERM, ROW_DATA}
+import 
org.apache.flink.table.planner.codegen.OperatorCodeGenerator.{generateCollect, 
INPUT_SELECTION}
+import org.apache.flink.table.planner.typeutils.RowTypeUtils
+import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
+import org.apache.flink.table.types.logical.RowType
+import org.apache.flink.util.Preconditions
+
+/** Operator code generator for runtime filter operator. */
+object RuntimeFilterCodeGenerator {
+  def gen(
+      ctx: CodeGeneratorContext,
+      buildType: RowType,
+      probeType: RowType,
+      probeIndices: Array[Int]): CodeGenOperatorFactory[RowData] = {
+    val probeGenProj = ProjectionCodeGenerator.generateProjection(
+      ctx,
+      "RuntimeFilterProjection",
+      probeType,
+      RowTypeUtils.projectRowType(probeType, probeIndices),
+      probeIndices)
+    ctx.addReusableInnerClass(probeGenProj.getClassName, probeGenProj.getCode)
+
+    val probeProjection = newName("probeToBinaryRow")
+    ctx.addReusableMember(s"private transient ${probeGenProj.getClassName} 
$probeProjection;")
+    val probeProjRefs = ctx.addReusableObject(probeGenProj.getReferences, 
"probeProjRefs")
+    ctx.addReusableOpenStatement(
+      s"$probeProjection = new ${probeGenProj.getClassName}($probeProjRefs);")
+
+    val buildComplete = newName("buildComplete")
+    ctx.addReusableMember(s"private transient boolean $buildComplete;")
+    ctx.addReusableOpenStatement(s"$buildComplete = false;")
+
+    val filter = newName("filter")
+    val filterClass = className[BloomFilter]
+    ctx.addReusableMember(s"private transient $filterClass $filter;")
+
+    val processElement1Code =
+      s"""
+         |${className[Preconditions]}.checkState(!$buildComplete, "Should not 
build completed.");
+         |
+         |if ($filter == null && !$DEFAULT_INPUT1_TERM.isNullAt(1)) {
+         |    $filter = 
$filterClass.fromBytes($DEFAULT_INPUT1_TERM.getBinary(1));
+         |}
+         |""".stripMargin
+
+    val processElement2Code =
+      s"""
+         |${className[Preconditions]}.checkState($buildComplete, "Should build 
completed.");
+         |
+         |if ($filter != null) {
+         |    final int hashCode = 
$probeProjection.apply($DEFAULT_INPUT2_TERM).hashCode();
+         |    if ($filter.testHash(hashCode)) {
+         |        ${generateCollect(s"$DEFAULT_INPUT2_TERM")}
+         |    }
+         |} else {
+         |    ${generateCollect(s"$DEFAULT_INPUT2_TERM")}
+         |}
+         |""".stripMargin
+
+    val nextSelectionCode =
+      s"return $buildComplete ? $INPUT_SELECTION.SECOND : 
$INPUT_SELECTION.FIRST;"
+
+    val endInputCode1 =
+      s"""
+         |${className[Preconditions]}.checkState(!$buildComplete, "Should not 
build completed.");
+         |
+         |LOG.info("RuntimeFilter build completed.");
+         |$buildComplete = true;
+         |""".stripMargin
+
+    val endInputCode2 =
+      s"""
+         |${className[Preconditions]}.checkState($buildComplete, "Should build 
completed.");
+         |
+         |LOG.info("Finish RuntimeFilter probe phase.");
+         |""".stripMargin
+
+    new CodeGenOperatorFactory[RowData](
+      OperatorCodeGenerator.generateTwoInputStreamOperator(
+        ctx,
+        "RuntimeFilterOperator",
+        processElement1Code,
+        processElement2Code,
+        buildType,
+        probeType,
+        DEFAULT_INPUT1_TERM,
+        DEFAULT_INPUT2_TERM,
+        nextSelectionCode = Some(nextSelectionCode),
+        endInputCode1 = Some(endInputCode1),
+        endInputCode2 = Some(endInputCode2)
+      ))
+  }
+}
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/runtimefilter/RuntimeFilterCodeGeneratorTest.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/runtimefilter/RuntimeFilterCodeGeneratorTest.java
new file mode 100644
index 00000000000..29a26388d31
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/runtimefilter/RuntimeFilterCodeGeneratorTest.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.codegen.runtimefilter;
+
+import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
+import 
org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory;
+import 
org.apache.flink.table.runtime.operators.runtimefilter.util.RuntimeFilterUtils;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarBinaryType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static 
org.apache.flink.table.runtime.operators.runtimefilter.LocalRuntimeFilterBuilderOperatorTest.createLocalRuntimeFilterBuilderOperatorHarnessAndProcessElements;
+import static 
org.apache.flink.table.runtime.operators.runtimefilter.LocalRuntimeFilterBuilderOperatorTest.createRowDataRecord;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link RuntimeFilterCodeGenerator}. */
+class RuntimeFilterCodeGeneratorTest {
+    private StreamTaskMailboxTestHarness<RowData> testHarness;
+
+    @BeforeEach
+    void setup() throws Exception {
+        final RowType leftType = RowType.of(new IntType(), new 
VarBinaryType());
+        final RowType rightType = RowType.of(new VarCharType(), new IntType());
+        final CodeGeneratorContext ctx =
+                new CodeGeneratorContext(
+                        TableConfig.getDefault(), 
Thread.currentThread().getContextClassLoader());
+        final CodeGenOperatorFactory<RowData> operatorFactory =
+                RuntimeFilterCodeGenerator.gen(ctx, leftType, rightType, new 
int[] {0});
+
+        testHarness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                TwoInputStreamTask::new, 
InternalTypeInfo.of(rightType))
+                        .setupOutputForSingletonOperatorChain(operatorFactory)
+                        .addInput(InternalTypeInfo.of(leftType))
+                        .addInput(InternalTypeInfo.of(rightType))
+                        .build();
+    }
+
+    @AfterEach
+    void cleanup() throws Exception {
+        if (testHarness != null) {
+            testHarness.close();
+        }
+    }
+
+    @Test
+    void testNormalFilter() throws Exception {
+        // finish build phase
+        finishBuildPhase(createNormalInput());
+
+        // finish probe phase
+        testHarness.processElement(createRowDataRecord("var1", 111), 1);
+        testHarness.processElement(createRowDataRecord("var3", 333), 1);
+        testHarness.processElement(createRowDataRecord("var5", 555), 1);
+        testHarness.processElement(createRowDataRecord("var6", 666), 1);
+        testHarness.processElement(createRowDataRecord("var8", 888), 1);
+        testHarness.processElement(createRowDataRecord("var9", 999), 1);
+        testHarness.processEvent(new EndOfData(StopMode.DRAIN), 1);
+
+        assertThat(getOutputRowData())
+                .containsExactly(
+                        GenericRowData.of("var1", 111),
+                        GenericRowData.of("var3", 333),
+                        GenericRowData.of("var5", 555));
+    }
+
+    @Test
+    void testOverMaxRowCountLimitFilter() throws Exception {
+        // finish build phase
+        finishBuildPhase(createOverMaxRowCountLimitInput());
+
+        // finish probe phase
+        testHarness.processElement(createRowDataRecord("var1", 111), 1);
+        testHarness.processElement(createRowDataRecord("var3", 333), 1);
+        testHarness.processElement(createRowDataRecord("var5", 555), 1);
+        testHarness.processElement(createRowDataRecord("var6", 666), 1);
+        testHarness.processElement(createRowDataRecord("var8", 888), 1);
+        testHarness.processElement(createRowDataRecord("var9", 999), 1);
+        testHarness.processEvent(new EndOfData(StopMode.DRAIN), 1);
+
+        assertThat(getOutputRowData())
+                .containsExactly(
+                        GenericRowData.of("var1", 111),
+                        GenericRowData.of("var3", 333),
+                        GenericRowData.of("var5", 555),
+                        GenericRowData.of("var6", 666),
+                        GenericRowData.of("var8", 888),
+                        GenericRowData.of("var9", 999));
+    }
+
+    private void finishBuildPhase(StreamRecord<RowData> leftInput) throws 
Exception {
+        testHarness.processElement(leftInput, 0);
+        testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0);
+    }
+
+    private List<GenericRowData> getOutputRowData() {
+        return testHarness.getOutput().stream()
+                .map(record -> ((StreamRecord<RowData>) record).getValue())
+                .map(
+                        rowData -> {
+                            assertThat(rowData.getArity()).isEqualTo(2);
+                            return GenericRowData.of(
+                                    rowData.getString(0).toString(), 
rowData.getInt(1));
+                        })
+                .collect(Collectors.toList());
+    }
+
+    private static StreamRecord<RowData> createNormalInput() throws Exception {
+        StreamTaskMailboxTestHarness<RowData> localRuntimeFilterBuilder =
+                
createLocalRuntimeFilterBuilderOperatorHarnessAndProcessElements(5, 10);
+        StreamRecord<RowData> normalFilter =
+                (StreamRecord<RowData>) 
localRuntimeFilterBuilder.getOutput().poll();
+        localRuntimeFilterBuilder.close();
+        return normalFilter;
+    }
+
+    private static StreamRecord<RowData> createOverMaxRowCountLimitInput() {
+        return new 
StreamRecord<>(GenericRowData.of(RuntimeFilterUtils.OVER_MAX_ROW_COUNT, null));
+    }
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/runtimefilter/LocalRuntimeFilterBuilderOperator.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/runtimefilter/LocalRuntimeFilterBuilderOperator.java
index 6cffe5cd984..db78166da36 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/runtimefilter/LocalRuntimeFilterBuilderOperator.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/runtimefilter/LocalRuntimeFilterBuilderOperator.java
@@ -68,7 +68,7 @@ public class LocalRuntimeFilterBuilderOperator extends 
TableStreamOperator<RowDa
         super.open();
 
         this.buildSideProjection = 
buildProjectionCode.newInstance(getUserCodeClassloader());
-        this.filter = 
RuntimeFilterUtils.createOnHeapBloomFilter(estimatedRowCount, 0.05);
+        this.filter = 
RuntimeFilterUtils.createOnHeapBloomFilter(estimatedRowCount);
         this.collector = new StreamRecordCollector<>(output);
         this.actualRowCount = 0;
     }
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/runtimefilter/util/RuntimeFilterUtils.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/runtimefilter/util/RuntimeFilterUtils.java
index 20aab8be28b..b86d998d911 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/runtimefilter/util/RuntimeFilterUtils.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/runtimefilter/util/RuntimeFilterUtils.java
@@ -30,8 +30,14 @@ public class RuntimeFilterUtils {
 
     public static final int OVER_MAX_ROW_COUNT = -1;
 
-    public static BloomFilter createOnHeapBloomFilter(int numExpectedEntries, 
double fpp) {
-        int byteSize = (int) 
Math.ceil(BloomFilter.optimalNumOfBits(numExpectedEntries, fpp) / 8D);
+    private static final double EXPECTED_FPP = 0.05;
+
+    public static BloomFilter createOnHeapBloomFilter(int numExpectedEntries) {
+        int byteSize =
+                (int)
+                        Math.ceil(
+                                
BloomFilter.optimalNumOfBits(numExpectedEntries, EXPECTED_FPP)
+                                        / 8D);
         final BloomFilter filter = new BloomFilter(numExpectedEntries, 
byteSize);
         
filter.setBitsLocation(MemorySegmentFactory.allocateUnpooledSegment(byteSize), 
0);
         return filter;
diff --git 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/runtimefilter/GlobalRuntimeFilterBuilderOperatorTest.java
 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/runtimefilter/GlobalRuntimeFilterBuilderOperatorTest.java
index 002de3b7d9d..d0bc0d08a64 100644
--- 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/runtimefilter/GlobalRuntimeFilterBuilderOperatorTest.java
+++ 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/runtimefilter/GlobalRuntimeFilterBuilderOperatorTest.java
@@ -140,7 +140,7 @@ class GlobalRuntimeFilterBuilderOperatorTest {
     }
 
     private static BloomFilter createBloomFilter1() {
-        final BloomFilter bloomFilter1 = 
RuntimeFilterUtils.createOnHeapBloomFilter(10, 0.05);
+        final BloomFilter bloomFilter1 = 
RuntimeFilterUtils.createOnHeapBloomFilter(10);
         bloomFilter1.addHash("var1".hashCode());
         bloomFilter1.addHash("var2".hashCode());
         bloomFilter1.addHash("var3".hashCode());
@@ -150,7 +150,7 @@ class GlobalRuntimeFilterBuilderOperatorTest {
     }
 
     private static BloomFilter createBloomFilter2() {
-        final BloomFilter bloomFilter2 = 
RuntimeFilterUtils.createOnHeapBloomFilter(10, 0.05);
+        final BloomFilter bloomFilter2 = 
RuntimeFilterUtils.createOnHeapBloomFilter(10);
         bloomFilter2.addHash("var6".hashCode());
         bloomFilter2.addHash("var7".hashCode());
         bloomFilter2.addHash("var8".hashCode());
diff --git 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/runtimefilter/LocalRuntimeFilterBuilderOperatorTest.java
 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/runtimefilter/LocalRuntimeFilterBuilderOperatorTest.java
index a855479660a..838694482fd 100644
--- 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/runtimefilter/LocalRuntimeFilterBuilderOperatorTest.java
+++ 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/runtimefilter/LocalRuntimeFilterBuilderOperatorTest.java
@@ -47,7 +47,7 @@ import static 
org.apache.flink.table.runtime.operators.runtimefilter.util.Runtim
 import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
 
 /** Test for {@link LocalRuntimeFilterBuilderOperator}. */
-class LocalRuntimeFilterBuilderOperatorTest implements Serializable {
+public class LocalRuntimeFilterBuilderOperatorTest implements Serializable {
 
     @Test
     void testNormalOutput() throws Exception {
@@ -106,11 +106,11 @@ class LocalRuntimeFilterBuilderOperatorTest implements 
Serializable {
                 
projection.apply(GenericRowData.of(StringData.fromString(string))).hashCode());
     }
 
-    private static StreamRecord<RowData> createRowDataRecord(String string, 
int integer) {
+    public static StreamRecord<RowData> createRowDataRecord(String string, int 
integer) {
         return new 
StreamRecord<>(GenericRowData.of(StringData.fromString(string), integer));
     }
 
-    private static StreamTaskMailboxTestHarness<RowData>
+    public static StreamTaskMailboxTestHarness<RowData>
             createLocalRuntimeFilterBuilderOperatorHarnessAndProcessElements(
                     int estimatedRowCount, int maxRowCount) throws Exception {
         final GeneratedProjection buildProjectionCode =

Reply via email to