lazycal opened a new pull request #10106: URL: https://github.com/apache/tvm/pull/10106
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread. Previous logic for importing `where` only considers the shape of the tensors with the largest rank. This is incorrect, for instance when cond, x, and y are of the following shapes: [3,1], [2], [2]. The resulting shape should be [3,2], but original logic gives [3,1]. Below is the code snippet to reproduce. ```python from tvm import relay import numpy as np import onnx from onnx import TensorProto, helper, mapping, numpy_helper def get_onnx_model(condition, x, y): outdata = np.where(condition, x, y) dtype = TensorProto.FLOAT where_inputs = ["cond", "x", "y"] node = helper.make_node("Where", inputs=where_inputs, outputs=["out"]) node_list = [node] graph = helper.make_graph( node_list, "where_test", inputs=[ helper.make_tensor_value_info( "cond", TensorProto.BOOL, list(condition.shape)), helper.make_tensor_value_info("x", dtype, list(x.shape)), helper.make_tensor_value_info("y", dtype, list(y.shape)), ], outputs=[helper.make_tensor_value_info( "out", dtype, list(outdata.shape))], ) model = helper.make_model(graph, producer_name="where_test") return model def main(): condition = np.random.uniform(size=(3, 1)) < 0.5 x = np.random.uniform(size=(2,)).astype(np.float32) y = np.random.uniform(size=(2,)).astype(np.float32) model = get_onnx_model(condition, x, y) mod, params = relay.frontend.from_onnx(model, freeze_params=True) res = relay.build_module.create_executor('graph', mod).evaluate()( **{'cond': condition, 'x': x, 'y': y}) assert np.allclose(res.asnumpy(), np.where( condition, x, y), rtol=0, atol=0) main() ``` This PR simply delegates the broadcast logic to `relay.where`, instead of handling during import. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
