This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s1
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 8d2d64aa8de2ef13e48810e8f8821509e4ab169c
Author: tqchen <[email protected]>
AuthorDate: Sat Apr 12 20:51:16 2025 -0400

    fix msc and cpptests
---
 src/contrib/msc/core/ir/graph_builder.cc            |  2 +-
 src/contrib/msc/core/ir/graph_builder.h             | 16 ++++++++++++++++
 src/contrib/msc/core/printer/msc_base_printer.cc    | 15 +++++++++------
 src/contrib/msc/core/transform/set_expr_layout.cc   | 14 +++++++-------
 src/contrib/msc/core/transform/set_expr_name.cc     |  2 +-
 src/contrib/msc/framework/tensorflow/codegen.cc     |  2 +-
 src/contrib/msc/framework/tensorrt/codegen.cc       |  2 +-
 .../msc/framework/tensorrt/transform_tensorrt.cc    |  6 +++---
 src/contrib/msc/framework/torch/codegen.cc          |  2 +-
 src/contrib/msc/framework/tvm/codegen.cc            |  2 +-
 src/target/llvm/intrin_rule_hexagon.cc              | 21 ++++++++-------------
 tests/cpp/target_test.cc                            |  4 ++--
 tests/cpp/tir_scalable_datatype.cc                  |  2 +-
 13 files changed, 52 insertions(+), 38 deletions(-)

diff --git a/src/contrib/msc/core/ir/graph_builder.cc 
b/src/contrib/msc/core/ir/graph_builder.cc
index 35a85f1e49..853f75216f 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -725,7 +725,7 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* 
binding, const CallNode*
     AddNode(GetRef<Call>(call_node), binding->var, name);
   } catch (runtime::InternalError& err) {
     LOG(WARNING) << "Failed to add node from " << binding->var << " : " << 
binding->value
-                 << ", reason: " << err.message();
+                 << ", reason: " << err.what();
     throw err;
   }
 }
diff --git a/src/contrib/msc/core/ir/graph_builder.h 
b/src/contrib/msc/core/ir/graph_builder.h
index ce3c7bd022..00582fab4b 100644
--- a/src/contrib/msc/core/ir/graph_builder.h
+++ b/src/contrib/msc/core/ir/graph_builder.h
@@ -126,6 +126,22 @@ class AttrGetter : public AttrVisitor {
 
   void Visit(const char* key, std::string* value) final { attrs_->Set(key, 
*value); }
 
+  void Visit(const char* key, Optional<double>* value) final {
+    if (value->has_value()) {
+      attrs_->Set(key, std::to_string(value->value()));
+    } else {
+      attrs_->Set(key, "");
+    }
+  }
+
+  void Visit(const char* key, Optional<int64_t>* value) final {
+    if (value->has_value()) {
+      attrs_->Set(key, std::to_string(value->value()));
+    } else {
+      attrs_->Set(key, "");
+    }
+  }
+
   void Visit(const char* key, DataType* value) final {
     attrs_->Set(key, runtime::DLDataTypeToString(*value));
   }
diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc 
b/src/contrib/msc/core/printer/msc_base_printer.cc
index 289c1b79fd..5c3832fb43 100644
--- a/src/contrib/msc/core/printer/msc_base_printer.cc
+++ b/src/contrib/msc/core/printer/msc_base_printer.cc
@@ -97,12 +97,14 @@ void MSCBasePrinter::PrintDoc(const Doc& doc, bool 
new_line) {
 }
 
 void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) {
-  const ObjectRef& value = doc->value;
-  if (!value.defined()) {
+  const Any& value = doc->value;
+  if (value == nullptr) {
     output_ << "\"\"";
-  } else if (const auto* int_imm = value.as<IntImmNode>()) {
+  } else if (auto opt_int_imm = value.as<IntImm>()) {
+    IntImm int_imm = *std::move(opt_int_imm);
     output_ << int_imm->value;
-  } else if (const auto* float_imm = value.as<FloatImmNode>()) {
+  } else if (auto opt_float_imm = value.as<FloatImm>()) {
+    FloatImm float_imm = *std::move(opt_float_imm);
     output_.precision(config_.float_precision);
     if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) {
       output_ << '"' << float_imm->value << '"';
@@ -110,9 +112,10 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) {
       output_ << float_imm->value;
     }
   } else if (const auto* string_obj = value.as<StringObj>()) {
-    output_ << "\"" << tvm::support::StrEscape(string_obj->data, 
string_obj->size) << "\"";
+    output_ << "\"" << tvm::support::StrEscape(
+      string_obj->bytes.data, string_obj->bytes.size) << "\"";
   } else {
-    LOG(FATAL) << "TypeError: Unsupported literal value type: " << 
value->GetTypeKey();
+    LOG(FATAL) << "TypeError: Unsupported literal value type: " << 
value.GetTypeKey();
   }
 }
 
diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc 
b/src/contrib/msc/core/transform/set_expr_layout.cc
index 00e4c5dcfc..68d55a1d2f 100644
--- a/src/contrib/msc/core/transform/set_expr_layout.cc
+++ b/src/contrib/msc/core/transform/set_expr_layout.cc
@@ -266,7 +266,7 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& 
call,
   if (attrs->keepdims) {
     return InferLayoutOutput({input_layout}, {input_layout}, Attrs());
   }
-  if (!attrs->axis.defined()) {
+  if (!attrs->axis.has_value()) {
     return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs());
   }
   const auto& input_shape = ExprUtils::GetShape(call->args[0]);
@@ -274,7 +274,7 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& 
call,
     return InferLayoutOutput();
   }
   std::vector<size_t> axes;
-  axes.push_back(CommonUtils::GetIndex(Downcast<Integer>(attrs->axis)->value, 
input_shape.size()));
+  axes.push_back(CommonUtils::GetIndex(attrs->axis.value(), 
input_shape.size()));
   LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, axes);
   return InferLayoutOutput({input_layout}, {output_layout}, Attrs());
 }
@@ -1102,7 +1102,7 @@ class LayoutInfer : public ExprVisitor {
               BackwardInferLayoutCommon(call, Map<String, Array<String>>(), 
var_layout_map_);
         }
       } catch (runtime::InternalError& err) {
-        LOG(WARNING) << "Failed to backward infer layout " << expr << " : " << 
err.message();
+        LOG(WARNING) << "Failed to backward infer layout " << expr << " : " << 
err.what();
         infered_layout = InferLayoutOutput();
       }
       try {
@@ -1111,7 +1111,7 @@ class LayoutInfer : public ExprVisitor {
         }
       } catch (runtime::InternalError& err) {
         LOG(WARNING) << "Failed to backward set inputs layout for " << call << 
" : "
-                     << err.message();
+                     << err.what();
       }
     }
   }
@@ -1163,7 +1163,7 @@ class LayoutInfer : public ExprVisitor {
           }
         } catch (runtime::InternalError& err) {
           LOG(WARNING) << "Failed to forward infer layout for " << 
binding->var << " : "
-                       << binding->value << ", reason: " << err.message();
+                       << binding->value << ", reason: " << err.what();
           infered_layout = InferLayoutOutput();
         }
         if (infered_layout.defined() && infered_layout->output_layouts.size() 
== 1) {
@@ -1171,7 +1171,7 @@ class LayoutInfer : public ExprVisitor {
             SetExprLayout(binding->var, infered_layout->output_layouts[0]);
           } catch (runtime::InternalError& err) {
             LOG(WARNING) << "Failed to forward set output layout for " << 
binding->var << " : "
-                         << binding->value << ", reason: " << err.message();
+                         << binding->value << ", reason: " << err.what();
           }
         }
         if (set_inputs && infered_layout.defined()) {
@@ -1179,7 +1179,7 @@ class LayoutInfer : public ExprVisitor {
             SetInputLayouts(call, infered_layout->input_layouts);
           } catch (runtime::InternalError& err) {
             LOG(WARNING) << "Failed to forward set inputs layout for " << call 
<< " : "
-                         << err.message();
+                         << err.what();
           }
         }
       }
