This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 460374fed5 [TOPI] Support symbolic shape in einsum (#14521)
460374fed5 is described below

commit 460374fed5f805a0ceb2b0a8de07c190ab90d83f
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Apr 7 09:26:01 2023 -0700

    [TOPI] Support symbolic shape in einsum (#14521)
    
    * [TOPI] Support symbolic shape in einsum
    
    * Update test_topi_einsum.py
---
 src/topi/einsum.cc                           | 31 +++++++++++------
 tests/python/topi/python/test_topi_einsum.py | 52 ++++++++++++++++++++++------
 2 files changed, 62 insertions(+), 21 deletions(-)

diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc
index 892a17e58d..3e9cd358e9 100644
--- a/src/topi/einsum.cc
+++ b/src/topi/einsum.cc
@@ -98,24 +98,33 @@ EinsumEquation EinsumEquation::FromString(const 
std::string& equation) {
 }
 
 PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr& 
extent2) {
-  int64_t extent1_value = GetConstInt(extent1);
-  int64_t extent2_value = GetConstInt(extent2);
-  if (extent1_value == extent2_value) {
+  const IntImmNode* extent1_imm = extent1.as<IntImmNode>();
+  const IntImmNode* extent2_imm = extent2.as<IntImmNode>();
+  if (extent1_imm != nullptr && extent2_imm != nullptr) {
+    if (extent1_imm->value == extent2_imm->value) {
+      return extent1;
+    } else if (extent1_imm->value == 1 || extent2_imm->value == 1) {
+      return Integer(std::max(extent1_imm->value, extent2_imm->value));
+    }
+    LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2;
+    throw;
+  } else if (extent1_imm != nullptr) {
+    return extent2;
+  } else if (extent2_imm != nullptr) {
     return extent1;
-  } else if (extent1_value == 1 || extent2_value == 1) {
-    return Integer(std::max(extent1_value, extent2_value));
+  } else {
+    return max(extent1, extent2);
   }
-  LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2;
-  throw;
 }
 
 PrimExpr GetIndexForBroadcastedDim(const Var& index, const PrimExpr& extent,
                                    const PrimExpr& broadcasted_extent) {
-  if (GetConstInt(extent) == GetConstInt(broadcasted_extent)) {
-    return index;
-  } else {
-    return Integer(0);
+  // Check if current dimension is being broadcasted to `broadcasted_extent` 
(symbolic shape is
+  // handled)
+  if (is_one(extent) && !is_one(broadcasted_extent)) {
+    return make_zero(index.dtype());
   }
+  return index;
 }
 
 /*! \brief The compute builder for Einsum */
diff --git a/tests/python/topi/python/test_topi_einsum.py 
b/tests/python/topi/python/test_topi_einsum.py
index d6dc43e4da..a84cbaffc1 100644
--- a/tests/python/topi/python/test_topi_einsum.py
+++ b/tests/python/topi/python/test_topi_einsum.py
@@ -23,39 +23,59 @@ from tvm import topi
 from tvm.topi.utils import get_const_tuple
 
 
-def with_tvm(lam, *args):
+def with_tvm(lam, shapes, ops, out_shape):
     """Take numpy arrays as args, convert them to TVM tensors and call `lam`.
     Result of lambda is converted back to numpy array and returned.
     """
     dev = tvm.cpu(0)
     pls = []  # placeholders
     vals_nd = []  # initial values
-    for i, arg in enumerate(args):
-        pls.append(te.placeholder(arg.shape, name="pl" + str(i)))
+    for i, (shape, arg) in enumerate(zip(shapes, ops)):
+        pls.append(te.placeholder(shape, name="pl" + str(i)))
         vals_nd.append(tvm.nd.array(arg, dev))
 
     out = lam(*pls)
-    out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape), 
dtype=out.dtype), dev)
+    out_nd = tvm.nd.array(np.zeros(out_shape).astype(out.dtype), device=dev)
     s = te.create_schedule([out.op])
     m = tvm.build(s, pls + [out], "llvm")
     m(*(vals_nd + [out_nd]))
     return out_nd.numpy()
 
 
-def verify_einsum(subscripts, shapes):
-    ops = []
+def verify_einsum(subscripts, shapes, shape_dict={}):
+    ops = []  # ndarrays to be used as inputs
+    symbolic_shapes = []  # shapes to declare the placeholders
+    name_to_var = {}
+
+    def get_concrete_shape(shape):
+        return [shape_dict[s] if isinstance(s, str) else s for s in shape]
+
+    def get_symblic_shape_var(name, dtype="int32"):
+        if name not in name_to_var:
+            name_to_var[name] = te.var(name, dtype=dtype)
+        return name_to_var[name]
+
+    def get_symbolic_shape(shape):
+        return [get_symblic_shape_var(s) if isinstance(s, str) else s for s in 
shape]
+
     for shape in shapes:
-        tmp = np.random.uniform(low=-1.0, high=1.0, 
size=shape).astype(np.float32)
+        concrete_shape = get_concrete_shape(shape)
+        tmp = np.random.uniform(low=-1.0, high=1.0, 
size=concrete_shape).astype(np.float32)
         ops.append(tmp)
+        symbolic_shape = get_symbolic_shape(shape)
+        symbolic_shapes.append(symbolic_shape)
 
     c1 = np.einsum(subscripts, *ops)
+    out_shape = c1.shape
 
     if len(ops) == 1:
-        c2 = with_tvm(lambda A: topi.einsum(subscripts, A), *ops)
+        c2 = with_tvm(lambda A: topi.einsum(subscripts, A), symbolic_shapes, 
ops, out_shape)
     elif len(ops) == 2:
-        c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), *ops)
+        c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), 
symbolic_shapes, ops, out_shape)
     elif len(ops) == 3:
-        c2 = with_tvm(lambda A, B, C: topi.einsum(subscripts, A, B, C), *ops)
+        c2 = with_tvm(
+            lambda A, B, C: topi.einsum(subscripts, A, B, C), symbolic_shapes, 
ops, out_shape
+        )
 
     tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
 
@@ -82,5 +102,17 @@ def test_einsum(equation, inputs):
     verify_einsum(equation, inputs)
 
 
[email protected](
+    "equation,inputs,shape_dict",
+    [
+        ("ij,jk->ik", [(2, "K"), (1, "N")], {"K": 3, "N": 4}),
+        ("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 3}),
+        ("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 1}),
+    ],
+)
+def test_einsum_symblic_shape(equation, inputs, shape_dict):
+    verify_einsum(equation, inputs, shape_dict)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to