westonpace commented on code in PR #13232:
URL: https://github.com/apache/arrow/pull/13232#discussion_r888297879


##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -204,152 +204,259 @@ const int* GetIndex(const KeyToIndex& key_to_index, 
const Key& key) {
   return &it->second;
 }
 
-ExtensionIdRegistry* default_extension_id_registry() {
-  static struct Impl : ExtensionIdRegistry {
-    Impl() {
-      struct TypeName {
-        std::shared_ptr<DataType> type;
-        util::string_view name;
-      };
-
-      // The type (variation) mappings listed below need to be kept in sync
-      // with the YAML at substrait/format/extension_types.yaml manually;
-      // see ARROW-15535.
-      for (TypeName e : {
-               TypeName{uint8(), "u8"},
-               TypeName{uint16(), "u16"},
-               TypeName{uint32(), "u32"},
-               TypeName{uint64(), "u64"},
-               TypeName{float16(), "fp16"},
-           }) {
-        DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, 
std::move(e.type)));
-      }
-
-      for (TypeName e : {
-               TypeName{null(), "null"},
-               TypeName{month_interval(), "interval_month"},
-               TypeName{day_time_interval(), "interval_day_milli"},
-               TypeName{month_day_nano_interval(), "interval_month_day_nano"},
-           }) {
-        DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, 
std::move(e.type)));
-      }
-
-      // TODO: this is just a placeholder right now. We'll need a YAML file for
-      // all functions (and prototypes) that Arrow provides that are relevant
-      // for Substrait, and include mappings for all of them here. See
-      // ARROW-15535.
-      for (util::string_view name : {
-               "add",
-               "equal",
-               "is_not_distinct_from",
-           }) {
-        DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, 
name.to_string()));
-      }
+namespace {
+
+struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
+  virtual ~ExtensionIdRegistryImpl() {}
+
+  std::vector<util::string_view> Uris() const override {
+    return {uris_.begin(), uris_.end()};
+  }
+
+  util::optional<TypeRecord> GetType(const DataType& type) const override {
+    if (auto index = GetIndex(type_to_index_, &type)) {
+      return TypeRecord{type_ids_[*index], types_[*index]};
+    }
+    return {};
+  }
+
+  util::optional<TypeRecord> GetType(Id id) const override {
+    if (auto index = GetIndex(id_to_index_, id)) {
+      return TypeRecord{type_ids_[*index], types_[*index]};
+    }
+    return {};
+  }
+
+  Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const 
override {
+    if (id_to_index_.find(id) != id_to_index_.end()) {
+      return Status::Invalid("Type id was already registered");
+    }
+    if (type_to_index_.find(&*type) != type_to_index_.end()) {
+      return Status::Invalid("Type was already registered");
+    }
+    return Status::OK();
+  }
+
+  Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+    DCHECK_EQ(type_ids_.size(), types_.size());
+
+    Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+                 *names_.emplace(id.name.to_string()).first};
+
+    auto index = static_cast<int>(type_ids_.size());
+
+    auto it_success = id_to_index_.emplace(copied_id, index);
+
+    if (!it_success.second) {
+      return Status::Invalid("Type id was already registered");
+    }
+
+    if (!type_to_index_.emplace(type.get(), index).second) {
+      id_to_index_.erase(it_success.first);
+      return Status::Invalid("Type was already registered");
     }
 
-    std::vector<util::string_view> Uris() const override {
-      return {uris_.begin(), uris_.end()};
+    type_ids_.push_back(copied_id);
+    types_.push_back(std::move(type));
+    return Status::OK();
+  }
+
+  util::optional<FunctionRecord> GetFunction(
+      util::string_view arrow_function_name) const override {
+    if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
+      return FunctionRecord{function_ids_[*index], 
*function_name_ptrs_[*index]};
     }
+    return {};
+  }
 
