elvin-n commented on code in PR #14010:
URL: https://github.com/apache/tvm/pull/14010#discussion_r1124308155
##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -21,7 +22,88 @@
from tvm.driver.tvmc import TVMCException
-def convert_graph_layout(mod, desired_layout):
+def generate_mixed_precision_rule(acc_dtype):
+ def _mixed_precision_rule(call_node: "relay.Call", mixed_precision_type:
str):
+ return [
+ relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
+ acc_dtype,
+ mixed_precision_type,
+ ]
+
+ return _mixed_precision_rule
+
+
+class MixedPrecision(object):
+ """Temporarily changes attr of ops to enable required precision."""
+
+ def __init__(self, ops, acc_type):
+ """Saves the required info for RAII pattern usage.
+
+ Parameters
+ ----------
+ ops : list
+ list of operators
+ acc_type: str
+ Output or accumulation precision to be used.
+ """
+ self.older_attr = {}
+ self.ops = ops
+ self.acc_type = acc_type
+ self.attr_key = "FTVMMixedPrecisionConversionType"
+
+ def __enter__(self):
+ for op_name in self.ops:
+ op = relay.op.get(op_name)
+ self.older_attr[op_name] = op.get_attr(self.attr_key)
+ op.reset_attr(self.attr_key)
+ op.set_attr(self.attr_key,
generate_mixed_precision_rule(self.acc_type))
+ return self
+
+ def __exit__(self, ptype, value, trace):
+ for op_name in self.ops:
+ op = relay.op.get(op_name)
+ op.reset_attr(self.attr_key)
+ if self.older_attr[op_name]:
+ op.set_attr(self.attr_key, self.older_attr[op_name])
+
+
+def convert_to_mixed_precision(mod, ops=None, calculation_type="float16",
acc_type="float16"):
+ """Converts the operator datatypes
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+ The relay module to convert.
+ ops : list
+ List of operators to be precision converted.
+ calculation_type: str
Review Comment:
this parameter is not used. need to remove from parameters of
`convert_to_mixed_precision` or pass it to `ToMixedPrecision`
##########
python/tvm/driver/tvmc/autotuner.py:
##########
@@ -376,12 +375,28 @@ def tune_model(
If using the autoscheduler, write the estimated latency at each step
of tuning to file.
additional_target_options: Optional[Dict[str, Dict[str, Any]]]
Additional target options in a dictionary to combine with initial
Target arguments
+ desired_layout: str, optional
+ Can be one of "NCHW" or "NHWC". When specified, compatible operations
in the graph
+ will have their layout set to this format. Tasks will then be tuned
using this
+ specified layout.
+ desired_layout_ops: list[str], optional
Review Comment:
could you please point where it is used, was not able to find its usage in PR
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]