This is an automated email from the ASF dual-hosted git repository.

icexelloss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new e8107bfa58 GH-34333: [Python] Test run_query with a registered scalar 
UDF (#34373)
e8107bfa58 is described below

commit e8107bfa58ef5ad50c5c40d3f54bb7a96bdf2d0e
Author: Li Jin <[email protected]>
AuthorDate: Wed Mar 1 14:23:59 2023 -0500

    GH-34333: [Python] Test run_query with a registered scalar UDF (#34373)
    
    <!--
    Thanks for opening a pull request!
    If this is your first pull request you can find detailed information on
    how
    to contribute here:
    * [New Contributor's
    
Guide](https://arrow.apache.org/docs/dev/developers/guide/step_by_step/pr_lifecycle.html#reviews-and-merge-of-the-pull-request)
    * [Contributing
    Overview](https://arrow.apache.org/docs/dev/developers/overview.html)
    
    
    If this is not a [minor
    PR](https://github.com/apache/arrow/blob/main/CONTRIBUTING.md#Minor-Fixes).
    Could you open an issue for this pull request on GitHub?
    https://github.com/apache/arrow/issues/new/choose
    
    Opening GitHub issues ahead of time contributes to the
    
[Openness](http://theapacheway.com/open/#:~:text=Openness%20allows%20new%20users%20the,must%20happen%20in%20the%20open.)
    of the Apache Arrow project.
    
    Then could you also rename the pull request title in the following
    format?
    
        GH-${GITHUB_ISSUE_ID}: [${COMPONENT}] ${SUMMARY}
    
    or
    
        MINOR: [${COMPONENT}] ${SUMMARY}
    
    In the case of PARQUET issues on JIRA the title also supports:
    
        PARQUET-${JIRA_ISSUE_ID}: [${COMPONENT}] ${SUMMARY}
    
    -->
    
    ### Rationale for this change
    
    Currently Acero has a way to execute a registered UDF via substrait
    however there are no tests for it.
    <!--
    Why are you proposing this change? If this is already explained clearly
    in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand
    your changes and offer better suggestions for fixes.
    -->
    
    ### What changes are included in this PR?
    This PR adds a test for passing a registered UDF via a substrait plan.
    
    <!--
    There is no need to duplicate the description in the issue here but it
    is sometimes worth providing a summary of the individual changes in this
    PR.
    -->
    
    ### Are these changes tested?
    N/A
    
    <!--
    We typically require tests for all PRs in order to:
    1. Prevent the code from being accidentally broken by subsequent changes
    2. Serve as another way to document the expected behavior of the code
    
    If tests are not included in your PR, please explain why (for example,
    are they covered by existing tests)?
    -->
    
    ### Are there any user-facing changes?
    No
    <!--
    If there are user-facing changes then we may require documentation to be
    updated before approving the PR.
    -->
    
    <!--
    If there are any breaking changes to public APIs, please uncomment the
    line below and explain which changes are breaking.
    -->
    <!-- **This PR includes breaking changes to public APIs.** -->
    
    <!--
    Please uncomment the line below (and provide explanation) if the changes
    fix either (a) a security vulnerability, (b) a bug that caused incorrect
    or invalid data to be produced, or (c) a bug that causes a crash (even
    when the API contract is upheld). We use this to highlight fixes to
    issues that may affect users without their knowledge. For this reason,
    fixing bugs that cause errors don't count, since those are usually
    obvious.
    -->
    <!-- **This PR contains a "Critical Fix".** -->
    * Closes: #34333
    
    ---------
    
    Co-authored-by: Weston Pace <[email protected]>
---
 .../arrow/engine/substrait/relation_internal.cc    |   5 +-
 python/pyarrow/conftest.py                         |  23 ++
 python/pyarrow/tests/test_substrait.py             | 248 +++++++++++++++++++++
 python/pyarrow/tests/test_udf.py                   |  20 --
 4 files changed, 274 insertions(+), 22 deletions(-)

diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc 
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index e5b9116a0a..b21d8c5878 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -547,8 +547,9 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
         std::shared_ptr<Field> project_field;
         ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr,
                               FromProto(expr, ext_set, conversion_options));
-        auto bound_expr = des_expr.Bind(*input.output_schema);
-        if (auto* expr_call = bound_expr->call()) {
+        ARROW_ASSIGN_OR_RAISE(compute::Expression bound_expr,
+                              des_expr.Bind(*input.output_schema));
+        if (auto* expr_call = bound_expr.call()) {
           project_field = field(expr_call->function_name,
                                 
expr_call->kernel->signature->out_type().type());
         } else if (auto* field_ref = des_expr.field_ref()) {
diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py
index bea735bd3a..e8e7228298 100644
--- a/python/pyarrow/conftest.py
+++ b/python/pyarrow/conftest.py
@@ -16,6 +16,7 @@
 # under the License.
 
 import pytest
+import pyarrow as pa
 from pyarrow import Codec
 from pyarrow import fs
 
@@ -265,3 +266,25 @@ def add_fs(doctest_namespace, request, tmp_path):
         doctest_namespace["local_path"] = str(tmp_path)
         doctest_namespace["path"] = str(path)
     yield
+
+
+# Define udf fixture for test_udf.py and test_substrait.py
[email protected](scope="session")
+def unary_func_fixture():
+    """
+    Register a unary scalar function.
+    """
+    from pyarrow import compute as pc
+
+    def unary_function(ctx, x):
+        return pc.call_function("add", [x, 1],
+                                memory_pool=ctx.memory_pool)
+    func_name = "y=x+1"
+    unary_doc = {"summary": "add function",
+                 "description": "test add function"}
+    pc.register_scalar_function(unary_function,
+                                func_name,
+                                unary_doc,
+                                {"array": pa.int64()},
+                                pa.int64())
+    return unary_function, func_name
diff --git a/python/pyarrow/tests/test_substrait.py 
b/python/pyarrow/tests/test_substrait.py
index 8c6a1871a3..87d3bfc444 100644
--- a/python/pyarrow/tests/test_substrait.py
+++ b/python/pyarrow/tests/test_substrait.py
@@ -34,6 +34,11 @@ except ImportError:
 pytestmark = [pytest.mark.dataset, pytest.mark.substrait]
 
 
+def mock_scalar_udf_context(batch_length=10):
+    from pyarrow._compute import _get_scalar_udf_context
+    return _get_scalar_udf_context(pa.default_memory_pool(), batch_length)
+
+
 def _write_dummy_data_to_disk(tmpdir, file_name, table):
     path = os.path.join(str(tmpdir), file_name)
     with pa.ipc.RecordBatchFileWriter(path, schema=table.schema) as writer:
@@ -315,3 +320,246 @@ def test_named_table_empty_names():
     exec_message = "names for NamedTable not provided"
     with pytest.raises(ArrowInvalid, match=exec_message):
         substrait.run_query(buf, table_provider=table_provider)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+    test_table = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+    def table_provider(names, _):
+        if not names:
+            raise Exception("No names provided")
+        elif names[0] == "t1":
+            return test_table
+        else:
+            raise Exception("Unrecognized table name")
+
+    substrait_query = b"""
+    {
+  "extensionUris": [
+    {
+      "extensionUriAnchor": 1
+    },
+    {
+      "extensionUriAnchor": 2,
+      "uri": "urn:arrow:substrait_simple_extension_function"
+    }
+  ],
+  "extensions": [
+    {
+      "extensionFunction": {
+        "extensionUriReference": 2,
+        "functionAnchor": 1,
+        "name": "y=x+1"
+      }
+    }
+  ],
+  "relations": [
+    {
+      "root": {
+        "input": {
+          "project": {
+            "common": {
+              "emit": {
+                "outputMapping": [
+                  1,
+                  2,
+                ]
+              }
+            },
+            "input": {
+              "read": {
+                "baseSchema": {
+                  "names": [
+                    "t",
+                  ],
+                  "struct": {
+                    "types": [
+                      {
+                        "i64": {
+                          "nullability": "NULLABILITY_REQUIRED"
+                        }
+                      },
+                    ],
+                    "nullability": "NULLABILITY_REQUIRED"
+                  }
+                },
+                "namedTable": {
+                  "names": [
+                    "t1"
+                  ]
+                }
+              }
+            },
+            "expressions": [
+              {
+                "selection": {
+                  "directReference": {
+                    "structField": {}
+                  },
+                  "rootReference": {}
+                }
+              },
+              {
+                "scalarFunction": {
+                  "functionReference": 1,
+                  "outputType": {
+                    "i64": {
+                      "nullability": "NULLABILITY_NULLABLE"
+                    }
+                  },
+                  "arguments": [
+                    {
+                      "value": {
+                        "selection": {
+                          "directReference": {
+                            "structField": {}
+                          },
+                          "rootReference": {}
+                        }
+                      }
+                    }
+                  ]
+                }
+              }
+            ]
+          }
+        },
+        "names": [
+          "x",
+          "y",
+        ]
+      }
+    }
+  ]
+}
+    """
+
+    buf = pa._substrait._parse_json_plan(substrait_query)
+    reader = pa.substrait.run_query(
+        buf, table_provider=table_provider, use_threads=use_threads)
+    res_tb = reader.read_all()
+
+    function, name = unary_func_fixture
+    expected_tb = test_table.add_column(1, 'y', function(
+        mock_scalar_udf_context(10), test_table['x']))
+    res_tb = res_tb.rename_columns(['x', 'y'])
+    assert res_tb == expected_tb
+
+
+def test_udf_via_substrait_wrong_udf_name():
+    test_table = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+    def table_provider(names, _):
+        if not names:
+            raise Exception("No names provided")
+        elif names[0] == "t1":
+            return test_table
+        else:
+            raise Exception("Unrecognized table name")
+
+    substrait_query = b"""
+    {
+  "extensionUris": [
+    {
+      "extensionUriAnchor": 1
+    },
+    {
+      "extensionUriAnchor": 2,
+      "uri": "urn:arrow:substrait_simple_extension_function"
+    }
+  ],
+  "extensions": [
+    {
+      "extensionFunction": {
+        "extensionUriReference": 2,
+        "functionAnchor": 1,
+        "name": "wrong_udf_name"
+      }
+    }
+  ],
+  "relations": [
+    {
+      "root": {
+        "input": {
+          "project": {
+            "common": {
+              "emit": {
+                "outputMapping": [
+                  1,
+                  2,
+                ]
+              }
+            },
+            "input": {
+              "read": {
+                "baseSchema": {
+                  "names": [
+                    "t",
+                  ],
+                  "struct": {
+                    "types": [
+                      {
+                        "i64": {
+                          "nullability": "NULLABILITY_REQUIRED"
+                        }
+                      },
+                    ],
+                    "nullability": "NULLABILITY_REQUIRED"
+                  }
+                },
+                "namedTable": {
+                  "names": [
+                    "t1"
+                  ]
+                }
+              }
+            },
+            "expressions": [
+              {
+                "selection": {
+                  "directReference": {
+                    "structField": {}
+                  },
+                  "rootReference": {}
+                }
+              },
+              {
+                "scalarFunction": {
+                  "functionReference": 1,
+                  "outputType": {
+                    "i64": {
+                      "nullability": "NULLABILITY_NULLABLE"
+                    }
+                  },
+                  "arguments": [
+                    {
+                      "value": {
+                        "selection": {
+                          "directReference": {
+                            "structField": {}
+                          },
+                          "rootReference": {}
+                        }
+                      }
+                    }
+                  ]
+                }
+              }
+            ]
+          }
+        },
+        "names": [
+          "x",
+          "y",
+        ]
+      }
+    }
+  ]
+}
+    """
+
+    buf = pa._substrait._parse_json_plan(substrait_query)
+    with pytest.raises(pa.ArrowKeyError) as excinfo:
+        pa.substrait.run_query(buf, table_provider=table_provider)
+    assert "No function registered" in str(excinfo.value)
diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py
index 6a67e0bae9..0f336555f7 100644
--- a/python/pyarrow/tests/test_udf.py
+++ b/python/pyarrow/tests/test_udf.py
@@ -24,7 +24,6 @@ from pyarrow import compute as pc
 # UDFs are all tested with a dataset scan
 pytestmark = pytest.mark.dataset
 
-
 try:
     import pyarrow.dataset as ds
 except ImportError:
@@ -40,25 +39,6 @@ class MyError(RuntimeError):
     pass
 
 
[email protected](scope="session")
-def unary_func_fixture():
-    """
-    Register a unary scalar function.
-    """
-    def unary_function(ctx, x):
-        return pc.call_function("add", [x, 1],
-                                memory_pool=ctx.memory_pool)
-    func_name = "y=x+1"
-    unary_doc = {"summary": "add function",
-                 "description": "test add function"}
-    pc.register_scalar_function(unary_function,
-                                func_name,
-                                unary_doc,
-                                {"array": pa.int64()},
-                                pa.int64())
-    return unary_function, func_name
-
-
 @pytest.fixture(scope="session")
 def binary_func_fixture():
     """

Reply via email to