-    util::optional<TypeRecord> GetType(const DataType& type) const override {
-      if (auto index = GetIndex(type_to_index_, &type)) {
-        return TypeRecord{type_ids_[*index], types_[*index]};
-      }
-      return {};
+  util::optional<FunctionRecord> GetFunction(Id id) const override {
+    if (auto index = GetIndex(function_id_to_index_, id)) {
+      return FunctionRecord{function_ids_[*index], 
*function_name_ptrs_[*index]};
     }
+    return {};
+  }
 
-    util::optional<TypeRecord> GetType(Id id) const override {
-      if (auto index = GetIndex(id_to_index_, id)) {
-        return TypeRecord{type_ids_[*index], types_[*index]};
-      }
-      return {};
+  Status CanRegisterFunction(Id id,
+                             const std::string& arrow_function_name) const 
override {
+    if (function_id_to_index_.find(id) != function_id_to_index_.end()) {
+      return Status::Invalid("Function id was already registered");
+    }
+    if (function_name_to_index_.find(arrow_function_name) !=
+        function_name_to_index_.end()) {
+      return Status::Invalid("Function name was already registered");
     }
+    return Status::OK();
+  }
 
-    Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
-      DCHECK_EQ(type_ids_.size(), types_.size());
+  Status RegisterFunction(Id id, std::string arrow_function_name) override {
+    DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
 
-      Id copied_id{*uris_.emplace(id.uri.to_string()).first,
-                   *names_.emplace(id.name.to_string()).first};
+    Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+                 *names_.emplace(id.name.to_string()).first};
 
-      auto index = static_cast<int>(type_ids_.size());
+    const std::string& copied_function_name{
+        *function_names_.emplace(std::move(arrow_function_name)).first};
 
-      auto it_success = id_to_index_.emplace(copied_id, index);
+    auto index = static_cast<int>(function_ids_.size());
 
-      if (!it_success.second) {
-        return Status::Invalid("Type id was already registered");
-      }
+    auto it_success = function_id_to_index_.emplace(copied_id, index);
 
-      if (!type_to_index_.emplace(type.get(), index).second) {
-        id_to_index_.erase(it_success.first);
-        return Status::Invalid("Type was already registered");
-      }
+    if (!it_success.second) {
+      return Status::Invalid("Function id was already registered");
+    }
 
-      type_ids_.push_back(copied_id);
-      types_.push_back(std::move(type));
-      return Status::OK();
+    if (!function_name_to_index_.emplace(copied_function_name, index).second) {
+      function_id_to_index_.erase(it_success.first);
+      return Status::Invalid("Function name was already registered");
     }
 
-    util::optional<FunctionRecord> GetFunction(
-        util::string_view arrow_function_name) const override {
-      if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) 
{
-        return FunctionRecord{function_ids_[*index], 
*function_name_ptrs_[*index]};
-      }
-      return {};
+    function_name_ptrs_.push_back(&copied_function_name);
+    function_ids_.push_back(copied_id);
+    return Status::OK();
+  }
+
+  // owning storage of uris, names, (arrow::)function_names, types
+  //    note that storing strings like this is safe since references into an
+  //    unordered_set are not invalidated on insertion
+  std::unordered_set<std::string> uris_, names_, function_names_;
+  DataTypeVector types_;
+
+  // non-owning lookup helpers
+  std::vector<Id> type_ids_, function_ids_;
+  std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_;
+  std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> 
type_to_index_;
+
+  std::vector<const std::string*> function_name_ptrs_;
+  std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
+  std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
+      function_name_to_index_;
+};
+
+struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl {
+  explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent)
+      : parent_(parent) {}
+
+  virtual ~NestedExtensionIdRegistryImpl() {}
+
+  std::vector<util::string_view> Uris() const override {
+    std::vector<util::string_view> uris = parent_->Uris();
+    std::unordered_set<util::string_view> uri_set;
+    uri_set.insert(uris.begin(), uris.end());
+    uri_set.insert(uris_.begin(), uris_.end());
+    return std::vector<util::string_view>(uris);
+  }
+
+  util::optional<TypeRecord> GetType(const DataType& type) const override {
+    auto type_opt = ExtensionIdRegistryImpl::GetType(type);
+    if (type_opt) {
+      return type_opt;
     }
+    return parent_->GetType(type);
+  }
 
-    util::optional<FunctionRecord> GetFunction(Id id) const override {
-      if (auto index = GetIndex(function_id_to_index_, id)) {
-        return FunctionRecord{function_ids_[*index], 
*function_name_ptrs_[*index]};
-      }
-      return {};
+  util::optional<TypeRecord> GetType(Id id) const override {
+    auto type_opt = ExtensionIdRegistryImpl::GetType(id);
+    if (type_opt) {
+      return type_opt;
     }
+    return parent_->GetType(id);
+  }
 
-    Status RegisterFunction(Id id, std::string arrow_function_name) override {
-      DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
+  Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const 
override {
+    return parent_->CanRegisterType(id, type) &
+           ExtensionIdRegistryImpl::CanRegisterType(id, type);
+  }
 
-      Id copied_id{*uris_.emplace(id.uri.to_string()).first,
-                   *names_.emplace(id.name.to_string()).first};
+  Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+    return parent_->CanRegisterType(id, type) &
+           ExtensionIdRegistryImpl::RegisterType(id, type);

Review Comment:
   I'll be the first to admit that there is a good possibility I'm just plain 
wrong or missing something :laughing: so I apologize for the trouble in 
advance.  However, I think you are correct, that a `std::move` is not allowed 
here.  After doing some research though, I think your explanation is slightly 
off.
   
   The function argument is (correctly, IMO) defined as `RegisterType(Id id, 
std::shared_ptr<DataType> type)` and not `RegisterType(Id id, const 
std::shared_ptr<DataType>& type)` so this function is explicitly requesting its 
own copy of the `shared_ptr`.
   
   So nothing we do here should be capable of invalidating the caller's copy.
   
   Instead I think this is a consequence of using `&` as the status coalescing 
operator and not `&&`.  I believe it was you that pointed out a while back that 
this is a little odd.  Perhaps it is not just odd but actually semantically 
incorrect as well.  It turns out that `&&` defines a "sequence point" 
(https://en.cppreference.com/w/cpp/language/eval_order) that prevents the 
compiler from reordering statements.
   
   What I think is happening is the compiler is choosing to reorder the call to 
`ExtensionIdRegistryImpl::RegisterType(id, type)` before the call to 
`parent_->CanRegisterType(id, type)`.  When I first made this comment I thought 
that sort of reordering was illegal and thus a bug.  However, after reading 
about sequence points, it seems that reordering is perfectly legal and so using 
`std::move` is indeed incorrect here (although this is more motivation to 
perhaps change the operator to `&&` but that does not need to be a concern for 
this PR).
   
   So, thank you for bearing with me, I consider this comment resolved.  I'll 
follow up on the `&`/`&&` in a separate JIRA to see how others feel about it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to