This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 72b539f542 ARROW-17521: [Python] Add python bindings for 
NamedTableProvider for Substrait consumer (#14024)
72b539f542 is described below

commit 72b539f54233f6610b01ec7381755a84c652d151
Author: Vibhatha Lakmal Abeykoon <[email protected]>
AuthorDate: Thu Sep 15 01:15:58 2022 +0530

    ARROW-17521: [Python] Add python bindings for NamedTableProvider for 
Substrait consumer (#14024)
    
    This PR includes a basic version to use NamedTable feature in Substrait. 
The idea is to provide the flexibility to write Python tests with in-memory 
PyArrow tables.
    
    Authored-by: Vibhatha Abeykoon <[email protected]>
    Signed-off-by: Weston Pace <[email protected]>
---
 cpp/src/arrow/compute/exec/exec_plan.cc            |   4 +
 cpp/src/arrow/compute/exec/exec_plan.h             |   5 +
 cpp/src/arrow/engine/substrait/function_test.cc    |   4 +-
 .../arrow/engine/substrait/relation_internal.cc    |   6 +
 cpp/src/arrow/engine/substrait/serde_test.cc       |  13 +--
 cpp/src/arrow/engine/substrait/util.cc             |  23 ++--
 cpp/src/arrow/engine/substrait/util.h              |   8 +-
 python/pyarrow/_exec_plan.pyx                      |   2 +-
 python/pyarrow/_substrait.pyx                      |  97 +++++++++++++++-
 python/pyarrow/includes/libarrow.pxd               |   1 +
 python/pyarrow/includes/libarrow_substrait.pxd     |  28 ++++-
 python/pyarrow/tests/test_substrait.py             | 126 +++++++++++++++++++++
 12 files changed, 289 insertions(+), 28 deletions(-)

diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc 
b/cpp/src/arrow/compute/exec/exec_plan.cc
index b6a3916de1..00415495aa 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.cc
+++ b/cpp/src/arrow/compute/exec/exec_plan.cc
@@ -643,6 +643,10 @@ Declaration Declaration::Sequence(std::vector<Declaration> 
decls) {
   return out;
 }
 
+bool Declaration::IsValid(ExecFactoryRegistry* registry) const {
+  return !this->factory_name.empty() && this->options != nullptr;
+}
+
 namespace internal {
 
 void RegisterSourceNode(ExecFactoryRegistry*);
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h 
b/cpp/src/arrow/compute/exec/exec_plan.h
index 263f3634a5..a9481e21a6 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -451,6 +451,8 @@ inline Result<ExecNode*> MakeExecNode(
 struct ARROW_EXPORT Declaration {
   using Input = util::Variant<ExecNode*, Declaration>;
 
+  Declaration() {}
+
   Declaration(std::string factory_name, std::vector<Input> inputs,
               std::shared_ptr<ExecNodeOptions> options, std::string label)
       : factory_name{std::move(factory_name)},
@@ -514,6 +516,9 @@ struct ARROW_EXPORT Declaration {
   Result<ExecNode*> AddToPlan(ExecPlan* plan, ExecFactoryRegistry* registry =
                                                   
default_exec_factory_registry()) const;
 
+  // Validate a declaration
+  bool IsValid(ExecFactoryRegistry* registry = 
default_exec_factory_registry()) const;
+
   std::string factory_name;
   std::vector<Input> inputs;
   std::shared_ptr<ExecNodeOptions> options;
diff --git a/cpp/src/arrow/engine/substrait/function_test.cc 
b/cpp/src/arrow/engine/substrait/function_test.cc
index 0bcb475d31..3465f00e13 100644
--- a/cpp/src/arrow/engine/substrait/function_test.cc
+++ b/cpp/src/arrow/engine/substrait/function_test.cc
@@ -132,8 +132,8 @@ void CheckValidTestCases(const 
std::vector<FunctionTestCase>& valid_cases) {
     ASSERT_FINISHES_OK(plan->finished());
 
     // Could also modify the Substrait plan with an emit to drop the leading 
columns
-    ASSERT_OK_AND_ASSIGN(output_table,
-                         
output_table->SelectColumns({output_table->num_columns() - 1}));
+    int result_column = output_table->num_columns() - 1;  // last column holds 
result
+    ASSERT_OK_AND_ASSIGN(output_table, 
output_table->SelectColumns({result_column}));
 
     ASSERT_OK_AND_ASSIGN(
         std::shared_ptr<Table> expected_output,
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc 
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 4213895b61..3911373b7b 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -135,8 +135,14 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
         const substrait::ReadRel::NamedTable& named_table = read.named_table();
         std::vector<std::string> table_names(named_table.names().begin(),
                                              named_table.names().end());
+        if (table_names.empty()) {
+          return Status::Invalid("names for NamedTable not provided");
+        }
         ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl,
                               named_table_provider(table_names));
+        if (!source_decl.IsValid()) {
+          return Status::Invalid("Invalid NamedTable Source");
+        }
         return ProcessEmit(std::move(read),
                            DeclarationInfo{std::move(source_decl), 
base_schema},
                            std::move(base_schema));
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc 
b/cpp/src/arrow/engine/substrait/serde_test.cc
index 251c2bfe35..b50e1c6084 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -1924,7 +1924,6 @@ TEST(Substrait, BasicPlanRoundTripping) {
 
   ASSERT_OK_AND_ASSIGN(auto tempdir,
                        
arrow::internal::TemporaryDir::Make("substrait-tempdir-"));
-  std::cout << "file_path_str " << tempdir->path().ToString() << std::endl;
   ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name));
   std::string file_path_str = file_path.ToString();
 
@@ -2189,7 +2188,7 @@ TEST(Substrait, ProjectRel) {
               }
             },
             "namedTable": {
-                     "names": []
+              "names": ["A"]
             }
           }
         }
@@ -2313,7 +2312,7 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) {
               }
             },
             "namedTable": {
-                     "names": []
+              "names": ["A"]
             }
           }
         }
