lazycal opened a new pull request #10106:
URL: https://github.com/apache/tvm/pull/10106


   Thanks for contributing to TVM!   Please refer to guideline 
https://tvm.apache.org/docs/contribute/ for useful information and tips. After 
the pull request is submitted, please request code reviews from 
[Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers)
 by @ them in the pull request thread.
   
   
   Previous logic for importing `where` only considers the shape of the tensors 
with the largest rank. This is incorrect, for instance when cond, x, and y are 
of the following shapes: [3,1], [2], [2]. The resulting shape should be [3,2], 
but original logic gives [3,1]. Below is the code snippet to reproduce.
   
   ```python
   from tvm import relay
   import numpy as np
   import onnx
   from onnx import TensorProto, helper, mapping, numpy_helper
   
   
   def get_onnx_model(condition, x, y):
       outdata = np.where(condition, x, y)
       dtype = TensorProto.FLOAT
       where_inputs = ["cond", "x", "y"]
       node = helper.make_node("Where", inputs=where_inputs, outputs=["out"])
       node_list = [node]
       graph = helper.make_graph(
           node_list,
           "where_test",
           inputs=[
               helper.make_tensor_value_info(
                   "cond", TensorProto.BOOL, list(condition.shape)),
               helper.make_tensor_value_info("x", dtype, list(x.shape)),
               helper.make_tensor_value_info("y", dtype, list(y.shape)),
           ],
           outputs=[helper.make_tensor_value_info(
               "out", dtype, list(outdata.shape))],
       )
       model = helper.make_model(graph, producer_name="where_test")
       return model
   
   
   def main():
       condition = np.random.uniform(size=(3, 1)) < 0.5
       x = np.random.uniform(size=(2,)).astype(np.float32)
       y = np.random.uniform(size=(2,)).astype(np.float32)
       model = get_onnx_model(condition, x, y)
       mod, params = relay.frontend.from_onnx(model, freeze_params=True)
   
       res = relay.build_module.create_executor('graph', mod).evaluate()(
           **{'cond': condition, 'x': x, 'y': y})
       assert np.allclose(res.asnumpy(), np.where(
           condition, x, y), rtol=0, atol=0)
   
   
   main()
   ```
   
   This PR simply delegates the broadcast logic to `relay.where`, instead of 
handling during import.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to