This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b9fa576ab3 [BugFix] Use shape dtype on ArgReduce to determine return 
type (#12083)
b9fa576ab3 is described below

commit b9fa576ab35ef67e0a7fbaf669109f010e66e20c
Author: Everton Constantino <[email protected]>
AuthorDate: Thu Jul 14 14:54:56 2022 -0300

    [BugFix] Use shape dtype on ArgReduce to determine return type (#12083)
    
    Fix ArgReduce automatic return type inference by forcing it to use the
    datatype of the shape of the Tensor instead of the fixed Int32.
    
    Including additional tests.
---
 src/relay/op/tensor/reduce.cc         |  2 +-
 tests/python/relay/test_type_infer.py | 38 +++++++++++++++++++++++++++++++++++
 2 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc
index fba2a60cec..2b1afc6e55 100644
--- a/src/relay/op/tensor/reduce.cc
+++ b/src/relay/op/tensor/reduce.cc
@@ -338,7 +338,7 @@ bool GenericReduceRel(const Array<Type>& types, int 
num_inputs, const Attrs& att
 
   // assign output type and shape
   auto oshape = ReduceShapeImpl(in_shape, param, reporter);
-  reporter->Assign(types[1], TensorType(oshape, DataType::Int(32)));
+  reporter->Assign(types[1], TensorType(oshape, data->shape[0].dtype()));
   return true;
 }
 /*!
diff --git a/tests/python/relay/test_type_infer.py 
b/tests/python/relay/test_type_infer.py
index af64ce714d..b0b7ef0481 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -23,6 +23,8 @@ from tvm import IRModule, parser, relay, te
 from tvm.relay import analysis, op, transform
 from tvm.relay.op import op as _op
 
+import numpy as np
+
 
 def infer_mod(mod, annotate_spans=True):
     if annotate_spans:
@@ -544,6 +546,42 @@ def test_repeat_register():
         assert "Operator custom_log3 is registered before" in str(cm.execption)
 
 
+def test_argreduce_infer_return_type():
+    x_shape = (1, 1)
+    broadcast_shape = [1, 1]
+    shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: 
np.int64(x))]
+
+    # Testing with argmax
+    for (sdtype, conv) in shape_dtypes:
+        x = relay.var("data", relay.TensorType(x_shape, "float32"))
+        broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, 
dtype=sdtype))
+        argmax = relay.op.argmax(broadcast_to, axis=[1])
+
+        f = relay.Function([x], argmax)
+        assert_has_type(
+            f,
+            relay.FuncType(
+                [relay.TensorType(broadcast_shape, "float32")],
+                relay.TensorType([conv(1)], dtype=sdtype),
+            ),
+        )
+
+    # Testing with argmin
+    for (sdtype, conv) in shape_dtypes:
+        x = relay.var("data", relay.TensorType(x_shape, "float32"))
+        broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, 
dtype=sdtype))
+        argmin = relay.op.argmin(broadcast_to, axis=[1])
+
+        f = relay.Function([x], argmin)
+        assert_has_type(
+            f,
+            relay.FuncType(
+                [relay.TensorType(broadcast_shape, "float32")],
+                relay.TensorType([conv(1)], dtype=sdtype),
+            ),
+        )
+
+
 if __name__ == "__main__":
     import sys
 

Reply via email to