This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch sequal
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 441b41e245e8542abad92e8293258d650c9ecd92
Author: tqchen <[email protected]>
AuthorDate: Thu Jul 17 11:51:40 2025 -0400

    wip
---
 ffi/src/ffi/reflection/structural_equal.cc | 225 +++++++++++++++++++++++++++++
 1 file changed, 225 insertions(+)

diff --git a/ffi/src/ffi/reflection/structural_equal.cc 
b/ffi/src/ffi/reflection/structural_equal.cc
new file mode 100644
index 0000000000..ced25425d8
--- /dev/null
+++ b/ffi/src/ffi/reflection/structural_equal.cc
@@ -0,0 +1,225 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * \file src/ffi/reflection/structural_equal.cc
+ *
+ * \brief Structural equal implementation.
+ */
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/string.h>
+#include <tvm/ffi/container/ndarray.h>
+#include <tvm/ffi/container/shape.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/ndarray.h>
+#include <tvm/ffi/reflection/access_path.h>
+#include <tvm/ffi/reflection/accessor.h>
+
+#include <unordered_map>
+
+
+namespace tvm {
+namespace ffi {
+namespace reflection {
+
+class StructuralEqual{
+ public:
+
+
+
+ private:
+  bool CompareAny(ffi::Any lhs, ffi::Any rhs) {
+    const TVMFFIAny* lhs_data = details::AnyUnsafe::TVMFFIAnyPtrFromAny(lhs);
+    const TVMFFIAny* rhs_data = details::AnyUnsafe::TVMFFIAnyPtrFromAny(rhs);
+    if (lhs_data->type_index != rhs_data->type_index) {
+      return false;
+    }
+    if (lhs_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
+      // this is mostly POD data, we can just compare the value
+      return lhs_data->v_int64 == rhs_data->v_int64;
+    }
+    switch (lhs_data->type_index) {
+      case TypeIndex::kTVMFFIStr:
+      case TypeIndex::kTVMFFIBytes: {
+        // compare bytes
+        const BytesObjBase* lhs_str =
+            details::AnyUnsafe::CopyFromAnyViewAfterCheck<const 
BytesObjBase*>(lhs);
+        const BytesObjBase* rhs_str =
+            details::AnyUnsafe::CopyFromAnyViewAfterCheck<const 
BytesObjBase*>(rhs);
+        return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, 
rhs_str->size) == 0;
+      }
+      case TypeIndex::kTVMFFIArray: {
+        return CompareArray(
+          
details::AnyUnsafe::MoveFromAnyAfterCheck<Array<Any>>(std::move(lhs)),
+          details::AnyUnsafe::MoveFromAnyAfterCheck<Array<Any>>(std::move(rhs))
+        );
+      }
+      case TypeIndex::kTVMFFIMap: {
+        return CompareMap(
+          details::AnyUnsafe::MoveFromAnyAfterCheck<Map<Any, 
Any>>(std::move(lhs)),
+          details::AnyUnsafe::MoveFromAnyAfterCheck<Map<Any, 
Any>>(std::move(rhs))
+        );
+      }
+      case TypeIndex::kTVMFFIShape: {
+        return CompareShape(
+          details::AnyUnsafe::MoveFromAnyAfterCheck<Shape>(std::move(lhs)),
+          details::AnyUnsafe::MoveFromAnyAfterCheck<Shape>(std::move(rhs))
+        );
+      }
+      case TypeIndex::kTVMFFINDArray: {
+        return CompareNDArray(
+          details::AnyUnsafe::MoveFromAnyAfterCheck<NDArray>(std::move(lhs)),
+          details::AnyUnsafe::MoveFromAnyAfterCheck<NDArray>(std::move(rhs))
+        );
+      }
+      default: {
+        return CompareObject(
+          details::AnyUnsafe::MoveFromAnyAfterCheck<ObjectRef>(std::move(lhs)),
+          details::AnyUnsafe::MoveFromAnyAfterCheck<ObjectRef>(std::move(rhs))
+        );
+      }
+    }
+  }
+
+  bool CompareObject(ObjectRef lhs, ObjectRef rhs) {
+    // NOTE: invariant: lhs and rhs are already the same type
+    const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index());
+    if (type_info->extra_info == nullptr ||
+        type_info->extra_info->structural_eq_hash_kind == 
kTVMFFIStructuralEqHashKindUnsupported ||
+        type_info->extra_info->structural_eq_hash_kind == 
kTVMFFIStructuralEqHashKindUniqueInstance) {
+      // use pointer comparison
+      return lhs.same_as(rhs);
+    }
+    if (type_info->extra_info->structural_eq_hash_kind == 
kTVMFFIStructuralEqHashKindConstTreeNode) {
+      // constant tree node, pointer equality indicate equality and avoid 
content comparison
+      if (lhs.same_as(rhs)) return true;
+    }
+    if (type_info->extra_info->structural_eq_hash_kind == 
kTVMFFIStructuralEqHashKindDAGNode ||
+        type_info->extra_info->structural_eq_hash_kind == 
kTVMFFIStructuralEqHashKindFreeVar) {
+      // DAG node, we need to compare the children
+      // TODO: implement DAG node comparison
+      return false;
+    }
+
+    bool success = true;
+    // go over each field
+    ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* 
field_info) {
+      // skip fields that are marked as structural eq hash ignore
+      if (field_info->flags & kTVMFFIFieldFlagBitMaskStructuralEqHashIgnore) 
return false;
+      FieldGetter getter(field_info);
+      Any lhs_value = getter(lhs);
+      Any rhs_value = getter(rhs);
+
+      // field is in def region, enable free var mapping
+      bool allow_free_var = (field_info->flags & 
kTVMFFIFieldFlagBitMaskStructuralEqHashDef) != 0;
+      std::swap(allow_free_var, map_free_var_);
+      success = CompareAny(lhs_value, rhs_value);
+      std::swap(allow_free_var, map_free_var_);
+
+      if (!success) {
+        // record the first mismatching field
+        if (first_mismatch_reverse_path_ != nullptr) {
+          first_mismatch_reverse_path_->emplace_back(
+            AccessStep::ObjectField(String(field_info->name))
+          );
+        }
+        // return true to indicate early stop
+        return true;
+      } else {
+        // return false to continue checking other fields
+        return false;
+      }
+    });
+    return success;
+  }
+
+  bool CompareMap(Map<Any, Any> lhs, Map<Any, Any> rhs) {
+    if (lhs.size() != rhs.size()) {
+      return false;
+    }
+    for (const auto& [key, value] : lhs) {
+      if (!CompareAny(key, value)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  bool CompareArray(ffi::Array<Any> lhs, ffi::Array<Any> rhs) {
+    if (lhs.size() != rhs.size()) {
+      return false;
+    }
+    for (size_t i = 0; i < lhs.size(); i++) {
+      if (!CompareAny(lhs[i], rhs[i])) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  bool CompareShape(Shape lhs, Shape rhs) {
+    if (lhs.size() != rhs.size()) {
+      return false;
+    }
+    for (size_t i = 0; i < lhs.size(); ++i) {
+      if (lhs[i] != rhs[i]) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  bool CompareNDArray(NDArray lhs, NDArray rhs) {
+    if (lhs.same_as(rhs)) return true;
+    if (lhs->ndim != rhs->ndim) return false;
+    for (int i = 0; i < lhs->ndim; ++i) {
+      if (lhs->shape[i] != rhs->shape[i]) return false;
+    }
+    if (lhs->dtype != rhs->dtype) return false;
+
+    TVM_FFI_ICHECK_EQ(lhs->device.device_type, kDLCPU) << "can only compare 
CPU tensor";
+    TVM_FFI_ICHECK_EQ(rhs->device.device_type, kDLCPU) << "can only compare 
CPU tensor";
+    TVM_FFI_ICHECK(lhs.IsContiguous()) << "Can only compare contiguous tensor";
+    TVM_FFI_ICHECK(rhs.IsContiguous()) << "Can only compare contiguous tensor";
+
+    size_t data_size = GetDataSize(*(lhs.operator->()));
+    if (compare_ndarray_data_) {
+      return std::memcmp(lhs->data, rhs->data, data_size) == 0;
+    } else {
+      return true;
+    }
+  }
+
+  // whether we map free variables that are not defined
+  bool map_free_var_{false};
+  // whether we compare ndarray data
+  bool compare_ndarray_data_{true};
+
+  // the root lhs for result printing
+  std::vector<AccessPath>* first_mismatch_reverse_path_;
+  // map from lhs to rhs
+  std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> 
equal_map_lhs_;
+  // map from rhs to lhs
+  std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> 
equal_map_rhs_;
+};
+
+}  // namespace reflection
+}  // namespace ffi
+}  // namespace tvm

Reply via email to