rtpsw commented on code in PR #13232:
URL: https://github.com/apache/arrow/pull/13232#discussion_r888237895


##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -204,152 +204,259 @@ const int* GetIndex(const KeyToIndex& key_to_index, 
const Key& key) {
   return &it->second;
 }
 
-ExtensionIdRegistry* default_extension_id_registry() {
-  static struct Impl : ExtensionIdRegistry {
-    Impl() {
-      struct TypeName {
-        std::shared_ptr<DataType> type;
-        util::string_view name;
-      };
-
-      // The type (variation) mappings listed below need to be kept in sync
-      // with the YAML at substrait/format/extension_types.yaml manually;
-      // see ARROW-15535.
-      for (TypeName e : {
-               TypeName{uint8(), "u8"},
-               TypeName{uint16(), "u16"},
-               TypeName{uint32(), "u32"},
-               TypeName{uint64(), "u64"},
-               TypeName{float16(), "fp16"},
-           }) {
-        DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, 
std::move(e.type)));
-      }
-
-      for (TypeName e : {
-               TypeName{null(), "null"},
-               TypeName{month_interval(), "interval_month"},
-               TypeName{day_time_interval(), "interval_day_milli"},
-               TypeName{month_day_nano_interval(), "interval_month_day_nano"},
-           }) {
-        DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, 
std::move(e.type)));
-      }
-
-      // TODO: this is just a placeholder right now. We'll need a YAML file for
-      // all functions (and prototypes) that Arrow provides that are relevant
-      // for Substrait, and include mappings for all of them here. See
-      // ARROW-15535.
-      for (util::string_view name : {
-               "add",
-               "equal",
-               "is_not_distinct_from",
-           }) {
-        DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, 
name.to_string()));
-      }
+namespace {
+
+struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
+  virtual ~ExtensionIdRegistryImpl() {}
+
+  std::vector<util::string_view> Uris() const override {
+    return {uris_.begin(), uris_.end()};
+  }
+
+  util::optional<TypeRecord> GetType(const DataType& type) const override {
+    if (auto index = GetIndex(type_to_index_, &type)) {
+      return TypeRecord{type_ids_[*index], types_[*index]};
+    }
+    return {};
+  }
+
+  util::optional<TypeRecord> GetType(Id id) const override {
+    if (auto index = GetIndex(id_to_index_, id)) {
+      return TypeRecord{type_ids_[*index], types_[*index]};
+    }
+    return {};
+  }
+
+  Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const 
override {
+    if (id_to_index_.find(id) != id_to_index_.end()) {
+      return Status::Invalid("Type id was already registered");
+    }
+    if (type_to_index_.find(&*type) != type_to_index_.end()) {
+      return Status::Invalid("Type was already registered");
+    }
+    return Status::OK();
+  }
+
+  Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+    DCHECK_EQ(type_ids_.size(), types_.size());
+
+    Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+                 *names_.emplace(id.name.to_string()).first};
+
+    auto index = static_cast<int>(type_ids_.size());
+
+    auto it_success = id_to_index_.emplace(copied_id, index);
+
+    if (!it_success.second) {
+      return Status::Invalid("Type id was already registered");
+    }
+
+    if (!type_to_index_.emplace(type.get(), index).second) {
+      id_to_index_.erase(it_success.first);
+      return Status::Invalid("Type was already registered");
     }
 
-    std::vector<util::string_view> Uris() const override {
-      return {uris_.begin(), uris_.end()};
+    type_ids_.push_back(copied_id);
+    types_.push_back(std::move(type));
+    return Status::OK();
+  }
+
+  util::optional<FunctionRecord> GetFunction(
+      util::string_view arrow_function_name) const override {
+    if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
+      return FunctionRecord{function_ids_[*index], 
*function_name_ptrs_[*index]};
     }
+    return {};
+  }
 
