This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 173b4fc Fixed div by zero core dump. Fixed rounding intrinsics on int
crash (#5026)
173b4fc is described below
commit 173b4fc4c46499056ebc5682c20fcff2582bc9db
Author: pankratz <[email protected]>
AuthorDate: Thu Mar 12 10:35:36 2020 -0600
Fixed div by zero core dump. Fixed rounding intrinsics on int crash (#5026)
---
src/arith/const_fold.h | 2 ++
src/tir/ir/op.cc | 15 +++++++++++++++
tests/python/unittest/test_lang_basic.py | 19 +++++++++++++++++--
tests/python/unittest/test_tvm_intrin.py | 11 +++++++++++
4 files changed, 45 insertions(+), 2 deletions(-)
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index bae34bd..a440af9 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -181,6 +181,7 @@ inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a,
PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
+ CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value % pb->value);
}
if (pa) {
@@ -226,6 +227,7 @@ inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a,
PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
+ CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, floormod(pa->value, pb->value));
}
if (pa) {
diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc
index 452c3bb..2882fea 100644
--- a/src/tir/ir/op.cc
+++ b/src/tir/ir/op.cc
@@ -606,6 +606,9 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
}
PrimExpr floor(PrimExpr x) {
+ if (x.dtype().is_int() || x.dtype().is_uint()) {
+ return x;
+ }
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
@@ -613,6 +616,9 @@ PrimExpr floor(PrimExpr x) {
}
PrimExpr ceil(PrimExpr x) {
+ if (x.dtype().is_int() || x.dtype().is_uint()) {
+ return x;
+ }
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
@@ -620,6 +626,9 @@ PrimExpr ceil(PrimExpr x) {
}
PrimExpr round(PrimExpr x) {
+ if (x.dtype().is_int() || x.dtype().is_uint()) {
+ return x;
+ }
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
@@ -627,6 +636,9 @@ PrimExpr round(PrimExpr x) {
}
PrimExpr nearbyint(PrimExpr x) {
+ if (x.dtype().is_int() || x.dtype().is_uint()) {
+ return x;
+ }
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
@@ -634,6 +646,9 @@ PrimExpr nearbyint(PrimExpr x) {
}
PrimExpr trunc(PrimExpr x) {
+ if (x.dtype().is_int() || x.dtype().is_uint()) {
+ return x;
+ }
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
diff --git a/tests/python/unittest/test_lang_basic.py
b/tests/python/unittest/test_lang_basic.py
index cd532a0..c279194 100644
--- a/tests/python/unittest/test_lang_basic.py
+++ b/tests/python/unittest/test_lang_basic.py
@@ -187,14 +187,14 @@ def test_bitwise():
assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype ==
"int8x2"
+
def test_float_bitwise():
t = tvm.tir.const(1.5,dtype='float32')
for test in [lambda lhs, rhs : lhs << rhs,
lambda lhs, rhs : lhs >> rhs,
lambda lhs, rhs : lhs | rhs,
lambda lhs, rhs : lhs ^ rhs,
- lambda lhs, rhs : lhs & rhs
- ]:
+ lambda lhs, rhs : lhs & rhs]:
try:
test(t,10.0)
assert False
@@ -206,6 +206,20 @@ def test_float_bitwise():
except RuntimeError:
pass
+
+def test_divide_by_zero():
+ for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
+ lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
+ lambda lhs, rhs : tvm.tir.truncmod(lhs,rhs),
+ lambda lhs, rhs : tvm.tir.truncdiv(lhs,rhs),
+ lambda lhs, rhs : tvm.tir.div(lhs,rhs)]:
+ try:
+ test(tvm.tir.const(5,'int32'), tvm.tir.const(0,'int32'))
+ assert False
+ except tvm.TVMError:
+ pass
+
+
def test_isnan():
x = te.var('x', 'float32')
assert str(tvm.tir.isnan(x)) == 'isnan(x)'
@@ -250,6 +264,7 @@ if __name__ == "__main__":
test_all()
test_bitwise()
test_float_bitwise()
+ test_divide_by_zero()
test_isnan()
test_equality()
test_equality_string_imm()
diff --git a/tests/python/unittest/test_tvm_intrin.py
b/tests/python/unittest/test_tvm_intrin.py
index 0054273..52ae440 100644
--- a/tests/python/unittest/test_tvm_intrin.py
+++ b/tests/python/unittest/test_tvm_intrin.py
@@ -44,6 +44,16 @@ def test_nearbyint():
tvm.testing.assert_allclose(
a_rounded.asnumpy(), np.rint(a.asnumpy()))
+def test_round_intrinsics_on_int():
+ i = tvm.te.var("i", 'int32')
+ for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil,
+ tvm.tir.floor, tvm.tir.nearbyint]:
+ assert op(tvm.tir.const(10,'int32')).value == 10
+ assert op(tvm.tir.const(True,'bool')).value == True
+ assert op(i).same_as(i)
+
+ assert tvm.tir.isnan(tvm.tir.const(10, 'int32')).value == False
+
def test_unary_intrin():
test_funcs = [
@@ -75,3 +85,4 @@ def test_unary_intrin():
if __name__ == "__main__":
test_nearbyint()
test_unary_intrin()
+ test_round_intrinsics_on_int()