@@ -2396,7 +2395,7 @@ TEST(Substrait, ReadRelWithEmit) {
           }
         },
         "namedTable": {
-          "names" : []
+          "names" : ["A"]
         }
       }
     }
@@ -2501,7 +2500,7 @@ TEST(Substrait, FilterRelWithEmit) {
               }
             },
             "namedTable": {
-              "names" : []
+              "names" : ["A"]
             }
           }
         }
@@ -2885,7 +2884,7 @@ TEST(Substrait, AggregateRel) {
                 }
               },
               "namedTable" : {
-                "names": []
+                "names": ["A"]
               }
             }
           },
@@ -3004,7 +3003,7 @@ TEST(Substrait, AggregateRelEmit) {
                 }
               },
               "namedTable" : {
-                "names" : []
+                "names" : ["A"]
               }
             }
           },
diff --git a/cpp/src/arrow/engine/substrait/util.cc 
b/cpp/src/arrow/engine/substrait/util.cc
index 936bde5c65..f51666ef85 100644
--- a/cpp/src/arrow/engine/substrait/util.cc
+++ b/cpp/src/arrow/engine/substrait/util.cc
@@ -63,8 +63,12 @@ class SubstraitSinkConsumer : public 
compute::SinkNodeConsumer {
 class SubstraitExecutor {
  public:
   explicit SubstraitExecutor(std::shared_ptr<compute::ExecPlan> plan,
-                             compute::ExecContext exec_context)
-      : plan_(std::move(plan)), plan_started_(false), 
exec_context_(exec_context) {}
+                             compute::ExecContext exec_context,
+                             const ConversionOptions& conversion_options = {})
+      : plan_(std::move(plan)),
+        plan_started_(false),
+        exec_context_(exec_context),
+        conversion_options_(conversion_options) {}
 
   ~SubstraitExecutor() { ARROW_UNUSED(this->Close()); }
 
@@ -95,8 +99,8 @@ class SubstraitExecutor {
       return sink_consumer_;
     };
     ARROW_ASSIGN_OR_RAISE(
-        declarations_,
-        engine::DeserializePlans(substrait_buffer, consumer_factory, 
registry));
+        declarations_, engine::DeserializePlans(substrait_buffer, 
consumer_factory,
+                                                registry, nullptr, 
conversion_options_));
     return Status::OK();
   }
 
@@ -107,19 +111,20 @@ class SubstraitExecutor {
   bool plan_started_;
   compute::ExecContext exec_context_;
   std::shared_ptr<SubstraitSinkConsumer> sink_consumer_;
+  const ConversionOptions& conversion_options_;
 };
 
 }  // namespace
 
 Result<std::shared_ptr<RecordBatchReader>> ExecuteSerializedPlan(
-    const Buffer& substrait_buffer, const ExtensionIdRegistry* extid_registry,
-    compute::FunctionRegistry* func_registry) {
-  // TODO(ARROW-15732)
+    const Buffer& substrait_buffer, const ExtensionIdRegistry* registry,
+    compute::FunctionRegistry* func_registry,
+    const ConversionOptions& conversion_options) {
   compute::ExecContext exec_context(arrow::default_memory_pool(),
                                     ::arrow::internal::GetCpuThreadPool(), 
func_registry);
   ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context));
