westonpace commented on code in PR #34373:
URL: https://github.com/apache/arrow/pull/34373#discussion_r1119524825


##########
python/pyarrow/tests/test_udf.py:
##########
@@ -20,6 +20,7 @@
 
 import pyarrow as pa
 from pyarrow import compute as pc
+from pyarrow.lib import tobytes

Review Comment:
   Since we are in pure-python context (and not a cython context) I think you 
can just use:
   
   ```
   substrait_query = b"""
   ...
   ```
   
   Then you don't have to rely on `tobytes`.  Even if that doesn't work I think 
`substrait_query.encode()` is still preferable over `tobytes`.



##########
python/pyarrow/tests/test_udf.py:
##########
@@ -613,3 +614,180 @@ def test_udt_datasource1_generator():
 def test_udt_datasource1_exception():
     with pytest.raises(RuntimeError, match='datasource1_exception'):
         _test_datasource1_udt(datasource1_exception)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+    test_table_1 = pa.Table.from_pydict({"t": [1, 2, 3], "p": [4, 5, 6]})
+
+    def table_provider(names, _):
+        if not names:
+            raise Exception("No names provided")
+        elif names[0] == "t1":
+            return test_table_1
+        else:
+            raise Exception("Unrecognized table name")
+
+    substrait_query = """
+    {
+  "extensionUris": [
+    {
+      "extensionUriAnchor": 1
+    },
+    {
+      "extensionUriAnchor": 2,
+      "uri": "urn:arrow:substrait_simple_extension_function"
+    }
+  ],
+  "extensions": [
+    {
+      "extensionFunction": {
+        "extensionUriReference": 2,
+        "functionAnchor": 1,
+        "name": "y=x+1"
+      }
+    }
+  ],
+  "relations": [
+    {
+      "root": {
+        "input": {
+          "project": {

Review Comment:
   Why two project nodes?



##########
python/pyarrow/tests/test_udf.py:
##########
@@ -613,3 +614,180 @@ def test_udt_datasource1_generator():
 def test_udt_datasource1_exception():
     with pytest.raises(RuntimeError, match='datasource1_exception'):
         _test_datasource1_udt(datasource1_exception)
+
+
[email protected]("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+    test_table_1 = pa.Table.from_pydict({"t": [1, 2, 3], "p": [4, 5, 6]})
+
+    def table_provider(names, _):
+        if not names:
+            raise Exception("No names provided")
+        elif names[0] == "t1":
+            return test_table_1
+        else:
+            raise Exception("Unrecognized table name")
+
+    substrait_query = """
+    {
+  "extensionUris": [
+    {
+      "extensionUriAnchor": 1
+    },
+    {
+      "extensionUriAnchor": 2,
+      "uri": "urn:arrow:substrait_simple_extension_function"
+    }
+  ],
+  "extensions": [
+    {
+      "extensionFunction": {
+        "extensionUriReference": 2,
+        "functionAnchor": 1,
+        "name": "y=x+1"
+      }
+    }
+  ],
+  "relations": [
+    {
+      "root": {
+        "input": {
+          "project": {
+            "common": {
+              "emit": {
+                "outputMapping": [
+                  2,
+                  3,
+                  4
+                ]
+              }
+            },
+            "input": {
+              "project": {
+                "common": {
+                  "emit": {
+                    "outputMapping": [
+                      2,
+                      3
+                    ]
+                  }
+                },
+                "input": {
+                  "read": {
+                    "baseSchema": {
+                      "names": [
+                        "t",
+                        "p"
+                      ],
+                      "struct": {
+                        "types": [
+                          {
+                            "i64": {
+                              "nullability": "NULLABILITY_REQUIRED"
+                            }
+                          },
+                          {
+                            "i64": {
+                              "nullability": "NULLABILITY_NULLABLE"
+                            }
+                          }
+                        ],
+                        "nullability": "NULLABILITY_REQUIRED"
+                      }
+                    },
+                    "namedTable": {
+                      "names": [
+                        "t1"
+                      ]
+                    }
+                  }
+                },
+                "expressions": [
+                  {
+                    "selection": {
+                      "directReference": {
+                        "structField": {}
+                      },
+                      "rootReference": {}
+                    }
+                  },
+                  {
+                    "selection": {
+                      "directReference": {
+                        "structField": {
+                          "field": 1
+                        }
+                      },
+                      "rootReference": {}
+                    }
+                  }
+                ]
+              }
+            },
+            "expressions": [
+              {
+                "selection": {
+                  "directReference": {
+                    "structField": {}
+                  },
+                  "rootReference": {}
+                }
+              },
+              {
+                "selection": {
+                  "directReference": {
+                    "structField": {
+                      "field": 1
+                    }
+                  },
+                  "rootReference": {}
+                }
+              },
+              {
+                "scalarFunction": {
+                  "functionReference": 1,
+                  "outputType": {
+                    "i64": {
+                      "nullability": "NULLABILITY_NULLABLE"
+                    }
+                  },
+                  "arguments": [
+                    {
+                      "value": {
+                        "selection": {
+                          "directReference": {
+                            "structField": {
+                              "field": 1
+                            }
+                          },
+                          "rootReference": {}
+                        }
+                      }
+                    }
+                  ]
+                }
+              }
+            ]
+          }
+        },
+        "names": [
+          "t",
+          "p",
+          "p2"
+        ]
+      }
+    }
+  ]
+}
+    """
+
+    buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
+    reader = pa.substrait.run_query(
+        buf, table_provider=table_provider, use_threads=use_threads)
+    res_tb = reader.read_all()
+
+    function, name = unary_func_fixture
+    expected_tb = test_table_1.add_column(2, 'p2', function(
+        mock_scalar_udf_context(10), test_table_1['p']))
+    res_tb = res_tb.rename_columns(['t', 'p', 'p2'])

Review Comment:
   This may explain the problem we are seeing in 
https://github.com/ibis-project/ibis-substrait/pull/414



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to