gemini-code-assist[bot] commented on code in PR #455: URL: https://github.com/apache/tvm-ffi/pull/455#discussion_r2810690513
########## src/ffi/extra/type_checker.cc: ########## @@ -0,0 +1,538 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/list.h> +#include <tvm/ffi/container/map.h> +#include <tvm/ffi/extra/json.h> +#include <tvm/ffi/extra/type_checker.h> + +#include <sstream> +#include <unordered_map> + +namespace tvm { +namespace ffi { + +// --------------------------------------------------------------------------- +// TypeSchema +// --------------------------------------------------------------------------- + +TypeSchema TypeSchema::FromJSON(const Any& obj) { + Map<Any, Any> map = obj.cast<Map<Any, Any>>(); + TypeSchema result; + result.origin = map[String("type")].cast<String>(); + if (auto it = map.find(String("args")); it != map.end()) { + Array<Any> json_args = (*it).second.cast<Array<Any>>(); + result.args.reserve(json_args.size()); + for (const Any& arg : json_args) { + result.args.push_back(TypeSchema::FromJSON(arg)); + } + } + return result; +} + +TypeSchema TypeSchema::FromJSONStr(const String& json_str) { + return TypeSchema::FromJSON(json::Parse(json_str)); +} + +std::string TypeSchema::Repr() const { + std::vector<std::string> arg_reprs; + arg_reprs.reserve(args.size()); + for (const auto& a : args) { + arg_reprs.push_back(a.Repr()); + } + + if (origin == "Variant") { + std::ostringstream oss; + for (size_t i = 0; i < arg_reprs.size(); ++i) { + if (i > 0) oss << " | "; + oss << arg_reprs[i]; + } + return oss.str(); + } + if (origin == "Optional") { + return arg_reprs.at(0) + " | None"; + } + if (origin == "Callable" || origin == "ffi.Function") { + if (arg_reprs.empty()) { + return "Callable[..., Any]"; + } + std::ostringstream oss; + oss << "Callable[["; + for (size_t i = 1; i < arg_reprs.size(); ++i) { + if (i > 1) oss << ", "; + oss << arg_reprs[i]; + } + oss << "], " << arg_reprs[0] << "]"; + return oss.str(); + } + if (arg_reprs.empty()) { + return origin; + } + std::ostringstream oss; + oss << origin << "["; + for (size_t i = 0; i < arg_reprs.size(); ++i) { + if (i > 0) oss << ", "; + oss << arg_reprs[i]; + } + oss << "]"; + return oss.str(); +} + +// --------------------------------------------------------------------------- +// TypeChecker – helpers +// --------------------------------------------------------------------------- + +namespace { + +/*! + * \brief Runtime check if `object_type_index` is an instance of the type + * identified by `target_type_index`. + * + * This is the runtime equivalent of `details::IsObjectInstance<T>()`. + */ +bool IsObjectInstanceRuntime(int32_t object_type_index, int32_t target_type_index) { + if (target_type_index == TypeIndex::kTVMFFIObject) return true; + if (object_type_index == target_type_index) return true; + const TVMFFITypeInfo* obj_info = TVMFFIGetTypeInfo(object_type_index); + const TVMFFITypeInfo* tgt_info = TVMFFIGetTypeInfo(target_type_index); + if (obj_info->type_depth <= tgt_info->type_depth) return false; + return obj_info->type_ancestors[tgt_info->type_depth]->type_index == target_type_index; +} + +} // namespace + +// --------------------------------------------------------------------------- +// TypeChecker – Kind resolution +// --------------------------------------------------------------------------- + +TypeChecker::Kind TypeChecker::ResolveKind(const std::string& origin, + int32_t* resolved_type_index) { + // Static lookup table mapping origin strings to Kind values. + // Heap-allocated to avoid destruction-order issues. + static const auto* table = new std::unordered_map<std::string, Kind>{ + // POD + {"Any", Kind::kAny}, + {"None", Kind::kNone}, + {"int", Kind::kInt}, + {"bool", Kind::kBool}, + {"float", Kind::kFloat}, + // String family + {"ffi.String", Kind::kStr}, + {"ffi.SmallStr", Kind::kStr}, + {"std::string", Kind::kStr}, + {"const char*", Kind::kStr}, + // Bytes family + {"ffi.Bytes", Kind::kBytes}, + {"ffi.SmallBytes", Kind::kBytes}, + {"TVMFFIByteArray*", Kind::kBytes}, + // Special POD + {"DataType", Kind::kDataType}, + {"Device", Kind::kDevice}, + {"void*", Kind::kOpaquePtr}, + {"DLTensor*", Kind::kDLTensorPtr}, + // Callable + {"ffi.Function", Kind::kCallable}, + // Containers + {"ffi.Array", Kind::kArray}, + {"ffi.List", Kind::kList}, + {"ffi.Map", Kind::kMap}, + // Generic wrappers + {"Optional", Kind::kOptional}, + {"Variant", Kind::kVariant}, + {"Tuple", Kind::kTuple}, + }; + + auto it = table->find(origin); + if (it != table->end()) { + Kind kind = it->second; + if (kind == Kind::kCallable) { + *resolved_type_index = TypeIndex::kTVMFFIFunction; + } + return kind; + } + // Anything else is treated as a named Object type. + *resolved_type_index = TypeIndex::kTVMFFIObject; // default fallback + TVMFFIByteArray key_bytes{origin.data(), origin.size()}; + int32_t tindex = -1; + if (TVMFFITypeKeyToIndex(&key_bytes, &tindex) == 0) { + *resolved_type_index = tindex; + } + return Kind::kObject; +} + +// --------------------------------------------------------------------------- +// TypeChecker – construction +// --------------------------------------------------------------------------- + +TypeChecker::TypeChecker(const TypeSchema& schema) : schema_(schema) { + kind_ = ResolveKind(schema_.origin, &resolved_type_index_); + args_.reserve(schema_.args.size()); + for (const auto& arg : schema_.args) { + args_.emplace_back(arg); + } +} Review Comment:  The constructor does not validate the number of type arguments for generic types. This can lead to runtime errors or incorrect behavior if a malformed `TypeSchema` is provided. For example: - A `TypeSchema` for `Optional` without any arguments will cause a crash due to out-of-bounds access in `CheckStrict` and `TryCast`. - A `TypeSchema` for `Map` with one argument is treated as a fully untyped map (`Map<Any, Any>`), which is likely not the intended behavior and can hide type errors. It would be safer to validate the arity of type arguments in the constructor based on the `kind_`. For instance: ```cpp TypeChecker::TypeChecker(const TypeSchema& schema) : schema_(schema) { kind_ = ResolveKind(schema_.origin, &resolved_type_index_); // Add validation for argument count based on kind switch (kind_) { case Kind::kOptional: if (schema_.args.size() != 1) { TVM_FFI_THROW(Error("Invalid TypeSchema: Optional expects 1 type argument, but got " + std::to_string(schema_.args.size()))); } break; case Kind::kMap: if (!schema_.args.empty() && schema_.args.size() != 2) { TVM_FFI_THROW(Error("Invalid TypeSchema: Map expects 0 or 2 type arguments, but got " + std::to_string(schema_.args.size()))); } break; // ... other cases like Array, List, Tuple default: break; } args_.reserve(schema_.args.size()); for (const auto& arg : schema_.args) { args_.emplace_back(arg); } } ``` This would make the `TypeChecker` more robust against invalid schemas. ########## src/ffi/extra/type_checker.cc: ########## @@ -0,0 +1,538 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/list.h> +#include <tvm/ffi/container/map.h> +#include <tvm/ffi/extra/json.h> +#include <tvm/ffi/extra/type_checker.h> + +#include <sstream> +#include <unordered_map> + +namespace tvm { +namespace ffi { + +// --------------------------------------------------------------------------- +// TypeSchema +// --------------------------------------------------------------------------- + +TypeSchema TypeSchema::FromJSON(const Any& obj) { + Map<Any, Any> map = obj.cast<Map<Any, Any>>(); + TypeSchema result; + result.origin = map[String("type")].cast<String>(); + if (auto it = map.find(String("args")); it != map.end()) { + Array<Any> json_args = (*it).second.cast<Array<Any>>(); + result.args.reserve(json_args.size()); + for (const Any& arg : json_args) { + result.args.push_back(TypeSchema::FromJSON(arg)); + } + } + return result; +} + +TypeSchema TypeSchema::FromJSONStr(const String& json_str) { + return TypeSchema::FromJSON(json::Parse(json_str)); +} + +std::string TypeSchema::Repr() const { + std::vector<std::string> arg_reprs; + arg_reprs.reserve(args.size()); + for (const auto& a : args) { + arg_reprs.push_back(a.Repr()); + } + + if (origin == "Variant") { + std::ostringstream oss; + for (size_t i = 0; i < arg_reprs.size(); ++i) { + if (i > 0) oss << " | "; + oss << arg_reprs[i]; + } + return oss.str(); + } + if (origin == "Optional") { + return arg_reprs.at(0) + " | None"; + } + if (origin == "Callable" || origin == "ffi.Function") { + if (arg_reprs.empty()) { + return "Callable[..., Any]"; + } + std::ostringstream oss; + oss << "Callable[["; + for (size_t i = 1; i < arg_reprs.size(); ++i) { + if (i > 1) oss << ", "; + oss << arg_reprs[i]; + } + oss << "], " << arg_reprs[0] << "]"; + return oss.str(); + } + if (arg_reprs.empty()) { + return origin; + } + std::ostringstream oss; + oss << origin << "["; + for (size_t i = 0; i < arg_reprs.size(); ++i) { + if (i > 0) oss << ", "; + oss << arg_reprs[i]; + } + oss << "]"; + return oss.str(); +} + +// --------------------------------------------------------------------------- +// TypeChecker – helpers +// --------------------------------------------------------------------------- + +namespace { + +/*! + * \brief Runtime check if `object_type_index` is an instance of the type + * identified by `target_type_index`. + * + * This is the runtime equivalent of `details::IsObjectInstance<T>()`. + */ +bool IsObjectInstanceRuntime(int32_t object_type_index, int32_t target_type_index) { + if (target_type_index == TypeIndex::kTVMFFIObject) return true; + if (object_type_index == target_type_index) return true; + const TVMFFITypeInfo* obj_info = TVMFFIGetTypeInfo(object_type_index); + const TVMFFITypeInfo* tgt_info = TVMFFIGetTypeInfo(target_type_index); + if (obj_info->type_depth <= tgt_info->type_depth) return false; + return obj_info->type_ancestors[tgt_info->type_depth]->type_index == target_type_index; +} + +} // namespace + +// --------------------------------------------------------------------------- +// TypeChecker – Kind resolution +// --------------------------------------------------------------------------- + +TypeChecker::Kind TypeChecker::ResolveKind(const std::string& origin, + int32_t* resolved_type_index) { + // Static lookup table mapping origin strings to Kind values. + // Heap-allocated to avoid destruction-order issues. + static const auto* table = new std::unordered_map<std::string, Kind>{ + // POD + {"Any", Kind::kAny}, + {"None", Kind::kNone}, + {"int", Kind::kInt}, + {"bool", Kind::kBool}, + {"float", Kind::kFloat}, + // String family + {"ffi.String", Kind::kStr}, + {"ffi.SmallStr", Kind::kStr}, + {"std::string", Kind::kStr}, + {"const char*", Kind::kStr}, + // Bytes family + {"ffi.Bytes", Kind::kBytes}, + {"ffi.SmallBytes", Kind::kBytes}, + {"TVMFFIByteArray*", Kind::kBytes}, + // Special POD + {"DataType", Kind::kDataType}, + {"Device", Kind::kDevice}, + {"void*", Kind::kOpaquePtr}, + {"DLTensor*", Kind::kDLTensorPtr}, + // Callable + {"ffi.Function", Kind::kCallable}, + // Containers + {"ffi.Array", Kind::kArray}, + {"ffi.List", Kind::kList}, + {"ffi.Map", Kind::kMap}, + // Generic wrappers + {"Optional", Kind::kOptional}, + {"Variant", Kind::kVariant}, + {"Tuple", Kind::kTuple}, + }; + + auto it = table->find(origin); + if (it != table->end()) { + Kind kind = it->second; + if (kind == Kind::kCallable) { + *resolved_type_index = TypeIndex::kTVMFFIFunction; + } + return kind; + } + // Anything else is treated as a named Object type. + *resolved_type_index = TypeIndex::kTVMFFIObject; // default fallback + TVMFFIByteArray key_bytes{origin.data(), origin.size()}; + int32_t tindex = -1; + if (TVMFFITypeKeyToIndex(&key_bytes, &tindex) == 0) { + *resolved_type_index = tindex; + } + return Kind::kObject; +} + +// --------------------------------------------------------------------------- +// TypeChecker – construction +// --------------------------------------------------------------------------- + +TypeChecker::TypeChecker(const TypeSchema& schema) : schema_(schema) { + kind_ = ResolveKind(schema_.origin, &resolved_type_index_); + args_.reserve(schema_.args.size()); + for (const auto& arg : schema_.args) { + args_.emplace_back(arg); + } +} + +// --------------------------------------------------------------------------- +// TypeChecker – CheckStrict +// --------------------------------------------------------------------------- + +bool TypeChecker::CheckStrict(AnyView src) const { + int32_t ti = src.type_index(); + + switch (kind_) { + case Kind::kAny: + return true; + case Kind::kNone: + return ti == TypeIndex::kTVMFFINone; + case Kind::kInt: + return ti == TypeIndex::kTVMFFIInt; + case Kind::kBool: + return ti == TypeIndex::kTVMFFIBool; + case Kind::kFloat: + return ti == TypeIndex::kTVMFFIFloat; + case Kind::kStr: + return ti == TypeIndex::kTVMFFIStr || ti == TypeIndex::kTVMFFISmallStr; + case Kind::kBytes: + return ti == TypeIndex::kTVMFFIBytes || ti == TypeIndex::kTVMFFISmallBytes; + case Kind::kDataType: + return ti == TypeIndex::kTVMFFIDataType; + case Kind::kDevice: + return ti == TypeIndex::kTVMFFIDevice; + case Kind::kOpaquePtr: + return ti == TypeIndex::kTVMFFIOpaquePtr; + case Kind::kDLTensorPtr: + return ti == TypeIndex::kTVMFFIDLTensorPtr; + case Kind::kCallable: + return ti >= TypeIndex::kTVMFFIStaticObjectBegin && + IsObjectInstanceRuntime(ti, TypeIndex::kTVMFFIFunction); + case Kind::kObject: { + // Small strings/bytes are logically Object-derived types but have POD type indices. + // Remap them to their heap counterparts for the subtype check. + int32_t effective_ti = ti; + if (ti == TypeIndex::kTVMFFISmallStr) + effective_ti = TypeIndex::kTVMFFIStr; + else if (ti == TypeIndex::kTVMFFISmallBytes) + effective_ti = TypeIndex::kTVMFFIBytes; + return effective_ti >= TypeIndex::kTVMFFIStaticObjectBegin && + IsObjectInstanceRuntime(effective_ti, resolved_type_index_); + } + case Kind::kOptional: + if (ti == TypeIndex::kTVMFFINone) return true; + return args_[0].CheckStrict(src); + case Kind::kArray: { + if (ti != TypeIndex::kTVMFFIArray) return false; + if (args_.empty()) return true; // Array<Any> + Array<Any> arr = src.cast<Array<Any>>(); + for (const auto& elem : arr) { + if (!args_[0].CheckStrict(elem)) return false; + } + return true; + } + case Kind::kList: { + if (ti != TypeIndex::kTVMFFIList) return false; + if (args_.empty()) return true; + // List and Array share SeqBaseObj, cast via Array<Any> + List<Any> list = src.cast<List<Any>>(); + for (const auto& elem : list) { + if (!args_[0].CheckStrict(elem)) return false; + } + return true; + } + case Kind::kMap: { + if (ti != TypeIndex::kTVMFFIMap) return false; + if (args_.size() < 2) return true; // Map<Any, Any> + Map<Any, Any> map = src.cast<Map<Any, Any>>(); + for (const auto& kv : map) { + if (!args_[0].CheckStrict(kv.first)) return false; + if (!args_[1].CheckStrict(kv.second)) return false; + } + return true; + } + case Kind::kVariant: { + for (const auto& checker : args_) { + if (checker.CheckStrict(src)) return true; + } + return false; + } + case Kind::kTuple: { + if (ti != TypeIndex::kTVMFFIArray) return false; + Array<Any> arr = src.cast<Array<Any>>(); + if (arr.size() != args_.size()) return false; + for (size_t i = 0; i < args_.size(); ++i) { + if (!args_[i].CheckStrict(arr[static_cast<int64_t>(i)])) return false; + } + return true; + } + } + return false; +} + +// --------------------------------------------------------------------------- +// TypeChecker – TryCast +// --------------------------------------------------------------------------- + +std::optional<Any> TypeChecker::TryCast(AnyView src) const { + int32_t ti = src.type_index(); + + switch (kind_) { + case Kind::kAny: + return Any(src); + + case Kind::kNone: + if (ti == TypeIndex::kTVMFFINone) return Any(nullptr); + return std::nullopt; + + case Kind::kInt: + if (ti == TypeIndex::kTVMFFIInt || ti == TypeIndex::kTVMFFIBool) { + return Any(src.cast<int64_t>()); + } + return std::nullopt; + + case Kind::kBool: + if (ti == TypeIndex::kTVMFFIBool || ti == TypeIndex::kTVMFFIInt) { + return Any(src.cast<bool>()); + } + return std::nullopt; + + case Kind::kFloat: + if (ti == TypeIndex::kTVMFFIFloat) { + return Any(src.cast<double>()); + } + if (ti == TypeIndex::kTVMFFIInt || ti == TypeIndex::kTVMFFIBool) { + return Any(src.cast<double>()); + } + return std::nullopt; + + case Kind::kStr: + if (ti == TypeIndex::kTVMFFIStr || ti == TypeIndex::kTVMFFISmallStr || + ti == TypeIndex::kTVMFFIRawStr) { + return Any(src.cast<String>()); + } + return std::nullopt; + + case Kind::kBytes: + if (ti == TypeIndex::kTVMFFIBytes || ti == TypeIndex::kTVMFFISmallBytes) { + return Any(src); + } + return std::nullopt; + + case Kind::kDataType: + if (ti == TypeIndex::kTVMFFIDataType) return Any(src); + return std::nullopt; + + case Kind::kDevice: + if (ti == TypeIndex::kTVMFFIDevice) return Any(src); + return std::nullopt; + + case Kind::kOpaquePtr: + if (ti == TypeIndex::kTVMFFIOpaquePtr) return Any(src); + if (ti == TypeIndex::kTVMFFINone) return Any(static_cast<void*>(nullptr)); + return std::nullopt; + + case Kind::kDLTensorPtr: + // DLTensor* cannot be stored in Any — return as AnyView-backed Any + if (ti == TypeIndex::kTVMFFIDLTensorPtr || ti == TypeIndex::kTVMFFITensor) { + // Delegate to the compile-time TypeTraits<DLTensor*>::TryCastFromAnyView + if (auto opt = src.try_cast<DLTensor*>()) { + // Cannot store DLTensor* in Any, so return the Tensor object instead + if (ti == TypeIndex::kTVMFFITensor) return Any(src); + // For raw DLTensor*, we can't store it in Any either. + // Return the source as-is since it's already the right type view. + return Any(src); + } Review Comment:  The logic inside `if (auto opt = src.try_cast<DLTensor*>())` can be simplified. Both branches of the inner `if (ti == TypeIndex::kTVMFFITensor)` statement result in `return Any(src)`. You can combine them for better readability. ```c if (src.try_cast<DLTensor*>()) { // Cannot store DLTensor* in Any, so return the original object // which is either a Tensor or a view of a raw DLTensor*. return Any(src); } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
