westonpace commented on code in PR #34373:
URL: https://github.com/apache/arrow/pull/34373#discussion_r1120837528
##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -315,3 +320,246 @@ def table_provider(names, _):
exec_message = "names for NamedTable not provided"
with pytest.raises(ArrowInvalid, match=exec_message):
substrait.run_query(buf, table_provider=table_provider)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
Review Comment:
```suggestion
test_table = pa.Table.from_pydict({"x": [1, 2, 3]})
```
##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -315,3 +320,246 @@ def table_provider(names, _):
exec_message = "names for NamedTable not provided"
with pytest.raises(ArrowInvalid, match=exec_message):
substrait.run_query(buf, table_provider=table_provider)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names, _):
+ if not names:
+ raise Exception("No names provided")
+ elif names[0] == "t1":
+ return test_table_1
Review Comment:
```suggestion
return test_table
```
##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -315,3 +320,246 @@ def table_provider(names, _):
exec_message = "names for NamedTable not provided"
with pytest.raises(ArrowInvalid, match=exec_message):
substrait.run_query(buf, table_provider=table_provider)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names, _):
Review Comment:
You could even simplify to:
```
def table_provider(_names, _schema):
return test_table
```
##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -315,3 +320,246 @@ def table_provider(names, _):
exec_message = "names for NamedTable not provided"
with pytest.raises(ArrowInvalid, match=exec_message):
substrait.run_query(buf, table_provider=table_provider)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names, _):
+ if not names:
+ raise Exception("No names provided")
+ elif names[0] == "t1":
+ return test_table_1
+ else:
+ raise Exception("Unrecognized table name")
+
+ substrait_query = b"""
+ {
+ "extensionUris": [
+ {
+ "extensionUriAnchor": 1
+ },
+ {
+ "extensionUriAnchor": 2,
+ "uri": "urn:arrow:substrait_simple_extension_function"
+ }
+ ],
+ "extensions": [
+ {
+ "extensionFunction": {
+ "extensionUriReference": 2,
+ "functionAnchor": 1,
+ "name": "y=x+1"
+ }
+ }
+ ],
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "project": {
+ "common": {
+ "emit": {
+ "outputMapping": [
+ 1,
+ 2,
+ ]
+ }
+ },
+ "input": {
+ "read": {
+ "baseSchema": {
+ "names": [
+ "t",
+ ],
+ "struct": {
+ "types": [
+ {
+ "i64": {
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ ],
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "namedTable": {
+ "names": [
+ "t1"
+ ]
+ }
+ }
+ },
+ "expressions": [
+ {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ },
+ {
+ "scalarFunction": {
+ "functionReference": 1,
+ "outputType": {
+ "i64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ }
+ }
+ ]
+ }
+ }
+ ]
+ }
+ },
+ "names": [
+ "x",
+ "y",
+ ]
+ }
+ }
+ ]
+}
+ """
+
+ buf = pa._substrait._parse_json_plan(substrait_query)
+ reader = pa.substrait.run_query(
+ buf, table_provider=table_provider, use_threads=use_threads)
+ res_tb = reader.read_all()
+
+ function, name = unary_func_fixture
+ expected_tb = test_table_1.add_column(1, 'y', function(
+ mock_scalar_udf_context(10), test_table_1['x']))
+ res_tb = res_tb.rename_columns(['x', 'y'])
+ assert res_tb == expected_tb
+
+
+def test_udf_via_substrait_wrong_udf_name():
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
Review Comment:
```suggestion
test_table = pa.Table.from_pydict({"x": [1, 2, 3]})
```
##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -315,3 +320,246 @@ def table_provider(names, _):
exec_message = "names for NamedTable not provided"
with pytest.raises(ArrowInvalid, match=exec_message):
substrait.run_query(buf, table_provider=table_provider)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names, _):
+ if not names:
+ raise Exception("No names provided")
+ elif names[0] == "t1":
+ return test_table_1
+ else:
+ raise Exception("Unrecognized table name")
+
+ substrait_query = b"""
+ {
+ "extensionUris": [
+ {
+ "extensionUriAnchor": 1
+ },
+ {
+ "extensionUriAnchor": 2,
+ "uri": "urn:arrow:substrait_simple_extension_function"
+ }
+ ],
+ "extensions": [
+ {
+ "extensionFunction": {
+ "extensionUriReference": 2,
+ "functionAnchor": 1,
+ "name": "y=x+1"
+ }
+ }
+ ],
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "project": {
+ "common": {
+ "emit": {
+ "outputMapping": [
+ 1,
+ 2,
+ ]
+ }
+ },
+ "input": {
+ "read": {
+ "baseSchema": {
+ "names": [
+ "t",
+ ],
+ "struct": {
+ "types": [
+ {
+ "i64": {
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ ],
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "namedTable": {
+ "names": [
+ "t1"
+ ]
+ }
+ }
+ },
+ "expressions": [
+ {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ },
+ {
+ "scalarFunction": {
+ "functionReference": 1,
+ "outputType": {
+ "i64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ }
+ }
+ ]
+ }
+ }
+ ]
+ }
+ },
+ "names": [
+ "x",
+ "y",
+ ]
+ }
+ }
+ ]
+}
+ """
+
+ buf = pa._substrait._parse_json_plan(substrait_query)
+ reader = pa.substrait.run_query(
+ buf, table_provider=table_provider, use_threads=use_threads)
+ res_tb = reader.read_all()
+
+ function, name = unary_func_fixture
+ expected_tb = test_table_1.add_column(1, 'y', function(
+ mock_scalar_udf_context(10), test_table_1['x']))
+ res_tb = res_tb.rename_columns(['x', 'y'])
+ assert res_tb == expected_tb
+
+
+def test_udf_via_substrait_wrong_udf_name():
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names, _):
+ if not names:
+ raise Exception("No names provided")
+ elif names[0] == "t1":
+ return test_table_1
Review Comment:
```suggestion
return test_table
```
##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -315,3 +320,246 @@ def table_provider(names, _):
exec_message = "names for NamedTable not provided"
with pytest.raises(ArrowInvalid, match=exec_message):
substrait.run_query(buf, table_provider=table_provider)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names, _):
+ if not names:
+ raise Exception("No names provided")
+ elif names[0] == "t1":
+ return test_table_1
+ else:
+ raise Exception("Unrecognized table name")
+
+ substrait_query = b"""
+ {
+ "extensionUris": [
+ {
+ "extensionUriAnchor": 1
+ },
+ {
+ "extensionUriAnchor": 2,
+ "uri": "urn:arrow:substrait_simple_extension_function"
+ }
+ ],
+ "extensions": [
+ {
+ "extensionFunction": {
+ "extensionUriReference": 2,
+ "functionAnchor": 1,
+ "name": "y=x+1"
+ }
+ }
+ ],
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "project": {
+ "common": {
+ "emit": {
+ "outputMapping": [
+ 1,
+ 2,
+ ]
+ }
+ },
+ "input": {
+ "read": {
+ "baseSchema": {
+ "names": [
+ "t",
+ ],
+ "struct": {
+ "types": [
+ {
+ "i64": {
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ ],
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "namedTable": {
+ "names": [
+ "t1"
+ ]
+ }
+ }
+ },
+ "expressions": [
+ {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ },
+ {
+ "scalarFunction": {
+ "functionReference": 1,
+ "outputType": {
+ "i64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ }
+ }
+ ]
+ }
+ }
+ ]
+ }
+ },
+ "names": [
+ "x",
+ "y",
+ ]
+ }
+ }
+ ]
+}
+ """
+
+ buf = pa._substrait._parse_json_plan(substrait_query)
+ reader = pa.substrait.run_query(
+ buf, table_provider=table_provider, use_threads=use_threads)
+ res_tb = reader.read_all()
+
+ function, name = unary_func_fixture
+ expected_tb = test_table_1.add_column(1, 'y', function(
+ mock_scalar_udf_context(10), test_table_1['x']))
Review Comment:
```suggestion
mock_scalar_udf_context(10), test_table['x']))
```
##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -315,3 +320,246 @@ def table_provider(names, _):
exec_message = "names for NamedTable not provided"
with pytest.raises(ArrowInvalid, match=exec_message):
substrait.run_query(buf, table_provider=table_provider)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names, _):
+ if not names:
+ raise Exception("No names provided")
+ elif names[0] == "t1":
+ return test_table_1
+ else:
+ raise Exception("Unrecognized table name")
+
+ substrait_query = b"""
+ {
+ "extensionUris": [
+ {
+ "extensionUriAnchor": 1
+ },
+ {
+ "extensionUriAnchor": 2,
+ "uri": "urn:arrow:substrait_simple_extension_function"
+ }
+ ],
+ "extensions": [
+ {
+ "extensionFunction": {
+ "extensionUriReference": 2,
+ "functionAnchor": 1,
+ "name": "y=x+1"
+ }
+ }
+ ],
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "project": {
+ "common": {
+ "emit": {
+ "outputMapping": [
+ 1,
+ 2,
+ ]
+ }
+ },
+ "input": {
+ "read": {
+ "baseSchema": {
+ "names": [
+ "t",
+ ],
+ "struct": {
+ "types": [
+ {
+ "i64": {
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ ],
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "namedTable": {
+ "names": [
+ "t1"
+ ]
+ }
+ }
+ },
+ "expressions": [
+ {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ },
+ {
+ "scalarFunction": {
+ "functionReference": 1,
+ "outputType": {
+ "i64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ }
+ }
+ ]
+ }
+ }
+ ]
+ }
+ },
+ "names": [
+ "x",
+ "y",
+ ]
+ }
+ }
+ ]
+}
+ """
+
+ buf = pa._substrait._parse_json_plan(substrait_query)
+ reader = pa.substrait.run_query(
+ buf, table_provider=table_provider, use_threads=use_threads)
+ res_tb = reader.read_all()
+
+ function, name = unary_func_fixture
+ expected_tb = test_table_1.add_column(1, 'y', function(
Review Comment:
```suggestion
expected_tb = test_table.add_column(1, 'y', function(
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]