elvin-n commented on code in PR #14010:
URL: https://github.com/apache/tvm/pull/14010#discussion_r1124308155


##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -21,7 +22,88 @@
 from tvm.driver.tvmc import TVMCException
 
 
-def convert_graph_layout(mod, desired_layout):
+def generate_mixed_precision_rule(acc_dtype):
+    def _mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: 
str):
+        return [
+            relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
+            acc_dtype,
+            mixed_precision_type,
+        ]
+
+    return _mixed_precision_rule
+
+
+class MixedPrecision(object):
+    """Temporarily changes attr of ops to enable required precision."""
+
+    def __init__(self, ops, acc_type):
+        """Saves the required info for RAII pattern usage.
+
+        Parameters
+        ----------
+        ops : list
+            list of operators
+        acc_type: str
+            Output or accumulation precision to be used.
+        """
+        self.older_attr = {}
+        self.ops = ops
+        self.acc_type = acc_type
+        self.attr_key = "FTVMMixedPrecisionConversionType"
+
+    def __enter__(self):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            self.older_attr[op_name] = op.get_attr(self.attr_key)
+            op.reset_attr(self.attr_key)
+            op.set_attr(self.attr_key, 
generate_mixed_precision_rule(self.acc_type))
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            op.reset_attr(self.attr_key)
+            if self.older_attr[op_name]:
+                op.set_attr(self.attr_key, self.older_attr[op_name])
+
+
+def convert_to_mixed_precision(mod, ops=None, calculation_type="float16", 
acc_type="float16"):
+    """Converts the operator datatypes
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The relay module to convert.
+    ops : list
+        List of operators to be precision converted.
+    calculation_type: str

Review Comment:
   this parameter is not used. need to remove from parameters of 
`convert_to_mixed_precision` or pass it to `ToMixedPrecision`



##########
python/tvm/driver/tvmc/autotuner.py:
##########
@@ -376,12 +375,28 @@ def tune_model(
         If using the autoscheduler, write the estimated latency at each step 
of tuning to file.
     additional_target_options: Optional[Dict[str, Dict[str, Any]]]
         Additional target options in a dictionary to combine with initial 
Target arguments
+    desired_layout: str, optional
+        Can be one of "NCHW" or "NHWC". When specified, compatible operations 
in the graph
+        will have their layout set to this format. Tasks will then be tuned 
using this
+        specified layout.
+    desired_layout_ops: list[str], optional

Review Comment:
   could you please point where it is used, was not able to find its usage in PR



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to