elvin-n commented on code in PR #14010:
URL: https://github.com/apache/tvm/pull/14010#discussion_r1116173448


##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -13,15 +13,97 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language
+# pylint: disable=unused-argument
 """
 TVMC Graph Transforms
 """
 
 from tvm import relay, transform
 from tvm.driver.tvmc import TVMCException
 
+# ToMixedPrecision
+ACC_DTYPE = "float32"
 
-def convert_graph_layout(mod, desired_layout):
+
+def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
+    global ACC_DTYPE
+    return [
+        relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
+        ACC_DTYPE,
+        mixed_precision_type,
+    ]
+
+
+class MixedPrecision(object):
+    """Temporarily changes attr of ops to enable required precision."""
+
+    def __init__(self, ops):
+        """Saves the required info for RAII pattern usage.
+
+        Parameters
+        ----------
+        ops : list
+            list of operators
+        """
+        self.older_attr = {}
+        self.ops = ops
+        self.attr_key = "FTVMMixedPrecisionConversionType"
+
+    def __enter__(self):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            self.older_attr[op_name] = op.get_attr(self.attr_key)
+            op.reset_attr(self.attr_key)
+            op.set_attr(self.attr_key, mixed_precision_rule)
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            op.reset_attr(self.attr_key)
+            if self.older_attr[op_name]:
+                op.set_attr(self.attr_key, self.older_attr[op_name])
+
+
+def convert_to_mixed_precision(
+    mod, ops="nn.conv2d,nn.dense", input_type="float16", out_type="float16"
+):
+    """Converts the operator datatypes
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The relay module to convert.
+    ops : str
+        List of operators to be precision converted.
+    input_type: str
+        Input precision to be used.
+    output_type: str

Review Comment:
   out_type? mismatch with arg name passed into function?



##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -13,15 +13,97 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language
+# pylint: disable=unused-argument
 """
 TVMC Graph Transforms
 """
 
 from tvm import relay, transform
 from tvm.driver.tvmc import TVMCException
 
+# ToMixedPrecision
+ACC_DTYPE = "float32"
 
-def convert_graph_layout(mod, desired_layout):
+
+def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
+    global ACC_DTYPE
+    return [
+        relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
+        ACC_DTYPE,
+        mixed_precision_type,
+    ]
+
+
+class MixedPrecision(object):
+    """Temporarily changes attr of ops to enable required precision."""
+
+    def __init__(self, ops):
+        """Saves the required info for RAII pattern usage.
+
+        Parameters
+        ----------
+        ops : list
+            list of operators
+        """
+        self.older_attr = {}
+        self.ops = ops
+        self.attr_key = "FTVMMixedPrecisionConversionType"
+
+    def __enter__(self):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            self.older_attr[op_name] = op.get_attr(self.attr_key)
+            op.reset_attr(self.attr_key)
+            op.set_attr(self.attr_key, mixed_precision_rule)
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        for op_name in self.ops:
+            op = relay.op.get(op_name)
+            op.reset_attr(self.attr_key)
+            if self.older_attr[op_name]:
+                op.set_attr(self.attr_key, self.older_attr[op_name])
+
+
+def convert_to_mixed_precision(
+    mod, ops="nn.conv2d,nn.dense", input_type="float16", out_type="float16"

Review Comment:
   names of arguments `input_type` and `out_type` does not correspond to the 
meaning how you are using them.
   there are three places where data types can be pointed
   1. data type of calculation. It is pointed as a parameter of 
`ToMixedPrecision` pass. like `amp_mod = 
ToMixedPrecision(mixed_precision_dtype)(mod)`
   2. data type for accumulation. This is exactly how you are using 
output_type. At the same time this data type will never be kept in this 
precision. There will be always cast  to dtype
   3. data type for the latest op in a sequence or it's better to say the 
attribute pointing if we need to convert back to original data type. This is 
exactly pointed in `config={"relay.ToMixedPrecision.keep_orig_output_dtype": 
True}`
   
   I propose
   1. to rename input_type to something like calculatin_type and feed to 
`ToMixedPrecision` that is not implemented yet.
   2. to rename out_type to something like acc_type
   3. I do not propose to introduce any new argument for keeping of output in 
origin precision. It will confuse because really it does not say anything about 
output of the network, it will relate to the converted ops and will be very 
topology specific. I can assume that there are usage models for certain models 
and certain hardware where it might be required, but for 99,99% of people it 
will not be and will confuse by the non deterministic behaviour.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to