leandron commented on a change in pull request #7366:
URL: https://github.com/apache/tvm/pull/7366#discussion_r567214252



##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,46 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_shape_string(inputs):
+    """Parse an input shape dictionary string to a usable dictionary.
+
+    Parameters
+    ----------
+    inputs: str
+        A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" 
that indicates
+        the desired shape for specific model inputs.
+
+    Returns
+    -------
+    shape_dict: dict
+        A dictionary mapping input names to their shape for use in relay 
frontend converters.
+    """
+    inputs = inputs.replace(" ", "")
+    # Check if the passed input is in the proper format.
+    valid_pattern = 
re.compile("(\w+:(\d+(x|X))*(\d)+)(,(\w+:(\d+(x|X))*(\d)+))*")
+    result = re.fullmatch(valid_pattern, inputs)
+    if result is None:
+        raise argparse.ArgumentTypeError(
+            "--shapes argument must be of the form 
'input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2"
+        )
+    d = {}
+    # Break apart each specific input string
+    inputs = inputs.split(",")
+    for string in inputs:
+        # Split name from shape string.
+        string = string.split(":")
+        shapelist = []
+        # Separate each dimension in the shape.
+        string[1] = string[1].lower().split("x")
+        # Parse each dimension into an integer.
+        for x in string[1]:
+            x = int(x)
+            # Negative numbers are converted to dynamic axes.
+            if x < 0:
+                x = relay.Any()
+            shapelist.append(x)
+        # Assign dictionary key value pair.
+        d[string[0]] = shapelist
+    return d

Review comment:
       I suggest renaming `d` to be `shape_dict`, according to the doctoring on 
this function.

##########
File path: tests/python/driver/tvmc/test_common.py
##########
@@ -149,3 +149,27 @@ def 
test_tracker_host_port_from_cli__only_hostname__default_port_is_9090():
 
     assert expected_host == actual_host
     assert expected_port == actual_port
+
+
+def test_shape_parser():
+    # Check that a valid input is parsed correctly
+    shape_string = "input:10x10x10"
+    shape_dict = tvmc.common.parse_shape_string(shape_string)
+    assert shape_dict == {"input": [10, 10, 10]}
+    # Check that multiple valid input shapes are parse correctly
+    shape_string = "input:10x10x10,input2:20x20x20x20"
+    shape_dict = tvmc.common.parse_shape_string(shape_string)
+    assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
+    # Check that alternate syntax parses correctly
+    shape_string = "input:10X10X10, input2:20X20X20X20"
+    shape_dict = tvmc.common.parse_shape_string(shape_string)
+    assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
+
+    # Check that invalid pattern raises expected error.
+    shape_string = "input:ax10"
+    with pytest.raises(argparse.ArgumentTypeError):
+        tvmc.common.parse_shape_string(shape_string)
+    # Check that input with invalid separators raises error.
+    shape_string = "input:5,10 input2:10,10"
+    with pytest.raises(argparse.ArgumentTypeError):
+        tvmc.common.parse_shape_string(shape_string)

Review comment:
       I suggest splitting these test cases above in specific unit test - I see 
5 self-contained unit tests here.

##########
File path: tests/python/driver/tvmc/test_compiler.py
##########
@@ -56,6 +53,15 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
     assert type(dumps) is dict
 
 
+def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
+    # Check default compilation.

Review comment:
       ```suggestion
   def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
       # some CI environments wont offer flute, so skip in case it is not 
present
        pytest.importorskip("tflite")
   ```
   
   Not all CI environments will offer `tflite`.

##########
File path: python/tvm/driver/tvmc/autotuner.py
##########
@@ -210,6 +210,13 @@ def add_tune_parser(subparsers):
     #     can be improved in future to add integration with a modelzoo
     #     or URL, for example.
     parser.add_argument("FILE", help="path to the input model file")
+    parser.add_argument(
+        "--shapes",

Review comment:
       Question: is the terminology `--shapes` specific enough for people to 
always infer they are "input shapes"? I'm asking because in the PR we 
previously had in discussion, we called it `--input-shapes`, and just want to 
hear what others think.
   
   cc @comaniac @ekalda 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to