jwfromm commented on a change in pull request #7366:
URL: https://github.com/apache/tvm/pull/7366#discussion_r568111211
##########
File path: tests/python/driver/tvmc/test_common.py
##########
@@ -149,3 +149,27 @@ def
test_tracker_host_port_from_cli__only_hostname__default_port_is_9090():
assert expected_host == actual_host
assert expected_port == actual_port
+
+
+def test_shape_parser():
+ # Check that a valid input is parsed correctly
+ shape_string = "input:10x10x10"
+ shape_dict = tvmc.common.parse_shape_string(shape_string)
+ assert shape_dict == {"input": [10, 10, 10]}
+ # Check that multiple valid input shapes are parse correctly
+ shape_string = "input:10x10x10,input2:20x20x20x20"
+ shape_dict = tvmc.common.parse_shape_string(shape_string)
+ assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
+ # Check that alternate syntax parses correctly
+ shape_string = "input:10X10X10, input2:20X20X20X20"
+ shape_dict = tvmc.common.parse_shape_string(shape_string)
+ assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
+
+ # Check that invalid pattern raises expected error.
+ shape_string = "input:ax10"
+ with pytest.raises(argparse.ArgumentTypeError):
+ tvmc.common.parse_shape_string(shape_string)
+ # Check that input with invalid separators raises error.
+ shape_string = "input:5,10 input2:10,10"
+ with pytest.raises(argparse.ArgumentTypeError):
+ tvmc.common.parse_shape_string(shape_string)
Review comment:
I actually really prefer having these grouped under a single test since
they're all testing functionality of the same simple function and breaking them
into separate tests would add a lot of bloat for little benefit. @comaniac,
which way do you prefer?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]