This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 24031e5 [FLINK-22297][python] Perform early validation for the result
of Pandas UDF
24031e5 is described below
commit 24031e55e4cf35a5818db2e927e65b290a9b2aed
Author: Dian Fu <[email protected]>
AuthorDate: Mon Apr 19 09:53:53 2021 +0800
[FLINK-22297][python] Perform early validation for the result of Pandas UDF
This closes #15681.
---
.../pyflink/fn_execution/flink_fn_execution_pb2.py | 163 +++++++++++----------
.../pyflink/fn_execution/operation_utils.py | 18 ++-
.../pyflink/proto/flink-fn-execution.proto | 3 +
.../pyflink/table/tests/test_pandas_udf.py | 24 +++
.../streaming/api/utils/PythonOperatorUtils.java | 4 +
5 files changed, 133 insertions(+), 79 deletions(-)
diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
index 89205d3..39830ba 100644
--- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
+++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
@@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
name='flink-fn-execution.proto',
package='org.apache.flink.fn_execution.v1',
syntax='proto3',
- serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12
org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01
\x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02
\x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03
\x01(\x0cH\x00\x42\x07\n\x05input\"\x91\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01
\x01(\x0c\x12\x37\n\x06inputs\x18\x02
\x03(\x0b\x32\'.org.apache.flink.fn_execution.v1.Input\x12\x14 [...]
+ serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12
org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01
\x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02
\x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03
\x01(\x0cH\x00\x42\x07\n\x05input\"\xa8\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01
\x01(\x0c\x12\x37\n\x06inputs\x18\x02
\x03(\x0b\x32\'.org.apache.flink.fn_execution.v1.Input\x12\x14 [...]
)
@@ -82,8 +82,8 @@ _OVERWINDOW_WINDOWTYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=670,
- serialized_end=878,
+ serialized_start=693,
+ serialized_end=901,
)
_sym_db.RegisterEnumDescriptor(_OVERWINDOW_WINDOWTYPE)
@@ -132,8 +132,8 @@ _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE =
_descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=1612,
- serialized_end=1768,
+ serialized_start=1635,
+ serialized_end=1791,
)
_sym_db.RegisterEnumDescriptor(_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE)
@@ -158,8 +158,8 @@ _GROUPWINDOW_WINDOWTYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=2917,
- serialized_end=3008,
+ serialized_start=2940,
+ serialized_end=3031,
)
_sym_db.RegisterEnumDescriptor(_GROUPWINDOW_WINDOWTYPE)
@@ -188,8 +188,8 @@ _GROUPWINDOW_WINDOWPROPERTY = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=3010,
- serialized_end=3109,
+ serialized_start=3033,
+ serialized_end=3132,
)
_sym_db.RegisterEnumDescriptor(_GROUPWINDOW_WINDOWPROPERTY)
@@ -286,8 +286,8 @@ _SCHEMA_TYPENAME = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=5363,
- serialized_end=5652,
+ serialized_start=5386,
+ serialized_end=5675,
)
_sym_db.RegisterEnumDescriptor(_SCHEMA_TYPENAME)
@@ -316,8 +316,8 @@ _CODERPARAM_DATATYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=5946,
- serialized_end=6002,
+ serialized_start=5969,
+ serialized_end=6025,
)
_sym_db.RegisterEnumDescriptor(_CODERPARAM_DATATYPE)
@@ -342,8 +342,8 @@ _CODERPARAM_OUTPUTMODE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=6004,
- serialized_end=6065,
+ serialized_start=6027,
+ serialized_end=6088,
)
_sym_db.RegisterEnumDescriptor(_CODERPARAM_OUTPUTMODE)
@@ -440,8 +440,8 @@ _TYPEINFO_TYPENAME = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=6903,
- serialized_end=7180,
+ serialized_start=6926,
+ serialized_end=7203,
)
_sym_db.RegisterEnumDescriptor(_TYPEINFO_TYPENAME)
@@ -529,6 +529,13 @@ _USERDEFINEDFUNCTION = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='is_pandas_udf',
full_name='org.apache.flink.fn_execution.v1.UserDefinedFunction.is_pandas_udf',
index=4,
+ number=5, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
],
extensions=[
],
@@ -542,7 +549,7 @@ _USERDEFINEDFUNCTION = _descriptor.Descriptor(
oneofs=[
],
serialized_start=200,
- serialized_end=345,
+ serialized_end=368,
)
@@ -586,8 +593,8 @@ _USERDEFINEDFUNCTIONS = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=348,
- serialized_end=526,
+ serialized_start=371,
+ serialized_end=549,
)
@@ -632,8 +639,8 @@ _OVERWINDOW = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=529,
- serialized_end=878,
+ serialized_start=552,
+ serialized_end=901,
)
@@ -670,8 +677,8 @@ _USERDEFINEDDATASTREAMFUNCTION_JOBPARAMETER =
_descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1228,
- serialized_end=1270,
+ serialized_start=1251,
+ serialized_end=1293,
)
_USERDEFINEDDATASTREAMFUNCTION_RUNTIMECONTEXT = _descriptor.Descriptor(
@@ -749,8 +756,8 @@ _USERDEFINEDDATASTREAMFUNCTION_RUNTIMECONTEXT =
_descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1273,
- serialized_end=1609,
+ serialized_start=1296,
+ serialized_end=1632,
)
_USERDEFINEDDATASTREAMFUNCTION = _descriptor.Descriptor(
@@ -808,8 +815,8 @@ _USERDEFINEDDATASTREAMFUNCTION = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=881,
- serialized_end=1768,
+ serialized_start=904,
+ serialized_end=1791,
)
@@ -839,8 +846,8 @@ _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC_LISTVIEW =
_descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=2299,
- serialized_end=2383,
+ serialized_start=2322,
+ serialized_end=2406,
)
_USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC_MAPVIEW = _descriptor.Descriptor(
@@ -876,8 +883,8 @@ _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC_MAPVIEW =
_descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=2386,
- serialized_end=2537,
+ serialized_start=2409,
+ serialized_end=2560,
)
_USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC = _descriptor.Descriptor(
@@ -930,8 +937,8 @@ _USERDEFINEDAGGREGATEFUNCTION_DATAVIEWSPEC =
_descriptor.Descriptor(
name='data_view',
full_name='org.apache.flink.fn_execution.v1.UserDefinedAggregateFunction.DataViewSpec.data_view',
index=0, containing_type=None, fields=[]),
],
- serialized_start=2036,
- serialized_end=2550,
+ serialized_start=2059,
+ serialized_end=2573,
)
_USERDEFINEDAGGREGATEFUNCTION = _descriptor.Descriptor(
@@ -995,8 +1002,8 @@ _USERDEFINEDAGGREGATEFUNCTION = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1771,
- serialized_end=2550,
+ serialized_start=1794,
+ serialized_end=2573,
)
@@ -1091,8 +1098,8 @@ _GROUPWINDOW = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=2553,
- serialized_end=3109,
+ serialized_start=2576,
+ serialized_end=3132,
)
@@ -1199,8 +1206,8 @@ _USERDEFINEDAGGREGATEFUNCTIONS = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3112,
- serialized_end=3621,
+ serialized_start=3135,
+ serialized_end=3644,
)
@@ -1237,8 +1244,8 @@ _SCHEMA_MAPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3699,
- serialized_end=3850,
+ serialized_start=3722,
+ serialized_end=3873,
)
_SCHEMA_TIMEINFO = _descriptor.Descriptor(
@@ -1267,8 +1274,8 @@ _SCHEMA_TIMEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3852,
- serialized_end=3881,
+ serialized_start=3875,
+ serialized_end=3904,
)
_SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
@@ -1297,8 +1304,8 @@ _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3883,
- serialized_end=3917,
+ serialized_start=3906,
+ serialized_end=3940,
)
_SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -1327,8 +1334,8 @@ _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3919,
- serialized_end=3963,
+ serialized_start=3942,
+ serialized_end=3986,
)
_SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -1357,8 +1364,8 @@ _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3965,
- serialized_end=4004,
+ serialized_start=3988,
+ serialized_end=4027,
)
_SCHEMA_DECIMALINFO = _descriptor.Descriptor(
@@ -1394,8 +1401,8 @@ _SCHEMA_DECIMALINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=4006,
- serialized_end=4053,
+ serialized_start=4029,
+ serialized_end=4076,
)
_SCHEMA_BINARYINFO = _descriptor.Descriptor(
@@ -1424,8 +1431,8 @@ _SCHEMA_BINARYINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=4055,
- serialized_end=4083,
+ serialized_start=4078,
+ serialized_end=4106,
)
_SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
@@ -1454,8 +1461,8 @@ _SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=4085,
- serialized_end=4116,
+ serialized_start=4108,
+ serialized_end=4139,
)
_SCHEMA_CHARINFO = _descriptor.Descriptor(
@@ -1484,8 +1491,8 @@ _SCHEMA_CHARINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=4118,
- serialized_end=4144,
+ serialized_start=4141,
+ serialized_end=4167,
)
_SCHEMA_VARCHARINFO = _descriptor.Descriptor(
@@ -1514,8 +1521,8 @@ _SCHEMA_VARCHARINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=4146,
- serialized_end=4175,
+ serialized_start=4169,
+ serialized_end=4198,
)
_SCHEMA_FIELDTYPE = _descriptor.Descriptor(
@@ -1638,8 +1645,8 @@ _SCHEMA_FIELDTYPE = _descriptor.Descriptor(
name='type_info',
full_name='org.apache.flink.fn_execution.v1.Schema.FieldType.type_info',
index=0, containing_type=None, fields=[]),
],
- serialized_start=4178,
- serialized_end=5250,
+ serialized_start=4201,
+ serialized_end=5273,
)
_SCHEMA_FIELD = _descriptor.Descriptor(
@@ -1682,8 +1689,8 @@ _SCHEMA_FIELD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=5252,
- serialized_end=5360,
+ serialized_start=5275,
+ serialized_end=5383,
)
_SCHEMA = _descriptor.Descriptor(
@@ -1713,8 +1720,8 @@ _SCHEMA = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3624,
- serialized_end=5652,
+ serialized_start=3647,
+ serialized_end=5675,
)
@@ -1770,8 +1777,8 @@ _CODERPARAM = _descriptor.Descriptor(
name='data_info',
full_name='org.apache.flink.fn_execution.v1.CoderParam.data_info',
index=0, containing_type=None, fields=[]),
],
- serialized_start=5655,
- serialized_end=6078,
+ serialized_start=5678,
+ serialized_end=6101,
)
@@ -1808,8 +1815,8 @@ _TYPEINFO_MAPTYPEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=6492,
- serialized_end=6631,
+ serialized_start=6515,
+ serialized_end=6654,
)
_TYPEINFO_ROWTYPEINFO_FIELD = _descriptor.Descriptor(
@@ -1845,8 +1852,8 @@ _TYPEINFO_ROWTYPEINFO_FIELD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=6727,
- serialized_end=6818,
+ serialized_start=6750,
+ serialized_end=6841,
)
_TYPEINFO_ROWTYPEINFO = _descriptor.Descriptor(
@@ -1875,8 +1882,8 @@ _TYPEINFO_ROWTYPEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=6634,
- serialized_end=6818,
+ serialized_start=6657,
+ serialized_end=6841,
)
_TYPEINFO_TUPLETYPEINFO = _descriptor.Descriptor(
@@ -1905,8 +1912,8 @@ _TYPEINFO_TUPLETYPEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=6820,
- serialized_end=6900,
+ serialized_start=6843,
+ serialized_end=6923,
)
_TYPEINFO = _descriptor.Descriptor(
@@ -1967,8 +1974,8 @@ _TYPEINFO = _descriptor.Descriptor(
name='type_info',
full_name='org.apache.flink.fn_execution.v1.TypeInfo.type_info',
index=0, containing_type=None, fields=[]),
],
- serialized_start=6081,
- serialized_end=7193,
+ serialized_start=6104,
+ serialized_end=7216,
)
_INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION
diff --git a/flink-python/pyflink/fn_execution/operation_utils.py
b/flink-python/pyflink/fn_execution/operation_utils.py
index c280c77..bf1576f 100644
--- a/flink-python/pyflink/fn_execution/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/operation_utils.py
@@ -17,6 +17,7 @@
################################################################################
import datetime
from enum import Enum
+from functools import partial
from typing import Any, Tuple, Dict, List
@@ -65,6 +66,18 @@ def wrap_inputs_as_row(*args):
return Row(*args)
+def check_pandas_udf_result(f, *input_args):
+ output = f(*input_args)
+ import pandas as pd
+ assert type(output) == pd.Series or type(output) == pd.DataFrame, \
+ "The result type of Pandas UDF '%s' must be pandas.Series or
pandas.DataFrame, got %s" \
+ % (f.__name__, type(output))
+ assert len(output) == len(input_args[0]), \
+ "The result length '%d' of Pandas UDF '%s' is not equal to the input
length '%d'" \
+ % (len(output), f.__name__, len(input_args[0]))
+ return output
+
+
def extract_over_window_user_defined_function(user_defined_function_proto):
window_index = user_defined_function_proto.window_index
return (*extract_user_defined_function(user_defined_function_proto, True),
window_index)
@@ -117,7 +130,10 @@ def
extract_user_defined_function(user_defined_function_proto, pandas_udaf=False
func_name = 'f%s' % _next_func_num()
if isinstance(user_defined_func, DelegatingScalarFunction) \
or isinstance(user_defined_func, DelegationTableFunction):
- variable_dict[func_name] = user_defined_func.func
+ if user_defined_function_proto.is_pandas_udf:
+ variable_dict[func_name] = partial(check_pandas_udf_result,
user_defined_func.func)
+ else:
+ variable_dict[func_name] = user_defined_func.func
else:
variable_dict[func_name] = user_defined_func.eval
user_defined_funcs.append(user_defined_func)
diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto
b/flink-python/pyflink/proto/flink-fn-execution.proto
index 26e9c65..f18c455 100644
--- a/flink-python/pyflink/proto/flink-fn-execution.proto
+++ b/flink-python/pyflink/proto/flink-fn-execution.proto
@@ -50,6 +50,9 @@ message UserDefinedFunction {
// Whether the UDF takes row as input instead of each columns of a row
bool takes_row_as_input = 4;
+
+ // Whether it's pandas UDF
+ bool is_pandas_udf = 5;
}
// A list of user-defined functions to be executed in a batch.
diff --git a/flink-python/pyflink/table/tests/test_pandas_udf.py
b/flink-python/pyflink/table/tests/test_pandas_udf.py
index ee3172c..8d3334d 100644
--- a/flink-python/pyflink/table/tests/test_pandas_udf.py
+++ b/flink-python/pyflink/table/tests/test_pandas_udf.py
@@ -283,6 +283,30 @@ class PandasUDFITTests(object):
"1970-01-02 00:00:00.123, [hello, 中文, null], [1970-01-02
00:00:00.123], "
"[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02 00:00:00.123,
[1, 2]]]"])
+ def test_invalid_pandas_udf(self):
+
+ @udf(result_type=DataTypes.INT(), udf_type="pandas")
+ def length_mismatch(i):
+ return i[1:]
+
+ @udf(result_type=DataTypes.INT(), udf_type="pandas")
+ def result_type_not_series(i):
+ return i.iloc[0]
+
+ t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a',
'b', 'c'])
+
+ msg = "The result length '0' of Pandas UDF 'length_mismatch' is not
equal " \
+ "to the input length '1'"
+ from py4j.protocol import Py4JJavaError
+ with self.assertRaisesRegex(Py4JJavaError, expected_regex=msg):
+ t.select(length_mismatch(t.a)).to_pandas()
+
+ msg = "The result type of Pandas UDF 'result_type_not_series' must be
pandas.Series or " \
+ "pandas.DataFrame, got <class 'numpy.int64'>"
+ from py4j.protocol import Py4JJavaError
+ with self.assertRaisesRegex(Py4JJavaError, expected_regex=msg):
+ t.select(result_type_not_series(t.a)).to_pandas()
+
class BlinkPandasUDFITTests(object):
diff --git
a/flink-python/src/main/java/org/apache/flink/streaming/api/utils/PythonOperatorUtils.java
b/flink-python/src/main/java/org/apache/flink/streaming/api/utils/PythonOperatorUtils.java
index fd2ef99..76a8a45 100644
---
a/flink-python/src/main/java/org/apache/flink/streaming/api/utils/PythonOperatorUtils.java
+++
b/flink-python/src/main/java/org/apache/flink/streaming/api/utils/PythonOperatorUtils.java
@@ -26,6 +26,7 @@ import
org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionI
import
org.apache.flink.streaming.api.operators.sorted.state.BatchExecutionKeyedStateBackend;
import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.typeutils.DataViewUtils;
import com.google.protobuf.ByteString;
@@ -61,6 +62,9 @@ public enum PythonOperatorUtils {
builder.addInputs(inputProto);
}
builder.setTakesRowAsInput(pythonFunctionInfo.getPythonFunction().takesRowAsInput());
+ builder.setIsPandasUdf(
+ pythonFunctionInfo.getPythonFunction().getPythonFunctionKind()
+ == PythonFunctionKind.PANDAS);
return builder.build();
}