This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9e88723385 [TIR] Improved error messages for PrimExpr operator
overloads (#12638)
9e88723385 is described below
commit 9e88723385f83a2d27a60432cbe50782bed2885f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Aug 29 17:27:34 2022 -0700
[TIR] Improved error messages for PrimExpr operator overloads (#12638)
Previously, type-checks in boolean operators on `PrimExpr` would
state that the type is incorrect, but further investigation would be
required in order to determine what expression caused the error.
After this commit, error messages for these type checks include the
expression that was used, and the dtype of that expression.
---
src/tir/op/op.cc | 58 ++++++++++++++++++++++++++++++++++++++------------------
1 file changed, 40 insertions(+), 18 deletions(-)
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 69d1da5e8c..b9e0c3c370 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -520,10 +520,37 @@ PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) {
return tir::NE(a, b, span);
}
+namespace {
+void type_check_boolean_args(const PrimExpr& arg, const char* op) {
+ ICHECK(arg.dtype().is_bool()) << "Expected boolean argument for " << op <<
", but received "
+ << arg << " of type " << arg.dtype();
+}
+void type_check_boolean_args(const PrimExpr& lhs, const PrimExpr& rhs, const
char* op) {
+ ICHECK(lhs.dtype().is_bool()) << "Expected boolean argument as LHS of " <<
op << ", but received "
+ << lhs << " of type " << lhs.dtype();
+ ICHECK(rhs.dtype().is_bool()) << "Expected boolean argument as RHS of " <<
op << ", but received "
+ << rhs << " of type " << rhs.dtype();
+}
+
+void type_check_integer_args(const PrimExpr& arg, const char* op) {
+ ICHECK(arg.dtype().is_int() || arg.dtype().is_uint())
+ << "Expected integer argument for " << op << ", but received " << arg <<
" of type "
+ << arg.dtype();
+}
+
+void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const
char* op) {
+ ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint())
+ << "Expected integer argument as LHS of " << op << ", but received " <<
lhs << " of type "
+ << lhs.dtype();
+ ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint())
+ << "Expected integer argument as RHS of " << op << ", but received " <<
rhs << " of type "
+ << rhs.dtype();
+}
+} // namespace
+
PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) {
- ICHECK(a.dtype().is_bool());
- ICHECK(b.dtype().is_bool());
+ type_check_boolean_args(a, b, "&& operator (logical AND)");
PrimExpr ret = arith::TryConstFold<tir::And>(a, b);
if (ret.defined()) return ret;
return tir::And(a, b, span);
@@ -531,8 +558,7 @@ PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) {
PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); }
PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) {
- ICHECK(a.dtype().is_bool());
- ICHECK(b.dtype().is_bool());
+ type_check_boolean_args(a, b, "|| operator (logical OR)");
PrimExpr ret = arith::TryConstFold<tir::Or>(a, b);
if (ret.defined()) return ret;
return tir::Or(a, b, span);
@@ -540,7 +566,7 @@ PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) {
PrimExpr operator!(PrimExpr a) { return logical_not(a); }
PrimExpr logical_not(PrimExpr a, Span span) {
- ICHECK(a.dtype().is_bool());
+ type_check_boolean_args(a, "! operator (logical NOT)");
PrimExpr ret = arith::TryConstFold<tir::Not>(a);
if (ret.defined()) return ret;
return tir::Not(a, span);
@@ -550,8 +576,8 @@ PrimExpr logical_not(PrimExpr a, Span span) {
PrimExpr operator>>(PrimExpr a, PrimExpr b) { return right_shift(a, b); }
PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) {
- ICHECK(a.dtype().is_int() || a.dtype().is_uint());
- ICHECK(b.dtype().is_int() || b.dtype().is_uint());
+ type_check_integer_args(a, b, ">> operator (right shift)");
+
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
@@ -573,8 +599,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) {
// shift left
PrimExpr operator<<(PrimExpr a, PrimExpr b) { return left_shift(a, b); }
PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
- ICHECK(a.dtype().is_int() || a.dtype().is_uint());
- ICHECK(b.dtype().is_int() || b.dtype().is_uint());
+ type_check_integer_args(a, b, "<< operator (left shift)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
@@ -593,8 +618,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
// bitwise and
PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); }
PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
- ICHECK(a.dtype().is_int() || a.dtype().is_uint());
- ICHECK(b.dtype().is_int() || b.dtype().is_uint());
+ type_check_integer_args(a, b, "& operator (bitwise AND)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
@@ -606,8 +630,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
// bitwise_or
PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); }
PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
- ICHECK(a.dtype().is_int() || a.dtype().is_uint());
- ICHECK(b.dtype().is_int() || b.dtype().is_uint());
+ type_check_integer_args(a, b, "| operator (bitwise OR)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
@@ -619,8 +642,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
// bitwise_xor
PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
- ICHECK(a.dtype().is_int() || a.dtype().is_uint());
- ICHECK(b.dtype().is_int() || b.dtype().is_uint());
+ type_check_integer_args(a, b, "^ operator (bitwise XOR)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
@@ -633,7 +655,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }
PrimExpr bitwise_neg(PrimExpr a, Span span) {
- ICHECK(a.dtype().is_int() || a.dtype().is_uint());
+ type_check_integer_args(a, "~ operator (bitwise NOT)");
return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span);
}
@@ -728,7 +750,7 @@ PrimExpr sum(PrimExpr source, Array<IterVar> rdom,
Array<PrimExpr> init, Span sp
}
PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span
span) {
- ICHECK(source.dtype().is_bool());
+ type_check_boolean_args(source, "tvm::all");
Var x("x", source.dtype(), span), y("y", source.dtype());
PrimExpr result = tir::And(x, y, span);
PrimExpr identity_element = make_const(source.dtype(), true, span);
@@ -737,7 +759,7 @@ PrimExpr all(PrimExpr source, Array<IterVar> rdom,
Array<PrimExpr> init, Span sp
}
PrimExpr any(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span
span) {
- ICHECK(source.dtype().is_bool());
+ type_check_boolean_args(source, "tvm::any");
Var x("x", source.dtype(), span), y("y", source.dtype(), span);
PrimExpr result = tir::Or(x, y, span);
PrimExpr identity_element = make_const(source.dtype(), false, span);