This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 460374fed5 [TOPI] Support symbolic shape in einsum (#14521)
460374fed5 is described below
commit 460374fed5f805a0ceb2b0a8de07c190ab90d83f
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Apr 7 09:26:01 2023 -0700
[TOPI] Support symbolic shape in einsum (#14521)
* [TOPI] Support symbolic shape in einsum
* Update test_topi_einsum.py
---
src/topi/einsum.cc | 31 +++++++++++------
tests/python/topi/python/test_topi_einsum.py | 52 ++++++++++++++++++++++------
2 files changed, 62 insertions(+), 21 deletions(-)
diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc
index 892a17e58d..3e9cd358e9 100644
--- a/src/topi/einsum.cc
+++ b/src/topi/einsum.cc
@@ -98,24 +98,33 @@ EinsumEquation EinsumEquation::FromString(const
std::string& equation) {
}
PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr&
extent2) {
- int64_t extent1_value = GetConstInt(extent1);
- int64_t extent2_value = GetConstInt(extent2);
- if (extent1_value == extent2_value) {
+ const IntImmNode* extent1_imm = extent1.as<IntImmNode>();
+ const IntImmNode* extent2_imm = extent2.as<IntImmNode>();
+ if (extent1_imm != nullptr && extent2_imm != nullptr) {
+ if (extent1_imm->value == extent2_imm->value) {
+ return extent1;
+ } else if (extent1_imm->value == 1 || extent2_imm->value == 1) {
+ return Integer(std::max(extent1_imm->value, extent2_imm->value));
+ }
+ LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2;
+ throw;
+ } else if (extent1_imm != nullptr) {
+ return extent2;
+ } else if (extent2_imm != nullptr) {
return extent1;
- } else if (extent1_value == 1 || extent2_value == 1) {
- return Integer(std::max(extent1_value, extent2_value));
+ } else {
+ return max(extent1, extent2);
}
- LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2;
- throw;
}
PrimExpr GetIndexForBroadcastedDim(const Var& index, const PrimExpr& extent,
const PrimExpr& broadcasted_extent) {
- if (GetConstInt(extent) == GetConstInt(broadcasted_extent)) {
- return index;
- } else {
- return Integer(0);
+ // Check if current dimension is being broadcasted to `broadcasted_extent`
(symbolic shape is
+ // handled)
+ if (is_one(extent) && !is_one(broadcasted_extent)) {
+ return make_zero(index.dtype());
}
+ return index;
}
/*! \brief The compute builder for Einsum */
diff --git a/tests/python/topi/python/test_topi_einsum.py
b/tests/python/topi/python/test_topi_einsum.py
index d6dc43e4da..a84cbaffc1 100644
--- a/tests/python/topi/python/test_topi_einsum.py
+++ b/tests/python/topi/python/test_topi_einsum.py
@@ -23,39 +23,59 @@ from tvm import topi
from tvm.topi.utils import get_const_tuple
-def with_tvm(lam, *args):
+def with_tvm(lam, shapes, ops, out_shape):
"""Take numpy arrays as args, convert them to TVM tensors and call `lam`.
Result of lambda is converted back to numpy array and returned.
"""
dev = tvm.cpu(0)
pls = [] # placeholders
vals_nd = [] # initial values
- for i, arg in enumerate(args):
- pls.append(te.placeholder(arg.shape, name="pl" + str(i)))
+ for i, (shape, arg) in enumerate(zip(shapes, ops)):
+ pls.append(te.placeholder(shape, name="pl" + str(i)))
vals_nd.append(tvm.nd.array(arg, dev))
out = lam(*pls)
- out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape),
dtype=out.dtype), dev)
+ out_nd = tvm.nd.array(np.zeros(out_shape).astype(out.dtype), device=dev)
s = te.create_schedule([out.op])
m = tvm.build(s, pls + [out], "llvm")
m(*(vals_nd + [out_nd]))
return out_nd.numpy()
-def verify_einsum(subscripts, shapes):
- ops = []
+def verify_einsum(subscripts, shapes, shape_dict={}):
+ ops = [] # ndarrays to be used as inputs
+ symbolic_shapes = [] # shapes to declare the placeholders
+ name_to_var = {}
+
+ def get_concrete_shape(shape):
+ return [shape_dict[s] if isinstance(s, str) else s for s in shape]
+
+ def get_symblic_shape_var(name, dtype="int32"):
+ if name not in name_to_var:
+ name_to_var[name] = te.var(name, dtype=dtype)
+ return name_to_var[name]
+
+ def get_symbolic_shape(shape):
+ return [get_symblic_shape_var(s) if isinstance(s, str) else s for s in
shape]
+
for shape in shapes:
- tmp = np.random.uniform(low=-1.0, high=1.0,
size=shape).astype(np.float32)
+ concrete_shape = get_concrete_shape(shape)
+ tmp = np.random.uniform(low=-1.0, high=1.0,
size=concrete_shape).astype(np.float32)
ops.append(tmp)
+ symbolic_shape = get_symbolic_shape(shape)
+ symbolic_shapes.append(symbolic_shape)
c1 = np.einsum(subscripts, *ops)
+ out_shape = c1.shape
if len(ops) == 1:
- c2 = with_tvm(lambda A: topi.einsum(subscripts, A), *ops)
+ c2 = with_tvm(lambda A: topi.einsum(subscripts, A), symbolic_shapes,
ops, out_shape)
elif len(ops) == 2:
- c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), *ops)
+ c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B),
symbolic_shapes, ops, out_shape)
elif len(ops) == 3:
- c2 = with_tvm(lambda A, B, C: topi.einsum(subscripts, A, B, C), *ops)
+ c2 = with_tvm(
+ lambda A, B, C: topi.einsum(subscripts, A, B, C), symbolic_shapes,
ops, out_shape
+ )
tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
@@ -82,5 +102,17 @@ def test_einsum(equation, inputs):
verify_einsum(equation, inputs)
[email protected](
+ "equation,inputs,shape_dict",
+ [
+ ("ij,jk->ik", [(2, "K"), (1, "N")], {"K": 3, "N": 4}),
+ ("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 3}),
+ ("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 1}),
+ ],
+)
+def test_einsum_symblic_shape(equation, inputs, shape_dict):
+ verify_einsum(equation, inputs, shape_dict)
+
+
if __name__ == "__main__":
tvm.testing.main()