-  SubstraitExecutor executor(std::move(plan), exec_context);
-  RETURN_NOT_OK(executor.Init(substrait_buffer, extid_registry));
+  SubstraitExecutor executor(std::move(plan), exec_context, 
conversion_options);
+  RETURN_NOT_OK(executor.Init(substrait_buffer, registry));
   ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute());
   // check closing here, not in destructor, to expose error to caller
   RETURN_NOT_OK(executor.Close());
diff --git a/cpp/src/arrow/engine/substrait/util.h 
b/cpp/src/arrow/engine/substrait/util.h
index 3ac9320e1d..90cb4e3dd2 100644
--- a/cpp/src/arrow/engine/substrait/util.h
+++ b/cpp/src/arrow/engine/substrait/util.h
@@ -20,6 +20,7 @@
 #include <memory>
 #include "arrow/compute/registry.h"
 #include "arrow/engine/substrait/api.h"
+#include "arrow/engine/substrait/options.h"
 #include "arrow/util/iterator.h"
 #include "arrow/util/optional.h"
 
@@ -27,10 +28,13 @@ namespace arrow {
 
 namespace engine {
 
-/// \brief Retrieve a RecordBatchReader from a Substrait plan.
+using PythonTableProvider =
+    std::function<Result<std::shared_ptr<Table>>(const 
std::vector<std::string>&)>;
+
 ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>> 
ExecuteSerializedPlan(
     const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = 
NULLPTR,
-    compute::FunctionRegistry* func_registry = NULLPTR);
+    compute::FunctionRegistry* func_registry = NULLPTR,
+    const ConversionOptions& conversion_options = {});
 
 /// \brief Get a Serialized Plan from a Substrait JSON plan.
 /// This is a helper method for Python tests.
diff --git a/python/pyarrow/_exec_plan.pyx b/python/pyarrow/_exec_plan.pyx
index 89e474f439..9506caf7d2 100644
--- a/python/pyarrow/_exec_plan.pyx
+++ b/python/pyarrow/_exec_plan.pyx
@@ -92,7 +92,7 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, 
c_bool use_threads
             node_factory = "table_source"
             c_in_table = pyarrow_unwrap_table(ipt)
             c_tablesourceopts = make_shared[CTableSourceNodeOptions](
-                c_in_table, 1 << 20)
+                c_in_table)
             c_input_node_opts = static_pointer_cast[CExecNodeOptions, 
CTableSourceNodeOptions](
                 c_tablesourceopts)
         elif isinstance(ipt, Dataset):
diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx
index 05794a95a2..47a519cf16 100644
--- a/python/pyarrow/_substrait.pyx
+++ b/python/pyarrow/_substrait.pyx
@@ -17,15 +17,38 @@
 
 # cython: language_level = 3
 from cython.operator cimport dereference as deref
+from libcpp.vector cimport vector as std_vector
 
 from pyarrow import Buffer
-from pyarrow.lib import frombytes
+from pyarrow.lib import frombytes, tobytes
 from pyarrow.lib cimport *
 from pyarrow.includes.libarrow cimport *
 from pyarrow.includes.libarrow_substrait cimport *
 
 
