mbrookhart commented on pull request #7910:
URL: https://github.com/apache/tvm/pull/7910#issuecomment-825041882
The subgraph that's causing the where issues is this:
```
%p0: Tensor[(?), int64]
%p1: Tensor[(2), int64]
%0 = less(%p0, 0 /* ty=int64 */) /* ty=Tensor[(?), bool] */;
%1 = take(%p1, 0 /* ty=int32 */) /* ty=int64 */;
%2 = add(%p0, %1) /* ty=Tensor[(?), int64] */;
%3 = where(%0, %2, %p0) /* ty=Tensor[(?), int64] */;
```
And this unit test reproduces it:
```
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index ef02b6f10..7b21c01b2 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -1512,6 +1512,22 @@ def test_any_where():
any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4),
y_np_shape_invalid=(2, 4)
)
+ # Test scalar where in a dynamically shaped graph
+ x_np = np.random.randn(2).astype("int64")
+ y_np = np.random.randn(2, 6).astype("float32")
+ expected = y_np[:, 4]
+ x = relay.var("x", shape=any_dims(1), dtype="int64")
+ y = relay.var("y", shape=any_dims(2), dtype="float32")
+
+ left = relay.take(x, relay.const(1, dtype="int32")) + relay.const(4,
"int64")
+ right = relay.const(4, "int64")
+ where = relay.where(relay.const(False, "bool"), left, right)
+ z = relay.take(y, where, axis=1)
+
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([x, y], z)
+ check_result([x_np, y_np], mod, expected)
+
@tvm.testing.uses_gpu
def test_non_max_suppression():
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]