diff --git a/src/contrib/msc/core/transform/set_expr_name.cc 
b/src/contrib/msc/core/transform/set_expr_name.cc
index a2c7ae0142..f4602d2784 100644
--- a/src/contrib/msc/core/transform/set_expr_name.cc
+++ b/src/contrib/msc/core/transform/set_expr_name.cc
@@ -210,7 +210,7 @@ class RelaxExprNameSetter : public ExprVisitor {
         input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true);
       } catch (runtime::InternalError& err) {
         LOG(WARNING) << "Failed to GetInputTypes for " << GetRef<Call>(val) << 
" : "
-                     << err.message();
+                     << err.what();
         throw err;
       }
       for (size_t i = 0; i < input_types.size(); i++) {
diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc 
b/src/contrib/msc/framework/tensorflow/codegen.cc
index 634dd79698..9506d4eac8 100644
--- a/src/contrib/msc/framework/tensorflow/codegen.cc
+++ b/src/contrib/msc/framework/tensorflow/codegen.cc
@@ -145,7 +145,7 @@ const Array<Doc> TensorflowCodeGen::GetOpCodes(const 
MSCJoint& node) {
   try {
     return it->second->GetDocs();
   } catch (runtime::InternalError& err) {
-    LOG(WARNING) << "Failed to get docs for " << node << " : " << 
err.message();
+    LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what();
     throw err;
   }
 }
diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc 
b/src/contrib/msc/framework/tensorrt/codegen.cc
index 0472e49b61..f3def93f52 100644
--- a/src/contrib/msc/framework/tensorrt/codegen.cc
+++ b/src/contrib/msc/framework/tensorrt/codegen.cc
@@ -548,7 +548,7 @@ const Array<Doc> TensorRTCodeGen::GetOpCodes(const 
MSCJoint& node) {
   try {
     return it->second->GetDocs();
   } catch (runtime::InternalError& err) {
-    LOG(WARNING) << "Failed to get docs for " << node << " : " << 
err.message();
+    LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what();
     throw err;
   }
 }
diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc 
b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
index 9bbfc1c22d..f1866a9b90 100644
--- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
+++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
@@ -163,8 +163,8 @@ Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, 
const Call& src_call
   static const Op& topk_op = Op::Get("relax.topk");
   auto topk_attrs = make_object<TopKAttrs>();
   topk_attrs->k = 1;
-  if (src_attrs->axis.defined()) {
-    topk_attrs->axis = src_attrs->axis.value()->value;
+  if (src_attrs->axis.has_value()) {
+    topk_attrs->axis = src_attrs->axis.value();
   }
   topk_attrs->largest = call->op == Op::Get("relax.argmax");
   topk_attrs->ret_type = "both";
@@ -395,7 +395,7 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var& 
var, const Call& src_ca
     if (in_dim != out_dim) {
       Array<Expr> concat_inputs(out_dim / in_dim, concat_input);
       auto concat_attrs = make_object<ConcatAttrs>();
-      concat_attrs->axis = Integer(i);
+      concat_attrs->axis = i;
       concat_input = RewriteUtils::MakeCall(
           builder, ExprUtils::GetSpanName(call, "concat_" + 
std::to_string(i)), concat_op,
           {Tuple(concat_inputs)}, Attrs(concat_attrs));
diff --git a/src/contrib/msc/framework/torch/codegen.cc 
b/src/contrib/msc/framework/torch/codegen.cc
index 86351bdd06..547c1c22ba 100644
--- a/src/contrib/msc/framework/torch/codegen.cc
+++ b/src/contrib/msc/framework/torch/codegen.cc
@@ -146,7 +146,7 @@ const Array<Doc> TorchCodeGen::GetOpCodes(const MSCJoint& 
node) {
   try {
     return it->second->GetDocs();
   } catch (runtime::InternalError& err) {
-    LOG(WARNING) << "Failed to get docs for " << node << " : " << 
err.message();
+    LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what();
     throw err;
   }
 }
diff --git a/src/contrib/msc/framework/tvm/codegen.cc 
b/src/contrib/msc/framework/tvm/codegen.cc
index 5443cdc96a..1d6d74d7e4 100644
--- a/src/contrib/msc/framework/tvm/codegen.cc
+++ b/src/contrib/msc/framework/tvm/codegen.cc
@@ -205,7 +205,7 @@ const Array<Doc> RelaxCodeGen::GetOpCodes(const MSCJoint& 
node) {
   try {
     return it->second->GetDocs();
   } catch (runtime::InternalError& err) {
-    LOG(WARNING) << "Failed to get docs for " << node << " : " << 
err.message();
+    LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what();
     throw err;
   }
 }
diff --git a/src/target/llvm/intrin_rule_hexagon.cc 
b/src/target/llvm/intrin_rule_hexagon.cc
index 2661f2fa65..58661c9978 100644
--- a/src/target/llvm/intrin_rule_hexagon.cc
+++ b/src/target/llvm/intrin_rule_hexagon.cc
@@ -57,10 +57,8 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) 
{
   const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent");
   ICHECK(f != nullptr);
   const auto ret = (*f)(true);
-  const Target t = ret.AsObjectRef<Target>();
-  bool useqhl = true;
-  if (t.defined()) {
-    const std::string tstring = t->str();
+  if (auto opt_target = ret.as<Target>()) {
+    const std::string tstring = opt_target.value()->str();
     useqhl = tstring.find("+hvx-qfloat") != std::string::npos;
   }
 
@@ -108,10 +106,9 @@ TVM_REGISTER_OP("tir.tanh")
       const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent");
       ICHECK(f != nullptr);
       const auto ret = (*f)(true);
-      const Target t = ret.AsObjectRef<Target>();
       bool useqhl = true;
-      if (t.defined()) {
-        const std::string tstring = t->str();
+      if (auto opt_target = ret.as<Target>()) {
+        const std::string tstring = opt_target.value()->str();
         useqhl = tstring.find("+hvx-qfloat") != std::string::npos;
       }
 
@@ -144,10 +141,9 @@ TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>(
       const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent");
       ICHECK(f != nullptr);
       const auto ret = (*f)(true);
-      const Target t = ret.AsObjectRef<Target>();
       bool useqhl = true;
-      if (t.defined()) {
-        const std::string tstring = t->str();
+      if (auto opt_target = ret.as<Target>()) {
+        const std::string tstring = opt_target.value()->str();
         useqhl = tstring.find("+hvx-qfloat") != std::string::npos;
       }
 
@@ -175,10 +171,9 @@ TVM_REGISTER_OP("tir.sigmoid")
       const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent");
       ICHECK(f != nullptr);
       const auto ret = (*f)(true);
-      const Target t = ret.AsObjectRef<Target>();
       bool useqhl = true;
-      if (t.defined()) {
-        const std::string tstring = t->str();
+      if (auto opt_target = ret.as<Target>()) {
+        const std::string tstring = opt_target.value()->str();
         useqhl = tstring.find("+hvx-qfloat") != std::string::npos;
       }
 
diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc
index 50c611814f..a75ba09e82 100644
--- a/tests/cpp/target_test.cc
+++ b/tests/cpp/target_test.cc
@@ -191,7 +191,7 @@ TEST(TargetCreation, TargetFeaturesBeforeParser) {
       {"mcpu", String("woof")},
       {"features", features},
   };
-  EXPECT_THROW(Target test(config), InternalError);
+  EXPECT_THROW(Target test(config), ffi::Error);
 }
 
 TEST(TargetCreation, TargetAttrsPreProcessor) {
@@ -200,7 +200,7 @@ TEST(TargetCreation, TargetAttrsPreProcessor) {
 }
 
 TEST(TargetCreation, ClashingTargetProcessing) {
-  EXPECT_THROW(Target test("TestClashingPreprocessor -mcpu=woof -mattr=cake"), 
InternalError);
+  EXPECT_THROW(Target test("TestClashingPreprocessor -mcpu=woof -mattr=cake"), 
ffi::Error);
 }
 
 TVM_REGISTER_TARGET_KIND("TestStringKind", kDLCPU)
diff --git a/tests/cpp/tir_scalable_datatype.cc 
b/tests/cpp/tir_scalable_datatype.cc
index 8fda119b6a..a5eeab034f 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -99,7 +99,7 @@ TEST(ScalableDataType, TestInvalidStringToScalableDataType) {
         try {
           tvm::runtime::StringToDLDataType(scalable_type_str);
         } catch (const tvm::ffi::Error& e) {
-          EXPECT_THAT(e.what(), HasSubstr("unknown dtype int32x4xvscale"));
+          EXPECT_THAT(e.what(), HasSubstr("unknown dtype `int32x4xvscale`"));
           throw;
         }
       },

Reply via email to