zhiics commented on a change in pull request #6729:
URL: https://github.com/apache/incubator-tvm/pull/6729#discussion_r510308268
##########
File path: python/tvm/relay/op/_transform.py
##########
@@ -798,7 +801,26 @@ def _stack_shape_func(data_shape, axis, num_inputs):
@_reg.register_shape_func("stack", False)
def stack_shape_func(attrs, inputs, _):
+ """
+ Shape func for stack.
+ """
axis = get_const_int(attrs.axis)
if axis < 0:
axis += inputs[0].shape[0] + 1
return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))]
+
+
+@_reg.register_shape_func("where", False)
+def where_shape_func(attrs, inputs, _):
+ """
+ Shape func for where.
+ """
+ cond_shape = inputs[0]
+ x_shape = inputs[1]
+
Review comment:
could be just one-liner
```python
out_shape = x_shape if x_shape.shape else cond_shape
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]