junrushao1994 commented on code in PR #12101:
URL: https://github.com/apache/tvm/pull/12101#discussion_r931685347


##########
src/node/structural_equal.cc:
##########
@@ -42,6 +62,133 @@ bool ReflectionVTable::SEqualReduce(const Object* self, 
const Object* other,
   return fsequal_reduce_[tindex](self, other, equal);
 }
 
+struct SEqualReducer::PathTracingData {
+  ObjectPathPair current_paths;
+  ObjectRef lhs_object;
+  ObjectRef rhs_object;
+  Optional<ObjectPathPair>* first_mismatch;
+
+  ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) 
const {
+    Optional<String> lhs_attr_key = GetAttrKeyByAddress(lhs_object.get(), 
&lhs);
+    Optional<String> rhs_attr_key = GetAttrKeyByAddress(rhs_object.get(), 
&rhs);
+    return ObjectPathPair(current_paths->lhs_path->Attr(lhs_attr_key),
+                          current_paths->rhs_path->Attr(rhs_attr_key));
+  }
+};
+
+bool SEqualReducer::operator()(const ObjectRef& lhs, const ObjectRef& rhs) 
const {
+  if (tracing_data_ == nullptr) {
+    // Fast path: no tracing
+    return handler_->SEqualReduce(lhs, rhs, map_free_vars_, {});
+  }
+  return ObjectAttrsEqual(lhs, rhs, map_free_vars_, nullptr);
+}
+
+bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
+  if (tracing_data_ == nullptr) {
+    // Fast path: no tracing
+    return handler_->SEqualReduce(lhs, rhs, true, {});
+  }
+  return ObjectAttrsEqual(lhs, rhs, true, nullptr);
+}
+
+/* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch(
+    const void* lhs_address, const void* rhs_address, const PathTracingData* 
tracing_data) {
+  if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) {
+    Optional<String> lhs_attr_key =
+        GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address);
+    Optional<String> rhs_attr_key =
+        GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address);
+    *tracing_data->first_mismatch =
+        
ObjectPathPair(tracing_data->current_paths->lhs_path->Attr(lhs_attr_key),
+                       
tracing_data->current_paths->rhs_path->Attr(rhs_attr_key));
+  }
+}
+
+template <typename T>
+/* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& 
rhs,
+                                                        const PathTracingData* 
tracing_data) {
+  if (BaseValueEqual()(lhs, rhs)) {
+    return true;
+  } else {
+    GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
+    return false;
+  }
+}
+
+bool SEqualReducer::operator()(const double& lhs, const double& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const 
{
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const int& lhs, const int& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) 
const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const 
{
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address,
+                                   const void* rhs_address) const {
+  if (lhs == rhs) {
+    return true;
+  } else {
+    GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, 
tracing_data_);
+    return false;
+  }
+}
+
+const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const {
+  ICHECK(tracing_data_ != nullptr)
+      << "GetCurrentObjectPaths() can only be called when path tracing is 
enabled";
+  return tracing_data_->current_paths;
+}
+
+void SEqualReducer::RecordMismatchPaths(const ObjectPathPair& paths) const {
+  ICHECK(tracing_data_ != nullptr)
+      << "RecordMismatchPaths() can only be called when path tracing is 
enabled";
+  if (!tracing_data_->first_mismatch->defined()) {
+    *tracing_data_->first_mismatch = paths;
+  }
+}
+
+bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& 
rhs, bool map_free_vars,
+                                     const ObjectPathPair* paths) const {
+  if (tracing_data_ == nullptr) {
+    // Fast path: no tracing
+    return handler_->SEqualReduce(lhs, rhs, map_free_vars, {});

Review Comment:
   nit
   
   ```suggestion
       return handler_->SEqualReduce(lhs, rhs, map_free_vars, NullOpt);
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to