comaniac commented on a change in pull request #8517:
URL: https://github.com/apache/tvm/pull/8517#discussion_r673511444
##########
File path: tests/python/unittest/test_tir_base.py
##########
@@ -30,15 +32,25 @@ def build_tir_func(func):
def test_scalar_add():
- a = tir.Var("a", "float32")
- b = tir.Var("b", "float32")
- c = a + b
- c = tir.ret(c)
- c = tir.Evaluate(c)
- func = tir.PrimFunc([a, b], c)
- func = build_tir_func(func)
- out = func(1.0, 2.0)
- assert out == 3.0
+ # All these types should be interchangeable with each other
+ # E.g. float16 + float32 upconverts the float16 --> float32
+ # Meanwhile if an int or float or together the int will be
+ # cast to the float type.
+ lhs_types = ["float32", "float16", "int32", "int64"]
+ rhs_types = ["float32", "float16"]
+ for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types):
+ # Input vars should be float32, we will cast to test for upcasting
between them
+ lhs_input = tir.Var("lhs", "float32")
+ rhs_input = tir.Var("rhs", "float32")
+ lhs = tir.Cast(lhs_type, lhs_input)
+ rhs = tir.Cast(rhs_type, rhs_input)
+ output = lhs + rhs
+ output = tir.ret(output)
+ output = tir.Evaluate(output)
+ func = tir.PrimFunc([lhs_input, rhs_input], output)
+ func = build_tir_func(func)
+ out = func(1.0, 2.0)
+ assert out == 3.0
Review comment:
Will this test fail (at all combinations of lhs_types and rhs_types
except for `(float32, float32)`) without the changes in `op.cc`?
##########
File path: src/tir/op/op.cc
##########
@@ -112,12 +112,23 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs,
Span span) { // NOLINT(*)
ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype <<
" vs " << rtype;
}
if (lhs.dtype() == rhs.dtype()) return;
- // Only do very simple type coversion
+
+ // We keep casting pretty simple
+ // Two different floating point types will upconvert the lower bit floating
point
+ // to the same type as the higher bit version. E.g. fp16 + fp32 --> fp32 +
fp32.
+ // Furthermore:
// int->float, DataType::Int(32)->int(64)
// require the types to be relatively consistent
// This will the reduce amount code generated by operators
// and also help user to find potential type conversion problems.
Review comment:
Better to rewire an entire comment here. Here is my thought and feel
free to use your own words:
```
// We keep the dtypes to be relatively consistent to reduce the amount
code generated by operators.
// This can also help users find potential type conversion problems.
// Specifically, when two dtypes are inconsistent, the int or lower bit
one will be upcasted to match
// another. E.g. fp16 + fp32 --> fp32 + fp32; fp32 + int16 --> fp32 + fp32.
```
##########
File path: src/tir/op/op.cc
##########
@@ -112,12 +112,23 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs,
Span span) { // NOLINT(*)
ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype <<
" vs " << rtype;
}
if (lhs.dtype() == rhs.dtype()) return;
- // Only do very simple type coversion
+
+ // We keep casting pretty simple
+ // Two different floating point types will upconvert the lower bit floating
point
+ // to the same type as the higher bit version. E.g. fp16 + fp32 --> fp32 +
fp32.
+ // Furthermore:
// int->float, DataType::Int(32)->int(64)
// require the types to be relatively consistent
// This will the reduce amount code generated by operators
// and also help user to find potential type conversion problems.
- if (!lhs.dtype().is_float() &&
+ if (lhs.dtype().is_float() && rhs.dtype().is_float()) {
+ int max_num_bits = std::max(lhs.dtype().bits(), rhs.dtype().bits());
+ if (lhs.dtype().bits() != max_num_bits) {
+ lhs = cast(rhs.dtype(), lhs);
+ } else {
+ rhs = cast(lhs.dtype(), rhs);
+ }
Review comment:
Looks like you will still add a cast to RHS even their bits are already
the same?
```suggestion
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
} else if (lhs.dtype().bits() > rhs.dtype().bits()) {
rhs = cast(lhs.dtype(), rhs);
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]