This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 72b539f542 ARROW-17521: [Python] Add python bindings for
NamedTableProvider for Substrait consumer (#14024)
72b539f542 is described below
commit 72b539f54233f6610b01ec7381755a84c652d151
Author: Vibhatha Lakmal Abeykoon <[email protected]>
AuthorDate: Thu Sep 15 01:15:58 2022 +0530
ARROW-17521: [Python] Add python bindings for NamedTableProvider for
Substrait consumer (#14024)
This PR includes a basic version to use NamedTable feature in Substrait.
The idea is to provide the flexibility to write Python tests with in-memory
PyArrow tables.
Authored-by: Vibhatha Abeykoon <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
cpp/src/arrow/compute/exec/exec_plan.cc | 4 +
cpp/src/arrow/compute/exec/exec_plan.h | 5 +
cpp/src/arrow/engine/substrait/function_test.cc | 4 +-
.../arrow/engine/substrait/relation_internal.cc | 6 +
cpp/src/arrow/engine/substrait/serde_test.cc | 13 +--
cpp/src/arrow/engine/substrait/util.cc | 23 ++--
cpp/src/arrow/engine/substrait/util.h | 8 +-
python/pyarrow/_exec_plan.pyx | 2 +-
python/pyarrow/_substrait.pyx | 97 +++++++++++++++-
python/pyarrow/includes/libarrow.pxd | 1 +
python/pyarrow/includes/libarrow_substrait.pxd | 28 ++++-
python/pyarrow/tests/test_substrait.py | 126 +++++++++++++++++++++
12 files changed, 289 insertions(+), 28 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc
b/cpp/src/arrow/compute/exec/exec_plan.cc
index b6a3916de1..00415495aa 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.cc
+++ b/cpp/src/arrow/compute/exec/exec_plan.cc
@@ -643,6 +643,10 @@ Declaration Declaration::Sequence(std::vector<Declaration>
decls) {
return out;
}
+bool Declaration::IsValid(ExecFactoryRegistry* registry) const {
+ return !this->factory_name.empty() && this->options != nullptr;
+}
+
namespace internal {
void RegisterSourceNode(ExecFactoryRegistry*);
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h
b/cpp/src/arrow/compute/exec/exec_plan.h
index 263f3634a5..a9481e21a6 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -451,6 +451,8 @@ inline Result<ExecNode*> MakeExecNode(
struct ARROW_EXPORT Declaration {
using Input = util::Variant<ExecNode*, Declaration>;
+ Declaration() {}
+
Declaration(std::string factory_name, std::vector<Input> inputs,
std::shared_ptr<ExecNodeOptions> options, std::string label)
: factory_name{std::move(factory_name)},
@@ -514,6 +516,9 @@ struct ARROW_EXPORT Declaration {
Result<ExecNode*> AddToPlan(ExecPlan* plan, ExecFactoryRegistry* registry =
default_exec_factory_registry()) const;
+ // Validate a declaration
+ bool IsValid(ExecFactoryRegistry* registry =
default_exec_factory_registry()) const;
+
std::string factory_name;
std::vector<Input> inputs;
std::shared_ptr<ExecNodeOptions> options;
diff --git a/cpp/src/arrow/engine/substrait/function_test.cc
b/cpp/src/arrow/engine/substrait/function_test.cc
index 0bcb475d31..3465f00e13 100644
--- a/cpp/src/arrow/engine/substrait/function_test.cc
+++ b/cpp/src/arrow/engine/substrait/function_test.cc
@@ -132,8 +132,8 @@ void CheckValidTestCases(const
std::vector<FunctionTestCase>& valid_cases) {
ASSERT_FINISHES_OK(plan->finished());
// Could also modify the Substrait plan with an emit to drop the leading
columns
- ASSERT_OK_AND_ASSIGN(output_table,
-
output_table->SelectColumns({output_table->num_columns() - 1}));
+ int result_column = output_table->num_columns() - 1; // last column holds
result
+ ASSERT_OK_AND_ASSIGN(output_table,
output_table->SelectColumns({result_column}));
ASSERT_OK_AND_ASSIGN(
std::shared_ptr<Table> expected_output,
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 4213895b61..3911373b7b 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -135,8 +135,14 @@ Result<DeclarationInfo> FromProto(const substrait::Rel&
rel, const ExtensionSet&
const substrait::ReadRel::NamedTable& named_table = read.named_table();
std::vector<std::string> table_names(named_table.names().begin(),
named_table.names().end());
+ if (table_names.empty()) {
+ return Status::Invalid("names for NamedTable not provided");
+ }
ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl,
named_table_provider(table_names));
+ if (!source_decl.IsValid()) {
+ return Status::Invalid("Invalid NamedTable Source");
+ }
return ProcessEmit(std::move(read),
DeclarationInfo{std::move(source_decl),
base_schema},
std::move(base_schema));
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc
b/cpp/src/arrow/engine/substrait/serde_test.cc
index 251c2bfe35..b50e1c6084 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -1924,7 +1924,6 @@ TEST(Substrait, BasicPlanRoundTripping) {
ASSERT_OK_AND_ASSIGN(auto tempdir,
arrow::internal::TemporaryDir::Make("substrait-tempdir-"));
- std::cout << "file_path_str " << tempdir->path().ToString() << std::endl;
ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name));
std::string file_path_str = file_path.ToString();
@@ -2189,7 +2188,7 @@ TEST(Substrait, ProjectRel) {
}
},
"namedTable": {
- "names": []
+ "names": ["A"]
}
}
}
@@ -2313,7 +2312,7 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) {
}
},
"namedTable": {
- "names": []
+ "names": ["A"]
}
}
}
@@ -2396,7 +2395,7 @@ TEST(Substrait, ReadRelWithEmit) {
}
},
"namedTable": {
- "names" : []
+ "names" : ["A"]
}
}
}
@@ -2501,7 +2500,7 @@ TEST(Substrait, FilterRelWithEmit) {
}
},
"namedTable": {
- "names" : []
+ "names" : ["A"]
}
}
}
@@ -2885,7 +2884,7 @@ TEST(Substrait, AggregateRel) {
}
},
"namedTable" : {
- "names": []
+ "names": ["A"]
}
}
},
@@ -3004,7 +3003,7 @@ TEST(Substrait, AggregateRelEmit) {
}
},
"namedTable" : {
- "names" : []
+ "names" : ["A"]
}
}
},
diff --git a/cpp/src/arrow/engine/substrait/util.cc
b/cpp/src/arrow/engine/substrait/util.cc
index 936bde5c65..f51666ef85 100644
--- a/cpp/src/arrow/engine/substrait/util.cc
+++ b/cpp/src/arrow/engine/substrait/util.cc
@@ -63,8 +63,12 @@ class SubstraitSinkConsumer : public
compute::SinkNodeConsumer {
class SubstraitExecutor {
public:
explicit SubstraitExecutor(std::shared_ptr<compute::ExecPlan> plan,
- compute::ExecContext exec_context)
- : plan_(std::move(plan)), plan_started_(false),
exec_context_(exec_context) {}
+ compute::ExecContext exec_context,
+ const ConversionOptions& conversion_options = {})
+ : plan_(std::move(plan)),
+ plan_started_(false),
+ exec_context_(exec_context),
+ conversion_options_(conversion_options) {}
~SubstraitExecutor() { ARROW_UNUSED(this->Close()); }
@@ -95,8 +99,8 @@ class SubstraitExecutor {
return sink_consumer_;
};
ARROW_ASSIGN_OR_RAISE(
- declarations_,
- engine::DeserializePlans(substrait_buffer, consumer_factory,
registry));
+ declarations_, engine::DeserializePlans(substrait_buffer,
consumer_factory,
+ registry, nullptr,
conversion_options_));
return Status::OK();
}
@@ -107,19 +111,20 @@ class SubstraitExecutor {
bool plan_started_;
compute::ExecContext exec_context_;
std::shared_ptr<SubstraitSinkConsumer> sink_consumer_;
+ const ConversionOptions& conversion_options_;
};
} // namespace
Result<std::shared_ptr<RecordBatchReader>> ExecuteSerializedPlan(
- const Buffer& substrait_buffer, const ExtensionIdRegistry* extid_registry,
- compute::FunctionRegistry* func_registry) {
- // TODO(ARROW-15732)
+ const Buffer& substrait_buffer, const ExtensionIdRegistry* registry,
+ compute::FunctionRegistry* func_registry,
+ const ConversionOptions& conversion_options) {
compute::ExecContext exec_context(arrow::default_memory_pool(),
::arrow::internal::GetCpuThreadPool(),
func_registry);
ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context));
- SubstraitExecutor executor(std::move(plan), exec_context);
- RETURN_NOT_OK(executor.Init(substrait_buffer, extid_registry));
+ SubstraitExecutor executor(std::move(plan), exec_context,
conversion_options);
+ RETURN_NOT_OK(executor.Init(substrait_buffer, registry));
ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute());
// check closing here, not in destructor, to expose error to caller
RETURN_NOT_OK(executor.Close());
diff --git a/cpp/src/arrow/engine/substrait/util.h
b/cpp/src/arrow/engine/substrait/util.h
index 3ac9320e1d..90cb4e3dd2 100644
--- a/cpp/src/arrow/engine/substrait/util.h
+++ b/cpp/src/arrow/engine/substrait/util.h
@@ -20,6 +20,7 @@
#include <memory>
#include "arrow/compute/registry.h"
#include "arrow/engine/substrait/api.h"
+#include "arrow/engine/substrait/options.h"
#include "arrow/util/iterator.h"
#include "arrow/util/optional.h"
@@ -27,10 +28,13 @@ namespace arrow {
namespace engine {
-/// \brief Retrieve a RecordBatchReader from a Substrait plan.
+using PythonTableProvider =
+ std::function<Result<std::shared_ptr<Table>>(const
std::vector<std::string>&)>;
+
ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>>
ExecuteSerializedPlan(
const Buffer& substrait_buffer, const ExtensionIdRegistry* registry =
NULLPTR,
- compute::FunctionRegistry* func_registry = NULLPTR);
+ compute::FunctionRegistry* func_registry = NULLPTR,
+ const ConversionOptions& conversion_options = {});
/// \brief Get a Serialized Plan from a Substrait JSON plan.
/// This is a helper method for Python tests.
diff --git a/python/pyarrow/_exec_plan.pyx b/python/pyarrow/_exec_plan.pyx
index 89e474f439..9506caf7d2 100644
--- a/python/pyarrow/_exec_plan.pyx
+++ b/python/pyarrow/_exec_plan.pyx
@@ -92,7 +92,7 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan,
c_bool use_threads
node_factory = "table_source"
c_in_table = pyarrow_unwrap_table(ipt)
c_tablesourceopts = make_shared[CTableSourceNodeOptions](
- c_in_table, 1 << 20)
+ c_in_table)
c_input_node_opts = static_pointer_cast[CExecNodeOptions,
CTableSourceNodeOptions](
c_tablesourceopts)
elif isinstance(ipt, Dataset):
diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx
index 05794a95a2..47a519cf16 100644
--- a/python/pyarrow/_substrait.pyx
+++ b/python/pyarrow/_substrait.pyx
@@ -17,15 +17,38 @@
# cython: language_level = 3
from cython.operator cimport dereference as deref
+from libcpp.vector cimport vector as std_vector
from pyarrow import Buffer
-from pyarrow.lib import frombytes
+from pyarrow.lib import frombytes, tobytes
from pyarrow.lib cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_substrait cimport *
-def run_query(plan):
+cdef CDeclaration _create_named_table_provider(dict named_args, const
std_vector[c_string]& names):
+ cdef:
+ c_string c_name
+ shared_ptr[CTable] c_in_table
+ shared_ptr[CTableSourceNodeOptions] c_tablesourceopts
+ shared_ptr[CExecNodeOptions] c_input_node_opts
+ vector[CDeclaration.Input] no_c_inputs
+
+ py_names = []
+ for i in range(names.size()):
+ c_name = names[i]
+ py_names.append(frombytes(c_name))
+
+ py_table = named_args["provider"](py_names)
+ c_in_table = pyarrow_unwrap_table(py_table)
+ c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table)
+ c_input_node_opts = static_pointer_cast[CExecNodeOptions,
CTableSourceNodeOptions](
+ c_tablesourceopts)
+ return CDeclaration(tobytes("table_source"),
+ no_c_inputs, c_input_node_opts)
+
+
+def run_query(plan, table_provider=None):
"""
Execute a Substrait plan and read the results as a RecordBatchReader.
@@ -33,6 +56,63 @@ def run_query(plan):
----------
plan : Buffer
The serialized Substrait plan to execute.
+ table_provider : object (optional)
+ A function to resolve any NamedTable relation to a table.
+ The function will receive a single argument which will be a list
+ of strings representing the table name and should return a
pyarrow.Table.
+
+ Returns
+ -------
+ RecordBatchReader
+ A reader containing the result of the executed query
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> from pyarrow.lib import tobytes
+ >>> import pyarrow.substrait as substrait
+ >>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+ >>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]})
+ >>> def table_provider(names):
+ ... if not names:
+ ... raise Exception("No names provided")
+ ... elif names[0] == "t1":
+ ... return test_table_1
+ ... elif names[1] == "t2":
+ ... return test_table_2
+ ... else:
+ ... raise Exception("Unrecognized table name")
+ ...
+ >>> substrait_query = '''
+ ... {
+ ... "relations": [
+ ... {"rel": {
+ ... "read": {
+ ... "base_schema": {
+ ... "struct": {
+ ... "types": [
+ ... {"i64": {}}
+ ... ]
+ ... },
+ ... "names": [
+ ... "x"
+ ... ]
+ ... },
+ ... "namedTable": {
+ ... "names": ["t1"]
+ ... }
+ ... }
+ ... }}
+ ... ]
+ ... }
+ ... '''
+ >>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
+ >>> reader = pa.substrait.run_query(buf, table_provider)
+ >>> reader.read_all()
+ pyarrow.Table
+ x: int64
+ ----
+ x: [[1,2,3]]
"""
cdef:
@@ -41,10 +121,21 @@ def run_query(plan):
RecordBatchReader reader
c_string c_str_plan
shared_ptr[CBuffer] c_buf_plan
+ function[CNamedTableProvider] c_named_table_provider
+ CConversionOptions c_conversion_options
c_buf_plan = pyarrow_unwrap_buffer(plan)
+
+ if table_provider is not None:
+ named_table_args = {
+ "provider": table_provider
+ }
+ c_conversion_options.named_table_provider =
BindFunction[CNamedTableProvider](
+ &_create_named_table_provider, named_table_args)
+
with nogil:
- c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan))
+ c_res_reader = ExecuteSerializedPlan(
+ deref(c_buf_plan), default_extension_id_registry(),
GetFunctionRegistry(), c_conversion_options)
c_reader = GetResultValue(c_res_reader)
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index be273975f9..489d73bf27 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2574,6 +2574,7 @@ cdef extern from "arrow/compute/exec/exec_plan.h"
namespace "arrow::compute" nog
c_string label
vector[Input] inputs
+ CDeclaration()
CDeclaration(c_string factory_name, CExecNodeOptions options)
CDeclaration(c_string factory_name, vector[Input] inputs,
shared_ptr[CExecNodeOptions] options)
diff --git a/python/pyarrow/includes/libarrow_substrait.pxd
b/python/pyarrow/includes/libarrow_substrait.pxd
index 0b3ace75d9..04990380d9 100644
--- a/python/pyarrow/includes/libarrow_substrait.pxd
+++ b/python/pyarrow/includes/libarrow_substrait.pxd
@@ -22,10 +22,22 @@ from libcpp.vector cimport vector as std_vector
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
-
-cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine"
nogil:
- CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const
CBuffer& substrait_buffer)
- CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string&
substrait_json)
+ctypedef CResult[CDeclaration] CNamedTableProvider(const std_vector[c_string]&)
+
+cdef extern from "arrow/engine/substrait/options.h" namespace "arrow::engine"
nogil:
+ cdef enum ConversionStrictness \
+ "arrow::engine::ConversionStrictness":
+ EXACT_ROUNDTRIP \
+ "arrow::engine::ConversionStrictness::EXACT_ROUNDTRIP"
+ PRESERVE_STRUCTURE \
+ "arrow::engine::ConversionStrictness::PRESERVE_STRUCTURE"
+ BEST_EFFORT \
+ "arrow::engine::ConversionStrictness::BEST_EFFORT"
+
+ cdef cppclass CConversionOptions \
+ "arrow::engine::ConversionOptions":
+ ConversionStrictness conversion_strictness
+ function[CNamedTableProvider] named_table_provider
cdef extern from "arrow/engine/substrait/extension_set.h" \
namespace "arrow::engine" nogil:
@@ -34,3 +46,11 @@ cdef extern from "arrow/engine/substrait/extension_set.h" \
std_vector[c_string] GetSupportedSubstraitFunctions()
ExtensionIdRegistry* default_extension_id_registry()
+
+
+cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine"
nogil:
+ CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(
+ const CBuffer& substrait_buffer, const ExtensionIdRegistry* registry,
+ CFunctionRegistry* func_registry, const CConversionOptions&
conversion_options)
+
+ CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string&
substrait_json)
diff --git a/python/pyarrow/tests/test_substrait.py
b/python/pyarrow/tests/test_substrait.py
index c8fa6afcb9..c8fd8048aa 100644
--- a/python/pyarrow/tests/test_substrait.py
+++ b/python/pyarrow/tests/test_substrait.py
@@ -165,3 +165,129 @@ def test_get_supported_functions():
'functions_arithmetic.yaml', 'add')
assert has_function(supported_functions,
'functions_arithmetic.yaml', 'sum')
+
+
+def test_named_table():
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+ test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]})
+
+ def table_provider(names):
+ if not names:
+ raise Exception("No names provided")
+ elif names[0] == "t1":
+ return test_table_1
+ elif names[1] == "t2":
+ return test_table_2
+ else:
+ raise Exception("Unrecognized table name")
+
+ substrait_query = """
+ {
+ "relations": [
+ {"rel": {
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [
+ {"i64": {}}
+ ]
+ },
+ "names": [
+ "x"
+ ]
+ },
+ "namedTable": {
+ "names": ["t1"]
+ }
+ }
+ }}
+ ]
+ }
+ """
+
+ buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
+ reader = pa.substrait.run_query(buf, table_provider)
+ res_tb = reader.read_all()
+ assert res_tb == test_table_1
+
+
+def test_named_table_invalid_table_name():
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names):
+ if not names:
+ raise Exception("No names provided")
+ elif names[0] == "t1":
+ return test_table_1
+ else:
+ raise Exception("Unrecognized table name")
+
+ substrait_query = """
+ {
+ "relations": [
+ {"rel": {
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [
+ {"i64": {}}
+ ]
+ },
+ "names": [
+ "x"
+ ]
+ },
+ "namedTable": {
+ "names": ["t3"]
+ }
+ }
+ }}
+ ]
+ }
+ """
+
+ buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
+ exec_message = "Invalid NamedTable Source"
+ with pytest.raises(ArrowInvalid, match=exec_message):
+ substrait.run_query(buf, table_provider)
+
+
+def test_named_table_empty_names():
+ test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names):
+ if not names:
+ raise Exception("No names provided")
+ elif names[0] == "t1":
+ return test_table_1
+ else:
+ raise Exception("Unrecognized table name")
+
+ substrait_query = """
+ {
+ "relations": [
+ {"rel": {
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [
+ {"i64": {}}
+ ]
+ },
+ "names": [
+ "x"
+ ]
+ },
+ "namedTable": {
+ "names": []
+ }
+ }
+ }}
+ ]
+ }
+ """
+ query = tobytes(substrait_query)
+ buf = pa._substrait._parse_json_plan(tobytes(query))
+ exec_message = "names for NamedTable not provided"
+ with pytest.raises(ArrowInvalid, match=exec_message):
+ substrait.run_query(buf, table_provider)