dianfu commented on a change in pull request #14347:
URL: https://github.com/apache/flink/pull/14347#discussion_r539779543
##########
File path: flink-python/pyflink/table/tests/test_row_based_operation.py
##########
@@ -49,6 +49,35 @@ def test_map(self):
actual = source_sink_utils.results()
self.assert_equals(actual, ["4,9", "3,4", "7,36", "10,81", "5,16"])
+ def test_map_pandas(self):
Review comment:
```suggestion
def test_map_with_pandas_udf(self):
```
##########
File path: flink-python/pyflink/table/utils.py
##########
@@ -38,16 +38,35 @@ def create_array(s, t):
"pyarrow.Array (%s)."
raise RuntimeError(error_msg % (s.dtype, t), e)
- arrays = [create_array(
- tz_convert_to_internal(series[i], field_types[i], timezone),
- schema.types[i]) for i in range(0, len(schema))]
+ arrays = []
+ for i in range(len(schema)):
+ s = series[i]
+ field_type = field_types[i]
+ schema_type = schema.types[i]
+ if type(field_type) == RowType:
+ array_names = [(create_array(s[s.columns[i]], field.type),
field.name)
Review comment:
```suggestion
array_names = [(create_array(s[s.columns[j]], field.type),
field.name)
```
##########
File path: flink-python/pyflink/table/utils.py
##########
@@ -38,16 +38,35 @@ def create_array(s, t):
"pyarrow.Array (%s)."
raise RuntimeError(error_msg % (s.dtype, t), e)
- arrays = [create_array(
- tz_convert_to_internal(series[i], field_types[i], timezone),
- schema.types[i]) for i in range(0, len(schema))]
+ arrays = []
+ for i in range(len(schema)):
+ s = series[i]
+ field_type = field_types[i]
+ schema_type = schema.types[i]
+ if type(field_type) == RowType:
+ array_names = [(create_array(s[s.columns[i]], field.type),
field.name)
+ for i, field in enumerate(schema_type)]
Review comment:
```suggestion
for j, field in enumerate(schema_type)]
```
##########
File path: flink-python/pyflink/table/utils.py
##########
@@ -38,16 +38,35 @@ def create_array(s, t):
"pyarrow.Array (%s)."
raise RuntimeError(error_msg % (s.dtype, t), e)
- arrays = [create_array(
- tz_convert_to_internal(series[i], field_types[i], timezone),
- schema.types[i]) for i in range(0, len(schema))]
+ arrays = []
+ for i in range(len(schema)):
+ s = series[i]
+ field_type = field_types[i]
+ schema_type = schema.types[i]
+ if type(field_type) == RowType:
+ array_names = [(create_array(s[s.columns[i]], field.type),
field.name)
+ for i, field in enumerate(schema_type)]
+ struct_arrays, struct_names = zip(*array_names)
+ arrays.append(pa.StructArray.from_arrays(struct_arrays,
struct_names))
+ else:
+ arrays.append(create_array(
+ tz_convert_to_internal(s, field_type, timezone), schema_type))
return pa.RecordBatch.from_arrays(arrays, schema)
def arrow_to_pandas(timezone, field_types, batches):
+ def arrow_column_to_pandas(arrow_column, t: DataType):
+ if type(t) == RowType:
+ import pandas as pd
+ series = [column.to_pandas(date_as_object=True).rename(field.name)
+ for column, field in zip(arrow_column.flatten(),
arrow_column.type)]
+ s = pd.concat(series, axis=1)
+ else:
+ s = arrow_column.to_pandas(date_as_object=True)
Review comment:
```suggestion
return arrow_column.to_pandas(date_as_object=True)
```
##########
File path: flink-python/pyflink/table/utils.py
##########
@@ -38,16 +38,35 @@ def create_array(s, t):
"pyarrow.Array (%s)."
raise RuntimeError(error_msg % (s.dtype, t), e)
- arrays = [create_array(
- tz_convert_to_internal(series[i], field_types[i], timezone),
- schema.types[i]) for i in range(0, len(schema))]
+ arrays = []
+ for i in range(len(schema)):
+ s = series[i]
+ field_type = field_types[i]
+ schema_type = schema.types[i]
+ if type(field_type) == RowType:
+ array_names = [(create_array(s[s.columns[i]], field.type),
field.name)
+ for i, field in enumerate(schema_type)]
+ struct_arrays, struct_names = zip(*array_names)
+ arrays.append(pa.StructArray.from_arrays(struct_arrays,
struct_names))
+ else:
+ arrays.append(create_array(
+ tz_convert_to_internal(s, field_type, timezone), schema_type))
return pa.RecordBatch.from_arrays(arrays, schema)
def arrow_to_pandas(timezone, field_types, batches):
+ def arrow_column_to_pandas(arrow_column, t: DataType):
+ if type(t) == RowType:
+ import pandas as pd
+ series = [column.to_pandas(date_as_object=True).rename(field.name)
+ for column, field in zip(arrow_column.flatten(),
arrow_column.type)]
+ s = pd.concat(series, axis=1)
Review comment:
```suggestion
return pd.concat(series, axis=1)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]