This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new d8417565d6f [FLINK-28526][python] Fix Python UDF to support time
indicator inputs
d8417565d6f is described below
commit d8417565d6fb7f907f54eeabb8a53ebf790ffad8
Author: Dian Fu <[email protected]>
AuthorDate: Mon Jan 16 14:05:59 2023 +0800
[FLINK-28526][python] Fix Python UDF to support time indicator inputs
This closes #21686.
---
flink-python/pyflink/table/tests/test_udf.py | 43 ++++++++++++++++++++++
.../plan/nodes/exec/utils/CommonPythonUtil.java | 33 ++++++++++++-----
2 files changed, 66 insertions(+), 10 deletions(-)
diff --git a/flink-python/pyflink/table/tests/test_udf.py
b/flink-python/pyflink/table/tests/test_udf.py
index 851c43fc0b2..d974c2402bf 100644
--- a/flink-python/pyflink/table/tests/test_udf.py
+++ b/flink-python/pyflink/table/tests/test_udf.py
@@ -24,6 +24,7 @@ import uuid
import pytest
import pytz
+from pyflink.common import Row
from pyflink.table import DataTypes, expressions as expr
from pyflink.table.expressions import call
from pyflink.table.udf import ScalarFunction, udf, FunctionContext
@@ -860,6 +861,48 @@ class
PyFlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
lines.sort()
self.assertEqual(lines, ['1,2', '2,3', '3,4'])
+ def test_udf_with_rowtime_arguments(self):
+ from pyflink.common import WatermarkStrategy
+ from pyflink.common.typeinfo import Types
+ from pyflink.common.watermark_strategy import TimestampAssigner
+ from pyflink.table import Schema
+
+ class MyTimestampAssigner(TimestampAssigner):
+
+ def extract_timestamp(self, value, record_timestamp) -> int:
+ return int(value[0])
+
+ ds = self.env.from_collection(
+ [(1, 42, "a"), (2, 5, "a"), (3, 1000, "c"), (100, 1000, "c")],
+ Types.ROW_NAMED(["a", "b", "c"], [Types.LONG(), Types.INT(),
Types.STRING()]))
+
+ ds = ds.assign_timestamps_and_watermarks(
+ WatermarkStrategy.for_monotonous_timestamps()
+ .with_timestamp_assigner(MyTimestampAssigner()))
+
+ table = self.t_env.from_data_stream(
+ ds,
+ Schema.new_builder()
+ .column_by_metadata("rowtime", "TIMESTAMP_LTZ(3)")
+ .watermark("rowtime", "SOURCE_WATERMARK()")
+ .build())
+
+ @udf(result_type=DataTypes.ROW([DataTypes.FIELD('f1',
DataTypes.INT())]))
+ def inc(input_row):
+ return Row(input_row.b)
+
+ sink_table = generate_random_table_name()
+ sink_table_ddl = f"""
+ CREATE TABLE {sink_table}(
+ a INT
+ ) WITH ('connector'='test-sink')
+ """
+ self.t_env.execute_sql(sink_table_ddl)
+ table.map(inc).execute_insert(sink_table).wait()
+
+ actual = source_sink_utils.results()
+ self.assert_equals(actual, ['+I[42]', '+I[5]', '+I[1000]', '+I[1000]'])
+
class PyFlinkBatchUserDefinedFunctionTests(UserDefinedFunctionTests,
PyFlinkBatchTableTestCase):
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
index 201407b718a..ff4ed47dc3b 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
@@ -47,6 +47,7 @@ import
org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.functions.utils.AggSqlFunction;
import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction;
import org.apache.flink.table.planner.functions.utils.TableSqlFunction;
+import org.apache.flink.table.planner.plan.schema.TimeIndicatorRelDataType;
import org.apache.flink.table.planner.plan.utils.AggregateInfo;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment;
@@ -70,10 +71,12 @@ import org.apache.flink.table.types.logical.StructuredType;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.type.SqlTypeName;
import java.lang.reflect.Field;
@@ -438,22 +441,32 @@ public class CommonPythonUtil {
for (RexNode operand : pythonRexCall.getOperands()) {
if (operand instanceof RexCall) {
RexCall childPythonRexCall = (RexCall) operand;
- PythonFunctionInfo argPythonInfo =
- createPythonFunctionInfo(childPythonRexCall,
inputNodes, classLoader);
- inputs.add(argPythonInfo);
+ if (childPythonRexCall.getOperator() instanceof SqlCastFunction
+ && childPythonRexCall.getOperands().get(0) instanceof
RexInputRef
+ && childPythonRexCall.getOperands().get(0).getType()
+ instanceof TimeIndicatorRelDataType) {
+ operand = childPythonRexCall.getOperands().get(0);
+ } else {
+ PythonFunctionInfo argPythonInfo =
+ createPythonFunctionInfo(childPythonRexCall,
inputNodes, classLoader);
+ inputs.add(argPythonInfo);
+ continue;
+ }
} else if (operand instanceof RexLiteral) {
RexLiteral literal = (RexLiteral) operand;
inputs.add(
convertLiteralToPython(
literal, literal.getType().getSqlTypeName(),
classLoader));
+ continue;
+ }
+
+ assert operand instanceof RexInputRef;
+ if (inputNodes.containsKey(operand)) {
+ inputs.add(inputNodes.get(operand));
} else {
- if (inputNodes.containsKey(operand)) {
- inputs.add(inputNodes.get(operand));
- } else {
- Integer inputOffset = inputNodes.size();
- inputs.add(inputOffset);
- inputNodes.put(operand, inputOffset);
- }
+ Integer inputOffset = inputNodes.size();
+ inputs.add(inputOffset);
+ inputNodes.put(operand, inputOffset);
}
}
return new PythonFunctionInfo((PythonFunction) functionDefinition,
inputs.toArray());