-def run_query(plan):
+cdef CDeclaration _create_named_table_provider(dict named_args, const 
std_vector[c_string]& names):
+    cdef:
+        c_string c_name
+        shared_ptr[CTable] c_in_table
+        shared_ptr[CTableSourceNodeOptions] c_tablesourceopts
+        shared_ptr[CExecNodeOptions] c_input_node_opts
+        vector[CDeclaration.Input] no_c_inputs
+
+    py_names = []
+    for i in range(names.size()):
+        c_name = names[i]
+        py_names.append(frombytes(c_name))
+
+    py_table = named_args["provider"](py_names)
+    c_in_table = pyarrow_unwrap_table(py_table)
+    c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table)
+    c_input_node_opts = static_pointer_cast[CExecNodeOptions, 
CTableSourceNodeOptions](
+        c_tablesourceopts)
+    return CDeclaration(tobytes("table_source"),
+                        no_c_inputs, c_input_node_opts)
+
+
+def run_query(plan, table_provider=None):
     """
     Execute a Substrait plan and read the results as a RecordBatchReader.
 
@@ -33,6 +56,63 @@ def run_query(plan):
     ----------
     plan : Buffer
         The serialized Substrait plan to execute.
+    table_provider : object (optional)
+        A function to resolve any NamedTable relation to a table.
+        The function will receive a single argument which will be a list
+        of strings representing the table name and should return a 
pyarrow.Table.
+
+    Returns
+    -------
+    RecordBatchReader
+        A reader containing the result of the executed query
+
+    Examples
+    --------
+    >>> import pyarrow as pa
+    >>> from pyarrow.lib import tobytes
+    >>> import pyarrow.substrait as substrait
+    >>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+    >>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]})
+    >>> def table_provider(names):
+    ...     if not names:
+    ...        raise Exception("No names provided")
+    ...     elif names[0] == "t1":
+    ...        return test_table_1
+    ...     elif names[1] == "t2":
+    ...        return test_table_2
+    ...     else:
+    ...        raise Exception("Unrecognized table name")
+    ... 
+    >>> substrait_query = '''
+    ...         {
+    ...             "relations": [
+    ...             {"rel": {
+    ...                 "read": {
+    ...                 "base_schema": {
+    ...                     "struct": {
+    ...                     "types": [
+    ...                                 {"i64": {}}
+    ...                             ]
+    ...                     },
+    ...                     "names": [
+    ...                             "x"
+    ...                             ]
+    ...                 },
+    ...                 "namedTable": {
+    ...                         "names": ["t1"]
+    ...                 }
+    ...                 }
+    ...             }}
+    ...             ]
+    ...         }
+    ... '''
+    >>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
+    >>> reader = pa.substrait.run_query(buf, table_provider)
+    >>> reader.read_all()
+    pyarrow.Table
+    x: int64
+    ----
+    x: [[1,2,3]]
     """
 
     cdef:
@@ -41,10 +121,21 @@ def run_query(plan):
         RecordBatchReader reader
         c_string c_str_plan
         shared_ptr[CBuffer] c_buf_plan
+        function[CNamedTableProvider] c_named_table_provider
+        CConversionOptions c_conversion_options
 
     c_buf_plan = pyarrow_unwrap_buffer(plan)
+
+    if table_provider is not None:
+        named_table_args = {
+            "provider": table_provider
+        }
+        c_conversion_options.named_table_provider = 
BindFunction[CNamedTableProvider](
+            &_create_named_table_provider, named_table_args)
+
     with nogil:
-        c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan))
+        c_res_reader = ExecuteSerializedPlan(
+            deref(c_buf_plan), default_extension_id_registry(), 
GetFunctionRegistry(), c_conversion_options)
 
     c_reader = GetResultValue(c_res_reader)
 
diff --git a/python/pyarrow/includes/libarrow.pxd 
b/python/pyarrow/includes/libarrow.pxd
index be273975f9..489d73bf27 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2574,6 +2574,7 @@ cdef extern from "arrow/compute/exec/exec_plan.h" 
namespace "arrow::compute" nog
         c_string label
         vector[Input] inputs
 
+        CDeclaration()
         CDeclaration(c_string factory_name, CExecNodeOptions options)
         CDeclaration(c_string factory_name, vector[Input] inputs, 
shared_ptr[CExecNodeOptions] options)
 
diff --git a/python/pyarrow/includes/libarrow_substrait.pxd 
b/python/pyarrow/includes/libarrow_substrait.pxd
index 0b3ace75d9..04990380d9 100644
--- a/python/pyarrow/includes/libarrow_substrait.pxd
+++ b/python/pyarrow/includes/libarrow_substrait.pxd
@@ -22,10 +22,22 @@ from libcpp.vector cimport vector as std_vector
 from pyarrow.includes.common cimport *
 from pyarrow.includes.libarrow cimport *
 