-    util::optional<TypeRecord> GetType(const DataType& type) const override {
-      if (auto index = GetIndex(type_to_index_, &type)) {
-        return TypeRecord{type_ids_[*index], types_[*index]};
-      }
-      return {};
+  util::optional<FunctionRecord> GetFunction(Id id) const override {
+    if (auto index = GetIndex(function_id_to_index_, id)) {
+      return FunctionRecord{function_ids_[*index], 
*function_name_ptrs_[*index]};
     }
+    return {};
+  }
 
-    util::optional<TypeRecord> GetType(Id id) const override {
-      if (auto index = GetIndex(id_to_index_, id)) {
-        return TypeRecord{type_ids_[*index], types_[*index]};
-      }
-      return {};
+  Status CanRegisterFunction(Id id,
+                             const std::string& arrow_function_name) const 
override {
+    if (function_id_to_index_.find(id) != function_id_to_index_.end()) {
+      return Status::Invalid("Function id was already registered");
+    }
+    if (function_name_to_index_.find(arrow_function_name) !=
+        function_name_to_index_.end()) {
+      return Status::Invalid("Function name was already registered");
     }
+    return Status::OK();
+  }
 
-    Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
-      DCHECK_EQ(type_ids_.size(), types_.size());
+  Status RegisterFunction(Id id, std::string arrow_function_name) override {
+    DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
 
-      Id copied_id{*uris_.emplace(id.uri.to_string()).first,
-                   *names_.emplace(id.name.to_string()).first};
+    Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+                 *names_.emplace(id.name.to_string()).first};
 
-      auto index = static_cast<int>(type_ids_.size());
+    const std::string& copied_function_name{
+        *function_names_.emplace(std::move(arrow_function_name)).first};
 
-      auto it_success = id_to_index_.emplace(copied_id, index);
+    auto index = static_cast<int>(function_ids_.size());
 
-      if (!it_success.second) {
-        return Status::Invalid("Type id was already registered");
-      }
+    auto it_success = function_id_to_index_.emplace(copied_id, index);
 
-      if (!type_to_index_.emplace(type.get(), index).second) {
-        id_to_index_.erase(it_success.first);
-        return Status::Invalid("Type was already registered");
-      }
+    if (!it_success.second) {
+      return Status::Invalid("Function id was already registered");
+    }
 
-      type_ids_.push_back(copied_id);
-      types_.push_back(std::move(type));
-      return Status::OK();
+    if (!function_name_to_index_.emplace(copied_function_name, index).second) {
+      function_id_to_index_.erase(it_success.first);
+      return Status::Invalid("Function name was already registered");
     }
 
-    util::optional<FunctionRecord> GetFunction(
-        util::string_view arrow_function_name) const override {
-      if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) 
{
-        return FunctionRecord{function_ids_[*index], 
*function_name_ptrs_[*index]};
-      }
-      return {};
+    function_name_ptrs_.push_back(&copied_function_name);
+    function_ids_.push_back(copied_id);
+    return Status::OK();
+  }
+
+  // owning storage of uris, names, (arrow::)function_names, types
+  //    note that storing strings like this is safe since references into an
+  //    unordered_set are not invalidated on insertion
+  std::unordered_set<std::string> uris_, names_, function_names_;
+  DataTypeVector types_;
+
+  // non-owning lookup helpers
+  std::vector<Id> type_ids_, function_ids_;
+  std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_;
+  std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> 
type_to_index_;
+
+  std::vector<const std::string*> function_name_ptrs_;
+  std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
+  std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
+      function_name_to_index_;
+};
+
+struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl {
+  explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent)
+      : parent_(parent) {}
+
+  virtual ~NestedExtensionIdRegistryImpl() {}
+
+  std::vector<util::string_view> Uris() const override {
+    std::vector<util::string_view> uris = parent_->Uris();
+    std::unordered_set<util::string_view> uri_set;
+    uri_set.insert(uris.begin(), uris.end());
+    uri_set.insert(uris_.begin(), uris_.end());
+    return std::vector<util::string_view>(uris);
+  }
+
+  util::optional<TypeRecord> GetType(const DataType& type) const override {
+    auto type_opt = ExtensionIdRegistryImpl::GetType(type);
+    if (type_opt) {
+      return type_opt;
     }
+    return parent_->GetType(type);
+  }
 
