This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dafdafe883 [FFI][REFACTOR] Cleanup to align to latest ffi (#18183)
dafdafe883 is described below
commit dafdafe8838338d4c95bfb888d22f096e0587956
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Aug 1 15:49:54 2025 -0400
[FFI][REFACTOR] Cleanup to align to latest ffi (#18183)
This PR modernizee legacy use to align with the latest FFI.
- Use Any to represent general Any instead of ObjectRef
- Use Optional<T>.has_value() instead of defined
---
include/tvm/relax/tir_pattern.h | 2 +-
include/tvm/runtime/profiling.h | 8 ++---
include/tvm/script/printer/doc.h | 8 ++---
include/tvm/script/printer/ir_docsifier.h | 20 ++++++++----
include/tvm/tir/stmt.h | 4 +--
src/contrib/msc/core/ir/graph_builder.cc | 14 ++++----
src/contrib/msc/core/printer/cpp_printer.cc | 6 ++--
src/contrib/msc/core/printer/msc_base_printer.cc | 2 +-
src/contrib/msc/core/printer/prototxt_printer.cc | 6 ++--
src/contrib/msc/core/printer/prototxt_printer.h | 2 +-
src/contrib/msc/core/printer/python_printer.cc | 10 +++---
.../msc/core/transform/bind_named_params.cc | 11 +++----
src/contrib/msc/core/transform/fuse_tuple.cc | 4 +--
src/contrib/msc/core/transform/inline_params.cc | 2 +-
src/contrib/msc/core/transform/set_byoc_attrs.cc | 4 +--
src/contrib/msc/core/transform/set_expr_name.cc | 8 ++---
src/contrib/msc/core/utils.cc | 11 ++++---
src/contrib/msc/core/utils.h | 4 +--
src/contrib/msc/framework/tensorrt/codegen.cc | 2 +-
.../msc/framework/tensorrt/transform_tensorrt.cc | 2 +-
src/ir/apply_pass_to_function.cc | 2 +-
src/ir/name_supply.cc | 1 -
src/meta_schedule/arg_info.cc | 4 +--
src/meta_schedule/database/database.cc | 2 +-
src/meta_schedule/database/database_utils.cc | 2 +-
.../measure_callback/update_cost_model.cc | 4 +--
.../postproc/rewrite_reduction_block.cc | 4 +--
src/meta_schedule/schedule_rule/auto_inline.cc | 2 +-
.../space_generator/post_order_apply.cc | 2 +-
src/meta_schedule/task_scheduler/task_scheduler.cc | 6 ++--
src/meta_schedule/trace_apply.cc | 1 -
src/node/object_path.cc | 4 +--
src/node/repr_printer.cc | 2 ++
src/node/serialization.cc | 19 +++++++++++
src/node/structural_hash.cc | 12 -------
src/relax/analysis/well_formed.cc | 2 +-
.../backend/contrib/codegen_json/codegen_json.h | 28 +++++++---------
src/relax/backend/contrib/tensorrt/codegen.cc | 2 +-
src/relax/backend/vm/codegen_vm.cc | 8 ++---
src/relax/backend/vm/codegen_vm_tir.cc | 10 +++---
src/relax/ir/dataflow_expr_rewriter.cc | 2 +-
src/relax/transform/alter_op_impl.cc | 2 +-
src/relax/transform/attach_global_symbol.cc | 2 +-
src/relax/transform/bind_params.cc | 11 +++----
src/relax/transform/bind_symbolic_vars.cc | 22 ++++++-------
src/relax/transform/expand_tuple_arguments.cc | 2 +-
src/relax/transform/few_shot_tuning.cc | 6 ++--
src/relax/transform/fuse_ops.cc | 12 +++----
src/relax/transform/inline_functions.cc | 2 +-
src/relax/transform/lazy_transform_params.cc | 4 +--
src/relax/transform/merge_composite_functions.cc | 2 +-
src/relax/transform/meta_schedule.cc | 2 +-
src/relax/transform/remove_unused_outputs.cc | 2 +-
src/relax/transform/remove_unused_parameters.cc | 2 +-
src/relax/transform/split_call_tir_by_pattern.cc | 4 +--
src/relax/transform/static_plan_block_memory.cc | 28 ++++------------
src/relax/transform/utils.h | 2 +-
src/runtime/device_api.cc | 4 +--
src/runtime/disco/protocol.h | 4 +--
src/runtime/hexagon/hexagon_device_api.cc | 2 +-
src/runtime/memory/memory_manager.cc | 4 +--
src/runtime/meta_data.h | 2 --
src/runtime/opencl/opencl_device_api.cc | 12 +++----
src/runtime/profiling.cc | 38 +++++++++++-----------
src/runtime/rpc/rpc_endpoint.cc | 2 +-
src/runtime/vm/attn_backend.cc | 14 ++++----
src/runtime/vm/attn_backend.h | 11 +++----
src/runtime/vm/paged_kv_cache.cc | 16 ++++-----
src/runtime/vm/vm.cc | 2 +-
src/script/ir_builder/relax/frame.cc | 6 ++--
src/script/ir_builder/relax/ir.cc | 2 +-
src/script/ir_builder/tir/frame.cc | 6 ++--
src/script/ir_builder/tir/ir.cc | 2 +-
src/script/printer/doc.cc | 2 +-
.../printer/doc_printer/python_doc_printer.cc | 18 +++++-----
src/script/printer/ir/misc.cc | 8 -----
src/script/printer/ir_docsifier.cc | 14 ++++----
src/script/printer/tir/expr.cc | 4 +--
src/script/printer/tir/stmt.cc | 5 +--
src/support/ffi_testing.cc | 4 +--
src/target/llvm/codegen_cpu.cc | 4 +--
src/target/llvm/codegen_hexagon.cc | 2 +-
src/target/llvm/codegen_nvptx.cc | 2 +-
src/target/llvm/llvm_module.cc | 2 +-
src/target/parsers/cpu.cc | 2 +-
src/target/source/codegen_c_host.cc | 2 +-
src/target/source/codegen_metal.cc | 4 +--
src/target/source/codegen_webgpu.cc | 4 +--
src/target/spirv/spirv_utils.cc | 2 +-
src/tir/ir/data_type_rewriter.cc | 2 +-
src/tir/ir/index_map.cc | 2 +-
src/tir/ir/stmt.cc | 2 +-
src/tir/ir/tir_visitor_with_path.cc | 2 +-
src/tir/schedule/analysis/analysis.cc | 2 +-
src/tir/schedule/concrete_schedule.cc | 4 +--
src/tir/schedule/instruction_traits.h | 2 +-
src/tir/schedule/primitive/for_kind.cc | 6 ++--
src/tir/schedule/trace.cc | 10 +++---
src/tir/schedule/traced_schedule.cc | 2 +-
src/tir/schedule/utils.h | 2 +-
src/tir/transforms/bind_target.cc | 6 ++--
src/tir/transforms/compact_buffer_region.cc | 1 -
src/tir/transforms/inject_permuted_layout.cc | 4 +--
src/tir/transforms/inline_private_functions.cc | 2 +-
src/tir/transforms/loop_partition.cc | 4 +--
src/tir/transforms/make_packed_api.cc | 6 ++--
src/tir/transforms/make_unpacked_api.cc | 4 +--
src/tir/transforms/primfunc_utils.cc | 2 +-
108 files changed, 305 insertions(+), 320 deletions(-)
diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h
index 6d8620b65a..1397bafc36 100644
--- a/include/tvm/relax/tir_pattern.h
+++ b/include/tvm/relax/tir_pattern.h
@@ -74,7 +74,7 @@ class MatchResult : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode);
};
-using FCodegen = ffi::TypedFunction<Array<ObjectRef>(Array<MatchResult>
match_results)>;
+using FCodegen = ffi::TypedFunction<Array<ffi::Any>(Array<MatchResult>
match_results)>;
} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_TIR_PATTERN_H_
diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h
index cf467870c6..66c0b64f18 100644
--- a/include/tvm/runtime/profiling.h
+++ b/include/tvm/runtime/profiling.h
@@ -317,7 +317,7 @@ class MetricCollectorNode : public Object {
* \returns A set of metric names and the associated values. Values must be
* one of DurationNode, PercentNode, CountNode, or StringObj.
*/
- virtual Map<String, ffi::Any> Stop(ObjectRef obj) = 0;
+ virtual Map<String, ffi::Any> Stop(ffi::ObjectRef obj) = 0;
virtual ~MetricCollectorNode() {}
@@ -340,7 +340,7 @@ struct CallFrame {
/*! Runtime of the function or op */
Timer timer;
/*! Extra performance metrics */
- std::unordered_map<std::string, ObjectRef> extra_metrics;
+ std::unordered_map<std::string, ffi::Any> extra_metrics;
/*! User defined metric collectors. Each pair is the MetricCollector and its
* associated data (returned from MetricCollector.Start).
*/
@@ -404,12 +404,12 @@ class Profiler {
* `StartCall` and `StopCall` must be nested properly.
*/
void StartCall(String name, Device dev,
- std::unordered_map<std::string, ObjectRef> extra_metrics =
{});
+ std::unordered_map<std::string, ffi::Any> extra_metrics = {});
/*! \brief Stop the last `StartCall`.
* \param extra_metrics Optional additional profiling information to add to
* the frame (input sizes, allocations).
*/
- void StopCall(std::unordered_map<std::string, ObjectRef> extra_metrics = {});
+ void StopCall(std::unordered_map<std::string, ffi::Any> extra_metrics = {});
/*! \brief A report of total runtime between `Start` and `Stop` as
* well as individual statistics for each `StartCall`-`StopCall` pair.
* \returns A `Report` that can either be formatted as CSV (with `.AsCSV`)
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index dfb9c0beee..b19bcab4c3 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -246,7 +246,7 @@ class LiteralDocNode : public ExprDocNode {
* - String
* - null
*/
- ObjectRef value;
+ ffi::Any value;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
@@ -265,16 +265,14 @@ class LiteralDocNode : public ExprDocNode {
*/
class LiteralDoc : public ExprDoc {
protected:
- explicit LiteralDoc(ObjectRef value, const Optional<ObjectPath>&
object_path);
+ explicit LiteralDoc(ffi::Any value, const Optional<ObjectPath>& object_path);
public:
/*!
* \brief Create a LiteralDoc to represent None/null/empty value.
* \param p The object path
*/
- static LiteralDoc None(const Optional<ObjectPath>& p) {
- return LiteralDoc(ObjectRef(nullptr), p);
- }
+ static LiteralDoc None(const Optional<ObjectPath>& p) { return
LiteralDoc(ffi::Any(nullptr), p); }
/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
diff --git a/include/tvm/script/printer/ir_docsifier.h
b/include/tvm/script/printer/ir_docsifier.h
index 909f13ecc0..9d189dda09 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -145,7 +145,7 @@ class IRDocsifierNode : public Object {
/*! \brief Mapping from a var to its info */
std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual>
obj2info;
/*! \brief Metadata printing */
- std::unordered_map<String, Array<ObjectRef>> metadata;
+ std::unordered_map<String, Array<ffi::Any>> metadata;
/*! \brief GlobalInfo printing */
std::unordered_map<String, Array<GlobalInfo>> global_infos;
/*! \brief The variable names used already */
@@ -206,7 +206,7 @@ class IRDocsifierNode : public Object {
*/
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const;
/*! \brief Add a TVM object to the metadata section*/
- ExprDoc AddMetadata(const ObjectRef& obj);
+ ExprDoc AddMetadata(const ffi::Any& obj);
/*! \brief Add a GlobalInfo to the global_infos map.
* \param name The name of key of global_infos.
* \param ginfo The GlobalInfo to be added.
@@ -275,7 +275,7 @@ inline static void AddDocDecoration(const Doc& d, const
ObjectRef& obj, const Ob
const PrinterConfig& cfg) {
if (cfg->obj_to_annotate.count(obj)) {
if (const auto* stmt = d.as<StmtDocNode>()) {
- if (stmt->comment.defined()) {
+ if (stmt->comment.has_value()) {
stmt->comment = stmt->comment.value() + "\n" +
cfg->obj_to_annotate.at(obj);
} else {
stmt->comment = cfg->obj_to_annotate.at(obj);
@@ -295,7 +295,7 @@ inline static void AddDocDecoration(const Doc& d, const
ObjectRef& obj, const Ob
String attn = pair.second;
if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) {
if (const auto* stmt = d.as<StmtDocNode>()) {
- if (stmt->comment.defined()) {
+ if (stmt->comment.has_value()) {
stmt->comment = stmt->comment.value() + "\n" + attn;
} else {
stmt->comment = attn;
@@ -319,8 +319,16 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const
ObjectPath& path) con
return Downcast<TDoc>(LiteralDoc::Int(value.as<int64_t>().value(),
path));
case ffi::TypeIndex::kTVMFFIFloat:
return Downcast<TDoc>(LiteralDoc::Float(value.as<double>().value(),
path));
- case ffi::TypeIndex::kTVMFFIStr:
- return Downcast<TDoc>(LiteralDoc::Str(value.as<String>().value(), path));
+ case ffi::TypeIndex::kTVMFFIStr: {
+ std::string string_value = value.cast<std::string>();
+ bool has_multiple_lines = string_value.find_first_of('\n') !=
std::string::npos;
+ if (has_multiple_lines) {
+ Doc d = const_cast<IRDocsifierNode*>(this)->AddMetadata(string_value);
+ // TODO(tqchen): cross check AddDocDecoration
+ return Downcast<TDoc>(d);
+ }
+ return Downcast<TDoc>(LiteralDoc::Str(string_value, path));
+ }
case ffi::TypeIndex::kTVMFFIDataType:
return
Downcast<TDoc>(LiteralDoc::DataType(value.as<runtime::DataType>().value(),
path));
case ffi::TypeIndex::kTVMFFIDevice:
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 250475c61d..37410b1271 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -115,7 +115,7 @@ class LetStmt : public Stmt {
class AttrStmtNode : public StmtNode {
public:
/*! \brief this is attribute about certain node */
- ObjectRef node;
+ ffi::Any node;
/*! \brief the type key of the attribute */
String attr_key;
/*! \brief The attribute value, value is well defined at current scope. */
@@ -142,7 +142,7 @@ class AttrStmtNode : public StmtNode {
*/
class AttrStmt : public Stmt {
public:
- TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body,
Span span = Span());
+ TVM_DLL AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body,
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode);
diff --git a/src/contrib/msc/core/ir/graph_builder.cc
b/src/contrib/msc/core/ir/graph_builder.cc
index f1f9e08ab9..4670abe52e 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -342,7 +342,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr,
const Optional<Expr>& bin
if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
const auto& func =
Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
const auto& name_opt = func->GetAttr<String>(relax::attr::kComposite);
- if (name_opt.defined()) {
+ if (name_opt.has_value()) {
attrs = FuncAttrGetter().GetAttrs(func);
}
} else if (call_node->op->IsInstance<VarNode>()) {
@@ -760,7 +760,7 @@ void GraphBuilder::VisitBinding_(const VarBindingNode*
binding, const DataflowVa
void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const
FunctionNode* val) {
const auto& name_opt = val->GetAttr<String>(relax::attr::kComposite);
- ICHECK(name_opt.defined()) << "Unexpected target func without composite";
+ ICHECK(name_opt.has_value()) << "Unexpected target func without composite";
ICHECK(config_.target.size() > 0 &&
StringUtils::StartsWith(name_opt.value(), config_.target))
<< "Target should be given for target function";
target_funcs_.Set(binding->var, GetRef<Function>(val));
@@ -770,18 +770,18 @@ const std::tuple<String, String, String>
GraphBuilder::ParseFunc(const Function&
String node_name, optype, layout;
const auto& name_opt = func->GetAttr<String>(msc_attr::kUnique);
// get node_name
- if (name_opt.defined()) {
+ if (name_opt.has_value()) {
node_name = name_opt.value();
}
// get optype
const auto& codegen_opt = func->GetAttr<String>(relax::attr::kCodegen);
const auto& optype_opt = func->GetAttr<String>(msc_attr::kOptype);
const auto& composite_opt = func->GetAttr<String>(relax::attr::kComposite);
- if (codegen_opt.defined()) {
+ if (codegen_opt.has_value()) {
optype = codegen_opt.value();
- } else if (optype_opt.defined()) {
+ } else if (optype_opt.has_value()) {
optype = optype_opt.value();
- } else if (composite_opt.defined()) {
+ } else if (composite_opt.has_value()) {
optype = composite_opt.value();
if (config_.target.size() > 0) {
optype = StringUtils::Replace(composite_opt.value(), config_.target +
".", "");
@@ -789,7 +789,7 @@ const std::tuple<String, String, String>
GraphBuilder::ParseFunc(const Function&
}
// get layout
const auto& layout_opt = func->GetAttr<String>(msc_attr::kLayout);
- if (layout_opt.defined()) {
+ if (layout_opt.has_value()) {
layout = layout_opt.value();
}
return std::make_tuple(node_name, optype, layout);
diff --git a/src/contrib/msc/core/printer/cpp_printer.cc
b/src/contrib/msc/core/printer/cpp_printer.cc
index f162f5db1e..6ae71860b6 100644
--- a/src/contrib/msc/core/printer/cpp_printer.cc
+++ b/src/contrib/msc/core/printer/cpp_printer.cc
@@ -28,9 +28,9 @@ namespace contrib {
namespace msc {
void CppPrinter::PrintTypedDoc(const LiteralDoc& doc) {
- const ObjectRef& value = doc->value;
+ const ffi::Any& value = doc->value;
bool defined = false;
- if (!value.defined()) {
+ if (value == nullptr) {
output_ << "nullptr";
defined = true;
} else if (const auto* int_imm = value.as<IntImmNode>()) {
@@ -217,7 +217,7 @@ void CppPrinter::PrintTypedDoc(const ClassDoc& doc) {
}
void CppPrinter::PrintTypedDoc(const CommentDoc& doc) {
- if (doc->comment.defined()) {
+ if (doc->comment.has_value()) {
output_ << "// " << doc->comment.value();
}
}
diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc
b/src/contrib/msc/core/printer/msc_base_printer.cc
index 0f0b24fd3a..31869f29bb 100644
--- a/src/contrib/msc/core/printer/msc_base_printer.cc
+++ b/src/contrib/msc/core/printer/msc_base_printer.cc
@@ -158,7 +158,7 @@ void MSCBasePrinter::PrintTypedDoc(const ExprStmtDoc& doc) {
}
void MSCBasePrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) {
- if (stmt->comment.defined()) {
+ if (stmt->comment.has_value()) {
if (multi_lines) {
for (const auto& l : StringUtils::Split(stmt->comment.value(), "\n")) {
PrintDoc(CommentDoc(l));
diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc
b/src/contrib/msc/core/printer/prototxt_printer.cc
index 44a915ae7b..82d15dc718 100644
--- a/src/contrib/msc/core/printer/prototxt_printer.cc
+++ b/src/contrib/msc/core/printer/prototxt_printer.cc
@@ -30,7 +30,7 @@ namespace tvm {
namespace contrib {
namespace msc {
-LiteralDoc PrototxtPrinter::ToLiteralDoc(const ObjectRef& obj) {
+LiteralDoc PrototxtPrinter::ToLiteralDoc(const ffi::Any& obj) {
if (obj.as<ffi::StringObj>()) {
return LiteralDoc::Str(Downcast<String>(obj), std::nullopt);
} else if (obj.as<IntImmNode>()) {
@@ -51,7 +51,7 @@ DictDoc PrototxtPrinter::ToDictDoc(const Map<String,
ffi::Any>& dict) {
if (pair.second.as<DictDocNode>()) {
values.push_back(Downcast<DictDoc>(pair.second));
} else {
- values.push_back(ToLiteralDoc(pair.second.cast<ObjectRef>()));
+ values.push_back(ToLiteralDoc(pair.second));
}
}
return DictDoc(keys, values);
@@ -65,7 +65,7 @@ DictDoc PrototxtPrinter::ToDictDoc(const
std::vector<std::pair<String, Any>>& di
if (pair.second.as<DictDocNode>()) {
values.push_back(Downcast<DictDoc>(pair.second));
} else {
- values.push_back(ToLiteralDoc(pair.second.cast<ObjectRef>()));
+ values.push_back(ToLiteralDoc(pair.second));
}
}
return DictDoc(keys, values);
diff --git a/src/contrib/msc/core/printer/prototxt_printer.h
b/src/contrib/msc/core/printer/prototxt_printer.h
index 97c0f91818..e760a179d8 100644
--- a/src/contrib/msc/core/printer/prototxt_printer.h
+++ b/src/contrib/msc/core/printer/prototxt_printer.h
@@ -50,7 +50,7 @@ class PrototxtPrinter : public MSCBasePrinter {
explicit PrototxtPrinter(const std::string& options = "") :
MSCBasePrinter(options) {}
/*! \brief Change object to LiteralDoc*/
- static LiteralDoc ToLiteralDoc(const ObjectRef& obj);
+ static LiteralDoc ToLiteralDoc(const ffi::Any& obj);
/*! \brief Change map to DictDoc*/
static DictDoc ToDictDoc(const Map<String, ffi::Any>& dict);
diff --git a/src/contrib/msc/core/printer/python_printer.cc
b/src/contrib/msc/core/printer/python_printer.cc
index f1a13c7fd0..184d7ce870 100644
--- a/src/contrib/msc/core/printer/python_printer.cc
+++ b/src/contrib/msc/core/printer/python_printer.cc
@@ -30,9 +30,9 @@ namespace contrib {
namespace msc {
void PythonPrinter::PrintTypedDoc(const LiteralDoc& doc) {
- const ObjectRef& value = doc->value;
+ const ffi::Any& value = doc->value;
bool defined = false;
- if (!value.defined()) {
+ if (value == nullptr) {
output_ << "None";
defined = true;
} else if (const auto* int_imm = value.as<IntImmNode>()) {
@@ -176,7 +176,7 @@ void PythonPrinter::PrintTypedDoc(const FunctionDoc& doc) {
output_ << ":";
- if (doc->comment.defined()) {
+ if (doc->comment.has_value()) {
IncreaseIndent();
MaybePrintComment(doc, true);
DecreaseIndent();
@@ -197,7 +197,7 @@ void PythonPrinter::PrintTypedDoc(const ClassDoc& doc) {
}
void PythonPrinter::PrintTypedDoc(const CommentDoc& doc) {
- if (doc->comment.defined()) {
+ if (doc->comment.has_value()) {
output_ << "# " << doc->comment.value();
}
}
@@ -234,7 +234,7 @@ void PythonPrinter::PrintTypedDoc(const SwitchDoc& doc) {
}
void PythonPrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) {
- if (stmt->comment.defined() && multi_lines) {
+ if (stmt->comment.has_value() && multi_lines) {
NewLine();
output_ << "\"\"\"";
for (const auto& l : StringUtils::Split(stmt->comment.value(), "\n")) {
diff --git a/src/contrib/msc/core/transform/bind_named_params.cc
b/src/contrib/msc/core/transform/bind_named_params.cc
index 481a3092fe..df534f4cfa 100644
--- a/src/contrib/msc/core/transform/bind_named_params.cc
+++ b/src/contrib/msc/core/transform/bind_named_params.cc
@@ -49,7 +49,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>>
NormalizeNamedBindings(
Map<relax::Var, relax::Expr> relax_var_remap;
- auto normalize_key = [&](ObjectRef obj) -> relax::Var {
+ auto normalize_key = [&](ffi::Any obj) -> relax::Var {
if (auto opt_str = obj.as<String>()) {
std::string str = opt_str.value();
auto it = string_lookup.find(str);
@@ -77,18 +77,17 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>>
NormalizeNamedBindings(
LOG(FATAL)
<< "Expected bound parameter to be a relax::Var, "
<< " or a string that uniquely identifies a relax::Var param within
the function. "
- << "However, received object " << obj << " of type " <<
obj->GetTypeKey();
+ << "However, received object " << obj << " of type " <<
obj.GetTypeKey();
}
};
- auto normalize_value = [&](Var key, ObjectRef obj) -> relax::Expr {
+ auto normalize_value = [&](Var key, ffi::Any obj) -> relax::Expr {
if (auto opt = obj.as<relax::Expr>()) {
return opt.value();
} else if (auto opt = obj.as<runtime::NDArray>()) {
const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName,
key->name_hint());
return Constant(opt.value(), StructInfo(), span);
} else {
- LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
- << " into relax expression";
+ LOG(FATAL) << "Cannot coerce object of type " << obj.GetTypeKey() << "
into relax expression";
}
};
@@ -130,7 +129,7 @@ IRModule BindNamedParam(IRModule m, String func_name,
Map<ObjectRef, ObjectRef>
if (relax_f->GetLinkageType() == LinkageType::kExternal) {
// Use global_symbol if it's external linkage
Optional<String> gsymbol =
relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- if (gsymbol.defined() && gsymbol.value() == func_name) {
+ if (gsymbol.has_value() && gsymbol.value() == func_name) {
Function f_after_bind =
FunctionBindNamedParams(GetRef<Function>(relax_f), bind_params);
new_module->Update(func_pr.first, f_after_bind);
}
diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc
b/src/contrib/msc/core/transform/fuse_tuple.cc
index 297f4a94fe..19b8f08f47 100644
--- a/src/contrib/msc/core/transform/fuse_tuple.cc
+++ b/src/contrib/msc/core/transform/fuse_tuple.cc
@@ -55,7 +55,7 @@ class TupleFuser : public ExprMutator {
main_var = gv;
} else {
const auto& name_opt = func->GetAttr<String>(attr::kComposite);
- if (name_opt.defined() && StringUtils::StartsWith(name_opt.value(),
target_)) {
+ if (name_opt.has_value() && StringUtils::StartsWith(name_opt.value(),
target_)) {
target_funcs_.Set(gv, Downcast<Function>(func));
}
}
@@ -76,7 +76,7 @@ class TupleFuser : public ExprMutator {
if (arg->IsInstance<TupleNode>()) {
String tuple_name;
const auto& name_opt =
target_funcs_[val->op]->GetAttr<String>(msc_attr::kUnique);
- if (name_opt.defined()) {
+ if (name_opt.has_value()) {
if (val->args.size() == 1) {
tuple_name = name_opt.value() + "_input";
} else {
diff --git a/src/contrib/msc/core/transform/inline_params.cc
b/src/contrib/msc/core/transform/inline_params.cc
index c68948cef5..086c475f6d 100644
--- a/src/contrib/msc/core/transform/inline_params.cc
+++ b/src/contrib/msc/core/transform/inline_params.cc
@@ -63,7 +63,7 @@ class ParamsInliner : public ExprMutator {
}
if (struct_info->IsInstance<FuncStructInfoNode>()) {
const auto& optype_opt = func->GetAttr<String>(msc_attr::kOptype);
- ICHECK(optype_opt.defined())
+ ICHECK(optype_opt.has_value())
<< "Can not find attr " << msc_attr::kOptype << " form extern
func";
extern_types_.Set(p, optype_opt.value());
continue;
diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc
b/src/contrib/msc/core/transform/set_byoc_attrs.cc
index 8687f72647..85819ea58d 100644
--- a/src/contrib/msc/core/transform/set_byoc_attrs.cc
+++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc
@@ -55,7 +55,7 @@ class ByocNameSetter : public ExprMutator {
continue;
}
const auto& name_opt = func->GetAttr<String>(attr::kCodegen);
- if (name_opt.defined() && name_opt.value() == target_) {
+ if (name_opt.has_value() && name_opt.value() == target_) {
const String& func_name = target_ + "_" + std::to_string(func_cnt);
const auto& new_func = Downcast<Function>(VisitExpr(func));
builder_->UpdateFunction(gv, WithAttr(new_func, msc_attr::kUnique,
func_name));
@@ -75,7 +75,7 @@ class ByocNameSetter : public ExprMutator {
if (val->op->IsInstance<relax::VarNode>()) {
ICHECK(local_funcs_.count(val->op)) << "Can not find local func " <<
val->op;
const auto& name_opt =
local_funcs_[val->op]->GetAttr<String>(msc_attr::kUnique);
- if (name_opt.defined()) {
+ if (name_opt.has_value()) {
val->span = SpanUtils::SetAttr(val->span, "name", name_opt.value());
}
}
diff --git a/src/contrib/msc/core/transform/set_expr_name.cc
b/src/contrib/msc/core/transform/set_expr_name.cc
index c9cf65e783..14ea3ccfec 100644
--- a/src/contrib/msc/core/transform/set_expr_name.cc
+++ b/src/contrib/msc/core/transform/set_expr_name.cc
@@ -160,7 +160,7 @@ class RelaxExprNameSetter : public ExprVisitor {
void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) {
ExprVisitor::VisitBinding_(binding, val);
const auto& name_opt = val->GetAttr<String>(attr::kComposite);
- if (name_opt.defined()) {
+ if (name_opt.has_value()) {
local_funcs_.Set(binding->var, GetRef<Function>(val));
}
}
@@ -260,9 +260,9 @@ class RelaxExprNameSetter : public ExprVisitor {
String optype;
const auto& comp_opt = func->GetAttr<String>(attr::kComposite);
const auto& code_opt = func->GetAttr<String>(attr::kCodegen);
- if (comp_opt.defined()) {
+ if (comp_opt.has_value()) {
optype = comp_opt.value();
- } else if (code_opt.defined()) {
+ } else if (code_opt.has_value()) {
optype = code_opt.value();
} else {
optype = "extern_func";
@@ -277,7 +277,7 @@ class RelaxExprNameSetter : public ExprVisitor {
String name;
// get from unique
const auto& name_opt = func->GetAttr<String>(msc_attr::kUnique);
- if (name_opt.defined()) {
+ if (name_opt.has_value()) {
return name_opt.value();
}
// get from exprs in the func
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
index caac67d6f5..f4a79602f5 100644
--- a/src/contrib/msc/core/utils.cc
+++ b/src/contrib/msc/core/utils.cc
@@ -108,6 +108,7 @@ const String CommonUtils::ToAttrKey(const String& key) {
return msc_attr::kConsumerType;
}
LOG_FATAL << "Unexpected key " << key;
+ TVM_FFI_UNREACHABLE();
}
bool StringUtils::Contains(const String& src_string, const String& sub_string)
{
@@ -261,12 +262,12 @@ const String StringUtils::Lower(const String& src_string)
{
return str;
}
-const String StringUtils::ToString(const runtime::ObjectRef& obj) {
+const String StringUtils::ToString(const ffi::Any& obj) {
String obj_string;
- if (!obj.defined()) {
+ if (obj == nullptr) {
obj_string = "";
- } else if (obj.as<ffi::StringObj>()) {
- obj_string = Downcast<String>(obj);
+ } else if (auto opt_str = obj.as<String>()) {
+ obj_string = *opt_str;
} else if (const auto* n = obj.as<IntImmNode>()) {
obj_string = std::to_string(n->value);
} else if (const auto* n = obj.as<FloatImmNode>()) {
@@ -370,7 +371,7 @@ const Span SpanUtils::SetAttr(const Span& span, const
String& key, const String&
return Span(SourceName::Get(new_source), 0, 0, 0, 0);
}
-const String SpanUtils::GetAttr(const Span& span, const String& key) {
+String SpanUtils::GetAttr(const Span& span, const String& key) {
if (span.defined() && span->source_name.defined()) {
Array<String> tokens{"<" + key + ">", "</" + key + ">"};
return StringUtils::GetClosureOnce(span->source_name->name, tokens[0],
tokens[1]);
diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h
index 84e3c66741..aeb7f9eb88 100644
--- a/src/contrib/msc/core/utils.h
+++ b/src/contrib/msc/core/utils.h
@@ -173,7 +173,7 @@ class StringUtils {
* \brief Change Object to String.
* \return The String.
*/
- TVM_DLL static const String ToString(const runtime::ObjectRef& obj);
+ TVM_DLL static const String ToString(const ffi::Any& obj);
};
/*!
@@ -287,7 +287,7 @@ class SpanUtils {
* \brief Get the value in <key>value</key> from the Span.
* \return The value String.
*/
- TVM_DLL static const String GetAttr(const Span& span, const String& key);
+ TVM_DLL static String GetAttr(const Span& span, const String& key);
/*!
* \brief Get all the key:value in format <key>value</key> from the Span.
diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc
b/src/contrib/msc/framework/tensorrt/codegen.cc
index 55a46789b8..684abbe38c 100644
--- a/src/contrib/msc/framework/tensorrt/codegen.cc
+++ b/src/contrib/msc/framework/tensorrt/codegen.cc
@@ -606,7 +606,7 @@ Array<runtime::Module> MSCTensorRTCompiler(Array<Function>
functions,
for (const auto& func : functions) {
VLOG(1) << "MSC.TensorRT partition:" << std::endl << func;
const auto& name_opt = func->GetAttr<String>(msc_attr::kUnique);
- ICHECK(name_opt.defined()) << "Can not find " << msc_attr::kUnique << "
from attrs";
+ ICHECK(name_opt.has_value()) << "Can not find " << msc_attr::kUnique << "
from attrs";
const auto& name = name_opt.value();
std::string func_name = GetExtSymbol(func);
ICHECK(target_option.count(name)) << "Can not find target option for " <<
name;
diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
index 5c2965de76..3d43c74958 100644
--- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
+++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
@@ -285,7 +285,7 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var,
const Call& src_call
// causal_mask
Expr s_value;
- if (!src_attrs->causal_mask.defined()) {
+ if (!src_attrs->causal_mask.has_value()) {
auto softmax_attrs = make_object<SoftmaxAttrs>();
softmax_attrs->axis = 2;
s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"act"), softmax_op,
diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc
index c77bbf89d0..3436d49b02 100644
--- a/src/ir/apply_pass_to_function.cc
+++ b/src/ir/apply_pass_to_function.cc
@@ -73,7 +73,7 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex,
std::string name = gvar->name_hint;
if (tvm::runtime::regex_match(name, func_name_regex)) {
at_least_one_function_matched_regex = true;
- if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+ if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value()) {
// Function may be mutated, but is an internal function. Mark
// it as externally-exposed, so that any call-tracing internal
// transforms do not remove this function, in case it its
diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc
index 40412ad307..77c0480f85 100644
--- a/src/ir/name_supply.cc
+++ b/src/ir/name_supply.cc
@@ -68,7 +68,6 @@ String NameSupplyNode::add_prefix_to_name(const String& name)
{
}
std::ostringstream ss;
- ICHECK(name.defined());
ss << prefix_ << "_" << name;
return ss.str();
}
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 7994ff6863..78278103fe 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -69,7 +69,7 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) {
// The JSON object is always an array whose first element is a tag. For
example:
// `['TENSOR', 'float32', [1, 224, 224, 3]]
// Step 1. Extract the tag
- String tag{ffi::ObjectPtr<ffi::StringObj>(nullptr)};
+ Optional<String> tag{std::nullopt};
try {
const ffi::ArrayObj* json_array = json_obj.as<ffi::ArrayObj>();
CHECK(json_array && json_array->size() >= 1);
@@ -124,7 +124,7 @@ ObjectRef TensorInfoNode::AsJSON() const {
static String tag = "TENSOR";
String dtype = DLDataTypeToString(this->dtype);
Array<Integer> shape = support::AsArray(this->shape);
- return Array<ObjectRef>{tag, dtype, shape};
+ return Array<ffi::Any>{tag, dtype, shape};
}
TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) {
diff --git a/src/meta_schedule/database/database.cc
b/src/meta_schedule/database/database.cc
index 11b3c6dc3e..a3a48c6a9f 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -46,7 +46,7 @@ ObjectRef WorkloadNode::AsJSON() const {
// Dump the JSON string to base64
std::string b64_mod = Base64Encode(json_mod);
// Output
- return Array<ObjectRef>{SHash2Str(this->shash), String(b64_mod)};
+ return Array<ffi::Any>{SHash2Str(this->shash), String(b64_mod)};
}
Workload Workload::FromJSON(const ObjectRef& json_obj) {
diff --git a/src/meta_schedule/database/database_utils.cc
b/src/meta_schedule/database/database_utils.cc
index 1f39688272..230e4d3509 100644
--- a/src/meta_schedule/database/database_utils.cc
+++ b/src/meta_schedule/database/database_utils.cc
@@ -75,7 +75,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) {
if (i != 0) {
os << ",";
}
- os << '"' << support::StrEscape(kv.first->data, kv.first->size) << '"';
+ os << '"' << support::StrEscape(kv.first.data(), kv.first.size()) << '"';
os << ":";
JSONDumps(kv.second, os);
}
diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc
b/src/meta_schedule/measure_callback/update_cost_model.cc
index 65089c5ab8..97062d1122 100644
--- a/src/meta_schedule/measure_callback/update_cost_model.cc
+++ b/src/meta_schedule/measure_callback/update_cost_model.cc
@@ -44,8 +44,8 @@ class UpdateCostModelNode : public MeasureCallbackNode {
pruned_candidate.reserve(n);
pruned_runner_result.reserve(n);
for (int i = 0; i < n; i++) {
- if (!builder_results[i]->error_msg.defined() && //
- (runner_results[i]->error_msg.defined() || //
+ if (!builder_results[i]->error_msg.has_value() && //
+ (runner_results[i]->error_msg.has_value() || //
(runner_results[i]->run_secs.defined() &&
Sum(runner_results[i]->run_secs.value()) > 0))) {
pruned_candidate.push_back(measure_candidates[i]);
diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc
b/src/meta_schedule/postproc/rewrite_reduction_block.cc
index 571ed5675e..c17c90fe2d 100644
--- a/src/meta_schedule/postproc/rewrite_reduction_block.cc
+++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc
@@ -147,7 +147,7 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule&
sch) {
tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv,
loop_rvs[decompose_point]);
// Rewrite auto tensorization related annotations
- if (tir::GetAnn<String>(block_sref,
tir::attr::meta_schedule_auto_tensorize).defined()) {
+ if (tir::GetAnn<String>(block_sref,
tir::attr::meta_schedule_auto_tensorize).has_value()) {
// Remove tensorization annotation as it shouldn't be propagated to
the init block.
sch->Unannotate(init_block_rv,
tir::attr::meta_schedule_auto_tensorize);
Optional<String> tensorize_init =
@@ -157,7 +157,7 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule&
sch) {
// Annotate to hint `RewriteTensorize` postprocessor even if
tensorize_init is std::nullopt.
sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize,
tensorize_init.value_or(""));
- if (tensorize_init.defined()) {
+ if (tensorize_init.has_value()) {
sch->Unannotate(block_rv,
tir::attr::meta_schedule_auto_tensorize_init);
sch->Unannotate(init_block_rv,
tir::attr::meta_schedule_auto_tensorize_init);
}
diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc
b/src/meta_schedule/schedule_rule/auto_inline.cc
index bcf927803e..746b0487ad 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -162,7 +162,7 @@ inline InlineType AutoInlineNode::CheckInline(const
tir::Schedule& sch,
if (producer_srefs.size() == 1 &&
tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
CanReverseComputeInline(state, block_sref) &&
- !GetAnn<String>(producer_srefs[0],
tir::attr::meta_schedule_auto_tensorize).defined()) {
+ !GetAnn<String>(producer_srefs[0],
tir::attr::meta_schedule_auto_tensorize).has_value()) {
return InlineType::kInlineIntoProducer;
}
}
diff --git a/src/meta_schedule/space_generator/post_order_apply.cc
b/src/meta_schedule/space_generator/post_order_apply.cc
index 9f71989fa4..780b404299 100644
--- a/src/meta_schedule/space_generator/post_order_apply.cc
+++ b/src/meta_schedule/space_generator/post_order_apply.cc
@@ -80,7 +80,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
continue;
}
if (!ScheduleRule::IsApplyCustomRule(sch_rule)) {
- if (tir::GetAnn<String>(sch->GetSRef(block_rv),
"schedule_rule").defined()) {
+ if (tir::GetAnn<String>(sch->GetSRef(block_rv),
"schedule_rule").has_value()) {
stack.emplace_back(sch, blocks);
continue;
}
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc
b/src/meta_schedule/task_scheduler/task_scheduler.cc
index 1d9d8e89ad..453239dd4a 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -71,7 +71,7 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner)
{
for (int i = 0; i < n; ++i) {
const MeasureCandidate& candidate = candidates[i];
const BuilderResult& builder_result = builder_results[i];
- if (builder_result->error_msg.defined()) {
+ if (builder_result->error_msg.has_value()) {
++n_build_errors;
continue;
}
@@ -88,7 +88,7 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner)
{
results.reserve(n);
for (int i = 0, j = 0; i < n; ++i) {
const BuilderResult& builder_result = builder_results[i];
- if (builder_result->error_msg.defined()) {
+ if (builder_result->error_msg.has_value()) {
results.push_back(RunnerFuture(
/*f_done=*/[]() -> bool { return true; },
/*f_result=*/
@@ -129,7 +129,7 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const
Array<RunnerResult>& r
TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) //
<< "[Task #" << task_id << ": " << name << "]
Trial #" << trials
<< ": Error in "
- << (builder_result->error_msg.defined() ?
"building" : "running")
+ << (builder_result->error_msg.has_value() ?
"building" : "running")
<< ":\n"
<< err << "\n"
<< sch->mod() << "\n"
diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc
index 4037f6757b..114afc0ad7 100644
--- a/src/meta_schedule/trace_apply.cc
+++ b/src/meta_schedule/trace_apply.cc
@@ -57,7 +57,6 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace,
Target target) {
for (const auto& inst : anchor_trace->insts) {
if (inst->kind.same_as(kind_get_block)) {
auto block_name = Downcast<String>(inst->attrs[0]);
- ICHECK(block_name.defined());
get_block_names.insert(block_name);
}
}
diff --git a/src/node/object_path.cc b/src/node/object_path.cc
index 6fd7a43a04..3e68e0d0ef 100644
--- a/src/node/object_path.cc
+++ b/src/node/object_path.cc
@@ -101,7 +101,7 @@ ObjectPath ObjectPathNode::Attr(const char* attr_key) const
{
}
ObjectPath ObjectPathNode::Attr(Optional<String> attr_key) const {
- if (attr_key.defined()) {
+ if (attr_key.has_value()) {
return ObjectPath(make_object<AttributeAccessPathNode>(this,
attr_key.value()));
} else {
return ObjectPath(make_object<UnknownAttributeAccessPathNode>(this));
@@ -235,7 +235,7 @@ RootPathNode::RootPathNode(Optional<String> name) :
ObjectPathNode(nullptr), nam
bool RootPathNode::LastNodeEqual(const ObjectPathNode* other_path) const {
const auto* other = static_cast<const RootPathNode*>(other_path);
- if (other->name.defined() != name.defined()) {
+ if (other->name.has_value() != name.has_value()) {
return false;
} else if (name && other->name) {
return name.value() == other->name.value();
diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc
index d273ba6a73..34b08994d0 100644
--- a/src/node/repr_printer.cc
+++ b/src/node/repr_printer.cc
@@ -26,6 +26,8 @@
#include <tvm/node/repr_printer.h>
#include <tvm/runtime/device_api.h>
+#include "../support/str_escape.h"
+
namespace tvm {
void ReprPrinter::Print(const ObjectRef& node) {
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index 2570a5e800..c3060fc91f 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -108,6 +108,9 @@ class NodeIndexer {
MakeIndex(kv.second);
}
}
+ } else if (node.type_index() == ffi::TypeIndex::kTVMFFIStr ||
+ node.type_index() == ffi::TypeIndex::kTVMFFIBytes) {
+ // skip content index for string and bytes
} else if (auto opt_object = node.as<const Object*>()) {
Object* n = const_cast<Object*>(opt_object.value());
// if the node already have repr bytes, no need to visit Attrs.
@@ -272,6 +275,10 @@ class JSONAttrGetter {
node_->data.push_back(node_index_->at(kv.second));
}
}
+ } else if (auto opt_str = node.as<String>()) {
+ node_->repr_bytes = *opt_str;
+ } else if (auto opt_bytes = node.as<Bytes>()) {
+ node_->repr_bytes = *opt_bytes;
} else if (auto opt_object = node.as<const Object*>()) {
Object* n = const_cast<Object*>(opt_object.value());
// do not need to print additional things once we have repr bytes.
@@ -399,6 +406,11 @@ class FieldDependencyFinder {
if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
return;
}
+ if (node.type_index() == ffi::TypeIndex::kTVMFFIStr ||
+ node.type_index() == ffi::TypeIndex::kTVMFFIBytes) {
+ // skip indexing content of string and bytes
+ return;
+ }
// Skip the objects that have their own string repr
if (jnode->repr_bytes.length() > 0 ||
reflection_->GetReprBytes(node.cast<const Object*>(), nullptr)) {
@@ -552,6 +564,10 @@ class JSONAttrSetter {
setter.ParseValue("v_device_type", &device_type);
setter.ParseValue("v_device_id", &device_id);
return Any(DLDevice{static_cast<DLDeviceType>(device_type), device_id});
+ } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr) {
+ return Any(String(jnode->repr_bytes));
+ } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) {
+ return Any(Bytes(jnode->repr_bytes));
} else {
return ObjectRef(reflection->CreateInitObject(jnode->type_key,
jnode->repr_bytes));
}
@@ -581,6 +597,9 @@ class JSONAttrSetter {
}
}
*node = result;
+ } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr ||
+ jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) {
+ // skip set attrs for string and bytes
} else if (auto opt_object = node->as<const Object*>()) {
Object* n = const_cast<Object*>(opt_object.value());
if (n == nullptr) return;
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 383f344fac..bf9d7b23d5 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -58,24 +58,12 @@ struct RefToObjectPtr : public ObjectRef {
}
};
-TVM_REGISTER_REFLECTION_VTABLE(ffi::StringObj)
- .set_creator([](const std::string& bytes) { return
RefToObjectPtr::Get(String(bytes)); })
- .set_repr_bytes([](const Object* n) -> std::string {
- return GetRef<String>(static_cast<const ffi::StringObj*>(n)).operator
std::string();
- });
-
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ffi::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ffi::StringObj*>(node.get());
p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
});
-TVM_REGISTER_REFLECTION_VTABLE(ffi::BytesObj)
- .set_creator([](const std::string& bytes) { return
RefToObjectPtr::Get(String(bytes)); })
- .set_repr_bytes([](const Object* n) -> std::string {
- return GetRef<ffi::Bytes>(static_cast<const ffi::BytesObj*>(n)).operator
std::string();
- });
-
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ffi::BytesObj>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ffi::BytesObj*>(node.get());
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index db216aba96..a1bc99ee75 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -140,7 +140,7 @@ class WellFormedChecker : public relax::ExprVisitor,
// check name in global var and gsymbol
Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- if (gsymbol.defined() && gsymbol != var->name_hint) {
+ if (gsymbol.has_value() && gsymbol != var->name_hint) {
Malformed(Diagnostic::Error(func->span)
<< "Name in GlobalVar is not equal to name in gsymbol: " << var
<< " != " << gsymbol.value());
diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h
b/src/relax/backend/contrib/codegen_json/codegen_json.h
index 8d2e97a011..b2c3e47c73 100644
--- a/src/relax/backend/contrib/codegen_json/codegen_json.h
+++ b/src/relax/backend/contrib/codegen_json/codegen_json.h
@@ -112,7 +112,7 @@ class OpAttrExtractor {
}
}
- void Visit(const char* key, runtime::ObjectRef* value) {
+ void Visit(const char* key, ffi::Any* value) {
if (const auto* an = (*value).as<ffi::ArrayObj>()) {
std::vector<std::string> attr;
for (size_t i = 0; i < an->size(); ++i) {
@@ -120,25 +120,23 @@ class OpAttrExtractor {
attr.push_back(std::to_string(im->value));
} else if (const auto* fm = (*an)[i].as<FloatImmNode>()) {
attr.push_back(Fp2String(fm->value));
- } else if (const auto* str = (*an)[i].as<ffi::StringObj>()) {
- String s = GetRef<String>(str);
- attr.push_back(s);
+ } else if (auto opt_str = (*an)[i].as<String>()) {
+ attr.push_back(*opt_str);
} else {
LOG(FATAL) << "Not supported type: " << (*an)[i].GetTypeKey();
}
}
SetNodeAttr(key, attr);
- } else if (!(*value).defined()) { // Skip NullValue
+ } else if (*value == nullptr) { // Skip NullValue
SetNodeAttr(key, std::vector<std::string>{""});
} else if (const auto* im = (*value).as<IntImmNode>()) {
SetNodeAttr(key, std::vector<std::string>{std::to_string(im->value)});
} else if (const auto* fm = (*value).as<FloatImmNode>()) {
SetNodeAttr(key, std::vector<std::string>{Fp2String(fm->value)});
- } else if (const auto* str = (*value).as<ffi::StringObj>()) {
- String s = GetRef<String>(str);
- SetNodeAttr(key, std::vector<std::string>{s});
+ } else if (const auto opt_str = (*value).as<ffi::String>()) {
+ SetNodeAttr(key, std::vector<std::string>{*opt_str});
} else {
- LOG(FATAL) << "Not yet supported type: " << (*value)->GetTypeKey() << ":
" << *value;
+ LOG(FATAL) << "Not yet supported type: " << (*value).GetTypeKey();
}
}
@@ -178,14 +176,12 @@ class OpAttrExtractor {
break;
}
case ffi::TypeIndex::kTVMFFINDArray: {
- runtime::NDArray value = field_value.cast<runtime::NDArray>();
- this->Visit(field_info->name.data, &value);
+ this->Visit(field_info->name.data, &field_value);
break;
}
default: {
if (field_value.type_index() >=
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
- ObjectRef obj = field_value.cast<ObjectRef>();
- this->Visit(field_info->name.data, &obj);
+ this->Visit(field_info->name.data, &field_value);
break;
}
LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey();
@@ -294,7 +290,7 @@ class JSONSerializer : public
relax::MemoizedExprTranslator<NodeEntries> {
} else if (const auto* fn = cn->op.as<FunctionNode>()) {
ICHECK(false);
auto pattern = fn->GetAttr<String>(attr::kPartitionedFromPattern);
- ICHECK(pattern.defined());
+ ICHECK(pattern.has_value());
std::vector<std::string> values;
values.push_back(pattern.value());
std::vector<dmlc::any> attr;
@@ -394,7 +390,7 @@ class JSONSerializer : public
relax::MemoizedExprTranslator<NodeEntries> {
name = op_node->name;
} else if (const auto* fn = cn->op.as<FunctionNode>()) {
auto comp = fn->GetAttr<String>(attr::kComposite);
- ICHECK(comp.defined()) << "JSON runtime only supports composite
functions.";
+ ICHECK(comp.has_value()) << "JSON runtime only supports composite
functions.";
name = comp.value();
} else {
LOG(FATAL) << "JSON runtime does not support calls to " <<
cn->op->GetTypeKey();
@@ -422,7 +418,7 @@ class JSONSerializer : public
relax::MemoizedExprTranslator<NodeEntries> {
}
NodeEntries VisitExpr_(const FunctionNode* fn) {
- ICHECK(fn->GetAttr<String>(attr::kComposite).defined())
+ ICHECK(fn->GetAttr<String>(attr::kComposite).has_value())
<< "JSON runtime only supports composite functions";
// FunctionNode should be handled by the caller.
diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc
b/src/relax/backend/contrib/tensorrt/codegen.cc
index 8665fe347e..53c1ca0397 100644
--- a/src/relax/backend/contrib/tensorrt/codegen.cc
+++ b/src/relax/backend/contrib/tensorrt/codegen.cc
@@ -141,7 +141,7 @@ class TensorRTJSONSerializer : public JSONSerializer {
const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
auto opt_composite = fn->GetAttr<String>(attr::kComposite);
- ICHECK(opt_composite.defined());
+ ICHECK(opt_composite.has_value());
std::string name = opt_composite.value();
// Collect the constants and attributes of all operator calls inside the
composite body.
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index 13fe82d4bc..27165db343 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -83,8 +83,8 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const
Expr&)> {
void Codegen(const Function& func) {
Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(gsymbol.defined()) << "there should be no local functions in Relax
VM codegen phase. "
- "Did you forget to apply LambdaLift or
AttachGlobalSymbol Pass?";
+ ICHECK(gsymbol.has_value()) << "there should be no local functions in
Relax VM codegen phase. "
+ "Did you forget to apply LambdaLift or
AttachGlobalSymbol Pass?";
Array<String> param_names;
for (Var param : func->params) {
@@ -293,12 +293,12 @@ class CodeGenVM : public
ExprFunctor<Instruction::Arg(const Expr&)> {
// At this point: all global var must corresponds to the right symbol.
// TODO(relax-team): switch everything to extern before splitting TIR/relax
// so we do not have idle global var here.
- if (!symbol.defined()) {
+ if (!symbol.has_value()) {
symbol = gvar->name_hint;
kind = VMFuncInfo::FuncKind::kPackedFunc;
}
// declare the function to be safe.
- ICHECK(symbol.defined());
+ ICHECK(symbol.has_value());
builder_->DeclareFunction(symbol.value(), kind);
return builder_->GetFunction(symbol.value());
}
diff --git a/src/relax/backend/vm/codegen_vm_tir.cc
b/src/relax/backend/vm/codegen_vm_tir.cc
index 042bd5301a..c7cf06ea9d 100644
--- a/src/relax/backend/vm/codegen_vm_tir.cc
+++ b/src/relax/backend/vm/codegen_vm_tir.cc
@@ -127,7 +127,7 @@ class CodeGenVMTIR : public
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
void EmitCallCPacked(const tir::PrimFunc& prim_func, const Array<PrimExpr>&
args,
int64_t dst_anylist_slot = -1) {
Optional<String> gsymbol =
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(gsymbol.defined()) << "All functions must have global symbol at
this phase";
+ ICHECK(gsymbol.has_value()) << "All functions must have global symbol at
this phase";
Array<PrimExpr> all_args;
// negative index indicate return value can be discarded, emit call_packed
if (dst_anylist_slot >= 0) {
@@ -148,8 +148,8 @@ class CodeGenVMTIR : public
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
tir::PrimFunc Codegen(const Function& func) {
Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(gsymbol.defined()) << "there should be no local functions in Relax
VM codegen phase. "
- "Did you forget to apply LambdaLift or
AttachGlobalSymbol Pass?";
+ ICHECK(gsymbol.has_value()) << "there should be no local functions in
Relax VM codegen phase. "
+ "Did you forget to apply LambdaLift or
AttachGlobalSymbol Pass?";
// initialize the state
stmt_stack_ = {};
registers_num_ = 0;
@@ -379,7 +379,7 @@ class CodeGenVMTIR : public
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
Optional<PrimExpr> VisitExpr_(const GlobalVarNode* op) final {
VMFuncInfo::FuncKind kind;
auto symbol = LookupFunction(GetRef<Expr>(op), &kind);
- ICHECK(symbol.defined());
+ ICHECK(symbol.has_value());
builder_->DeclareFunction(symbol.value(), kind);
return FuncListGet(builder_->GetFunction(symbol.value()).value());
}
@@ -452,7 +452,7 @@ class CodeGenVMTIR : public
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
VMFuncInfo::FuncKind kind;
auto symbol = LookupFunction(call_node->op, &kind);
- if (symbol.defined() && kind == VMFuncInfo::FuncKind::kPackedFunc) {
+ if (symbol.has_value() && kind == VMFuncInfo::FuncKind::kPackedFunc) {
// primfunc in the same module.
// use cpacked to directly invoke without named based lookup
if (Optional<tir::PrimFunc> prim_func = LookupPrimFunc(symbol.value())) {
diff --git a/src/relax/ir/dataflow_expr_rewriter.cc
b/src/relax/ir/dataflow_expr_rewriter.cc
index 123b18d81c..5462154bab 100644
--- a/src/relax/ir/dataflow_expr_rewriter.cc
+++ b/src/relax/ir/dataflow_expr_rewriter.cc
@@ -689,7 +689,7 @@ PatternMatchingRewriter
PatternMatchingRewriter::FromModule(IRModule mod) {
Map<GlobalVar, BaseFunc> new_subroutines;
for (const auto& [gvar, func] : mod->functions) {
if (gvar->name_hint != "pattern" && gvar->name_hint != "replacement") {
- bool is_public =
func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_public =
func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
CHECK(!is_public) << "ValueError: "
<< "Expected module to have no publicly-exposed
functions "
<< "other than 'pattern' and 'replacement'. "
diff --git a/src/relax/transform/alter_op_impl.cc
b/src/relax/transform/alter_op_impl.cc
index 9ae492262d..4013d3aad1 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -123,7 +123,7 @@ class AlterOpImplMutator : public ExprMutator {
// If the callee does not have kOperatorName attribute or no replacement
is requested for
// it, nothing to do here.
- if (!maybe_op_kind.defined() || op_impl_map_.count(maybe_op_kind.value())
== 0) return call;
+ if (!maybe_op_kind.has_value() ||
op_impl_map_.count(maybe_op_kind.value()) == 0) return call;
auto op_kind = maybe_op_kind.value();
const auto& replacement_func = op_impl_map_[op_kind];
diff --git a/src/relax/transform/attach_global_symbol.cc
b/src/relax/transform/attach_global_symbol.cc
index 5c2fb9a797..9ef135608d 100644
--- a/src/relax/transform/attach_global_symbol.cc
+++ b/src/relax/transform/attach_global_symbol.cc
@@ -55,7 +55,7 @@ Pass AttachGlobalSymbol() {
new_func = WithAttr(GetRef<Function>(relax_func),
tvm::attr::kGlobalSymbol, new_name);
}
- if (new_name.defined() && (!old_name.defined() || old_name.value() !=
new_name.value())) {
+ if (new_name.has_value() && (!old_name.has_value() || old_name.value()
!= new_name.value())) {
updates->Add(gvar, new_func);
if (new_name.value() != gvar->name_hint) {
GlobalVar new_gvar(new_name.value());
diff --git a/src/relax/transform/bind_params.cc
b/src/relax/transform/bind_params.cc
index 6103dbbaec..49fe469e89 100644
--- a/src/relax/transform/bind_params.cc
+++ b/src/relax/transform/bind_params.cc
@@ -97,7 +97,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>>
NormalizeBindings(
Map<relax::Var, relax::Expr> relax_var_remap;
- auto normalize_key = [&](ObjectRef obj) -> relax::Var {
+ auto normalize_key = [&](ffi::Any obj) -> relax::Var {
if (auto opt_str = obj.as<String>()) {
std::string str = opt_str.value();
auto it = string_lookup.find(str);
@@ -125,17 +125,16 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>>
NormalizeBindings(
LOG(FATAL)
<< "Expected bound parameter to be a relax::Var, "
<< " or a string that uniquely identifies a relax::Var param within
the function. "
- << "However, received object " << obj << " of type " <<
obj->GetTypeKey();
+ << "However, received object " << obj << " of type " <<
obj.GetTypeKey();
}
};
- auto normalize_value = [&](ObjectRef obj) -> relax::Expr {
+ auto normalize_value = [&](ffi::Any obj) -> relax::Expr {
if (auto opt = obj.as<relax::Expr>()) {
return opt.value();
} else if (auto opt = obj.as<runtime::NDArray>()) {
return Constant(opt.value());
} else {
- LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
- << " into relax expression";
+ LOG(FATAL) << "Cannot coerce object of type " << obj.GetTypeKey() << "
into relax expression";
}
};
@@ -181,7 +180,7 @@ IRModule BindParam(IRModule m, String func_name,
Map<ObjectRef, ObjectRef> bind_
if (relax_f->GetLinkageType() == LinkageType::kExternal) {
// Use global_symbol if it's external linkage
Optional<String> gsymbol =
relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- if (gsymbol.defined() && gsymbol.value() == func_name) {
+ if (gsymbol.has_value() && gsymbol.value() == func_name) {
Function f_after_bind =
FunctionBindParams(GetRef<Function>(relax_f), bind_params);
new_module->Update(func_pr.first, f_after_bind);
}
diff --git a/src/relax/transform/bind_symbolic_vars.cc
b/src/relax/transform/bind_symbolic_vars.cc
index 7fd75ed7d3..22c557874c 100644
--- a/src/relax/transform/bind_symbolic_vars.cc
+++ b/src/relax/transform/bind_symbolic_vars.cc
@@ -31,7 +31,7 @@
namespace tvm {
namespace relax {
-Function FunctionBindSymbolicVars(Function func, Map<ObjectRef, PrimExpr>
obj_remap) {
+Function FunctionBindSymbolicVars(Function func, Map<ffi::Any, PrimExpr>
obj_remap) {
// Early bail-out if no updates need to be made.
if (obj_remap.empty()) {
return func;
@@ -50,7 +50,7 @@ Function FunctionBindSymbolicVars(Function func,
Map<ObjectRef, PrimExpr> obj_re
// Replacement map to be used when rewriting the function.
Map<tir::Var, PrimExpr> var_remap;
for (const auto& [key, replacement] : obj_remap) {
- if (auto opt = key.as<String>()) {
+ if (auto opt = key.as<ffi::String>()) {
String string_key = opt.value();
auto it = string_lookup.find(string_key);
CHECK(it != string_lookup.end())
@@ -74,7 +74,7 @@ Function FunctionBindSymbolicVars(Function func,
Map<ObjectRef, PrimExpr> obj_re
var_remap.Set(var, replacement);
} else {
LOG(FATAL) << "Expected symbolic variable to be a tir::Var or a string
name, "
- << "but " << key << " was of type " << key->GetTypeKey();
+ << "but " << key << " was of type " << key.GetTypeKey();
}
}
@@ -90,15 +90,15 @@ Function FunctionBindSymbolicVars(Function func,
Map<ObjectRef, PrimExpr> obj_re
}
namespace {
-IRModule ModuleBindSymbolicVars(IRModule mod, Map<ObjectRef, PrimExpr>
binding_map) {
- std::unordered_set<const Object*> used;
+IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, PrimExpr>
binding_map) {
+ std::unordered_set<ffi::Any, ffi::AnyHash, ffi::AnyEqual> used;
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<Function>()) {
auto func = opt.value();
// Collect bindings that are used by this function.
- auto func_binding_map = [&]() -> Map<ObjectRef, PrimExpr> {
+ auto func_binding_map = [&]() -> Map<ffi::Any, PrimExpr> {
std::unordered_set<std::string> var_names;
std::unordered_set<const tir::VarNode*> vars;
for (const auto& var : DefinedSymbolicVars(func)) {
@@ -106,7 +106,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod,
Map<ObjectRef, PrimExpr> binding_m
vars.insert(var.get());
}
- Map<ObjectRef, PrimExpr> out;
+ Map<ffi::Any, PrimExpr> out;
for (const auto& [key, replacement] : binding_map) {
bool used_by_function = false;
if (auto opt = key.as<String>()) {
@@ -115,10 +115,10 @@ IRModule ModuleBindSymbolicVars(IRModule mod,
Map<ObjectRef, PrimExpr> binding_m
used_by_function = vars.count(ptr);
} else {
LOG(FATAL) << "Expected symbolic variable to be a tir::Var "
- << "or a string name, but " << key << " was of type "
<< key->GetTypeKey();
+ << "or a string name, but " << key << " was of type "
<< key.GetTypeKey();
}
if (used_by_function) {
- used.insert(key.get());
+ used.insert(key);
out.Set(key, replacement);
}
}
@@ -132,9 +132,9 @@ IRModule ModuleBindSymbolicVars(IRModule mod,
Map<ObjectRef, PrimExpr> binding_m
}
}
- Array<ObjectRef> unused;
+ Array<ffi::Any> unused;
for (const auto& [key, replacement] : binding_map) {
- if (!used.count(key.get())) {
+ if (!used.count(key)) {
unused.push_back(key);
}
}
diff --git a/src/relax/transform/expand_tuple_arguments.cc
b/src/relax/transform/expand_tuple_arguments.cc
index 17ab181ec9..5b711b7675 100644
--- a/src/relax/transform/expand_tuple_arguments.cc
+++ b/src/relax/transform/expand_tuple_arguments.cc
@@ -33,7 +33,7 @@ template <typename T, typename U>
using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
Optional<Function> ExpandParams(Function func) {
- bool is_exposed =
func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_exposed =
func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
if (is_exposed) return std::nullopt;
bool has_tuple_param = std::any_of(
diff --git a/src/relax/transform/few_shot_tuning.cc
b/src/relax/transform/few_shot_tuning.cc
index a9ebdfebf3..819de35e20 100644
--- a/src/relax/transform/few_shot_tuning.cc
+++ b/src/relax/transform/few_shot_tuning.cc
@@ -86,7 +86,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc&
prim_func, const Target&
int idx = 0;
bool no_valid = true; // whether there is no valid schedule in this
iteration
for (const meta_schedule::BuilderResult& builder_result : builder_results)
{
- if (!builder_result->error_msg.defined()) {
+ if (!builder_result->error_msg.has_value()) {
results.push_back(candidates.value()[idx]->sch->mod());
valid_count--;
no_valid = false;
@@ -98,7 +98,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc&
prim_func, const Target&
Array<meta_schedule::RunnerInput> runner_inputs;
int idx = 0;
for (const meta_schedule::BuilderResult& builder_result :
builder_results) {
- if (!builder_result->error_msg.defined()) {
+ if (!builder_result->error_msg.has_value()) {
runner_inputs.push_back(meta_schedule::RunnerInput(
/*artifact_path=*/builder_result->artifact_path.value(),
/*device_type=*/target->kind->name,
@@ -109,7 +109,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc&
prim_func, const Target&
Array<meta_schedule::RunnerFuture> runner_futures =
runner->Run(runner_inputs);
for (const meta_schedule::RunnerFuture& runner_future : runner_futures) {
meta_schedule::RunnerResult runner_result = runner_future->Result();
- if (runner_result->error_msg.defined()) {
+ if (runner_result->error_msg.has_value()) {
costs.push_back(1e10);
} else {
double sum = 0;
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index bfc278b9c7..c6f9470167 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -120,7 +120,7 @@ class GraphCreator : public ExprVisitor {
// true.
const auto* func = it.second.as<FunctionNode>();
if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive) ||
- func->GetAttr<String>(attr::kCodegen).defined()) {
+ func->GetAttr<String>(attr::kCodegen).has_value()) {
continue;
}
creator(GetRef<Function>(func));
@@ -733,7 +733,7 @@ class OperatorFusor : public ExprMutator {
// Only visit Relax functions with neither attr::kPrimitive nor
// attr::kCodegen.
if (func->IsInstance<relax::FunctionNode>() &&
!func->HasNonzeroAttr(attr::kPrimitive) &&
- !func->GetAttr<String>(attr::kCodegen).defined()) {
+ !func->GetAttr<String>(attr::kCodegen).has_value()) {
auto updated_func = Downcast<Function>(VisitExpr(func));
builder_->UpdateFunction(gv, updated_func);
}
@@ -1263,8 +1263,8 @@ class CompositeFunctionAnnotator : public ExprMutator {
}
const auto& base_func = (*it).second;
if (const auto* func = base_func.as<FunctionNode>()) {
- if (func->GetAttr<String>(attr::kComposite).defined() ||
- func->GetAttr<String>(attr::kCodegen).defined()) {
+ if (func->GetAttr<String>(attr::kComposite).has_value() ||
+ func->GetAttr<String>(attr::kCodegen).has_value()) {
continue;
}
@@ -1363,8 +1363,8 @@ IRModule FuseOpsByPattern(const
tvm::Array<transform::FusionPattern>& patterns,
}
const FunctionNode* function = base_func.as<FunctionNode>();
if (function->GetAttr<bool>(attr::kPrimitive).value_or(false) ||
- function->GetAttr<String>(attr::kComposite).defined() ||
- function->GetAttr<String>(attr::kCodegen).defined()) {
+ function->GetAttr<String>(attr::kComposite).has_value() ||
+ function->GetAttr<String>(attr::kCodegen).has_value()) {
continue;
}
entry_functions.push_back(Downcast<Function>(base_func));
diff --git a/src/relax/transform/inline_functions.cc
b/src/relax/transform/inline_functions.cc
index 2c393a4a93..44363e1946 100644
--- a/src/relax/transform/inline_functions.cc
+++ b/src/relax/transform/inline_functions.cc
@@ -178,7 +178,7 @@ Pass InlinePrivateFunctions() {
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<relax::Function>()) {
auto func = opt.value();
- bool is_private =
!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_private =
!func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
if (is_private) {
replacements.Set(gvar, func);
}
diff --git a/src/relax/transform/lazy_transform_params.cc
b/src/relax/transform/lazy_transform_params.cc
index 23c99eb928..9b59b680ec 100644
--- a/src/relax/transform/lazy_transform_params.cc
+++ b/src/relax/transform/lazy_transform_params.cc
@@ -249,7 +249,7 @@ namespace transform {
Pass LazyGetInput() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
- if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+ if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value()) {
return func;
}
return WithLazyInputs(func);
@@ -267,7 +267,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
Pass LazySetOutput() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
- if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+ if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value()) {
return func;
}
return WithLazyOutputs(func);
diff --git a/src/relax/transform/merge_composite_functions.cc
b/src/relax/transform/merge_composite_functions.cc
index 3c8511f8ed..025e91c3c3 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -270,7 +270,7 @@ class CompositeGroupsBuilder : public
MemoizedExprTranslator<Group*> {
std::vector<Group*> GetGroupsToMerge(const CallNode* call) {
Optional<String> codegen_name = GetCodegenName(call->op);
- if (!codegen_name.defined()) {
+ if (!codegen_name.has_value()) {
return {};
}
diff --git a/src/relax/transform/meta_schedule.cc
b/src/relax/transform/meta_schedule.cc
index 85b021e2f5..acad7d1544 100644
--- a/src/relax/transform/meta_schedule.cc
+++ b/src/relax/transform/meta_schedule.cc
@@ -84,7 +84,7 @@ Pass MetaScheduleApplyDatabase(Optional<String> work_dir,
bool enable_warning =
if (Database::Current().defined()) {
database = Database::Current().value();
} else {
- ICHECK(work_dir.defined());
+ ICHECK(work_dir.has_value());
String path_workload = work_dir.value() + "/database_workload.json";
String path_tuning_record = work_dir.value() +
"/database_tuning_record.json";
LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload
diff --git a/src/relax/transform/remove_unused_outputs.cc
b/src/relax/transform/remove_unused_outputs.cc
index 3a2dd9c219..26145cde1d 100644
--- a/src/relax/transform/remove_unused_outputs.cc
+++ b/src/relax/transform/remove_unused_outputs.cc
@@ -44,7 +44,7 @@ class PartialTupleUsageCollector : ExprVisitor {
PMap<GlobalVar, size_t> num_outputs;
for (const auto& [gvar, base_func] : mod->functions) {
- bool is_exposed =
base_func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_exposed =
base_func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
if (!is_exposed) {
if (auto relax_func = base_func.as<FunctionNode>()) {
diff --git a/src/relax/transform/remove_unused_parameters.cc
b/src/relax/transform/remove_unused_parameters.cc
index 778e551e9a..2e88ebe417 100644
--- a/src/relax/transform/remove_unused_parameters.cc
+++ b/src/relax/transform/remove_unused_parameters.cc
@@ -55,7 +55,7 @@ struct CalleeAnalysis {
};
std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
- bool is_exposed =
func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_exposed =
func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
if (is_exposed) return std::nullopt;
auto free_relax_vars = [&]() -> PSet<Var> {
diff --git a/src/relax/transform/split_call_tir_by_pattern.cc
b/src/relax/transform/split_call_tir_by_pattern.cc
index e669979d09..41528c7d86 100644
--- a/src/relax/transform/split_call_tir_by_pattern.cc
+++ b/src/relax/transform/split_call_tir_by_pattern.cc
@@ -571,7 +571,7 @@ std::pair<PrimFunc, Optional<PrimFunc>>
SplitFunctions(PrimFunc func,
if (match_results.empty()) {
return {func, std::nullopt};
}
- Array<ObjectRef> codegen_result = f_codegen(match_results);
+ Array<ffi::Any> codegen_result = f_codegen(match_results);
ICHECK(codegen_result.size() == 3);
String library_code = Downcast<String>(codegen_result[0]);
int num_matched_ops = Downcast<Integer>(codegen_result[1])->value;
@@ -662,7 +662,7 @@ void StringReplace(std::string* subject, const std::string&
search, const std::s
tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String
global_symbol) {
using namespace tvm::tir;
Optional<String> library_code = pf->attrs.GetAttr<String>(kLibraryKernel);
- if (!library_code.defined()) {
+ if (!library_code.has_value()) {
return GetRef<tir::PrimFunc>(pf);
}
std::string source = library_code.value();
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
index 9e54a8fac8..f2e185ebd2 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -379,30 +379,16 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer*
ana,
// appear in the **function signature**.
Map<String, IntImm> var_upper_bound_attr_raw =
func->GetAttr<Map<String,
IntImm>>("tir_var_upper_bound").value_or(Map<String, IntImm>());
- Array<ObjectRef> non_negative_var_attr_raw =
-
func->GetAttr<Array<ObjectRef>>("tir_non_negative_var").value_or(Array<ObjectRef>());
+ Array<String> non_negative_var_attr_raw =
+
func->GetAttr<Array<String>>("tir_non_negative_var").value_or(Array<String>());
std::unordered_map<String, IntImm> var_upper_bound_attr;
std::unordered_set<String> non_negative_var_attr;
// We manually check the value type to ensure the values are all positive
IntImm.
- for (auto it : var_upper_bound_attr_raw) {
- const auto* key = it.first.as<ffi::StringObj>();
- const auto* value = it.second.as<IntImmNode>();
- CHECK(key != nullptr)
- << "The entry key of attr `tir_var_upper_bound` should be string.
However "
- << it.first->GetTypeKey() << " is got.";
- CHECK(value != nullptr)
- << "The entry value of attr `tir_var_upper_bound` should be integer.
However "
- << it.second.GetTypeKey() << " is got.";
- CHECK_GT(value->value, 0)
- << "The entry value of attr `tir_var_upper_bound` should be a positive
integer, while "
- << value->value << " is got.";
- var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
- }
- for (ObjectRef var_name : non_negative_var_attr_raw) {
- const auto* key = var_name.as<ffi::StringObj>();
- CHECK(key != nullptr) << "The element of attr `tir_non_negative_var`
should be string. However "
- << var_name->GetTypeKey() << " is got.";
- non_negative_var_attr.insert(GetRef<String>(key));
+ for (auto [key, value] : var_upper_bound_attr_raw) {
+ var_upper_bound_attr[key] = value;
+ }
+ for (const String& var_name : non_negative_var_attr_raw) {
+ non_negative_var_attr.insert(var_name);
}
Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(func));
for (const tir::Var& tir_var : var_in_signature) {
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index edd953e312..009d002607 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -125,7 +125,7 @@ TVM_DLL IRModule DeadCodeElimination(const IRModule& mod,
Array<String> entry_fu
*/
inline std::string GetExtSymbol(const Function& func) {
const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(name_node.defined()) << "Fail to retrieve external symbol.";
+ ICHECK(name_node.has_value()) << "Fail to retrieve external symbol.";
return std::string(name_node.value());
}
diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc
index 06eb0284f9..947d8884a5 100644
--- a/src/runtime/device_api.cc
+++ b/src/runtime/device_api.cc
@@ -107,7 +107,7 @@ static size_t GetDataAlignment(const DLDataType dtype) {
}
size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional<String> mem_scope)
{
- if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value()
== "global") {
+ if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value()
== "global") {
size_t size = 1;
for (int i = 0; i < arr.ndim; ++i) {
size *= static_cast<size_t>(arr.shape[i]);
@@ -121,7 +121,7 @@ size_t DeviceAPI::GetDataSize(const DLTensor& arr,
Optional<String> mem_scope) {
void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape,
DLDataType dtype,
Optional<String> mem_scope) {
- if (!mem_scope.defined() || mem_scope.value() == "" || mem_scope.value() ==
"global") {
+ if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value()
== "global") {
// by default, we can always redirect to the flat memory allocations
DLTensor temp;
temp.data = nullptr;
diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h
index 30a1e6ed66..a6af311e6e 100644
--- a/src/runtime/disco/protocol.h
+++ b/src/runtime/disco/protocol.h
@@ -81,7 +81,7 @@ struct DiscoProtocol {
}
support::Arena arena_;
- std::vector<ObjectRef> object_arena_;
+ std::vector<Any> object_arena_;
friend struct RPCReference;
};
@@ -175,7 +175,7 @@ inline void
DiscoProtocol<SubClassType>::WriteObject(Object* obj) {
template <class SubClassType>
inline void DiscoProtocol<SubClassType>::ReadObject(TVMFFIAny* out) {
SubClassType* self = static_cast<SubClassType*>(this);
- ObjectRef result{nullptr};
+ ffi::Any result{nullptr};
uint32_t type_index;
self->template Read<uint32_t>(&type_index);
if (type_index == TypeIndex::kRuntimeDiscoDRef) {
diff --git a/src/runtime/hexagon/hexagon_device_api.cc
b/src/runtime/hexagon/hexagon_device_api.cc
index fd8d6e53bf..a26f113f1e 100644
--- a/src/runtime/hexagon/hexagon_device_api.cc
+++ b/src/runtime/hexagon/hexagon_device_api.cc
@@ -74,7 +74,7 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim,
const int64_t* shap
// in Hexagon's "indirect tensor" format:
// - shape[0] indicates the number of tensor-content memory
allocations.
// - shape[1] indicates the size of each tensor-content memory
allocation.
- if (!mem_scope.defined() || mem_scope.value() == "global") {
+ if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value()
== "global") {
return DeviceAPI::AllocDataSpace(dev, ndim, shape, dtype, mem_scope);
}
diff --git a/src/runtime/memory/memory_manager.cc
b/src/runtime/memory/memory_manager.cc
index 7634100809..cef445ee91 100644
--- a/src/runtime/memory/memory_manager.cc
+++ b/src/runtime/memory/memory_manager.cc
@@ -234,10 +234,10 @@ NDArray Allocator::Empty(ffi::Shape shape, DLDataType
dtype, DLDevice dev,
size_t size = ffi::GetDataSize(shape.Product(), dtype);
Buffer buffer;
- if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value()
== "global") {
+ if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) ==
"global") {
buffer = this->Alloc(dev, size, alignment, dtype);
} else {
- buffer = this->Alloc(dev, shape, dtype, mem_scope.value());
+ buffer = this->Alloc(dev, shape, dtype, *mem_scope);
}
return NDArray::FromNDAlloc(BufferAlloc(buffer), shape, dtype, dev);
}
diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h
index 8acefecaad..aa629aef50 100644
--- a/src/runtime/meta_data.h
+++ b/src/runtime/meta_data.h
@@ -40,8 +40,6 @@ namespace runtime {
inline String get_name_mangled(const String& module_name, const String& name) {
std::stringstream ss;
- ICHECK(module_name.defined());
- ICHECK(name.defined());
ss << module_name << "_" << name;
return ss.str();
}
diff --git a/src/runtime/opencl/opencl_device_api.cc
b/src/runtime/opencl/opencl_device_api.cc
index 15616f1267..176884383d 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -77,7 +77,7 @@ ImageInfo GetImageInfo(const cl::BufferDescriptor* desc,
const DLTensor* tensor)
cl::BufferDescriptor::MemoryLayout cl::BufferDescriptor::MemoryLayoutFromScope(
Optional<String> mem_scope) {
- if (!mem_scope.defined()) {
+ if (!mem_scope.has_value()) {
return cl::BufferDescriptor::MemoryLayout::kBuffer1D;
} else if (mem_scope.value() == "global.texture") {
return cl::BufferDescriptor::MemoryLayout::kImage2DActivation;
@@ -277,7 +277,7 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t
width, size_t height, D
back_buffer->mbuf = buf;
}
- if (!mem_scope.defined()) {
+ if (!mem_scope.has_value()) {
mem_scope = String("global.texture");
}
return AllocCLImage(dev, back_buffer, width, height, row_pitch, type_hint,
mem_scope);
@@ -286,7 +286,7 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t
width, size_t height, D
void* OpenCLWorkspace::AllocDataSpace(Device dev, int ndim, const int64_t*
shape, DLDataType dtype,
Optional<String> mem_scope) {
this->Init();
- if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value()
== "global") {
+ if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) ==
"global") {
size_t size = GetMemObjectSize(dev, ndim, shape, dtype);
cl::BufferDescriptor* ret_buffer = nullptr;
auto buf = MemoryManager::GetOrCreateAllocator(dev, AllocatorType::kPooled)
@@ -349,7 +349,7 @@ void* OpenCLWorkspace::AllocCLImage(Device dev, void*
back_buffer, size_t width,
}
size_t OpenCLWorkspace::GetDataSize(const DLTensor& arr, Optional<String>
mem_scope) {
- if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value()
== "global") {
+ if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) ==
"global") {
return DeviceAPI::GetDataSize(arr);
}
cl_uint row_align = GetImageAlignment(GetThreadEntry()->device.device_id);
@@ -366,7 +366,7 @@ void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void*
data, ffi::Shape sha
// Fall back for devices w/o "cl_khr_image2d_from_buffer"
if (!IsBufferToImageSupported(dev.device_id)) {
cl::BufferDescriptor* ret_desc = desc; // buffer -> buffer
- if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value()
== "global") {
+ if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) ==
"global") {
if (desc->layout != cl::BufferDescriptor::MemoryLayout::kBuffer1D) {
// image -> buffer
size_t nbytes = GetMemObjectSize(dev, shape.size(), shape.data(),
dtype);
@@ -389,7 +389,7 @@ void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void*
data, ffi::Shape sha
return ret_desc;
}
- if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value()
== "global") {
+ if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) ==
"global") {
if (desc->layout == cl::BufferDescriptor::MemoryLayout::kBuffer1D) {
// buffer -> buffer
return desc;
diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc
index 390f383c00..4e29dcc392 100644
--- a/src/runtime/profiling.cc
+++ b/src/runtime/profiling.cc
@@ -151,7 +151,7 @@ void Profiler::Start() {
}
void Profiler::StartCall(String name, Device dev,
- std::unordered_map<std::string, ObjectRef>
extra_metrics) {
+ std::unordered_map<std::string, ffi::Any>
extra_metrics) {
std::vector<std::pair<MetricCollector, ObjectRef>> objs;
for (auto& collector : collectors_) {
ObjectRef obj = collector->Start(dev);
@@ -162,7 +162,7 @@ void Profiler::StartCall(String name, Device dev,
in_flight_.push(CallFrame{dev, name, Timer::Start(dev), extra_metrics,
objs});
}
-void Profiler::StopCall(std::unordered_map<std::string, ObjectRef>
extra_metrics) {
+void Profiler::StopCall(std::unordered_map<std::string, ffi::Any>
extra_metrics) {
CallFrame cf = in_flight_.top();
cf.timer->Stop();
for (auto& p : extra_metrics) {
@@ -172,7 +172,7 @@ void Profiler::StopCall(std::unordered_map<std::string,
ObjectRef> extra_metrics
for (const auto& obj : cf.extra_collectors) {
auto collector_metrics = obj.first->Stop(obj.second);
for (auto& p : collector_metrics) {
- cf.extra_metrics[p.first] = p.second.cast<ObjectRef>();
+ cf.extra_metrics[p.first] = p.second;
}
}
in_flight_.pop();
@@ -303,10 +303,10 @@ String ReportNode::AsCSV() const {
}
namespace {
-void metric_as_json(std::ostream& os, ObjectRef o) {
- if (o.as<ffi::StringObj>()) {
+void metric_as_json(std::ostream& os, ffi::Any o) {
+ if (auto opt_str = o.as<String>()) {
os << "{\"string\":"
- << "\"" << Downcast<String>(o) << "\""
+ << "\"" << *opt_str << "\""
<< "}";
} else if (const CountNode* n = o.as<CountNode>()) {
os << "{\"count\":" << n->value << "}";
@@ -320,7 +320,7 @@ void metric_as_json(std::ostream& os, ObjectRef o) {
os << "{\"ratio\":" <<
std::setprecision(std::numeric_limits<double>::max_digits10)
<< std::fixed << n->ratio << "}";
} else {
- LOG(FATAL) << "Unprintable type " << o->GetTypeKey();
+ LOG(FATAL) << "Unprintable type " << o.GetTypeKey();
}
}
} // namespace
@@ -340,7 +340,7 @@ String ReportNode::AsJSON() const {
s << "{";
for (const auto& kv : calls[i]) {
s << "\"" << kv.first << "\":";
- metric_as_json(s, kv.second.cast<ObjectRef>());
+ metric_as_json(s, kv.second);
if (j < calls[i].size() - 1) {
s << ",";
}
@@ -360,7 +360,7 @@ String ReportNode::AsJSON() const {
s << "\"" << dev_kv.first << "\":{";
for (const auto& metric_kv : dev_kv.second) {
s << "\"" << metric_kv.first << "\":";
- metric_as_json(s, metric_kv.second.cast<ObjectRef>());
+ metric_as_json(s, metric_kv.second);
if (j < dev_kv.second.size() - 1) {
s << ",";
}
@@ -378,7 +378,7 @@ String ReportNode::AsJSON() const {
size_t k = 0;
for (const auto& kv : configuration) {
s << "\"" << kv.first << "\":";
- metric_as_json(s, kv.second.cast<ObjectRef>());
+ metric_as_json(s, kv.second);
if (k < configuration.size() - 1) {
s << ",";
}
@@ -392,7 +392,7 @@ String ReportNode::AsJSON() const {
// Aggregate a set of values for a metric. Computes sum for Duration, Count,
// and Percent; average for Ratio; and assumes all Strings are the same. All
// ObjectRefs in metrics must have the same type.
-ObjectRef AggregateMetric(const std::vector<ObjectRef>& metrics) {
+Any AggregateMetric(const std::vector<ffi::Any>& metrics) {
ICHECK_GT(metrics.size(), 0) << "Must pass a non-zero number of metrics";
if (metrics[0].as<DurationNode>()) {
double sum = 0;
@@ -421,7 +421,7 @@ ObjectRef AggregateMetric(const std::vector<ObjectRef>&
metrics) {
} else if (metrics[0].as<ffi::StringObj>()) {
for (auto& m : metrics) {
if (Downcast<String>(metrics[0]) != Downcast<String>(m)) {
- return ObjectRef(String(""));
+ return String("");
}
}
// Assume all strings in metrics are the same.
@@ -429,8 +429,8 @@ ObjectRef AggregateMetric(const std::vector<ObjectRef>&
metrics) {
} else {
LOG(FATAL) << "Can only aggregate metrics with types DurationNode,
CountNode, "
"PercentNode, RatioNode, and StringObj, but got "
- << metrics[0]->GetTypeKey();
- return ObjectRef(); // To silence warnings
+ << metrics[0].GetTypeKey();
+ return ffi::Any(); // To silence warnings
}
}
@@ -446,7 +446,7 @@ static void set_locale_for_separators(std::stringstream& s)
{
}
}
-static String print_metric(ObjectRef metric) {
+static String print_metric(ffi::Any metric) {
std::string val;
if (metric.as<CountNode>()) {
std::stringstream s;
@@ -470,7 +470,7 @@ static String print_metric(ObjectRef metric) {
} else if (metric.as<ffi::StringObj>()) {
val = Downcast<String>(metric);
} else {
- LOG(FATAL) << "Cannot print metric of type " << metric->GetTypeKey();
+ LOG(FATAL) << "Cannot print metric of type " << metric.GetTypeKey();
}
return val;
}
@@ -509,7 +509,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool
compute_col_sums) con
}
}
for (const std::string& metric : metrics) {
- std::vector<ObjectRef> per_call;
+ std::vector<ffi::Any> per_call;
for (auto i : p.second) {
auto& call = calls[i];
auto it = std::find_if(call.begin(), call.end(),
@@ -517,7 +517,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool
compute_col_sums) con
return std::string(call_metric.first) ==
metric;
});
if (it != call.end()) {
- per_call.push_back((*it).second.cast<ObjectRef>());
+ per_call.push_back((*it).second);
}
}
if (per_call.size() > 0) {
@@ -719,7 +719,7 @@ Map<String, ffi::Any> parse_metrics(dmlc::JSONReader*
reader) {
std::string metric_name, metric_value_name;
Map<String, ffi::Any> metrics;
while (reader->NextObjectItem(&metric_name)) {
- ObjectRef o;
+ ffi::Any o;
reader->BeginObject();
reader->NextObjectItem(&metric_value_name);
if (metric_value_name == "microseconds") {
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index 7ee7214056..9b9816a4d9 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -1160,7 +1160,7 @@ class RPCClientSession : public RPCSession, public
DeviceAPI {
temp.shape = const_cast<int64_t*>(shape);
temp.strides = nullptr;
temp.byte_offset = 0;
- if (mem_scope.defined()) {
+ if (mem_scope.has_value()) {
return endpoint_
->SysCallRemote(RPCCode::kDevAllocDataWithScope, &temp,
static_cast<std::string>(mem_scope.value()))
diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc
index f3aaf3f688..04e5094d8e 100644
--- a/src/runtime/vm/attn_backend.cc
+++ b/src/runtime/vm/attn_backend.cc
@@ -25,12 +25,12 @@ namespace tvm {
namespace runtime {
namespace vm {
-std::unique_ptr<PagedPrefillFunc> ConvertPagedPrefillFunc(Array<ObjectRef>
args,
+std::unique_ptr<PagedPrefillFunc> ConvertPagedPrefillFunc(Array<ffi::Any> args,
AttnKind attn_kind) {
if (args.empty()) {
return nullptr;
}
- String backend_name = Downcast<String>(args[0]);
+ String backend_name = args[0].cast<String>();
if (backend_name == "tir") {
CHECK_EQ(args.size(), 2);
ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
@@ -47,12 +47,12 @@ std::unique_ptr<PagedPrefillFunc>
ConvertPagedPrefillFunc(Array<ObjectRef> args,
throw;
}
-std::unique_ptr<RaggedPrefillFunc> ConvertRaggedPrefillFunc(Array<ObjectRef>
args,
+std::unique_ptr<RaggedPrefillFunc> ConvertRaggedPrefillFunc(Array<ffi::Any>
args,
AttnKind
attn_kind) {
if (args.empty()) {
return nullptr;
}
- String backend_name = Downcast<String>(args[0]);
+ String backend_name = args[0].cast<String>();
if (backend_name == "tir") {
CHECK_EQ(args.size(), 2);
ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
@@ -69,7 +69,7 @@ std::unique_ptr<RaggedPrefillFunc>
ConvertRaggedPrefillFunc(Array<ObjectRef> arg
throw;
}
-std::unique_ptr<PagedDecodeFunc> ConvertPagedDecodeFunc(Array<ObjectRef> args,
AttnKind attn_kind) {
+std::unique_ptr<PagedDecodeFunc> ConvertPagedDecodeFunc(Array<ffi::Any> args,
AttnKind attn_kind) {
if (args.empty()) {
return nullptr;
}
@@ -90,7 +90,7 @@ std::unique_ptr<PagedDecodeFunc>
ConvertPagedDecodeFunc(Array<ObjectRef> args, A
throw;
}
-std::unique_ptr<PagedPrefillTreeMaskFunc>
ConvertPagedPrefillTreeMaskFunc(Array<ObjectRef> args,
+std::unique_ptr<PagedPrefillTreeMaskFunc>
ConvertPagedPrefillTreeMaskFunc(Array<ffi::Any> args,
AttnKind attn_kind) {
if (args.empty()) {
return nullptr;
@@ -105,7 +105,7 @@ std::unique_ptr<PagedPrefillTreeMaskFunc>
ConvertPagedPrefillTreeMaskFunc(Array<
throw;
}
-std::unique_ptr<RaggedPrefillTreeMaskFunc>
ConvertRaggedPrefillTreeMaskFunc(Array<ObjectRef> args,
+std::unique_ptr<RaggedPrefillTreeMaskFunc>
ConvertRaggedPrefillTreeMaskFunc(Array<ffi::Any> args,
AttnKind attn_kind) {
if (args.empty()) {
return nullptr;
diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h
index 21d8d81a8b..449a1def0a 100644
--- a/src/runtime/vm/attn_backend.h
+++ b/src/runtime/vm/attn_backend.h
@@ -499,8 +499,7 @@ class TIRRaggedPrefillTreeMaskFunc : public
RaggedPrefillTreeMaskFunc {
* ffi::Functions. \param attn_kind The attention kind of the function.
\return The created
* PagedPrefillFunc pointer.
*/
-std::unique_ptr<PagedPrefillFunc> ConvertPagedPrefillFunc(Array<ObjectRef>
args,
- AttnKind attn_kind);
+std::unique_ptr<PagedPrefillFunc> ConvertPagedPrefillFunc(Array<ffi::Any>
args, AttnKind attn_kind);
/*!
* \brief Create a PagedDecodeFunc from the given arguments and the attention
kind.
@@ -508,7 +507,7 @@ std::unique_ptr<PagedPrefillFunc>
ConvertPagedPrefillFunc(Array<ObjectRef> args,
* ffi::Functions. \param attn_kind The attention kind of the function.
\return The created
* PagedDecodeFunc pointer.
*/
-std::unique_ptr<PagedDecodeFunc> ConvertPagedDecodeFunc(Array<ObjectRef> args,
AttnKind attn_kind);
+std::unique_ptr<PagedDecodeFunc> ConvertPagedDecodeFunc(Array<ffi::Any> args,
AttnKind attn_kind);
/*!
* \brief Create a RaggedPrefillFunc from the given arguments and the
attention kind.
@@ -516,7 +515,7 @@ std::unique_ptr<PagedDecodeFunc>
ConvertPagedDecodeFunc(Array<ObjectRef> args, A
* ffi::Functions. \param attn_kind The attention kind of the function.
\return The created
* RaggedPrefillFunc pointer.
*/
-std::unique_ptr<RaggedPrefillFunc> ConvertRaggedPrefillFunc(Array<ObjectRef>
args,
+std::unique_ptr<RaggedPrefillFunc> ConvertRaggedPrefillFunc(Array<ffi::Any>
args,
AttnKind
attn_kind);
/*!
@@ -525,7 +524,7 @@ std::unique_ptr<RaggedPrefillFunc>
ConvertRaggedPrefillFunc(Array<ObjectRef> arg
* ffi::Functions. \param attn_kind The attention kind of the function.
\return The created
* PagedPrefillTreeMaskFunc pointer.
*/
-std::unique_ptr<PagedPrefillTreeMaskFunc>
ConvertPagedPrefillTreeMaskFunc(Array<ObjectRef> args,
+std::unique_ptr<PagedPrefillTreeMaskFunc>
ConvertPagedPrefillTreeMaskFunc(Array<ffi::Any> args,
AttnKind attn_kind);
/*!
@@ -534,7 +533,7 @@ std::unique_ptr<PagedPrefillTreeMaskFunc>
ConvertPagedPrefillTreeMaskFunc(Array<
* ffi::Functions. \param attn_kind The attention kind of the function.
\return The created
* RaggedPrefillTreeMaskFunc pointer.
*/
-std::unique_ptr<RaggedPrefillTreeMaskFunc>
ConvertRaggedPrefillTreeMaskFunc(Array<ObjectRef> args,
+std::unique_ptr<RaggedPrefillTreeMaskFunc>
ConvertRaggedPrefillTreeMaskFunc(Array<ffi::Any> args,
AttnKind attn_kind);
} // namespace vm
diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc
index 4ce854df59..bfd3fbd025 100644
--- a/src/runtime/vm/paged_kv_cache.cc
+++ b/src/runtime/vm/paged_kv_cache.cc
@@ -2470,21 +2470,21 @@ TVM_FFI_STATIC_INIT_BLOCK({
Optional<ffi::Function> f_transpose_append_mha = std::nullopt; //
args[13]
Optional<ffi::Function> f_transpose_append_mla = std::nullopt; //
args[14]
std::unique_ptr<RaggedPrefillFunc> f_attention_prefill_ragged =
- ConvertRaggedPrefillFunc(args[15].cast<Array<ObjectRef>>(),
AttnKind::kMHA);
+ ConvertRaggedPrefillFunc(args[15].cast<Array<ffi::Any>>(),
AttnKind::kMHA);
std::unique_ptr<PagedPrefillFunc> f_attention_prefill =
- ConvertPagedPrefillFunc(args[16].cast<Array<ObjectRef>>(),
AttnKind::kMHA);
+ ConvertPagedPrefillFunc(args[16].cast<Array<ffi::Any>>(),
AttnKind::kMHA);
std::unique_ptr<PagedDecodeFunc> f_attention_decode =
- ConvertPagedDecodeFunc(args[17].cast<Array<ObjectRef>>(),
AttnKind::kMHA);
+ ConvertPagedDecodeFunc(args[17].cast<Array<ffi::Any>>(),
AttnKind::kMHA);
std::unique_ptr<PagedPrefillFunc> f_attention_prefill_sliding_window =
- ConvertPagedPrefillFunc(args[18].cast<Array<ObjectRef>>(),
AttnKind::kMHA);
+ ConvertPagedPrefillFunc(args[18].cast<Array<ffi::Any>>(),
AttnKind::kMHA);
std::unique_ptr<PagedDecodeFunc> f_attention_decode_sliding_window =
- ConvertPagedDecodeFunc(args[19].cast<Array<ObjectRef>>(),
AttnKind::kMHA);
+ ConvertPagedDecodeFunc(args[19].cast<Array<ffi::Any>>(),
AttnKind::kMHA);
std::unique_ptr<PagedPrefillTreeMaskFunc>
f_attention_prefill_with_tree_mask_paged_kv =
- ConvertPagedPrefillTreeMaskFunc(args[20].cast<Array<ObjectRef>>(),
AttnKind::kMHA);
+ ConvertPagedPrefillTreeMaskFunc(args[20].cast<Array<ffi::Any>>(),
AttnKind::kMHA);
std::unique_ptr<RaggedPrefillTreeMaskFunc>
f_attention_prefill_with_tree_mask =
-
ConvertRaggedPrefillTreeMaskFunc(args[21].cast<Array<ObjectRef>>(),
AttnKind::kMHA);
+ ConvertRaggedPrefillTreeMaskFunc(args[21].cast<Array<ffi::Any>>(),
AttnKind::kMHA);
std::unique_ptr<PagedPrefillFunc> f_mla_prefill =
- ConvertPagedPrefillFunc(args[22].cast<Array<ObjectRef>>(),
AttnKind::kMLA);
+ ConvertPagedPrefillFunc(args[22].cast<Array<ffi::Any>>(),
AttnKind::kMLA);
Array<ffi::Function> f_merge_inplace =
args[23].cast<Array<ffi::Function>>();
ffi::Function f_split_rotary = args[24].cast<ffi::Function>();
ffi::Function f_copy_single_page = args[25].cast<ffi::Function>();
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 55a5a87d27..4a026f9dad 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -1051,7 +1051,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
}
}
- std::unordered_map<std::string, ObjectRef> metrics;
+ std::unordered_map<std::string, ffi::Any> metrics;
metrics["Argument Shapes"] = profiling::ShapeString(arrs);
// If a suitable device is found, enable profiling.
diff --git a/src/script/ir_builder/relax/frame.cc
b/src/script/ir_builder/relax/frame.cc
index 0cde34879e..b0475e4fb0 100644
--- a/src/script/ir_builder/relax/frame.cc
+++ b/src/script/ir_builder/relax/frame.cc
@@ -72,7 +72,7 @@ void FunctionFrameNode::ExitWithScope() {
Expr body =
this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks,
output.value()));
// if the function is not private, add a global symbol to its attributes
- if (!is_private.value_or(Bool(false))->value && name.defined() &&
+ if (!is_private.value_or(Bool(false))->value && name.has_value() &&
!attrs.count(tvm::attr::kGlobalSymbol)) {
attrs.Set(tvm::attr::kGlobalSymbol, name.value());
}
@@ -89,8 +89,8 @@ void FunctionFrameNode::ExitWithScope() {
builder->result = func;
} else if (Optional<IRModuleFrame> opt_frame =
builder->FindFrame<IRModuleFrame>()) {
// Case 1. A global function of an IRModule
- CHECK(name.defined()) << "ValueError: The function name must be defined
before exiting the "
- "function scope, if it's defined in a Module";
+ CHECK(name.has_value()) << "ValueError: The function name must be defined
before exiting the "
+ "function scope, if it's defined in a Module";
const IRModuleFrame& frame = opt_frame.value();
const String& func_name = name.value_or("");
if (!frame->global_var_map.count(func_name)) {
diff --git a/src/script/ir_builder/relax/ir.cc
b/src/script/ir_builder/relax/ir.cc
index 0bb73abf4f..b845434e91 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -78,7 +78,7 @@ tvm::relax::Var Arg(const String& name, const
tvm::relax::StructInfo& struct_inf
void FuncName(const String& name) {
FunctionFrame frame = FindFunctionFrame("R.func_name");
- if (frame->name.defined()) {
+ if (frame->name.has_value()) {
LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \""
<< frame->name.value()
<< "\"";
}
diff --git a/src/script/ir_builder/tir/frame.cc
b/src/script/ir_builder/tir/frame.cc
index 1eb46f70eb..931e7e77d1 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -52,7 +52,7 @@ void PrimFuncFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
// if the prim func is not private and there isn't already a global symbol,
// add a global symbol
- if (!is_private && name.defined() && !attrs.count(tvm::attr::kGlobalSymbol))
{
+ if (!is_private && name.has_value() &&
!attrs.count(tvm::attr::kGlobalSymbol)) {
attrs.Set(tvm::attr::kGlobalSymbol, name.value());
}
@@ -68,8 +68,8 @@ void PrimFuncFrameNode::ExitWithScope() {
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has
already been set";
builder->result = func;
} else if (Optional<ir::IRModuleFrame> opt_frame =
builder->FindFrame<ir::IRModuleFrame>()) {
- CHECK(name.defined()) << "ValueError: The function name must be defined
before exiting the "
- "function scope, if it's defined in a Module";
+ CHECK(name.has_value()) << "ValueError: The function name must be defined
before exiting the "
+ "function scope, if it's defined in a Module";
const ir::IRModuleFrame& frame = opt_frame.value();
const String& func_name = name.value_or("");
if (!frame->global_var_map.count(func_name)) {
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index e8c8d62c9b..9d5d9dade5 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -87,7 +87,7 @@ Buffer Arg(String name, Buffer buffer) {
void FuncName(String name) {
PrimFuncFrame frame = FindPrimFuncFrame("T.func_name");
- if (frame->name.defined()) {
+ if (frame->name.has_value()) {
LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " <<
frame->name.value();
}
frame->name = name;
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index 5137266fa4..23a2e94a7f 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -79,7 +79,7 @@ StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
this->data_ = std::move(n);
}
-LiteralDoc::LiteralDoc(ObjectRef value, const Optional<ObjectPath>&
object_path) {
+LiteralDoc::LiteralDoc(ffi::Any value, const Optional<ObjectPath>&
object_path) {
ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
n->value = value;
if (object_path.defined()) {
diff --git a/src/script/printer/doc_printer/python_doc_printer.cc
b/src/script/printer/doc_printer/python_doc_printer.cc
index a6b8a8db09..8c352298c1 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -252,7 +252,7 @@ class PythonDocPrinter : public DocPrinter {
}
void MaybePrintCommentInline(const StmtDoc& stmt) {
- if (stmt->comment.defined()) {
+ if (stmt->comment.has_value()) {
const std::string& comment = stmt->comment.value();
bool has_newline = std::find(comment.begin(), comment.end(), '\n') !=
comment.end();
CHECK(!has_newline) << "ValueError: the comment string of " <<
stmt->GetTypeKey()
@@ -265,7 +265,7 @@ class PythonDocPrinter : public DocPrinter {
}
void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) {
- if (stmt->comment.defined()) {
+ if (stmt->comment.has_value()) {
std::vector<std::string> comment_lines =
support::Split(stmt->comment.value(), '\n');
bool first_line = true;
size_t start_pos = output_.tellp();
@@ -313,8 +313,8 @@ class PythonDocPrinter : public DocPrinter {
};
void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
- const ObjectRef& value = doc->value;
- if (!value.defined()) {
+ const ffi::Any& value = doc->value;
+ if (value == nullptr) {
output_ << "None";
} else if (const auto* int_imm = value.as<IntImmNode>()) {
if (int_imm->dtype.is_bool()) {
@@ -354,7 +354,7 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc)
{
} else if (const auto* string_obj = value.as<ffi::StringObj>()) {
output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size)
<< "\"";
} else {
- LOG(FATAL) << "TypeError: Unsupported literal value type: " <<
value->GetTypeKey();
+ LOG(FATAL) << "TypeError: Unsupported literal value type: " <<
value.GetTypeKey();
}
}
@@ -682,7 +682,7 @@ void PythonDocPrinter::PrintTypedDoc(const FunctionDoc&
doc) {
output_ << ":";
- if (doc->comment.defined()) {
+ if (doc->comment.has_value()) {
PrintBlockComment(doc->comment.value());
}
PrintIndentedBlock(doc->body);
@@ -696,20 +696,20 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc)
{
PrintDoc(doc->name);
output_ << ":";
- if (doc->comment.defined()) {
+ if (doc->comment.has_value()) {
PrintBlockComment(doc->comment.value());
}
PrintIndentedBlock(doc->body);
}
void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) {
- if (doc->comment.defined()) {
+ if (doc->comment.has_value()) {
MaybePrintCommenMultiLines(doc, false);
}
}
void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) {
- if (doc->comment.defined() && !doc->comment.value().empty()) {
+ if (doc->comment.has_value() && !doc->comment.value().empty()) {
PrintDocString(doc->comment.value());
}
}
diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc
index caa5cbe895..8288016d3e 100644
--- a/src/script/printer/ir/misc.cc
+++ b/src/script/printer/ir/misc.cc
@@ -22,14 +22,6 @@ namespace tvm {
namespace script {
namespace printer {
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<String>("", [](String s, ObjectPath p, IRDocsifier d) -> Doc
{
- if (HasMultipleLines(s)) {
- return d->AddMetadata(s);
- }
- return LiteralDoc::Str(s, p);
- });
-
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Array<Any>>( //
"", [](Array<Any> array, ObjectPath p, IRDocsifier d) -> Doc {
diff --git a/src/script/printer/ir_docsifier.cc
b/src/script/printer/ir_docsifier.cc
index 0eb5c951e5..33c0076c30 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -81,11 +81,13 @@ Optional<ExprDoc> IRDocsifierNode::GetVarDoc(const
ObjectRef& obj) const {
return it->second.creator();
}
-ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) {
- ICHECK(obj.defined()) << "TypeError: Cannot add nullptr to metadata";
- String key = obj->GetTypeKey();
- Array<ObjectRef>& array = metadata[key];
- int index = std::find(array.begin(), array.end(), obj) - array.begin();
+ExprDoc IRDocsifierNode::AddMetadata(const ffi::Any& obj) {
+ ICHECK(obj != nullptr) << "TypeError: Cannot add nullptr to metadata";
+ String key = obj.GetTypeKey();
+ Array<ffi::Any>& array = metadata[key];
+ int index = std::find_if(array.begin(), array.end(),
+ [&](const ffi::Any& a) { return ffi::AnyEqual()(a,
obj); }) -
+ array.begin();
if (index == static_cast<int>(array.size())) {
array.push_back(obj);
}
@@ -104,7 +106,7 @@ bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj)
const { return obj2info
void IRDocsifierNode::RemoveVar(const ObjectRef& obj) {
auto it = obj2info.find(obj);
ICHECK(it != obj2info.end()) << "No such object: " << obj;
- if (it->second.name.defined()) {
+ if (it->second.name.has_value()) {
defined_names.erase(it->second.name.value());
}
obj2info.erase(it);
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index d2f02f7908..d0b14753cc 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -326,8 +326,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
bool IsNumber(const ExprDoc& e) {
if (const auto* n = e.as<LiteralDocNode>()) {
- if (n->value.defined()) {
- return n->value->IsInstance<IntImmNode>() ||
n->value->IsInstance<FloatImmNode>();
+ if (n->value != nullptr) {
+ return n->value.as<IntImmNode>() || n->value.as<FloatImmNode>();
}
}
return false;
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index 239d9f6721..50756bceb7 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -415,7 +415,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
ObjectPath body_p = stmt_p->Attr("body");
if (stmt->attr_key == "realize_scope") {
if (const auto* realize = stmt->body.as<tir::BufferRealizeNode>())
{
- if (realize->buffer.same_as(stmt->node)) {
+ // TODO(tqchen): add any.same_as(ObjectRef)
+ if (realize->buffer.same_as(stmt->node.cast<ObjectRef>())) {
rhs = DocsifyBufferRealize(
realize,
/*value=*/d->AsDoc<ExprDoc>(stmt->value,
stmt_p->Attr("value")),
@@ -426,7 +427,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
}
if (stmt->attr_key == "thread_extent" || stmt->attr_key ==
"virtual_thread") {
- if (stmt->node->IsInstance<tir::IterVarNode>()) {
+ if (stmt->node.as<tir::IterVarNode>()) {
rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d);
}
}
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index 12b08b209b..0becec1f3f 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -184,9 +184,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("testing.AcceptsVariant",
[](Variant<String, Integer> arg) -> String {
if (auto opt_str = arg.as<String>()) {
- return opt_str.value()->GetTypeKey();
+ return ffi::StringObj::_type_key;
} else {
- return arg.get<Integer>()->GetTypeKey();
+ return arg.get<Integer>().GetTypeKey();
}
})
.def("testing.AcceptsBool", [](bool arg) -> bool { return arg; })
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index b04d71da31..b85b51e3d2 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -143,7 +143,7 @@ void CodeGenCPU::Init(const std::string& module_name,
LLVMTarget* llvm_target,
t_void_p_, t_int_},
false);
// initialize TVM runtime API
- if (system_lib_prefix_.defined() && !target_c_runtime) {
+ if (system_lib_prefix_.has_value() && !target_c_runtime) {
// We will need this in environment for backward registration.
// Defined in include/tvm/runtime/c_backend_api.h:
// int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
@@ -153,7 +153,7 @@ void CodeGenCPU::Init(const std::string& module_name,
LLVMTarget* llvm_target,
} else {
f_tvm_register_system_symbol_ = nullptr;
}
- if (dynamic_lookup || system_lib_prefix_.defined()) {
+ if (dynamic_lookup || system_lib_prefix_.has_value()) {
f_tvm_ffi_func_call_ =
llvm::Function::Create(ftype_tvm_ffi_func_call_,
llvm::Function::ExternalLinkage,
"TVMFFIFunctionCall", module_.get());
diff --git a/src/target/llvm/codegen_hexagon.cc
b/src/target/llvm/codegen_hexagon.cc
index f0f1797a6c..6f90da3d8a 100644
--- a/src/target/llvm/codegen_hexagon.cc
+++ b/src/target/llvm/codegen_hexagon.cc
@@ -495,7 +495,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
auto f = Downcast<PrimFunc>(kv.second);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined());
+ ICHECK(global_symbol.has_value());
entry_func = global_symbol.value();
}
}
diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc
index ed636b82a2..bcea45cfa7 100644
--- a/src/target/llvm/codegen_nvptx.cc
+++ b/src/target/llvm/codegen_nvptx.cc
@@ -313,7 +313,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode*
op) {
int GetCUDAComputeVersion(const Target& target) {
Optional<String> mcpu = target->GetAttr<String>("mcpu");
- ICHECK(mcpu.defined()) << "InternalError: \"-mcpu\" is undefined in the
NVPTX target";
+ ICHECK(mcpu.has_value()) << "InternalError: \"-mcpu\" is undefined in the
NVPTX target";
std::string sm_version = mcpu.value();
return std::stoi(sm_version.substr(3));
}
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index b3e5249c02..924f520082 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -352,7 +352,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const
Target& target) {
// ICHECK(funcs.size() > 0);
// TODO(tqchen): remove the entry function behavior as it does not
// makes sense when we start to use multiple modules.
- cg->Init("TVMMod", llvm_target.get(), system_lib_prefix,
system_lib_prefix.defined(), false);
+ cg->Init("TVMMod", llvm_target.get(), system_lib_prefix,
system_lib_prefix.has_value(), false);
cg->SetFastMathFlags(llvm_target->GetFastMathFlags());
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
if (entry_func.length() != 0) {
diff --git a/src/target/parsers/cpu.cc b/src/target/parsers/cpu.cc
index c28d80133f..ee9bf814d3 100644
--- a/src/target/parsers/cpu.cc
+++ b/src/target/parsers/cpu.cc
@@ -44,7 +44,7 @@ TargetJSON ParseTarget(TargetJSON target) {
Optional<String> mcpu = Downcast<Optional<String>>(target.Get("mcpu"));
// Try to fill in the blanks by detecting target information from the system
- if (kind == "llvm" && !mtriple.defined() && !mcpu.defined()) {
+ if (kind == "llvm" && !mtriple.has_value() && !mcpu.has_value()) {
String system_triple = DetectSystemTriple().value_or("");
target.Set("mtriple", system_triple);
}
diff --git a/src/target/source/codegen_c_host.cc
b/src/target/source/codegen_c_host.cc
index 2333534122..2e808738ef 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -73,7 +73,7 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const
PrimFunc& func,
emit_fwd_func_decl_ = emit_fwd_func_decl;
CodeGenC::AddFunction(gvar, func);
if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
- ICHECK(global_symbol.defined())
+ ICHECK(global_symbol.has_value())
<< "CodeGenCHost: The entry func must have the global_symbol
attribute, "
<< "but function " << gvar << " only has attributes " << func->attrs;
diff --git a/src/target/source/codegen_metal.cc
b/src/target/source/codegen_metal.cc
index 962bd777f1..3cd4a6ed0d 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -78,7 +78,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const
PrimFunc& func) {
// add to alloc buffer type.
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined())
+ ICHECK(global_symbol.has_value())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
// Function header.
@@ -443,7 +443,7 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only
take PrimFunc";
auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined());
+ ICHECK(global_symbol.has_value());
std::string func_name = global_symbol.value();
source_maker << "// Function: " << func_name << "\n";
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index 31f9899016..f5bfd80fee 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -137,7 +137,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const
PrimFunc& f, bool skip_re
// add to alloc buffer type.
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined())
+ ICHECK(global_symbol.has_value())
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
header_stream << "//----------------------------------------\n"
@@ -767,7 +767,7 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenWebGPU: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined())
+ ICHECK(global_symbol.has_value())
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol
attribute";
std::string f_name = global_symbol.value();
cg.Init(output_ssa);
diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc
index 6afd087e5d..f0226466f6 100644
--- a/src/target/spirv/spirv_utils.cc
+++ b/src/target/spirv/spirv_utils.cc
@@ -130,7 +130,7 @@ std::pair<std::unordered_map<std::string,
runtime::SPIRVShader>, std::string> Lo
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenSPIRV: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined())
+ ICHECK(global_symbol.has_value())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol.value();
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index f65566109f..b0457a1239 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -91,7 +91,7 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
<< ", but get " << s->GetTypeKey();
const IterVarNode* iv = op->node.as<IterVarNode>();
ICHECK(iv != nullptr) << "Expected type to be IterVarNode"
- << ", but get " << op->node->GetTypeKey();
+ << ", but get " << op->node.GetTypeKey();
PrimExpr e = VisitExpr(iv->var);
Var var = Downcast<Var>(e);
if (ivmap_.find(iv) == ivmap_.end()) {
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index 3b91c5e84b..7b3c951587 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -329,7 +329,7 @@ IndexMap IndexMap::RenameVariables(
}
visited.emplace(obj.get());
Var var = Downcast<Var>(obj);
- if (Optional<String> opt_name = f_name_map(var); opt_name.defined()) {
+ if (Optional<String> opt_name = f_name_map(var); opt_name.has_value())
{
String name = opt_name.value();
ICHECK(!name_supply->ContainsName(name, /*add_prefix=*/false));
name_supply->ReserveName(name, /*add_prefix=*/false);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 17c763c6e4..6803e01f50 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -84,7 +84,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
TVM_REGISTER_NODE_TYPE(LetStmtNode);
// AttrStmt
-AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body,
Span span) {
+AttrStmt::AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body,
Span span) {
auto n = make_object<AttrStmtNode>();
n->node = node;
n->attr_key = std::move(attr_key);
diff --git a/src/tir/ir/tir_visitor_with_path.cc
b/src/tir/ir/tir_visitor_with_path.cc
index 78cfd004dd..1dbbe75528 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -40,7 +40,7 @@ void TIRVisitorWithPath::Visit(const IRModule& mod,
ObjectPath path) {
std::unordered_set<GlobalVar> externally_exposed;
for (const auto& [gvar, func] : mod->functions) {
gvars.push_back(gvar);
- if (func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+ if (func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value()) {
externally_exposed.insert(gvar);
}
}
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 691ce8ebd1..3aacfa1583 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -949,7 +949,7 @@ StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>&
srefs) {
}
bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) {
- return tir::GetAnn<String>(block_sref,
tir::attr::meta_schedule_tiling_structure).defined();
+ return tir::GetAnn<String>(block_sref,
tir::attr::meta_schedule_tiling_structure).has_value();
}
std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const
ScheduleState& self,
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 0b8aeec82c..c00c946852 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -311,9 +311,9 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name,
const Optional<String
Array<Block> blocks_;
};
GlobalVar gv = NullValue<GlobalVar>();
- if (func_name.defined()) {
+ if (func_name.has_value()) {
gv = state_->mod->GetGlobalVar(func_name.value());
- } else if (func_working_on_.defined()) {
+ } else if (func_working_on_.has_value()) {
gv = this->func_working_on_.value();
} else {
LOG(FATAL) << "ValueError: `get_block` does not know which function to be
working on. Please "
diff --git a/src/tir/schedule/instruction_traits.h
b/src/tir/schedule/instruction_traits.h
index cbd5185ff8..5507c02bfe 100644
--- a/src/tir/schedule/instruction_traits.h
+++ b/src/tir/schedule/instruction_traits.h
@@ -541,7 +541,7 @@ void PythonAPICall::OutputList(Array<String> outputs) {
String PythonAPICall::Str() const {
std::ostringstream os;
- if (output_.defined()) {
+ if (output_.has_value()) {
os << output_.value() << " = ";
}
os << "sch." << method_name_ << '(';
diff --git a/src/tir/schedule/primitive/for_kind.cc
b/src/tir/schedule/primitive/for_kind.cc
index f1e035d92e..6dd1eafcc0 100644
--- a/src/tir/schedule/primitive/for_kind.cc
+++ b/src/tir/schedule/primitive/for_kind.cc
@@ -164,13 +164,13 @@ void ParallelizeComputation(const ScheduleState& self,
const StmtSRef& loop_sref
// Step 2. Check whether the loop can be parallelized/vectorized/bound with
regard to each
// underlying block.
CheckParallelizability(self, GetRef<For>(loop), for_kind,
- thread_axis.defined() ?
runtime::ThreadScope::Create(thread_axis.value())
- : runtime::ThreadScope{-1, -1});
+ thread_axis.has_value() ?
runtime::ThreadScope::Create(thread_axis.value())
+ : runtime::ThreadScope{-1,
-1});
// Step 3. Loop update and IR replacement
ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
new_loop->kind = for_kind;
- if (thread_axis.defined()) {
+ if (thread_axis.has_value()) {
new_loop->thread_binding = IterVar(/*dom=*/Range(nullptr),
//
/*var=*/Var(thread_axis.value(),
loop->loop_var.dtype()), //
/*iter_type=*/kThreadIndex,
//
diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc
index 8160035e3d..6efb17de25 100644
--- a/src/tir/schedule/trace.cc
+++ b/src/tir/schedule/trace.cc
@@ -240,7 +240,7 @@ Array<String> TranslateAddOutputRVs(
ICHECK(!rv_names->count(output.cast<ObjectRef>()))
<< "ValueError: The random variable has been produced once: "
<< rv_names->at(output.cast<ObjectRef>());
- String result{ffi::ObjectPtr<ffi::StringObj>{nullptr}};
+ String result;
if (output == nullptr) {
result = "_";
} else if (output.as<BlockRVNode>()) {
@@ -320,8 +320,8 @@ void TraceNode::ApplyToSchedule(
ObjectRef TraceNode::AsJSON(bool remove_postproc) const {
std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual>
rv_names;
- Array<ObjectRef> json_insts;
- Array<ObjectRef> json_decisions;
+ Array<ffi::Any> json_insts;
+ Array<ffi::Any> json_decisions;
json_insts.reserve(this->insts.size());
json_decisions.reserve(this->insts.size());
@@ -331,7 +331,7 @@ ObjectRef TraceNode::AsJSON(bool remove_postproc) const {
if (remove_postproc && kind->IsPostproc()) {
break;
}
- json_insts.push_back(Array<ObjectRef>{
+ json_insts.push_back(Array<ffi::Any>{
/* 0: inst name */ kind->name,
/* 1: inputs */ TranslateInputRVs(inst->inputs, rv_names),
/* 2: attrs */ kind->f_attrs_as_json != nullptr ?
kind->f_attrs_as_json(inst->attrs)
@@ -346,7 +346,7 @@ ObjectRef TraceNode::AsJSON(bool remove_postproc) const {
}
++i;
}
- return Array<ObjectRef>{
+ return Array<ffi::Any>{
/* 0: trace */ std::move(json_insts),
/* 1: decision */ std::move(json_decisions),
};
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index d3e77e0e3b..b9718c1a5f 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -118,7 +118,7 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const
BlockRV& block_rv,
BlockRV TracedScheduleNode::GetBlock(const String& name, const
Optional<String>& func_name) {
GlobalVar gv = NullValue<GlobalVar>();
- if (func_name.defined()) {
+ if (func_name.has_value()) {
gv = state_->mod->GetGlobalVar(func_name.value());
} else if (func_working_on_.defined()) {
gv = this->func_working_on_.value();
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index deedfd6f68..0c35c5f043 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -290,7 +290,7 @@ inline Optional<TObjectRef> GetAnn(const StmtSRef& sref,
const String& ann_key)
*/
inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String&
ann_val) {
Optional<String> result = GetAnn<String>(sref, ann_key);
- return result.defined() && result.value() == ann_val;
+ return result.has_value() && result.value() == ann_val;
}
/*!
diff --git a/src/tir/transforms/bind_target.cc
b/src/tir/transforms/bind_target.cc
index 281249f4ad..46a40228ea 100644
--- a/src/tir/transforms/bind_target.cc
+++ b/src/tir/transforms/bind_target.cc
@@ -71,7 +71,7 @@ class FunctionClassifierVisitor : public StmtExprVisitor {
// Only analyze externally exposed functions as potential callers
// since they represent the entry points where host/device calls originate
for (const auto& [gvar, func] : mod->functions) {
- bool is_externally_exposed =
func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_externally_exposed =
func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
const auto* prim_func = func.as<PrimFuncNode>();
if (is_externally_exposed && prim_func != nullptr) {
@@ -268,7 +268,7 @@ IRModule BindTarget(IRModule mod, const Target& target) {
}
auto prim_func = GetRef<PrimFunc>(prim_func_node);
- bool is_externally_exposed =
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_externally_exposed =
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
// Rule 1: If the function has a target, and the target has a host, and
the function does not
@@ -341,7 +341,7 @@ IRModule BindTarget(IRModule mod, const Target& target) {
continue;
}
- bool is_externally_exposed =
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_externally_exposed =
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
if (is_externally_exposed) {
// Update calls in externally exposed functions to use host duplicates
PrimFunc new_func = substitutor.Substitute(Downcast<PrimFunc>(func));
diff --git a/src/tir/transforms/compact_buffer_region.cc
b/src/tir/transforms/compact_buffer_region.cc
index 6f5a496d1f..a1e99313b6 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -415,7 +415,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
if (iter->iter_type != IterVarType::kThreadIndex) {
return false;
}
- ICHECK(iter->thread_tag.defined());
// When there is warp memory
// threadIdx.x must be set to be warp index.
return CanRelaxStorageUnderThread(scope,
runtime::ThreadScope::Create((iter->thread_tag)));
diff --git a/src/tir/transforms/inject_permuted_layout.cc
b/src/tir/transforms/inject_permuted_layout.cc
index 02bdfcbfed..f90752e264 100644
--- a/src/tir/transforms/inject_permuted_layout.cc
+++ b/src/tir/transforms/inject_permuted_layout.cc
@@ -104,9 +104,9 @@ class PermutedLayoutInjector : private
IRMutatorWithAnalyzer {
}
static bool CheckAnnotation(const Any& annotation) {
- if (auto* node = annotation.as<ffi::StringObj>()) {
+ if (auto opt_str = annotation.as<String>()) {
// Support string annotation for backward compatibility
- return GetRef<String>(node) != "";
+ return *opt_str != "";
} else if (auto* node = annotation.as<IntImmNode>()) {
return node->value != 0;
} else if (auto opt_val = annotation.try_cast<int64_t>()) {
diff --git a/src/tir/transforms/inline_private_functions.cc
b/src/tir/transforms/inline_private_functions.cc
index 9e87ffe5b2..8521607f89 100644
--- a/src/tir/transforms/inline_private_functions.cc
+++ b/src/tir/transforms/inline_private_functions.cc
@@ -103,7 +103,7 @@ bool IsInlinablePrimFunc(const GlobalVar& gvar, const
PrimFunc& prim_func,
// Only inline private functions. Externally-exposed functions
// must be preserved so to avoid breaking callsites outside of
// the IRModule.
- bool is_exposed =
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_exposed =
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
if (is_exposed) return false;
// We do not currently implement any analysis for termination of
diff --git a/src/tir/transforms/loop_partition.cc
b/src/tir/transforms/loop_partition.cc
index b3825908b7..d29c380b35 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -155,9 +155,9 @@ class CandidateSelector final : public StmtExprVisitor {
} else if (op->attr_key == attr::pragma_loop_partition_hint) {
if (analyzer_.CanProve(op->value)) {
const VarNode* var = nullptr;
- if (op->node->IsInstance<VarNode>()) {
+ if (op->node.as<VarNode>()) {
var = op->node.as<VarNode>();
- } else if (op->node->IsInstance<IterVarNode>()) {
+ } else if (op->node.as<IterVarNode>()) {
var = op->node.as<IterVarNode>()->var.get();
}
ICHECK(var);
diff --git a/src/tir/transforms/make_packed_api.cc
b/src/tir/transforms/make_packed_api.cc
index 9f0228dc16..d95a02a0ba 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -187,7 +187,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc& func) {
// Internal function calls do not need the ffi::Function API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- if (!global_symbol.defined()) {
+ if (!global_symbol.has_value()) {
return std::nullopt;
}
@@ -196,7 +196,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc& func) {
PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func);
- if (!global_symbol.defined()) {
+ if (!global_symbol.has_value()) {
return func;
}
std::string name_hint = global_symbol.value();
@@ -365,7 +365,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
- ObjectRef node = String("default");
+ ffi::Any node = ffi::String("default");
seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop));
seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop));
diff --git a/src/tir/transforms/make_unpacked_api.cc
b/src/tir/transforms/make_unpacked_api.cc
index 898e781906..8276d26fcf 100644
--- a/src/tir/transforms/make_unpacked_api.cc
+++ b/src/tir/transforms/make_unpacked_api.cc
@@ -103,7 +103,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
// Internal function calls do not need API updates
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- if (!global_symbol.defined()) {
+ if (!global_symbol.has_value()) {
return func;
}
@@ -128,7 +128,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
// Setup device context
Integer device_type(target_device_type);
Integer device_id(0);
- ObjectRef node = String("default");
+ ffi::Any node = ffi::String("default");
const Stmt nop = Evaluate(0);
std::vector<Stmt> device_init;
diff --git a/src/tir/transforms/primfunc_utils.cc
b/src/tir/transforms/primfunc_utils.cc
index 274199a6c4..b1f3476eab 100644
--- a/src/tir/transforms/primfunc_utils.cc
+++ b/src/tir/transforms/primfunc_utils.cc
@@ -47,7 +47,7 @@ transform::Pass AnnotateEntryFunc() {
bool has_external_non_primfuncs = false;
IRModule with_annotations;
for (const auto& [gvar, base_func] : mod->functions) {
- bool is_external =
base_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+ bool is_external =
base_func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
if (is_external) {
if (auto ptr = base_func.as<PrimFuncNode>()) {
with_annotations->Add(gvar,