-
-cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" 
nogil:
-    CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const 
CBuffer& substrait_buffer)
-    CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& 
substrait_json)
+ctypedef CResult[CDeclaration] CNamedTableProvider(const std_vector[c_string]&)
+
+cdef extern from "arrow/engine/substrait/options.h" namespace "arrow::engine" 
nogil:
+    cdef enum ConversionStrictness \
+            "arrow::engine::ConversionStrictness":
+        EXACT_ROUNDTRIP \
+            "arrow::engine::ConversionStrictness::EXACT_ROUNDTRIP"
+        PRESERVE_STRUCTURE \
+            "arrow::engine::ConversionStrictness::PRESERVE_STRUCTURE"
+        BEST_EFFORT \
+            "arrow::engine::ConversionStrictness::BEST_EFFORT"
+
+    cdef cppclass CConversionOptions \
+            "arrow::engine::ConversionOptions":
+        ConversionStrictness conversion_strictness
+        function[CNamedTableProvider] named_table_provider
 
 cdef extern from "arrow/engine/substrait/extension_set.h" \
         namespace "arrow::engine" nogil:
@@ -34,3 +46,11 @@ cdef extern from "arrow/engine/substrait/extension_set.h" \
         std_vector[c_string] GetSupportedSubstraitFunctions()
 
     ExtensionIdRegistry* default_extension_id_registry()
+
+
+cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" 
nogil:
+    CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(
+        const CBuffer& substrait_buffer, const ExtensionIdRegistry* registry,
+        CFunctionRegistry* func_registry, const CConversionOptions& 
conversion_options)
+
+    CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& 
substrait_json)
diff --git a/python/pyarrow/tests/test_substrait.py 
b/python/pyarrow/tests/test_substrait.py
index c8fa6afcb9..c8fd8048aa 100644
--- a/python/pyarrow/tests/test_substrait.py
+++ b/python/pyarrow/tests/test_substrait.py
@@ -165,3 +165,129 @@ def test_get_supported_functions():
                         'functions_arithmetic.yaml', 'add')
     assert has_function(supported_functions,
                         'functions_arithmetic.yaml', 'sum')
+
+
+def test_named_table():
+    test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+    test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]})
+
+    def table_provider(names):
+        if not names:
+            raise Exception("No names provided")
+        elif names[0] == "t1":
+            return test_table_1
+        elif names[1] == "t2":
+            return test_table_2
+        else:
+            raise Exception("Unrecognized table name")
+
+    substrait_query = """
+    {
+        "relations": [
+        {"rel": {
+            "read": {
+            "base_schema": {
+                "struct": {
+                "types": [
+                            {"i64": {}}
+                        ]
+                },
+                "names": [
+                        "x"
+                        ]
+            },
+            "namedTable": {
+                    "names": ["t1"]
+            }
+            }
+        }}
+        ]
+    }
+    """
+
+    buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
+    reader = pa.substrait.run_query(buf, table_provider)
+    res_tb = reader.read_all()
+    assert res_tb == test_table_1
+
+
+def test_named_table_invalid_table_name():
+    test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+    def table_provider(names):
+        if not names:
+            raise Exception("No names provided")
+        elif names[0] == "t1":
+            return test_table_1
+        else:
+            raise Exception("Unrecognized table name")
+
+    substrait_query = """
+    {
+        "relations": [
+        {"rel": {
+            "read": {
+            "base_schema": {
+                "struct": {
+                "types": [
+                            {"i64": {}}
+                        ]
+                },
+                "names": [
+                        "x"
+                        ]
+            },
+            "namedTable": {
+                    "names": ["t3"]
+            }
+            }
+        }}
+        ]
+    }
+    """
+
+    buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
+    exec_message = "Invalid NamedTable Source"
+    with pytest.raises(ArrowInvalid, match=exec_message):
+        substrait.run_query(buf, table_provider)
+
+
+def test_named_table_empty_names():
+    test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+    def table_provider(names):
+        if not names:
+            raise Exception("No names provided")
+        elif names[0] == "t1":
+            return test_table_1
+        else:
+            raise Exception("Unrecognized table name")
+
+    substrait_query = """
+    {
+        "relations": [
+        {"rel": {
+            "read": {
+            "base_schema": {
+                "struct": {
+                "types": [
+                            {"i64": {}}
+                        ]
+                },
+                "names": [
+                        "x"
+                        ]
+            },
+            "namedTable": {
+                    "names": []
+            }
+            }
+        }}
+        ]
+    }
+    """
+    query = tobytes(substrait_query)
+    buf = pa._substrait._parse_json_plan(tobytes(query))
+    exec_message = "names for NamedTable not provided"
+    with pytest.raises(ArrowInvalid, match=exec_message):
+        substrait.run_query(buf, table_provider)

Reply via email to