-    util::optional<FunctionRecord> GetFunction(Id id) const override {
-      if (auto index = GetIndex(function_id_to_index_, id)) {
-        return FunctionRecord{function_ids_[*index], 
*function_name_ptrs_[*index]};
-      }
-      return {};
+  util::optional<TypeRecord> GetType(Id id) const override {
+    auto type_opt = ExtensionIdRegistryImpl::GetType(id);
+    if (type_opt) {
+      return type_opt;
     }
+    return parent_->GetType(id);
+  }
 
-    Status RegisterFunction(Id id, std::string arrow_function_name) override {
-      DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
+  Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const 
override {
+    return parent_->CanRegisterType(id, type) &
+           ExtensionIdRegistryImpl::CanRegisterType(id, type);
+  }
 
-      Id copied_id{*uris_.emplace(id.uri.to_string()).first,
-                   *names_.emplace(id.name.to_string()).first};
+  Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+    return parent_->CanRegisterType(id, type) &
+           ExtensionIdRegistryImpl::RegisterType(id, type);

Review Comment:
   I tested this change but it causes a segmentation fault with the following 
gdb output:
   ```
   $  gdb --args ./release/arrow-substrait-substrait-test 
   GNU gdb (Ubuntu 9.2-0ubuntu1~20.04) 9.2
   Copyright (C) 2020 Free Software Foundation, Inc.
   License GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>
   This is free software: you are free to change and redistribute it.
   There is NO WARRANTY, to the extent permitted by law.
   Type "show copying" and "show warranty" for details.
   This GDB was configured as "x86_64-linux-gnu".
   Type "show configuration" for configuration details.
   For bug reporting instructions, please see:
   <http://www.gnu.org/software/gdb/bugs/>.
   Find the GDB manual and other documentation resources online at:
       <http://www.gnu.org/software/gdb/documentation/>.
   
   For help, type "help".
   Type "apropos word" to search for commands related to "word"...
   Reading symbols from ./release/arrow-substrait-substrait-test...
   (gdb) run
   Starting program: 
/mnt/user1/tscontract/github/rtpsw/arrow/cpp/build/release/release/arrow-substrait-substrait-test
 
   [Thread debugging using libthread_db enabled]
   Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
   [New Thread 0x7ffff45ff700 (LWP 156004)]
   Running main() from 
/build/googletest-j5yxiC/googletest-1.10.0/googletest/src/gtest_main.cc
   [==========] Running 33 tests from 3 test suites.
   [----------] Global test environment set-up.
   [----------] 4 tests from ExtensionIdRegistryTest
   [ RUN      ] ExtensionIdRegistryTest.RegisterTempTypes
   
   Thread 1 "arrow-substrait" received signal SIGSEGV, Segmentation fault.
   0x00007ffff65199dc in arrow::DataType::Hash() const () from 
/mnt/user1/tscontract/github/rtpsw/arrow/cpp/build/release/release/libarrow.so.900
   (gdb) bt
   #0  0x00007ffff65199dc in arrow::DataType::Hash() const () from 
/mnt/user1/tscontract/github/rtpsw/arrow/cpp/build/release/release/libarrow.so.900
   #1  0x00007ffff7aa420c in std::_Hashtable<arrow::DataType const*, 
std::pair<arrow::DataType const* const, int>, 
std::allocator<std::pair<arrow::DataType const* const, int> >, 
std::__detail::_Select1st, arrow::engine::(anonymous namespace)::TypePtrHashEq, 
arrow::engine::(anonymous namespace)::TypePtrHashEq, 
std::__detail::_Mod_range_hashing, std::__detail::_Default_ranged_hash, 
std::__detail::_Prime_rehash_policy, std::__detail::_Hashtable_traits<true, 
false, true> >::find(arrow::DataType const* const&) const ()
      from 
/mnt/user1/tscontract/github/rtpsw/arrow/cpp/build/release/release/libarrow_substrait.so.900
   #2  0x00007ffff7aa7b7a in arrow::engine::(anonymous 
namespace)::ExtensionIdRegistryImpl::CanRegisterType(arrow::engine::ExtensionIdRegistry::Id,
 std::shared_ptr<arrow::DataType> const&) const ()
      from 
/mnt/user1/tscontract/github/rtpsw/arrow/cpp/build/release/release/libarrow_substrait.so.900
   #3  0x00007ffff7aa9a1c in arrow::engine::(anonymous 
namespace)::NestedExtensionIdRegistryImpl::RegisterType(arrow::engine::ExtensionIdRegistry::Id,
 std::shared_ptr<arrow::DataType>) ()
      from 
/mnt/user1/tscontract/github/rtpsw/arrow/cpp/build/release/release/libarrow_substrait.so.900
   #4  0x00005555555953da in 
arrow::engine::ExtensionIdRegistryTest_RegisterTempTypes_Test::TestBody() ()
   #5  0x00005555555ff3d1 in 
testing::internal::HandleSehExceptionsInMethodIfSupported<testing::Test, void> 
(location=0x5555556186bf "the test body", method=<optimized out>, 
object=0x555555670140)
       at ./googletest/src/gtest.cc:2414
   #6  testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, 
void> (object=object@entry=0x555555670140, method=<optimized out>, 
location=location@entry=0x5555556186bf "the test body")
       at ./googletest/src/gtest.cc:2469
   #7  0x00005555555f3756 in testing::Test::Run (this=0x555555670140) at 
./googletest/src/gtest.cc:2508
   #8  testing::Test::Run (this=0x555555670140) at 
./googletest/src/gtest.cc:2498
   #9  0x00005555555f38b5 in testing::TestInfo::Run (this=0x5555556749f0) at 
./googletest/src/gtest.cc:2684
   #10 testing::TestInfo::Run (this=0x5555556749f0) at 
./googletest/src/gtest.cc:2657
   #11 0x00005555555f399d in testing::TestSuite::Run (this=0x5555556731f0) at 
./googletest/src/gtest.cc:2816
   #12 testing::TestSuite::Run (this=0x5555556731f0) at 
./googletest/src/gtest.cc:2795
   #13 0x00005555555f3ebc in testing::internal::UnitTestImpl::RunAllTests 
(this=0x55555565bec0) at /usr/include/c++/9/bits/stl_vector.h:1040
   #14 0x00005555555ff941 in 
testing::internal::HandleSehExceptionsInMethodIfSupported<testing::internal::UnitTestImpl,
 bool> (
       location=0x555555619a78 "auxiliary test code (environments or event 
listeners)", method=<optimized out>, object=0x55555565bec0) at 
./googletest/src/gtest.cc:2414
   #15 
testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl,
 bool> (object=0x55555565bec0, method=<optimized out>, 
       location=location@entry=0x555555619a78 "auxiliary test code 
(environments or event listeners)") at ./googletest/src/gtest.cc:2469
   #16 0x00005555555f40ec in testing::UnitTest::Run (this=0x55555563f560 
<testing::UnitTest::GetInstance()::instance>) at 
./googletest/include/gtest/gtest.h:1412
   #17 0x0000555555594274 in main () at 
/usr/include/c++/9/ext/new_allocator.h:89
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to