This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 959cff1  [Relay] Fix interpreter for dyanmic shape input of 
ndarray_size (#6086)
959cff1 is described below

commit 959cff1c786e0eb33b99007be66de61d2275d7a5
Author: lixiaoquan <radiohe...@163.com>
AuthorDate: Sat Jul 25 23:16:06 2020 +0800

    [Relay] Fix interpreter for dyanmic shape input of ndarray_size (#6086)
---
 src/relay/backend/interpreter.cc | 14 ++------------
 tests/python/relay/test_any.py   | 22 ++++++++++++++++++++--
 2 files changed, 22 insertions(+), 14 deletions(-)

diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 9a75c0a..08c5a7c 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -213,11 +213,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
                     PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> 
{
  public:
   Interpreter(IRModule mod, DLContext context, Target target)
-      : mod_(mod),
-        context_(context),
-        target_(target),
-        debug_op_(Op::Get("debug")),
-        shape_of_op_(Op::Get("shape_of")) {
+      : mod_(mod), context_(context), target_(target), 
debug_op_(Op::Get("debug")) {
     engine_ = CompileEngine::Global();
   }
 
@@ -481,12 +477,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
 
     Array<Shape> out_shapes;
     auto ret_type = func->body->checked_type();
-    bool is_dyn = IsDynamic(func->checked_type());
-    if (call_node->op == shape_of_op_) {
-      // The output shape of shape_of must be static since Relay doesn't 
support
-      // dynamic rank tensors.
-      is_dyn = false;
-    }
+    bool is_dyn = IsDynamic(ret_type);
 
     if (is_dyn) {
       CHECK(func->HasNonzeroAttr(attr::kPrimitive));
@@ -722,7 +713,6 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
   CompileEngine engine_;
   // Cache ops that need to be frequently used later to reduce lookup overhead.
   const Op& debug_op_;
-  const Op& shape_of_op_;
 };
 
 TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, DLContext 
context, Target target) {
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index bf28ee1..0e8a328 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -814,7 +814,7 @@ def test_mixed_input_type():
         assert result.asnumpy().shape == ref_out_shape, \
             "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), 
str(result.asnumpy().shape))
 
-def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, 
crop_size, 
+def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, 
crop_size,
                                layout, static_boxes, static_box_indices_shape, 
ref_out_shape):
     mod = tvm.IRModule()
     dtype = "float32"
@@ -872,6 +872,24 @@ def test_any_mirror_pad():
         static_data_shape=(1, 256, 232, 232),
         ref_out_shape=(1, 256, 234, 234))
 
+def verify_any_ndarray_size(data_np_shape):
+    v = relay.var("v", shape=any_dims(len(data_np_shape)), dtype='float32')
+    n = relay.ndarray_size(v, dtype='int32')
+    mod = tvm.IRModule()
+    mod['main'] = relay.Function([v], n)
+    np_data = np.zeros(data_np_shape, dtype='float32')
+    ref_res = np.size(np_data)
+
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(np_data)
+        tvm.testing.assert_allclose(result.asnumpy(), ref_res)
+
+def test_any_ndarray_size():
+    verify_any_ndarray_size((2,))
+    verify_any_ndarray_size((2, 2))
+    verify_any_ndarray_size((1, 2, 3, 4))
+
 if __name__ == "__main__":
     test_any_full()
     test_any_full_like()
@@ -908,4 +926,4 @@ if __name__ == "__main__":
     test_mixed_input_type()
     test_any_crop_and_resize()
     test_any_mirror_pad()
-
+    test_any_ndarray_size()

Reply via email to