This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s1 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 8d2d64aa8de2ef13e48810e8f8821509e4ab169c Author: tqchen <[email protected]> AuthorDate: Sat Apr 12 20:51:16 2025 -0400 fix msc and cpptests --- src/contrib/msc/core/ir/graph_builder.cc | 2 +- src/contrib/msc/core/ir/graph_builder.h | 16 ++++++++++++++++ src/contrib/msc/core/printer/msc_base_printer.cc | 15 +++++++++------ src/contrib/msc/core/transform/set_expr_layout.cc | 14 +++++++------- src/contrib/msc/core/transform/set_expr_name.cc | 2 +- src/contrib/msc/framework/tensorflow/codegen.cc | 2 +- src/contrib/msc/framework/tensorrt/codegen.cc | 2 +- .../msc/framework/tensorrt/transform_tensorrt.cc | 6 +++--- src/contrib/msc/framework/torch/codegen.cc | 2 +- src/contrib/msc/framework/tvm/codegen.cc | 2 +- src/target/llvm/intrin_rule_hexagon.cc | 21 ++++++++------------- tests/cpp/target_test.cc | 4 ++-- tests/cpp/tir_scalable_datatype.cc | 2 +- 13 files changed, 52 insertions(+), 38 deletions(-) diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 35a85f1e49..853f75216f 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -725,7 +725,7 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const CallNode* AddNode(GetRef<Call>(call_node), binding->var, name); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to add node from " << binding->var << " : " << binding->value - << ", reason: " << err.message(); + << ", reason: " << err.what(); throw err; } } diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index ce3c7bd022..00582fab4b 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -126,6 +126,22 @@ class AttrGetter : public AttrVisitor { void Visit(const char* key, std::string* value) final { attrs_->Set(key, *value); } + void Visit(const char* key, Optional<double>* value) final { + if (value->has_value()) { + attrs_->Set(key, std::to_string(value->value())); + } else { + attrs_->Set(key, ""); + } + } + + void Visit(const char* key, Optional<int64_t>* value) final { + if (value->has_value()) { + attrs_->Set(key, std::to_string(value->value())); + } else { + attrs_->Set(key, ""); + } + } + void Visit(const char* key, DataType* value) final { attrs_->Set(key, runtime::DLDataTypeToString(*value)); } diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index 289c1b79fd..5c3832fb43 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -97,12 +97,14 @@ void MSCBasePrinter::PrintDoc(const Doc& doc, bool new_line) { } void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { - const ObjectRef& value = doc->value; - if (!value.defined()) { + const Any& value = doc->value; + if (value == nullptr) { output_ << "\"\""; - } else if (const auto* int_imm = value.as<IntImmNode>()) { + } else if (auto opt_int_imm = value.as<IntImm>()) { + IntImm int_imm = *std::move(opt_int_imm); output_ << int_imm->value; - } else if (const auto* float_imm = value.as<FloatImmNode>()) { + } else if (auto opt_float_imm = value.as<FloatImm>()) { + FloatImm float_imm = *std::move(opt_float_imm); output_.precision(config_.float_precision); if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { output_ << '"' << float_imm->value << '"'; @@ -110,9 +112,10 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { output_ << float_imm->value; } } else if (const auto* string_obj = value.as<StringObj>()) { - output_ << "\"" << tvm::support::StrEscape(string_obj->data, string_obj->size) << "\""; + output_ << "\"" << tvm::support::StrEscape( + string_obj->bytes.data, string_obj->bytes.size) << "\""; } else { - LOG(FATAL) << "TypeError: Unsupported literal value type: " << value->GetTypeKey(); + LOG(FATAL) << "TypeError: Unsupported literal value type: " << value.GetTypeKey(); } } diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 00e4c5dcfc..68d55a1d2f 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -266,7 +266,7 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, if (attrs->keepdims) { return InferLayoutOutput({input_layout}, {input_layout}, Attrs()); } - if (!attrs->axis.defined()) { + if (!attrs->axis.has_value()) { return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs()); } const auto& input_shape = ExprUtils::GetShape(call->args[0]); @@ -274,7 +274,7 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, return InferLayoutOutput(); } std::vector<size_t> axes; - axes.push_back(CommonUtils::GetIndex(Downcast<Integer>(attrs->axis)->value, input_shape.size())); + axes.push_back(CommonUtils::GetIndex(attrs->axis.value(), input_shape.size())); LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, axes); return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } @@ -1102,7 +1102,7 @@ class LayoutInfer : public ExprVisitor { BackwardInferLayoutCommon(call, Map<String, Array<String>>(), var_layout_map_); } } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to backward infer layout " << expr << " : " << err.message(); + LOG(WARNING) << "Failed to backward infer layout " << expr << " : " << err.what(); infered_layout = InferLayoutOutput(); } try { @@ -1111,7 +1111,7 @@ class LayoutInfer : public ExprVisitor { } } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to backward set inputs layout for " << call << " : " - << err.message(); + << err.what(); } } } @@ -1163,7 +1163,7 @@ class LayoutInfer : public ExprVisitor { } } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to forward infer layout for " << binding->var << " : " - << binding->value << ", reason: " << err.message(); + << binding->value << ", reason: " << err.what(); infered_layout = InferLayoutOutput(); } if (infered_layout.defined() && infered_layout->output_layouts.size() == 1) { @@ -1171,7 +1171,7 @@ class LayoutInfer : public ExprVisitor { SetExprLayout(binding->var, infered_layout->output_layouts[0]); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to forward set output layout for " << binding->var << " : " - << binding->value << ", reason: " << err.message(); + << binding->value << ", reason: " << err.what(); } } if (set_inputs && infered_layout.defined()) { @@ -1179,7 +1179,7 @@ class LayoutInfer : public ExprVisitor { SetInputLayouts(call, infered_layout->input_layouts); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to forward set inputs layout for " << call << " : " - << err.message(); + << err.what(); } } } diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index a2c7ae0142..f4602d2784 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -210,7 +210,7 @@ class RelaxExprNameSetter : public ExprVisitor { input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to GetInputTypes for " << GetRef<Call>(val) << " : " - << err.message(); + << err.what(); throw err; } for (size_t i = 0; i < input_types.size(); i++) { diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index 634dd79698..9506d4eac8 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -145,7 +145,7 @@ const Array<Doc> TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to get docs for " << node << " : " << err.message(); + LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what(); throw err; } } diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 0472e49b61..f3def93f52 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -548,7 +548,7 @@ const Array<Doc> TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to get docs for " << node << " : " << err.message(); + LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what(); throw err; } } diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 9bbfc1c22d..f1866a9b90 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -163,8 +163,8 @@ Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call static const Op& topk_op = Op::Get("relax.topk"); auto topk_attrs = make_object<TopKAttrs>(); topk_attrs->k = 1; - if (src_attrs->axis.defined()) { - topk_attrs->axis = src_attrs->axis.value()->value; + if (src_attrs->axis.has_value()) { + topk_attrs->axis = src_attrs->axis.value(); } topk_attrs->largest = call->op == Op::Get("relax.argmax"); topk_attrs->ret_type = "both"; @@ -395,7 +395,7 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_ca if (in_dim != out_dim) { Array<Expr> concat_inputs(out_dim / in_dim, concat_input); auto concat_attrs = make_object<ConcatAttrs>(); - concat_attrs->axis = Integer(i); + concat_attrs->axis = i; concat_input = RewriteUtils::MakeCall( builder, ExprUtils::GetSpanName(call, "concat_" + std::to_string(i)), concat_op, {Tuple(concat_inputs)}, Attrs(concat_attrs)); diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 86351bdd06..547c1c22ba 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -146,7 +146,7 @@ const Array<Doc> TorchCodeGen::GetOpCodes(const MSCJoint& node) { try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to get docs for " << node << " : " << err.message(); + LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what(); throw err; } } diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 5443cdc96a..1d6d74d7e4 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -205,7 +205,7 @@ const Array<Doc> RelaxCodeGen::GetOpCodes(const MSCJoint& node) { try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to get docs for " << node << " : " << err.message(); + LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what(); throw err; } } diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index 2661f2fa65..58661c9978 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -57,10 +57,8 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent"); ICHECK(f != nullptr); const auto ret = (*f)(true); - const Target t = ret.AsObjectRef<Target>(); - bool useqhl = true; - if (t.defined()) { - const std::string tstring = t->str(); + if (auto opt_target = ret.as<Target>()) { + const std::string tstring = opt_target.value()->str(); useqhl = tstring.find("+hvx-qfloat") != std::string::npos; } @@ -108,10 +106,9 @@ TVM_REGISTER_OP("tir.tanh") const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent"); ICHECK(f != nullptr); const auto ret = (*f)(true); - const Target t = ret.AsObjectRef<Target>(); bool useqhl = true; - if (t.defined()) { - const std::string tstring = t->str(); + if (auto opt_target = ret.as<Target>()) { + const std::string tstring = opt_target.value()->str(); useqhl = tstring.find("+hvx-qfloat") != std::string::npos; } @@ -144,10 +141,9 @@ TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>( const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent"); ICHECK(f != nullptr); const auto ret = (*f)(true); - const Target t = ret.AsObjectRef<Target>(); bool useqhl = true; - if (t.defined()) { - const std::string tstring = t->str(); + if (auto opt_target = ret.as<Target>()) { + const std::string tstring = opt_target.value()->str(); useqhl = tstring.find("+hvx-qfloat") != std::string::npos; } @@ -175,10 +171,9 @@ TVM_REGISTER_OP("tir.sigmoid") const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent"); ICHECK(f != nullptr); const auto ret = (*f)(true); - const Target t = ret.AsObjectRef<Target>(); bool useqhl = true; - if (t.defined()) { - const std::string tstring = t->str(); + if (auto opt_target = ret.as<Target>()) { + const std::string tstring = opt_target.value()->str(); useqhl = tstring.find("+hvx-qfloat") != std::string::npos; } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 50c611814f..a75ba09e82 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -191,7 +191,7 @@ TEST(TargetCreation, TargetFeaturesBeforeParser) { {"mcpu", String("woof")}, {"features", features}, }; - EXPECT_THROW(Target test(config), InternalError); + EXPECT_THROW(Target test(config), ffi::Error); } TEST(TargetCreation, TargetAttrsPreProcessor) { @@ -200,7 +200,7 @@ TEST(TargetCreation, TargetAttrsPreProcessor) { } TEST(TargetCreation, ClashingTargetProcessing) { - EXPECT_THROW(Target test("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError); + EXPECT_THROW(Target test("TestClashingPreprocessor -mcpu=woof -mattr=cake"), ffi::Error); } TVM_REGISTER_TARGET_KIND("TestStringKind", kDLCPU) diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index 8fda119b6a..a5eeab034f 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -99,7 +99,7 @@ TEST(ScalableDataType, TestInvalidStringToScalableDataType) { try { tvm::runtime::StringToDLDataType(scalable_type_str); } catch (const tvm::ffi::Error& e) { - EXPECT_THAT(e.what(), HasSubstr("unknown dtype int32x4xvscale")); + EXPECT_THAT(e.what(), HasSubstr("unknown dtype `int32x4xvscale`")); throw; } },
