Mousius commented on code in PR #14010:
URL: https://github.com/apache/tvm/pull/14010#discussion_r1118604801
##########
python/tvm/driver/tvmc/autotuner.py:
##########
@@ -354,10 +351,8 @@ def tune_model(
Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other
targets.
early_stopping : int, optional
When specified, stop tuning after this number of trials if results
aren't improving.
- desired_layout : str, optional
Review Comment:
Can we keep the arguments separate in these APIs? The individual
documentation for the scripting interface is helpful.
##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -13,15 +13,97 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language
+# pylint: disable=unused-argument
"""
TVMC Graph Transforms
"""
from tvm import relay, transform
from tvm.driver.tvmc import TVMCException
+# ToMixedPrecision
+ACC_DTYPE = "float32"
-def convert_graph_layout(mod, desired_layout):
+
+def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
+ global ACC_DTYPE
+ return [
+ relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
+ ACC_DTYPE,
+ mixed_precision_type,
+ ]
+
+
+class MixedPrecision(object):
+ """Temporarily changes attr of ops to enable required precision."""
+
+ def __init__(self, ops):
+ """Saves the required info for RAII pattern usage.
+
+ Parameters
+ ----------
+ ops : list
+ list of operators
+ """
+ self.older_attr = {}
+ self.ops = ops
+ self.attr_key = "FTVMMixedPrecisionConversionType"
+
+ def __enter__(self):
+ for op_name in self.ops:
+ op = relay.op.get(op_name)
+ self.older_attr[op_name] = op.get_attr(self.attr_key)
+ op.reset_attr(self.attr_key)
+ op.set_attr(self.attr_key, mixed_precision_rule)
+ return self
+
+ def __exit__(self, ptype, value, trace):
+ for op_name in self.ops:
+ op = relay.op.get(op_name)
+ op.reset_attr(self.attr_key)
+ if self.older_attr[op_name]:
+ op.set_attr(self.attr_key, self.older_attr[op_name])
+
+
+def convert_to_mixed_precision(
+ mod, ops="nn.conv2d,nn.dense", input_type="float16", out_type="float16"
+):
+ """Converts the operator datatypes
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+ The relay module to convert.
+ ops : str
+ List of operators to be precision converted.
+ input_type: str
+ Input precision to be used.
+ output_type: str
+ Output or accumulation precision to be used.
+
+ Returns
+ -------
+ mod : tvm.IRModule
+ The converted module.
+ """
+
+ global ACC_DTYPE
+ ACC_DTYPE = out_type
Review Comment:
Can we pass this into `MixedPrecision` and generate a rule from it instead
of using a global? Something like:
```python
def generate_rule(acc_dtype):
def _mixed_precision_rule(call_node: "relay.Call", mixed_precision_type:
str):
return [
relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
ACC_DTYPE,
mixed_precision_type,
];
return _mixed_precision_rule
```
##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -13,15 +13,97 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language
+# pylint: disable=unused-argument
"""
TVMC Graph Transforms
"""
from tvm import relay, transform
from tvm.driver.tvmc import TVMCException
+# ToMixedPrecision
+ACC_DTYPE = "float32"
-def convert_graph_layout(mod, desired_layout):
+
+def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
+ global ACC_DTYPE
Review Comment:
Why does this need to be marked as global? Doesn't it inherit from the outer
scope anyway?
##########
python/tvm/driver/tvmc/compiler.py:
##########
@@ -260,10 +256,8 @@ def compile_model(
target_host : str, optional
The target of the host machine if host-side code
needs to be generated.
- desired_layout: str, optional
Review Comment:
Similarly here, keep `desired_layout`, use array for `mixed_precision.ops` 😸
##########
tests/python/driver/tvmc/test_transform.py:
##########
@@ -70,5 +70,55 @@ def
test_layout_transform_convert_layout_pass_args(relay_conv2d, monkeypatch):
)
+def test_layout_transform_to_mixed_precision_pass_args(relay_conv2d,
monkeypatch):
+ """
+ Check the mixed precision arugments which are expected when
+ mixed precision arguments are provided.
+ """
+ mock_mixed_precision = MagicMock()
+ mock_mixed_precision.return_value =
tvm.driver.tvmc.transform.MixedPrecision([])
Review Comment:
This seems to do a good job of checking the graph transformation function is
called, do we have any tests which call into `tvmc.transform.MixedPrecision` to
test it behaves as we expect?
##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -13,15 +13,97 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language
+# pylint: disable=unused-argument
"""
TVMC Graph Transforms
"""
from tvm import relay, transform
from tvm.driver.tvmc import TVMCException
+# ToMixedPrecision
+ACC_DTYPE = "float32"
-def convert_graph_layout(mod, desired_layout):
+
+def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
+ global ACC_DTYPE
+ return [
+ relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
+ ACC_DTYPE,
+ mixed_precision_type,
+ ]
+
+
+class MixedPrecision(object):
+ """Temporarily changes attr of ops to enable required precision."""
+
+ def __init__(self, ops):
+ """Saves the required info for RAII pattern usage.
+
+ Parameters
+ ----------
+ ops : list
+ list of operators
+ """
+ self.older_attr = {}
+ self.ops = ops
+ self.attr_key = "FTVMMixedPrecisionConversionType"
+
+ def __enter__(self):
+ for op_name in self.ops:
+ op = relay.op.get(op_name)
+ self.older_attr[op_name] = op.get_attr(self.attr_key)
+ op.reset_attr(self.attr_key)
+ op.set_attr(self.attr_key, mixed_precision_rule)
+ return self
+
+ def __exit__(self, ptype, value, trace):
+ for op_name in self.ops:
+ op = relay.op.get(op_name)
+ op.reset_attr(self.attr_key)
+ if self.older_attr[op_name]:
+ op.set_attr(self.attr_key, self.older_attr[op_name])
+
+
+def convert_to_mixed_precision(
+ mod, ops="nn.conv2d,nn.dense", input_type="float16", out_type="float16"
+):
+ """Converts the operator datatypes
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+ The relay module to convert.
+ ops : str
+ List of operators to be precision converted.
+ input_type: str
+ Input precision to be used.
+ output_type: str
+ Output or accumulation precision to be used.
+
+ Returns
+ -------
+ mod : tvm.IRModule
+ The converted module.
+ """
+
+ global ACC_DTYPE
+ ACC_DTYPE = out_type
+
+ with MixedPrecision(ops.split(",")):
Review Comment:
Can we do this split somewhere in `drive_tune` / `drive_compile` ?
That way the scripting interface uses the more natural
`mixed_precision_ops=["nn.conv2d", "nn.woof"]` rather than
`mixed_precision_ops="nn.conv2d,nn.woof"`
##########
python/tvm/driver/tvmc/transform.py:
##########
@@ -58,3 +136,105 @@ def convert_graph_layout(mod, desired_layout):
return seq(mod)
except Exception as err:
raise TVMCException("Error converting layout to {0}:
{1}".format(desired_layout, str(err)))
+
+
+def apply_graph_transforms(mod, args):
+ """Alter the layout of the input graph.
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+ The relay module to convert.
+ args : dict
+ The transform arguments.
+
+ Returns
+ -------
+ mod : tvm.IRModule
+ The converted module.
+ """
+ if not args:
+ return mod
+
+ # AlterLayout
+ if args.get("desired_layout", False):
+ mod = convert_graph_layout(mod, args["desired_layout"])
+
+ # ToMixedPrecision
+ if args.get("mixed_precision", False):
+ mod = convert_to_mixed_precision(
+ mod,
+ args.get("mixed_precision_ops", "nn.conv2d,nn.dense"),
+ args.get("mixed_precision_input", "float16"),
+ args.get("mixed_precision_output", "float16"),
Review Comment:
Can we not rely on the arguments defaults from the initial call? It seems
odd to duplicate them here.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]