This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch sequal in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 441b41e245e8542abad92e8293258d650c9ecd92 Author: tqchen <[email protected]> AuthorDate: Thu Jul 17 11:51:40 2025 -0400 wip --- ffi/src/ffi/reflection/structural_equal.cc | 225 +++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) diff --git a/ffi/src/ffi/reflection/structural_equal.cc b/ffi/src/ffi/reflection/structural_equal.cc new file mode 100644 index 0000000000..ced25425d8 --- /dev/null +++ b/ffi/src/ffi/reflection/structural_equal.cc @@ -0,0 +1,225 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/reflection/structural_equal.cc + * + * \brief Structural equal implementation. + */ +#include <tvm/ffi/any.h> +#include <tvm/ffi/string.h> +#include <tvm/ffi/container/ndarray.h> +#include <tvm/ffi/container/shape.h> +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/map.h> +#include <tvm/ffi/container/ndarray.h> +#include <tvm/ffi/reflection/access_path.h> +#include <tvm/ffi/reflection/accessor.h> + +#include <unordered_map> + + +namespace tvm { +namespace ffi { +namespace reflection { + +class StructuralEqual{ + public: + + + + private: + bool CompareAny(ffi::Any lhs, ffi::Any rhs) { + const TVMFFIAny* lhs_data = details::AnyUnsafe::TVMFFIAnyPtrFromAny(lhs); + const TVMFFIAny* rhs_data = details::AnyUnsafe::TVMFFIAnyPtrFromAny(rhs); + if (lhs_data->type_index != rhs_data->type_index) { + return false; + } + if (lhs_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { + // this is mostly POD data, we can just compare the value + return lhs_data->v_int64 == rhs_data->v_int64; + } + switch (lhs_data->type_index) { + case TypeIndex::kTVMFFIStr: + case TypeIndex::kTVMFFIBytes: { + // compare bytes + const BytesObjBase* lhs_str = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(lhs); + const BytesObjBase* rhs_str = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(rhs); + return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0; + } + case TypeIndex::kTVMFFIArray: { + return CompareArray( + details::AnyUnsafe::MoveFromAnyAfterCheck<Array<Any>>(std::move(lhs)), + details::AnyUnsafe::MoveFromAnyAfterCheck<Array<Any>>(std::move(rhs)) + ); + } + case TypeIndex::kTVMFFIMap: { + return CompareMap( + details::AnyUnsafe::MoveFromAnyAfterCheck<Map<Any, Any>>(std::move(lhs)), + details::AnyUnsafe::MoveFromAnyAfterCheck<Map<Any, Any>>(std::move(rhs)) + ); + } + case TypeIndex::kTVMFFIShape: { + return CompareShape( + details::AnyUnsafe::MoveFromAnyAfterCheck<Shape>(std::move(lhs)), + details::AnyUnsafe::MoveFromAnyAfterCheck<Shape>(std::move(rhs)) + ); + } + case TypeIndex::kTVMFFINDArray: { + return CompareNDArray( + details::AnyUnsafe::MoveFromAnyAfterCheck<NDArray>(std::move(lhs)), + details::AnyUnsafe::MoveFromAnyAfterCheck<NDArray>(std::move(rhs)) + ); + } + default: { + return CompareObject( + details::AnyUnsafe::MoveFromAnyAfterCheck<ObjectRef>(std::move(lhs)), + details::AnyUnsafe::MoveFromAnyAfterCheck<ObjectRef>(std::move(rhs)) + ); + } + } + } + + bool CompareObject(ObjectRef lhs, ObjectRef rhs) { + // NOTE: invariant: lhs and rhs are already the same type + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index()); + if (type_info->extra_info == nullptr || + type_info->extra_info->structural_eq_hash_kind == kTVMFFIStructuralEqHashKindUnsupported || + type_info->extra_info->structural_eq_hash_kind == kTVMFFIStructuralEqHashKindUniqueInstance) { + // use pointer comparison + return lhs.same_as(rhs); + } + if (type_info->extra_info->structural_eq_hash_kind == kTVMFFIStructuralEqHashKindConstTreeNode) { + // constant tree node, pointer equality indicate equality and avoid content comparison + if (lhs.same_as(rhs)) return true; + } + if (type_info->extra_info->structural_eq_hash_kind == kTVMFFIStructuralEqHashKindDAGNode || + type_info->extra_info->structural_eq_hash_kind == kTVMFFIStructuralEqHashKindFreeVar) { + // DAG node, we need to compare the children + // TODO: implement DAG node comparison + return false; + } + + bool success = true; + // go over each field + ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) { + // skip fields that are marked as structural eq hash ignore + if (field_info->flags & kTVMFFIFieldFlagBitMaskStructuralEqHashIgnore) return false; + FieldGetter getter(field_info); + Any lhs_value = getter(lhs); + Any rhs_value = getter(rhs); + + // field is in def region, enable free var mapping + bool allow_free_var = (field_info->flags & kTVMFFIFieldFlagBitMaskStructuralEqHashDef) != 0; + std::swap(allow_free_var, map_free_var_); + success = CompareAny(lhs_value, rhs_value); + std::swap(allow_free_var, map_free_var_); + + if (!success) { + // record the first mismatching field + if (first_mismatch_reverse_path_ != nullptr) { + first_mismatch_reverse_path_->emplace_back( + AccessStep::ObjectField(String(field_info->name)) + ); + } + // return true to indicate early stop + return true; + } else { + // return false to continue checking other fields + return false; + } + }); + return success; + } + + bool CompareMap(Map<Any, Any> lhs, Map<Any, Any> rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (const auto& [key, value] : lhs) { + if (!CompareAny(key, value)) { + return false; + } + } + return true; + } + + bool CompareArray(ffi::Array<Any> lhs, ffi::Array<Any> rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); i++) { + if (!CompareAny(lhs[i], rhs[i])) { + return false; + } + } + return true; + } + + bool CompareShape(Shape lhs, Shape rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; + } + + bool CompareNDArray(NDArray lhs, NDArray rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs->ndim != rhs->ndim) return false; + for (int i = 0; i < lhs->ndim; ++i) { + if (lhs->shape[i] != rhs->shape[i]) return false; + } + if (lhs->dtype != rhs->dtype) return false; + + TVM_FFI_ICHECK_EQ(lhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; + TVM_FFI_ICHECK_EQ(rhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; + TVM_FFI_ICHECK(lhs.IsContiguous()) << "Can only compare contiguous tensor"; + TVM_FFI_ICHECK(rhs.IsContiguous()) << "Can only compare contiguous tensor"; + + size_t data_size = GetDataSize(*(lhs.operator->())); + if (compare_ndarray_data_) { + return std::memcmp(lhs->data, rhs->data, data_size) == 0; + } else { + return true; + } + } + + // whether we map free variables that are not defined + bool map_free_var_{false}; + // whether we compare ndarray data + bool compare_ndarray_data_{true}; + + // the root lhs for result printing + std::vector<AccessPath>* first_mismatch_reverse_path_; + // map from lhs to rhs + std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_; + // map from rhs to lhs + std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_rhs_; +}; + +} // namespace reflection +} // namespace ffi +} // namespace tvm
