Lunderberg commented on code in PR #16183:
URL: https://github.com/apache/tvm/pull/16183#discussion_r1704211785


##########
include/tvm/ir/expr.h:
##########
@@ -770,53 +770,121 @@ inline const TTypeNode* RelayExprNode::type_as() const {
 
 namespace tvm {
 namespace runtime {
-// common rule for RetValue and ArgValue
+
+// Automatic conversion into IntImm, Integer, and Bool, when called
+// through the FFI.  Automatic conversions into PrimExpr are
+// registered in "tvm/tir/expr.h", as it includes conversions to the
+// TIR-only StringImm.
+//
+// While the FFI only requires the From() method, these
+// implementations also define a TryFrom() method to avoid duplicate
+// logic in the PrimExpr conversion.
+
 template <>
-struct PackedFuncValueConverter<PrimExpr> {
-  static PrimExpr From(const TVMPODValue_& val) {
-    if (val.type_code() == kTVMNullptr) {
-      return PrimExpr(ObjectPtr<Object>(nullptr));
-    }
-    if (val.type_code() == kDLInt) {
-      int64_t value = val.operator int64_t();
-      if (value > std::numeric_limits<int>::max() || value < 
std::numeric_limits<int>::min()) {
-        return IntImm(runtime::DataType::Int(64), value);
-      }
-      return IntImm(runtime::DataType::Int(32), val.operator int());
-    }
-    if (val.type_code() == kDLFloat) {
-      return FloatImm(runtime::DataType::Float(32), val.operator double());
+struct PackedFuncValueConverter<tvm::IntImm> {
+  template <typename PODSubclass>
+  static Optional<tvm::IntImm> TryFrom(const PODSubclass& val) {
+    if (auto opt = val.TryAsInt()) {
+      int64_t value = opt.value();
+      auto dtype =
+          (value > std::numeric_limits<int>::max() || value < 
std::numeric_limits<int>::min())
+              ? DataType::Int(64)
+              : DataType::Int(32);
+      return IntImm(dtype, value);
+    } else if (auto opt = val.TryAsBool()) {
+      return IntImm(DataType::Int(32), opt.value());
+    } else {
+      return NullOpt;
     }
+  }
 
-    return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
+  template <typename PODSubclass>
+  static tvm::IntImm From(const PODSubclass& val) {
+    if (auto opt = TryFrom(val)) {
+      return opt.value();
+    } else {
+      return val.template AsObjectRef<tvm::IntImm>();
+    }
   }
 };
 
 template <>
 struct PackedFuncValueConverter<tvm::Integer> {
-  static tvm::Integer From(const TVMPODValue_& val) {
-    if (val.type_code() == kTVMNullptr) {
-      return Integer(ObjectPtr<Object>(nullptr));
+  template <typename PODSubclass>
+  static tvm::Integer From(const PODSubclass& val) {
+    if (auto opt = PackedFuncValueConverter<tvm::IntImm>::TryFrom(val)) {
+      return Integer(opt.value());
+    } else {
+      return val.template AsObjectRef<tvm::Integer>();
     }
-    if (val.type_code() == kTVMArgInt) {
-      return Integer(val.operator int());
-    }
-    return val.AsObjectRef<tvm::Integer>();
   }
 };
 
 template <>
 struct PackedFuncValueConverter<tvm::Bool> {
-  static tvm::Bool From(const TVMPODValue_& val) {
-    if (val.type_code() == kTVMNullptr) {
-      return Bool(ObjectPtr<Object>(nullptr));
+  template <typename PODSubclass>
+  static Optional<tvm::Bool> TryFrom(const PODSubclass& val) {
+    if (auto opt = val.TryAsBool()) {
+      return tvm::Bool(opt.value());
+    } else if (auto opt = val.TryAsInt()) {
+      int value = opt.value();
+      ICHECK(value == 0 || value == 1)
+          << "ValueError: boolean value can only be 0 or 1, but get " << value;
+      return tvm::Bool(static_cast<bool>(value));
+    } else {
+      return NullOpt;
+    }
+  }
+
+  template <typename PODSubclass>
+  static tvm::Bool From(const PODSubclass& val) {
+    if (auto opt = TryFrom(val)) {
+      return opt.value();
+    } else {
+      return val.template AsObjectRef<tvm::Bool>();
     }
-    if (val.type_code() == kTVMArgInt) {
-      int v = val.operator int();
-      ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 
1, but get " << v;
-      return Bool(static_cast<bool>(v));
+  }
+};
+
+template <>
+struct PackedFuncValueConverter<tvm::FloatImm> {
+  static Optional<tvm::FloatImm> TryFrom(const TVMPODValue_& val) {
+    if (auto opt = val.TryAsFloat()) {
+      return FloatImm(runtime::DataType::Float(32), opt.value());
+    } else {
+      return NullOpt;
+    }
+  }
+
+  template <typename PODSubclass>
+  static tvm::FloatImm From(const PODSubclass& val) {
+    if (auto opt = TryFrom(val)) {
+      return opt.value();
+    } else {
+      return val.template AsObjectRef<tvm::FloatImm>();
+    }
+  }
+};
+
+/* \brief Backwards compatibility wrapper for IntImm arguments
+ *
+ * In previous versions of TVM, IntImm was the default FFI type for
+ * integer arguments, instead of runtime::Int.  For backwards
+ * compatibility where the callee has been updated to expected a
+ * runtime::Int, the caller has not been updated to provide a
+ * runtime::Int (e.g. relay script parsing), and the auto-unboxing of
+ * runtime::Int does not apply (e.g. making an `Array<runtime::Int>`),
+ * allow the IntImm to be generated.
+ */
+template <>
+struct PackedFuncValueConverter<runtime::Int> {

Review Comment:
   Agreed.  I've made the draft PR https://github.com/apache/tvm/pull/17241 to 
remove the backwards-compatibility handler, in order to see which callers 
relied on it.



##########
include/tvm/ir/expr.h:
##########
@@ -770,53 +770,121 @@ inline const TTypeNode* RelayExprNode::type_as() const {
 
 namespace tvm {
 namespace runtime {
-// common rule for RetValue and ArgValue
+
+// Automatic conversion into IntImm, Integer, and Bool, when called
+// through the FFI.  Automatic conversions into PrimExpr are
+// registered in "tvm/tir/expr.h", as it includes conversions to the
+// TIR-only StringImm.
+//
+// While the FFI only requires the From() method, these
+// implementations also define a TryFrom() method to avoid duplicate
+// logic in the PrimExpr conversion.
+
 template <>
-struct PackedFuncValueConverter<PrimExpr> {
-  static PrimExpr From(const TVMPODValue_& val) {
-    if (val.type_code() == kTVMNullptr) {
-      return PrimExpr(ObjectPtr<Object>(nullptr));
-    }
-    if (val.type_code() == kDLInt) {
-      int64_t value = val.operator int64_t();
-      if (value > std::numeric_limits<int>::max() || value < 
std::numeric_limits<int>::min()) {
-        return IntImm(runtime::DataType::Int(64), value);
-      }
-      return IntImm(runtime::DataType::Int(32), val.operator int());
-    }
-    if (val.type_code() == kDLFloat) {
-      return FloatImm(runtime::DataType::Float(32), val.operator double());
+struct PackedFuncValueConverter<tvm::IntImm> {
+  template <typename PODSubclass>
+  static Optional<tvm::IntImm> TryFrom(const PODSubclass& val) {
+    if (auto opt = val.TryAsInt()) {
+      int64_t value = opt.value();
+      auto dtype =
+          (value > std::numeric_limits<int>::max() || value < 
std::numeric_limits<int>::min())
+              ? DataType::Int(64)
+              : DataType::Int(32);
+      return IntImm(dtype, value);
+    } else if (auto opt = val.TryAsBool()) {
+      return IntImm(DataType::Int(32), opt.value());
+    } else {
+      return NullOpt;
     }
+  }
 
-    return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
+  template <typename PODSubclass>
+  static tvm::IntImm From(const PODSubclass& val) {
+    if (auto opt = TryFrom(val)) {
+      return opt.value();
+    } else {
+      return val.template AsObjectRef<tvm::IntImm>();
+    }
   }
 };
 
 template <>
 struct PackedFuncValueConverter<tvm::Integer> {
-  static tvm::Integer From(const TVMPODValue_& val) {
-    if (val.type_code() == kTVMNullptr) {
-      return Integer(ObjectPtr<Object>(nullptr));
+  template <typename PODSubclass>
+  static tvm::Integer From(const PODSubclass& val) {
+    if (auto opt = PackedFuncValueConverter<tvm::IntImm>::TryFrom(val)) {
+      return Integer(opt.value());
+    } else {
+      return val.template AsObjectRef<tvm::Integer>();
     }
-    if (val.type_code() == kTVMArgInt) {
-      return Integer(val.operator int());
-    }
-    return val.AsObjectRef<tvm::Integer>();
   }
 };
 
 template <>
 struct PackedFuncValueConverter<tvm::Bool> {
-  static tvm::Bool From(const TVMPODValue_& val) {
-    if (val.type_code() == kTVMNullptr) {
-      return Bool(ObjectPtr<Object>(nullptr));
+  template <typename PODSubclass>
+  static Optional<tvm::Bool> TryFrom(const PODSubclass& val) {
+    if (auto opt = val.TryAsBool()) {
+      return tvm::Bool(opt.value());
+    } else if (auto opt = val.TryAsInt()) {
+      int value = opt.value();
+      ICHECK(value == 0 || value == 1)
+          << "ValueError: boolean value can only be 0 or 1, but get " << value;
+      return tvm::Bool(static_cast<bool>(value));
+    } else {
+      return NullOpt;
+    }
+  }
+
+  template <typename PODSubclass>
+  static tvm::Bool From(const PODSubclass& val) {
+    if (auto opt = TryFrom(val)) {
+      return opt.value();
+    } else {
+      return val.template AsObjectRef<tvm::Bool>();
     }
-    if (val.type_code() == kTVMArgInt) {
-      int v = val.operator int();
-      ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 
1, but get " << v;
-      return Bool(static_cast<bool>(v));
+  }
+};
+
+template <>
+struct PackedFuncValueConverter<tvm::FloatImm> {
+  static Optional<tvm::FloatImm> TryFrom(const TVMPODValue_& val) {
+    if (auto opt = val.TryAsFloat()) {
+      return FloatImm(runtime::DataType::Float(32), opt.value());
+    } else {
+      return NullOpt;
+    }
+  }
+
+  template <typename PODSubclass>
+  static tvm::FloatImm From(const PODSubclass& val) {
+    if (auto opt = TryFrom(val)) {
+      return opt.value();
+    } else {
+      return val.template AsObjectRef<tvm::FloatImm>();
+    }
+  }
+};
+
+/* \brief Backwards compatibility wrapper for IntImm arguments
+ *
+ * In previous versions of TVM, IntImm was the default FFI type for
+ * integer arguments, instead of runtime::Int.  For backwards
+ * compatibility where the callee has been updated to expected a
+ * runtime::Int, the caller has not been updated to provide a
+ * runtime::Int (e.g. relay script parsing), and the auto-unboxing of
+ * runtime::Int does not apply (e.g. making an `Array<runtime::Int>`),
+ * allow the IntImm to be generated.
+ */
+template <>
+struct PackedFuncValueConverter<runtime::Int> {

Review Comment:
   Agreed.  I've made the draft PR https://github.com/apache/tvm/pull/17241 to 
remove the backwards-compatibility handler, in order to run CI and see which 
callers relied on it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to