westonpace commented on code in PR #13401:
URL: https://github.com/apache/arrow/pull/13401#discussion_r907299931
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -316,5 +323,97 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
rel.DebugString());
}
+namespace {
+// TODO: add other types
+enum ArrowRelationType : uint8_t {
+ SCAN,
+ FILTER,
+ PROJECT,
+ JOIN,
+ AGGREGATE,
+};
+
+const std::map<std::string, ArrowRelationType> enum_map{
+ {"scan", ArrowRelationType::SCAN}, {"filter",
ArrowRelationType::FILTER},
+ {"project", ArrowRelationType::PROJECT}, {"join",
ArrowRelationType::JOIN},
+ {"aggregate", ArrowRelationType::AGGREGATE},
+};
+
+struct ExtractRelation {
+ explicit ExtractRelation(substrait::Rel* rel, ExtensionSet* ext_set)
+ : rel_(rel), ext_set_(ext_set) {}
+
+ Status AddRelation(const compute::Declaration& declaration) {
+ const std::string& rel_name = declaration.factory_name;
+ switch (enum_map.find(rel_name)->second) {
+ case ArrowRelationType::SCAN:
+ return AddReadRelation(declaration);
+ case ArrowRelationType::FILTER:
+ return Status::NotImplemented("Filter operator not supported.");
+ case ArrowRelationType::PROJECT:
+ return Status::NotImplemented("Project operator not supported.");
+ case ArrowRelationType::JOIN:
+ return Status::NotImplemented("Join operator not supported.");
+ case ArrowRelationType::AGGREGATE:
+ return Status::NotImplemented("Aggregate operator not supported.");
+ default:
+ return Status::Invalid("Unsupported factory name :", rel_name);
+ }
+ }
Review Comment:
I'm not sure that introducing an enum, just so that we can switch on it, is
much cleaner than comparing the strings directly. For example:
```
if (rel_name == "scan") {
return AddReadRelation(declaration);
} else if (rel_name == "filter") {
return Status::NotImplemented("Filter operator not supported.");
}
...
} else {
return Status::Invalid("Unsupported factory name :", rel_name);
}
```
Do you intend to use this enum elsewhere?
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -316,5 +323,97 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
rel.DebugString());
}
+namespace {
+// TODO: add other types
+enum ArrowRelationType : uint8_t {
+ SCAN,
+ FILTER,
+ PROJECT,
+ JOIN,
+ AGGREGATE,
+};
+
+const std::map<std::string, ArrowRelationType> enum_map{
+ {"scan", ArrowRelationType::SCAN}, {"filter",
ArrowRelationType::FILTER},
+ {"project", ArrowRelationType::PROJECT}, {"join",
ArrowRelationType::JOIN},
+ {"aggregate", ArrowRelationType::AGGREGATE},
+};
+
+struct ExtractRelation {
+ explicit ExtractRelation(substrait::Rel* rel, ExtensionSet* ext_set)
+ : rel_(rel), ext_set_(ext_set) {}
+
+ Status AddRelation(const compute::Declaration& declaration) {
+ const std::string& rel_name = declaration.factory_name;
+ switch (enum_map.find(rel_name)->second) {
+ case ArrowRelationType::SCAN:
+ return AddReadRelation(declaration);
+ case ArrowRelationType::FILTER:
+ return Status::NotImplemented("Filter operator not supported.");
+ case ArrowRelationType::PROJECT:
+ return Status::NotImplemented("Project operator not supported.");
+ case ArrowRelationType::JOIN:
+ return Status::NotImplemented("Join operator not supported.");
+ case ArrowRelationType::AGGREGATE:
+ return Status::NotImplemented("Aggregate operator not supported.");
+ default:
+ return Status::Invalid("Unsupported factory name :", rel_name);
+ }
+ }
+
+ Status AddReadRelation(const compute::Declaration& declaration) {
+ auto read_rel = internal::make_unique<substrait::ReadRel>();
+ const auto& scan_node_options =
+ internal::checked_cast<const
dataset::ScanNodeOptions&>(*declaration.options);
+
+ const auto& fds = internal::checked_cast<const
dataset::FileSystemDataset&>(
+ *scan_node_options.dataset);
Review Comment:
We don't know for certain that this cast will succeed (unlike the above cast
to ScanNodeOptions). A `checked_cast` will abort if it fails.
Instead we should do a `dynamic_cast` to `dataset::FileSystemDataset*`.
It's slightly slower (`dynamic_cast` is slightly slower) but it will return
`nullptr` if the cast fails (in which case we can return an invalid status "Can
only convert file system datasets to a Substrait plan")
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -316,5 +323,97 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
rel.DebugString());
}
+namespace {
+// TODO: add other types
+enum ArrowRelationType : uint8_t {
+ SCAN,
+ FILTER,
+ PROJECT,
+ JOIN,
+ AGGREGATE,
+};
+
+const std::map<std::string, ArrowRelationType> enum_map{
+ {"scan", ArrowRelationType::SCAN}, {"filter",
ArrowRelationType::FILTER},
+ {"project", ArrowRelationType::PROJECT}, {"join",
ArrowRelationType::JOIN},
+ {"aggregate", ArrowRelationType::AGGREGATE},
+};
+
+struct ExtractRelation {
+ explicit ExtractRelation(substrait::Rel* rel, ExtensionSet* ext_set)
+ : rel_(rel), ext_set_(ext_set) {}
+
+ Status AddRelation(const compute::Declaration& declaration) {
+ const std::string& rel_name = declaration.factory_name;
+ switch (enum_map.find(rel_name)->second) {
+ case ArrowRelationType::SCAN:
+ return AddReadRelation(declaration);
+ case ArrowRelationType::FILTER:
+ return Status::NotImplemented("Filter operator not supported.");
+ case ArrowRelationType::PROJECT:
+ return Status::NotImplemented("Project operator not supported.");
+ case ArrowRelationType::JOIN:
+ return Status::NotImplemented("Join operator not supported.");
+ case ArrowRelationType::AGGREGATE:
+ return Status::NotImplemented("Aggregate operator not supported.");
+ default:
+ return Status::Invalid("Unsupported factory name :", rel_name);
Review Comment:
```suggestion
return Status::Invalid("Unsupported exec node factory name :",
rel_name);
```
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -316,5 +323,97 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
rel.DebugString());
}
+namespace {
+// TODO: add other types
+enum ArrowRelationType : uint8_t {
+ SCAN,
+ FILTER,
+ PROJECT,
+ JOIN,
+ AGGREGATE,
+};
+
+const std::map<std::string, ArrowRelationType> enum_map{
+ {"scan", ArrowRelationType::SCAN}, {"filter",
ArrowRelationType::FILTER},
+ {"project", ArrowRelationType::PROJECT}, {"join",
ArrowRelationType::JOIN},
+ {"aggregate", ArrowRelationType::AGGREGATE},
+};
+
+struct ExtractRelation {
+ explicit ExtractRelation(substrait::Rel* rel, ExtensionSet* ext_set)
+ : rel_(rel), ext_set_(ext_set) {}
+
+ Status AddRelation(const compute::Declaration& declaration) {
+ const std::string& rel_name = declaration.factory_name;
+ switch (enum_map.find(rel_name)->second) {
+ case ArrowRelationType::SCAN:
+ return AddReadRelation(declaration);
+ case ArrowRelationType::FILTER:
+ return Status::NotImplemented("Filter operator not supported.");
+ case ArrowRelationType::PROJECT:
+ return Status::NotImplemented("Project operator not supported.");
+ case ArrowRelationType::JOIN:
+ return Status::NotImplemented("Join operator not supported.");
+ case ArrowRelationType::AGGREGATE:
+ return Status::NotImplemented("Aggregate operator not supported.");
+ default:
+ return Status::Invalid("Unsupported factory name :", rel_name);
+ }
+ }
+
+ Status AddReadRelation(const compute::Declaration& declaration) {
+ auto read_rel = internal::make_unique<substrait::ReadRel>();
+ const auto& scan_node_options =
+ internal::checked_cast<const
dataset::ScanNodeOptions&>(*declaration.options);
+
+ const auto& fds = internal::checked_cast<const
dataset::FileSystemDataset&>(
+ *scan_node_options.dataset);
+
+ // set schema
+ ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*fds.schema(), ext_set_));
+ read_rel->set_allocated_base_schema(named_struct.release());
+
+ // set local files
+ auto read_rel_lfs = internal::make_unique<substrait::ReadRel_LocalFiles>();
+ for (const auto& file : fds.files()) {
+ auto read_rel_lfs_ffs =
+ internal::make_unique<substrait::ReadRel_LocalFiles_FileOrFiles>();
+ read_rel_lfs_ffs->set_uri_path("file://" + file);
+
+ // set file format
+ auto format_type_name = fds.format()->type_name();
+ if (format_type_name == "parquet" || format_type_name == "arrow" ||
+ format_type_name == "feather") {
Review Comment:
We should add a comment here that `arrow` and `feather` are temporarily
handled via the Parquet format until we upgrade to the latest Substrait
version. Otherwise this is maybe a little confusing to a future reader.
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1173,5 +1175,173 @@ TEST(Substrait, JoinPlanInvalidKeys) {
&ext_set));
}
+TEST(Substrait, SerializeRelation) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#else
+ ExtensionSet ext_set;
+ auto dummy_schema = schema({field("foo", binary())});
+ // creating a dummy dataset using a dummy table
+ auto format = std::make_shared<arrow::dataset::ParquetFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
Review Comment:
I think we can use a mock filesystem here instead of relying on
PARQUET_TEST_DATA. We aren't actually reading the file so it doesn't need any
contents.
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -316,5 +323,97 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
rel.DebugString());
}
+namespace {
+// TODO: add other types
+enum ArrowRelationType : uint8_t {
+ SCAN,
+ FILTER,
+ PROJECT,
+ JOIN,
+ AGGREGATE,
+};
+
+const std::map<std::string, ArrowRelationType> enum_map{
+ {"scan", ArrowRelationType::SCAN}, {"filter",
ArrowRelationType::FILTER},
+ {"project", ArrowRelationType::PROJECT}, {"join",
ArrowRelationType::JOIN},
+ {"aggregate", ArrowRelationType::AGGREGATE},
+};
+
+struct ExtractRelation {
+ explicit ExtractRelation(substrait::Rel* rel, ExtensionSet* ext_set)
+ : rel_(rel), ext_set_(ext_set) {}
+
+ Status AddRelation(const compute::Declaration& declaration) {
+ const std::string& rel_name = declaration.factory_name;
+ switch (enum_map.find(rel_name)->second) {
+ case ArrowRelationType::SCAN:
+ return AddReadRelation(declaration);
+ case ArrowRelationType::FILTER:
+ return Status::NotImplemented("Filter operator not supported.");
+ case ArrowRelationType::PROJECT:
+ return Status::NotImplemented("Project operator not supported.");
+ case ArrowRelationType::JOIN:
+ return Status::NotImplemented("Join operator not supported.");
+ case ArrowRelationType::AGGREGATE:
+ return Status::NotImplemented("Aggregate operator not supported.");
+ default:
+ return Status::Invalid("Unsupported factory name :", rel_name);
+ }
+ }
+
+ Status AddReadRelation(const compute::Declaration& declaration) {
+ auto read_rel = internal::make_unique<substrait::ReadRel>();
+ const auto& scan_node_options =
+ internal::checked_cast<const
dataset::ScanNodeOptions&>(*declaration.options);
+
+ const auto& fds = internal::checked_cast<const
dataset::FileSystemDataset&>(
+ *scan_node_options.dataset);
+
+ // set schema
+ ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*fds.schema(), ext_set_));
+ read_rel->set_allocated_base_schema(named_struct.release());
+
+ // set local files
+ auto read_rel_lfs = internal::make_unique<substrait::ReadRel_LocalFiles>();
+ for (const auto& file : fds.files()) {
+ auto read_rel_lfs_ffs =
+ internal::make_unique<substrait::ReadRel_LocalFiles_FileOrFiles>();
+ read_rel_lfs_ffs->set_uri_path("file://" + file);
+
+ // set file format
+ auto format_type_name = fds.format()->type_name();
+ if (format_type_name == "parquet" || format_type_name == "arrow" ||
+ format_type_name == "feather") {
+ read_rel_lfs_ffs->set_format(
+ substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET);
+ } else {
+ return Status::Invalid("Unsupported file type : ", format_type_name);
+ }
+ read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release());
+ }
+ *read_rel->mutable_local_files() = *read_rel_lfs.get();
+
+ rel_->set_allocated_read(read_rel.release());
Review Comment:
It looks like we aren't handling any kind of pushdown projection or
filtering. Both of those would probably be a bit easier to do if I ever got
around to finishing ARROW-16072. Can you add a follow-up JIRA?
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1173,5 +1175,168 @@ TEST(Substrait, JoinPlanInvalidKeys) {
&ext_set));
}
+TEST(Substrait, SerializeRelation) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#else
+ ExtensionSet ext_set;
+ auto dummy_schema = schema({field("f1", int32()), field("f2", int32())});
+ // creating a dummy dataset using a dummy table
+ auto format = std::make_shared<arrow::dataset::ParquetFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
+
+ std::vector<fs::FileInfo> files;
+ const std::vector<std::string> f_paths = {"/tmp/data1.parquet",
"/tmp/data2.parquet"};
+
+ for (const auto& f_path : f_paths) {
+ ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path));
+ files.push_back(std::move(f_file));
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto ds_factory,
dataset::FileSystemDatasetFactory::Make(
+ std::move(filesystem),
std::move(files),
+ std::move(format), {}));
+ ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema));
+
+ auto options = std::make_shared<dataset::ScanOptions>();
+ options->projection = compute::project({}, {});
+ auto scan_node_options = dataset::ScanNodeOptions{dataset, options};
+
+ auto scan_declaration = compute::Declaration({"scan", scan_node_options});
+
+ ASSERT_OK_AND_ASSIGN(auto serialized_rel,
+ SerializeRelation(scan_declaration, &ext_set));
+ ASSERT_OK_AND_ASSIGN(auto deserialized_decl,
+ DeserializeRelation(*serialized_rel, ext_set));
+
+ auto dataset_comparator = [](std::shared_ptr<dataset::Dataset> ds_lhs,
+ std::shared_ptr<dataset::Dataset> ds_rhs) ->
bool {
+ const auto& fds_lhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
+ const auto& fds_rhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
+ const auto& files_lhs = fds_lhs.files();
+ const auto& files_rhs = fds_rhs.files();
+
+ bool cmp_fsize = files_lhs.size() == files_rhs.size();
+ uint64_t fidx = 0;
+ for (const auto& l_file : files_lhs) {
+ if (l_file != files_rhs[fidx]) {
+ return false;
+ }
+ fidx++;
+ }
+ bool cmp_file_format = fds_lhs.format()->Equals(*fds_lhs.format());
+ bool cmp_file_system = fds_lhs.filesystem()->Equals(fds_rhs.filesystem());
+ return cmp_fsize && cmp_file_format && cmp_file_system;
+ };
+
+ auto scan_option_comparator = [dataset_comparator](
+ const dataset::ScanNodeOptions& lhs,
+ const dataset::ScanNodeOptions& rhs) ->
bool {
+ bool cmp_rso = lhs.require_sequenced_output ==
rhs.require_sequenced_output;
+ bool cmp_ds = dataset_comparator(lhs.dataset, rhs.dataset);
+ return cmp_rso && cmp_ds;
+ };
Review Comment:
It makes sense for these things to have an `Equals` method. It could even
be useful to the user. However, I don't see any reason we need to add it now.
What you have here is fine also. If we end up doing the same comparison
elsewhere we can also turn the lambdas into methods like `AssertDatasetEquals`
inside of a `test_util` file. So let's just keep this as simple lambdas for
the moment.
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1173,5 +1175,173 @@ TEST(Substrait, JoinPlanInvalidKeys) {
&ext_set));
}
+TEST(Substrait, SerializeRelation) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#else
+ ExtensionSet ext_set;
+ auto dummy_schema = schema({field("foo", binary())});
+ // creating a dummy dataset using a dummy table
+ auto format = std::make_shared<arrow::dataset::ParquetFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
+
+ ASSERT_OK_AND_ASSIGN(std::string dir_string,
+ arrow::internal::GetEnvVar("PARQUET_TEST_DATA"));
+ auto file_name =
+
arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet");
+
+ std::vector<fs::FileInfo> files;
+ const std::vector<std::string> f_paths = {file_name->ToString()};
+
+ for (const auto& f_path : f_paths) {
+ ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path));
+ files.push_back(std::move(f_file));
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto ds_factory,
dataset::FileSystemDatasetFactory::Make(
+ std::move(filesystem),
std::move(files),
+ std::move(format), {}));
+ ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema));
+
+ auto options = std::make_shared<dataset::ScanOptions>();
+ options->projection = compute::project({}, {});
+ auto scan_node_options = dataset::ScanNodeOptions{dataset, options};
+
+ auto scan_declaration = compute::Declaration({"scan", scan_node_options});
+
+ ASSERT_OK_AND_ASSIGN(auto serialized_rel,
+ SerializeRelation(scan_declaration, &ext_set));
+ ASSERT_OK_AND_ASSIGN(auto deserialized_decl,
+ DeserializeRelation(*serialized_rel, ext_set));
+
+ auto dataset_comparator = [](std::shared_ptr<dataset::Dataset> ds_lhs,
+ std::shared_ptr<dataset::Dataset> ds_rhs) ->
bool {
+ const auto& fds_lhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
+ const auto& fds_rhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
+ const auto& files_lhs = fds_lhs.files();
+ const auto& files_rhs = fds_rhs.files();
+
+ bool cmp_fsize = files_lhs.size() == files_rhs.size();
Review Comment:
```suggestion
if (files_lhs.size() != files_rhs.size()) {
return false;
}
```
If `files_lhs.size() > files_rhs.size()` then you could get a segmentation
fault below at `files_rhs[fidx]`. Best to just bail early if the # of files is
unequal.
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1173,5 +1175,173 @@ TEST(Substrait, JoinPlanInvalidKeys) {
&ext_set));
}
+TEST(Substrait, SerializeRelation) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#else
+ ExtensionSet ext_set;
+ auto dummy_schema = schema({field("foo", binary())});
+ // creating a dummy dataset using a dummy table
+ auto format = std::make_shared<arrow::dataset::ParquetFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
+
+ ASSERT_OK_AND_ASSIGN(std::string dir_string,
+ arrow::internal::GetEnvVar("PARQUET_TEST_DATA"));
+ auto file_name =
+
arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet");
+
+ std::vector<fs::FileInfo> files;
+ const std::vector<std::string> f_paths = {file_name->ToString()};
+
+ for (const auto& f_path : f_paths) {
+ ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path));
+ files.push_back(std::move(f_file));
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto ds_factory,
dataset::FileSystemDatasetFactory::Make(
+ std::move(filesystem),
std::move(files),
+ std::move(format), {}));
+ ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema));
+
+ auto options = std::make_shared<dataset::ScanOptions>();
+ options->projection = compute::project({}, {});
+ auto scan_node_options = dataset::ScanNodeOptions{dataset, options};
+
+ auto scan_declaration = compute::Declaration({"scan", scan_node_options});
+
+ ASSERT_OK_AND_ASSIGN(auto serialized_rel,
+ SerializeRelation(scan_declaration, &ext_set));
+ ASSERT_OK_AND_ASSIGN(auto deserialized_decl,
+ DeserializeRelation(*serialized_rel, ext_set));
+
+ auto dataset_comparator = [](std::shared_ptr<dataset::Dataset> ds_lhs,
+ std::shared_ptr<dataset::Dataset> ds_rhs) ->
bool {
+ const auto& fds_lhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
+ const auto& fds_rhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
+ const auto& files_lhs = fds_lhs.files();
+ const auto& files_rhs = fds_rhs.files();
+
+ bool cmp_fsize = files_lhs.size() == files_rhs.size();
+ uint64_t fidx = 0;
+ for (const auto& l_file : files_lhs) {
+ if (l_file != files_rhs[fidx]) {
+ return false;
+ }
+ fidx++;
+ }
+ bool cmp_file_format = fds_lhs.format()->Equals(*fds_lhs.format());
+ bool cmp_file_system = fds_lhs.filesystem()->Equals(fds_rhs.filesystem());
+ return cmp_fsize && cmp_file_format && cmp_file_system;
+ };
+
+ auto scan_option_comparator = [dataset_comparator](
+ const dataset::ScanNodeOptions& lhs,
+ const dataset::ScanNodeOptions& rhs) ->
bool {
+ bool cmp_rso = lhs.require_sequenced_output ==
rhs.require_sequenced_output;
+ bool cmp_ds = dataset_comparator(lhs.dataset, rhs.dataset);
+ return cmp_rso && cmp_ds;
+ };
+
+ EXPECT_EQ(deserialized_decl.factory_name, scan_declaration.factory_name);
+ const auto& lhs =
+ checked_cast<const
dataset::ScanNodeOptions&>(*deserialized_decl.options);
+ const auto& rhs =
+ checked_cast<const dataset::ScanNodeOptions&>(*scan_declaration.options);
+ ASSERT_TRUE(scan_option_comparator(lhs, rhs));
+#endif
+}
+
+TEST(Substrait, SerializeRelationEndToEnd) {
Review Comment:
What is this test adding that the previous test does not already cover?
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1173,5 +1175,173 @@ TEST(Substrait, JoinPlanInvalidKeys) {
&ext_set));
}
+TEST(Substrait, SerializeRelation) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#else
+ ExtensionSet ext_set;
+ auto dummy_schema = schema({field("foo", binary())});
+ // creating a dummy dataset using a dummy table
+ auto format = std::make_shared<arrow::dataset::ParquetFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
+
+ ASSERT_OK_AND_ASSIGN(std::string dir_string,
+ arrow::internal::GetEnvVar("PARQUET_TEST_DATA"));
+ auto file_name =
+
arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet");
+
+ std::vector<fs::FileInfo> files;
+ const std::vector<std::string> f_paths = {file_name->ToString()};
+
+ for (const auto& f_path : f_paths) {
+ ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path));
+ files.push_back(std::move(f_file));
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto ds_factory,
dataset::FileSystemDatasetFactory::Make(
+ std::move(filesystem),
std::move(files),
+ std::move(format), {}));
+ ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema));
+
+ auto options = std::make_shared<dataset::ScanOptions>();
+ options->projection = compute::project({}, {});
+ auto scan_node_options = dataset::ScanNodeOptions{dataset, options};
+
+ auto scan_declaration = compute::Declaration({"scan", scan_node_options});
+
+ ASSERT_OK_AND_ASSIGN(auto serialized_rel,
+ SerializeRelation(scan_declaration, &ext_set));
+ ASSERT_OK_AND_ASSIGN(auto deserialized_decl,
+ DeserializeRelation(*serialized_rel, ext_set));
+
+ auto dataset_comparator = [](std::shared_ptr<dataset::Dataset> ds_lhs,
+ std::shared_ptr<dataset::Dataset> ds_rhs) ->
bool {
+ const auto& fds_lhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
+ const auto& fds_rhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
Review Comment:
```suggestion
const auto& fsd_lhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
const auto& fsd_rhs = checked_cast<const
dataset::FileSystemDataset&>(*ds_lhs);
```
Minor nit
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -316,5 +323,97 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
rel.DebugString());
}
+namespace {
+// TODO: add other types
+enum ArrowRelationType : uint8_t {
+ SCAN,
+ FILTER,
+ PROJECT,
+ JOIN,
+ AGGREGATE,
+};
+
+const std::map<std::string, ArrowRelationType> enum_map{
+ {"scan", ArrowRelationType::SCAN}, {"filter",
ArrowRelationType::FILTER},
+ {"project", ArrowRelationType::PROJECT}, {"join",
ArrowRelationType::JOIN},
+ {"aggregate", ArrowRelationType::AGGREGATE},
+};
+
+struct ExtractRelation {
+ explicit ExtractRelation(substrait::Rel* rel, ExtensionSet* ext_set)
+ : rel_(rel), ext_set_(ext_set) {}
+
+ Status AddRelation(const compute::Declaration& declaration) {
+ const std::string& rel_name = declaration.factory_name;
+ switch (enum_map.find(rel_name)->second) {
+ case ArrowRelationType::SCAN:
+ return AddReadRelation(declaration);
+ case ArrowRelationType::FILTER:
+ return Status::NotImplemented("Filter operator not supported.");
+ case ArrowRelationType::PROJECT:
+ return Status::NotImplemented("Project operator not supported.");
+ case ArrowRelationType::JOIN:
+ return Status::NotImplemented("Join operator not supported.");
+ case ArrowRelationType::AGGREGATE:
+ return Status::NotImplemented("Aggregate operator not supported.");
+ default:
+ return Status::Invalid("Unsupported factory name :", rel_name);
+ }
+ }
+
+ Status AddReadRelation(const compute::Declaration& declaration) {
+ auto read_rel = internal::make_unique<substrait::ReadRel>();
+ const auto& scan_node_options =
+ internal::checked_cast<const
dataset::ScanNodeOptions&>(*declaration.options);
+
+ const auto& fds = internal::checked_cast<const
dataset::FileSystemDataset&>(
+ *scan_node_options.dataset);
+
+ // set schema
+ ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*fds.schema(), ext_set_));
+ read_rel->set_allocated_base_schema(named_struct.release());
+
+ // set local files
+ auto read_rel_lfs = internal::make_unique<substrait::ReadRel_LocalFiles>();
+ for (const auto& file : fds.files()) {
+ auto read_rel_lfs_ffs =
+ internal::make_unique<substrait::ReadRel_LocalFiles_FileOrFiles>();
+ read_rel_lfs_ffs->set_uri_path("file://" + file);
+
+ // set file format
+ auto format_type_name = fds.format()->type_name();
+ if (format_type_name == "parquet" || format_type_name == "arrow" ||
+ format_type_name == "feather") {
+ read_rel_lfs_ffs->set_format(
+ substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET);
+ } else {
+ return Status::Invalid("Unsupported file type : ", format_type_name);
+ }
+ read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release());
+ }
+ *read_rel->mutable_local_files() = *read_rel_lfs.get();
+
+ rel_->set_allocated_read(read_rel.release());
+ return Status::OK();
+ }
+
+ Status operator()(const compute::Declaration& declaration) {
+ return AddRelation(declaration);
+ }
+
+ private:
+ substrait::Rel* rel_;
+ ExtensionSet* ext_set_;
+};
+
+} // namespace
+
+Result<std::unique_ptr<substrait::Rel>> ToProto(const compute::Declaration&
declaration,
+ ExtensionSet* ext_set) {
+ auto out = internal::make_unique<substrait::Rel>();
+ RETURN_NOT_OK(ExtractRelation(out.get(), ext_set)(declaration));
Review Comment:
Why do we need a struct here instead something like:
```
RETURN_NOT_OK(AddRelation(declaration, ext_set, out.get()));
```
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -316,5 +323,97 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
rel.DebugString());
}
+namespace {
+// TODO: add other types
+enum ArrowRelationType : uint8_t {
+ SCAN,
+ FILTER,
+ PROJECT,
+ JOIN,
+ AGGREGATE,
+};
+
+const std::map<std::string, ArrowRelationType> enum_map{
+ {"scan", ArrowRelationType::SCAN}, {"filter",
ArrowRelationType::FILTER},
+ {"project", ArrowRelationType::PROJECT}, {"join",
ArrowRelationType::JOIN},
+ {"aggregate", ArrowRelationType::AGGREGATE},
+};
+
+struct ExtractRelation {
+ explicit ExtractRelation(substrait::Rel* rel, ExtensionSet* ext_set)
+ : rel_(rel), ext_set_(ext_set) {}
+
+ Status AddRelation(const compute::Declaration& declaration) {
+ const std::string& rel_name = declaration.factory_name;
+ switch (enum_map.find(rel_name)->second) {
+ case ArrowRelationType::SCAN:
+ return AddReadRelation(declaration);
+ case ArrowRelationType::FILTER:
+ return Status::NotImplemented("Filter operator not supported.");
+ case ArrowRelationType::PROJECT:
+ return Status::NotImplemented("Project operator not supported.");
+ case ArrowRelationType::JOIN:
+ return Status::NotImplemented("Join operator not supported.");
+ case ArrowRelationType::AGGREGATE:
+ return Status::NotImplemented("Aggregate operator not supported.");
+ default:
+ return Status::Invalid("Unsupported factory name :", rel_name);
+ }
+ }
+
+ Status AddReadRelation(const compute::Declaration& declaration) {
+ auto read_rel = internal::make_unique<substrait::ReadRel>();
+ const auto& scan_node_options =
+ internal::checked_cast<const
dataset::ScanNodeOptions&>(*declaration.options);
+
+ const auto& fds = internal::checked_cast<const
dataset::FileSystemDataset&>(
Review Comment:
Minor nit: `fds` is maybe a little confusing. Perhaps `dataset`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]