gemini-code-assist[bot] commented on code in PR #18857: URL: https://github.com/apache/tvm/pull/18857#discussion_r2868384335
########## src/tir/transform/tvm_ffi_binder.cc: ########## @@ -0,0 +1,741 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm_ffi_binder.cc + * \brief Helper utility to match and bind packed function arguments. + */ +#include "tvm_ffi_binder.h" + +#include <tvm/runtime/device_api.h> +#include <tvm/tir/builtin.h> +#include <tvm/tir/expr.h> +#include <tvm/tir/op.h> + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +using ffi::reflection::AccessPath; +using ffi::reflection::AccessStep; + +// ============================================================ +// Constructor +// ============================================================ + +ArgBinder::ArgBinder(std::unordered_map<const VarNode*, PrimExpr>* def_map, + const std::string& func_name, const ffi::Array<Var>& params, + const ffi::Map<Var, Buffer>& buffer_map, const Var& v_packed_args) + : def_map_(def_map), + func_name_(func_name), + params_(params), + buffer_map_(buffer_map), + v_packed_args_(v_packed_args) { + // Build function signature string + std::ostringstream os; + os << func_name << "("; + for (size_t i = 0; i < params.size(); ++i) { + if (i > 0) os << ", "; + Var param = params[i]; + if (buffer_map.count(param)) { + Buffer buf = buffer_map[param]; + std::string buf_name = buf->name; + os << buf_name << ": Tensor(["; + for (size_t j = 0; j < buf->shape.size(); ++j) { + if (j > 0) os << ", "; + std::ostringstream shape_os; + shape_os << buf->shape[j]; + os << shape_os.str(); + } + os << "], " << buf->dtype << ")"; + param_names_[static_cast<int>(i)] = buf_name; + } else { + os << param->name_hint << ": " << param.dtype(); + param_names_[static_cast<int>(i)] = param->name_hint; + } + } + os << ")"; + func_signature_ = os.str(); + // Cache StringImm values for sharing across assertions + sig_imm_ = tvm::tir::StringImm(func_signature_); + when_calling_imm_ = tvm::tir::StringImm(" when calling:\n `"); +} + +// ============================================================ +// EmitRichAssert +// ============================================================ + +void ArgBinder::EmitRichAssert(const std::string& kind, PrimExpr cond, const std::string& detail, + const std::string& expectation, std::vector<Stmt>* target) { + ffi::Array<StringImm> parts; + parts.push_back(tvm::tir::StringImm(detail)); + parts.push_back(when_calling_imm_); + parts.push_back(sig_imm_); + parts.push_back(tvm::tir::StringImm("`,\n expected " + expectation)); + target->emplace_back(AssertStmt(tvm::tir::StringImm(kind), cond, parts)); +} Review Comment:  To improve consistency with how other static string parts are handled (like `when_calling_imm_` and `sig_imm_`) and to potentially enable more string reuse, you could split the string concatenation on line 91. A further improvement would be to cache the static part of this message as a new `StringImm` member in `ArgBinder`, similar to `sig_imm_`. ```c void ArgBinder::EmitRichAssert(const std::string& kind, PrimExpr cond, const std::string& detail, const std::string& expectation, std::vector<Stmt>* target) { ffi::Array<StringImm> parts; parts.push_back(tvm::tir::StringImm(detail)); parts.push_back(when_calling_imm_); parts.push_back(sig_imm_); parts.push_back(tvm::tir::StringImm("`, expected ")); parts.push_back(tvm::tir::StringImm(expectation)); target->emplace_back(AssertStmt(tvm::tir::StringImm(kind), cond, parts)); } ``` ########## src/tir/transform/make_packed_api.cc: ########## @@ -167,13 +168,34 @@ class SubroutineCallRewriter : public StmtExprMutator { } // namespace -inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { - return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg)); +/*! + * \brief Create an assert that lhs == rhs, with multi-part error message. + * \param kind The error kind (e.g. "TypeError", "ValueError", "RuntimeError"). + * \param lhs Left-hand side of equality check. + * \param rhs Right-hand side of equality check. + * \param detail The detail message string. + * \param func_signature Function signature for context. + */ +inline Stmt MakeAssertEQ(const std::string& kind, PrimExpr lhs, PrimExpr rhs, + const std::string& detail, const std::string& func_signature) { + ffi::Array<StringImm> parts; + parts.push_back(tvm::tir::StringImm(detail + " when calling:\n `")); + parts.push_back(tvm::tir::StringImm(func_signature)); + parts.push_back(tvm::tir::StringImm("`")); + return AssertStmt(tvm::tir::StringImm(kind), lhs == rhs, parts); } -inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) { +/*! + * \brief Create an assert that ptr is not NULL, with multi-part error message. + */ +inline Stmt MakeAssertNotNull(const std::string& kind, PrimExpr ptr, const std::string& detail, + const std::string& func_signature) { Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr}); - return AssertStmt(!isnull, tvm::tir::StringImm(msg)); + ffi::Array<StringImm> parts; + parts.push_back(tvm::tir::StringImm(detail + " when calling:\n `")); + parts.push_back(tvm::tir::StringImm(func_signature)); + parts.push_back(tvm::tir::StringImm("`")); + return AssertStmt(tvm::tir::StringImm(kind), !isnull, parts); } Review Comment:  These helper functions are used to generate assertions related to packed function calls and construct error messages that include the function signature. Since the `ArgBinder` class already manages the function signature and has cached `StringImm`s for parts of the error message, it would be beneficial to move `MakeAssertEQ` and `MakeAssertNotNull` into the `ArgBinder` class as public methods. This would encapsulate all related error message generation logic within `ArgBinder`, promote code reuse (e.g., for `when_calling_imm_`), and improve the overall organization. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
