AndrewZhaoLuo commented on a change in pull request #8517:
URL: https://github.com/apache/tvm/pull/8517#discussion_r673516044



##########
File path: src/tir/op/op.cc
##########
@@ -112,12 +112,23 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, 
Span span) {  // NOLINT(*)
     ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << 
" vs " << rtype;
   }
   if (lhs.dtype() == rhs.dtype()) return;
-  // Only do very simple type coversion
+
+  // We keep casting pretty simple
+  // Two different floating point types will upconvert the lower bit floating 
point 
+  // to the same type as the higher bit version. E.g. fp16 + fp32 --> fp32 + 
fp32.
+  // Furthermore:
   // int->float, DataType::Int(32)->int(64)
   // require the types to be relatively consistent
   // This will the reduce amount code generated by operators
   // and also help user to find potential type conversion problems.

Review comment:
       I've reorganized comments

##########
File path: src/tir/op/op.cc
##########
@@ -112,12 +112,23 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, 
Span span) {  // NOLINT(*)
     ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << 
" vs " << rtype;
   }
   if (lhs.dtype() == rhs.dtype()) return;
-  // Only do very simple type coversion
+
+  // We keep casting pretty simple
+  // Two different floating point types will upconvert the lower bit floating 
point 
+  // to the same type as the higher bit version. E.g. fp16 + fp32 --> fp32 + 
fp32.
+  // Furthermore:
   // int->float, DataType::Int(32)->int(64)
   // require the types to be relatively consistent
   // This will the reduce amount code generated by operators
   // and also help user to find potential type conversion problems.
-  if (!lhs.dtype().is_float() &&
+  if (lhs.dtype().is_float() && rhs.dtype().is_float()) {
+    int max_num_bits = std::max(lhs.dtype().bits(), rhs.dtype().bits());
+    if (lhs.dtype().bits() != max_num_bits) {
+      lhs = cast(rhs.dtype(), lhs);
+    } else {
+      rhs = cast(lhs.dtype(), rhs);
+    }

Review comment:
       If they have the same bits then it should hit the return statement above.
   
   That being said your code is probably easier to read so I will replace it.
   
   Done!

##########
File path: tests/python/unittest/test_tir_base.py
##########
@@ -30,15 +32,25 @@ def build_tir_func(func):
 
 
 def test_scalar_add():
-    a = tir.Var("a", "float32")
-    b = tir.Var("b", "float32")
-    c = a + b
-    c = tir.ret(c)
-    c = tir.Evaluate(c)
-    func = tir.PrimFunc([a, b], c)
-    func = build_tir_func(func)
-    out = func(1.0, 2.0)
-    assert out == 3.0
+    # All these types should be interchangeable with each other
+    # E.g. float16 + float32 upconverts the float16 --> float32
+    # Meanwhile if an int or float or together the int will be
+    # cast to the float type.
+    lhs_types = ["float32", "float16", "int32", "int64"]
+    rhs_types = ["float32", "float16"]
+    for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types):
+        # Input vars should be float32, we will cast to test for upcasting 
between them
+        lhs_input = tir.Var("lhs", "float32")
+        rhs_input = tir.Var("rhs", "float32")
+        lhs = tir.Cast(lhs_type, lhs_input)
+        rhs = tir.Cast(rhs_type, rhs_input)
+        output = lhs + rhs
+        output = tir.ret(output)
+        output = tir.Evaluate(output)
+        func = tir.PrimFunc([lhs_input, rhs_input], output)
+        func = build_tir_func(func)
+        out = func(1.0, 2.0)
+        assert out == 3.0

Review comment:
       It should work except when mixing fp32 and fp16. I have tested the old 
version with fp32 and fp16 and it indeed fails.
   
   The old code does do conversions between floating point types and integer 
types as expected.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to