Mousius commented on code in PR #14010:
URL: https://github.com/apache/tvm/pull/14010#discussion_r1118604801


##########
python/tvm/driver/tvmc/autotuner.py:
##########
@@ -354,10 +351,8 @@ def tune_model(
         Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other 
targets.
     early_stopping : int, optional
         When specified, stop tuning after this number of trials if results 
aren't improving.
-    desired_layout : str, optional

Review Comment:
   Can we keep the arguments separate in these APIs? The individual 
documentation for the scripting interface is helpful.



##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -13,15 +13,97 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language
+# pylint: disable=unused-argument
 """
 TVMC Graph Transforms
 """
 
 from tvm import relay, transform
 from tvm.driver.tvmc import TVMCException
 
+# ToMixedPrecision
+ACC_DTYPE = "float32"
 
-def convert_graph_layout(mod, desired_layout):
+
+def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
+    global ACC_DTYPE
+    return [
+        relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
+        ACC_DTYPE,
+        mixed_precision_type,
+    ]
+
+
+class MixedPrecision(object):
+    """Temporarily changes attr of ops to enable required precision."""
+
+    def __init__(self, ops):
+        """Saves the required info for RAII pattern usage.
+
+        Parameters
+        ----------
+        ops : list
+            list of operators
+        """
+        self.older_attr = {}
+        self.ops = ops
+        self.attr_key = "FTVMMixedPrecisionConversionType"
+
+    def __enter__(self):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            self.older_attr[op_name] = op.get_attr(self.attr_key)
+            op.reset_attr(self.attr_key)
+            op.set_attr(self.attr_key, mixed_precision_rule)
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            op.reset_attr(self.attr_key)
+            if self.older_attr[op_name]:
+                op.set_attr(self.attr_key, self.older_attr[op_name])
+
+
+def convert_to_mixed_precision(
+    mod, ops="nn.conv2d,nn.dense", input_type="float16", out_type="float16"
+):
+    """Converts the operator datatypes
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The relay module to convert.
+    ops : str
+        List of operators to be precision converted.
+    input_type: str
+        Input precision to be used.
+    output_type: str
+        Output or accumulation precision to be used.
+
+    Returns
+    -------
+    mod : tvm.IRModule
+        The converted module.
+    """
+
+    global ACC_DTYPE
+    ACC_DTYPE = out_type

Review Comment:
   Can we pass this into `MixedPrecision` and generate a rule from it instead 
of using a global? Something like:
   
   ```python
   def generate_rule(acc_dtype):
       def _mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: 
str):
           return [
               relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
               ACC_DTYPE,
               mixed_precision_type,
           ];
       return _mixed_precision_rule
   ```



##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -13,15 +13,97 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language
+# pylint: disable=unused-argument
 """
 TVMC Graph Transforms
 """
 
 from tvm import relay, transform
 from tvm.driver.tvmc import TVMCException
 
+# ToMixedPrecision
+ACC_DTYPE = "float32"
 
-def convert_graph_layout(mod, desired_layout):
+
+def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
+    global ACC_DTYPE

Review Comment:
   Why does this need to be marked as global? Doesn't it inherit from the outer 
scope anyway?



##########
python/tvm/driver/tvmc/compiler.py:
##########
@@ -260,10 +256,8 @@ def compile_model(
     target_host : str, optional
         The target of the host machine if host-side code
         needs to be generated.
-    desired_layout: str, optional

Review Comment:
   Similarly here, keep `desired_layout`, use array for `mixed_precision.ops` 😸 



##########
tests/python/driver/tvmc/test_transform.py:
##########
@@ -70,5 +70,55 @@ def 
test_layout_transform_convert_layout_pass_args(relay_conv2d, monkeypatch):
     )
 
 
+def test_layout_transform_to_mixed_precision_pass_args(relay_conv2d, 
monkeypatch):
+    """
+    Check the mixed precision arugments which are expected when
+    mixed precision arguments are provided.
+    """
+    mock_mixed_precision = MagicMock()
+    mock_mixed_precision.return_value = 
tvm.driver.tvmc.transform.MixedPrecision([])

Review Comment:
   This seems to do a good job of checking the graph transformation function is 
called, do we have any tests which call into `tvmc.transform.MixedPrecision` to 
test it behaves as we expect?



##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -13,15 +13,97 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language
+# pylint: disable=unused-argument
 """
 TVMC Graph Transforms
 """
 
 from tvm import relay, transform
 from tvm.driver.tvmc import TVMCException
 
+# ToMixedPrecision
+ACC_DTYPE = "float32"
 
-def convert_graph_layout(mod, desired_layout):
+
+def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
+    global ACC_DTYPE
+    return [
+        relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
+        ACC_DTYPE,
+        mixed_precision_type,
+    ]
+
+
+class MixedPrecision(object):
+    """Temporarily changes attr of ops to enable required precision."""
+
+    def __init__(self, ops):
+        """Saves the required info for RAII pattern usage.
+
+        Parameters
+        ----------
+        ops : list
+            list of operators
+        """
+        self.older_attr = {}
+        self.ops = ops
+        self.attr_key = "FTVMMixedPrecisionConversionType"
+
+    def __enter__(self):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            self.older_attr[op_name] = op.get_attr(self.attr_key)
+            op.reset_attr(self.attr_key)
+            op.set_attr(self.attr_key, mixed_precision_rule)
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            op.reset_attr(self.attr_key)
+            if self.older_attr[op_name]:
+                op.set_attr(self.attr_key, self.older_attr[op_name])
+
+
+def convert_to_mixed_precision(
+    mod, ops="nn.conv2d,nn.dense", input_type="float16", out_type="float16"
+):
+    """Converts the operator datatypes
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The relay module to convert.
+    ops : str
+        List of operators to be precision converted.
+    input_type: str
+        Input precision to be used.
+    output_type: str
+        Output or accumulation precision to be used.
+
+    Returns
+    -------
+    mod : tvm.IRModule
+        The converted module.
+    """
+
+    global ACC_DTYPE
+    ACC_DTYPE = out_type
+
+    with MixedPrecision(ops.split(",")):

Review Comment:
   Can we do this split somewhere in `drive_tune` / `drive_compile` ? 
   
   That way the scripting interface uses the more natural 
`mixed_precision_ops=["nn.conv2d", "nn.woof"]` rather than 
`mixed_precision_ops="nn.conv2d,nn.woof"`



##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -58,3 +136,105 @@ def convert_graph_layout(mod, desired_layout):
         return seq(mod)
     except Exception as err:
         raise TVMCException("Error converting layout to {0}: 
{1}".format(desired_layout, str(err)))
+
+
+def apply_graph_transforms(mod, args):
+    """Alter the layout of the input graph.
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The relay module to convert.
+    args : dict
+        The transform arguments.
+
+    Returns
+    -------
+    mod : tvm.IRModule
+        The converted module.
+    """
+    if not args:
+        return mod
+
+    # AlterLayout
+    if args.get("desired_layout", False):
+        mod = convert_graph_layout(mod, args["desired_layout"])
+
+    # ToMixedPrecision
+    if args.get("mixed_precision", False):
+        mod = convert_to_mixed_precision(
+            mod,
+            args.get("mixed_precision_ops", "nn.conv2d,nn.dense"),
+            args.get("mixed_precision_input", "float16"),
+            args.get("mixed_precision_output", "float16"),

Review Comment:
   Can we not rely on the arguments defaults from the initial call? It seems 
odd to duplicate them here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to