This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2bf3a0a428 [Unity][MSC][M3.1] Add distiller for distill model (#16264)
2bf3a0a428 is described below
commit 2bf3a0a4287069ac55ee3304c285b08592d3d1bc
Author: Archermmt <[email protected]>
AuthorDate: Tue Dec 26 18:53:26 2023 +0800
[Unity][MSC][M3.1] Add distiller for distill model (#16264)
* add distiller
* remove useless
---
python/tvm/contrib/msc/core/runtime/runner.py | 32 ++-
python/tvm/contrib/msc/core/tools/__init__.py | 1 +
.../tvm/tools => core/tools/distill}/__init__.py | 7 +-
.../contrib/msc/core/tools/distill/distiller.py | 261 +++++++++++++++++++++
.../tvm/contrib/msc/core/tools/distill/method.py | 72 ++++++
python/tvm/contrib/msc/core/utils/expr.py | 20 ++
python/tvm/contrib/msc/core/utils/file.py | 9 +-
.../msc/framework/tensorflow/tools/__init__.py | 1 +
.../tools => tensorflow/tools/distill}/__init__.py | 6 +-
.../tensorflow/tools/distill/distiller.py | 55 +++++
.../msc/framework/tensorrt/tools/__init__.py | 1 +
.../tools => tensorrt/tools/distill}/__init__.py | 6 +-
.../framework/tensorrt/tools/distill/distiller.py | 55 +++++
.../contrib/msc/framework/torch/tools/__init__.py | 1 +
.../{tvm/tools => torch/tools/distill}/__init__.py | 7 +-
.../msc/framework/torch/tools/distill/distiller.py | 144 ++++++++++++
.../msc/framework/torch/tools/distill/method.py | 116 +++++++++
.../contrib/msc/framework/tvm/tools/__init__.py | 1 +
.../framework/tvm/tools/{ => distill}/__init__.py | 6 +-
.../msc/framework/tvm/tools/distill/distiller.py | 55 +++++
python/tvm/contrib/msc/pipeline/manager.py | 4 +
tests/python/contrib/test_msc/test_tools.py | 34 ++-
22 files changed, 868 insertions(+), 26 deletions(-)
diff --git a/python/tvm/contrib/msc/core/runtime/runner.py
b/python/tvm/contrib/msc/core/runtime/runner.py
index 5228b06b10..4b84037994 100644
--- a/python/tvm/contrib/msc/core/runtime/runner.py
+++ b/python/tvm/contrib/msc/core/runtime/runner.py
@@ -26,7 +26,7 @@ import numpy as np
import tvm
from tvm.contrib.msc.core.ir import MSCGraph
from tvm.contrib.msc.core.frontend import from_relax
-from tvm.contrib.msc.core.tools import BaseTool, ToolType, create_tool,
remove_tools
+from tvm.contrib.msc.core.tools import BaseTool, ToolType, ToolScope,
create_tool, remove_tools
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core.utils.message import MSCStage
from tvm.contrib.msc.core import utils as msc_utils
@@ -180,9 +180,24 @@ class BaseRunner(object):
# Generate model
if not self._model:
- # Generate normal model
- self._graphs, self._weights = self.reset_tools(cache_dir=cache_dir)
- self._model = self._generate_model()
+ distiller = self.get_tool(ToolType.DISTILLER)
+ if distiller and not distiller.distilled:
+ build_root = self._generate_config["build_folder"]
+
+ def _build_scope_model(scope: str):
+ self._update_codegen({"tools_scope": scope})
+ self._generate_config["build_folder"] =
build_root.create_dir(scope)
+ return self._generate_model()
+
+ # Generate distill model
+ teacher_model = _build_scope_model(ToolScope.TEACHER)
+ self._graphs, self._weights =
self.reset_tools(cache_dir=cache_dir)
+ student_model = _build_scope_model(ToolScope.STUDENT)
+ self._model = distiller.build_model(teacher_model,
student_model)
+ else:
+ # Generate normal model
+ self._graphs, self._weights =
self.reset_tools(cache_dir=cache_dir)
+ self._model = self._generate_model()
# Log generate info
generate_msg = "Generate model({})".format(self.framework)
@@ -422,6 +437,15 @@ class BaseRunner(object):
self.run(inputs, ret_type="native")
quantizer.calibrate()
plan = quantizer.finalize()
+ elif tool_type == ToolType.DISTILLER:
+ distiller = self.get_tool(ToolType.DISTILLER)
+ while not distiller.distilled:
+ assert data_loader, "data_loader should be given to plan prune"
+ for inputs in data_loader():
+ loss = self.run(inputs, ret_type="native")
+ distiller.learn(loss)
+ distiller.distill()
+ plan = distiller.finalize()
else:
plan = self.get_tool(tool_type).finalize()
assert plan, "Failed to create plan for {}".format(tool_type)
diff --git a/python/tvm/contrib/msc/core/tools/__init__.py
b/python/tvm/contrib/msc/core/tools/__init__.py
index e97771cf6c..3c563fbea0 100644
--- a/python/tvm/contrib/msc/core/tools/__init__.py
+++ b/python/tvm/contrib/msc/core/tools/__init__.py
@@ -20,4 +20,5 @@ from .tool import *
from .execute import *
from .prune import *
from .quantize import *
+from .distill import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
b/python/tvm/contrib/msc/core/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/core/tools/distill/__init__.py
index ddfd41f3c8..8714eae4e4 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/core/tools/distill/__init__.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.core.tools.distill"""
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
+from .method import *
diff --git a/python/tvm/contrib/msc/core/tools/distill/distiller.py
b/python/tvm/contrib/msc/core/tools/distill/distiller.py
new file mode 100644
index 0000000000..f5c2ca2f88
--- /dev/null
+++ b/python/tvm/contrib/msc/core/tools/distill/distiller.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.tools.distill.distiller"""
+
+import os
+from typing import List, Any, Dict, Tuple
+
+import tvm
+from tvm.contrib.msc.core.ir import MSCGraph
+from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool, ToolStrategy
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class BaseDistiller(BaseTool):
+ """Base distiller for all"""
+
+ def setup(self) -> dict:
+ """Setup the tool
+
+ Returns
+ -------
+ info: dict
+ The setup info.
+ """
+
+ self._max_iter = self._options.get("max_iter", 5)
+ self._save_step = self._options.get("save_step", 50)
+ self._weights_folder =
msc_utils.get_weights_dir().create_dir("Distill")
+ self._weights_path =
self._weights_folder.relpath("distill_{}.bin".format(self._max_iter))
+ self._distilled = os.path.isfile(self._weights_path)
+ return super().setup()
+
+ def _reset(
+ self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]]
+ ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]:
+ """Reset the tool
+
+ Parameters
+ ----------
+ graphs: list<MSCgraph>
+ The msc graphs.
+ weights: list<dict<str, tvm.nd.array>>
+ The weights
+
+ Returns
+ -------
+ graphs: list<MSCgraph>
+ The msc graphs.
+ weights: list<dict<str, tvm.nd.array>>
+ The weights
+ """
+
+ self._current_iter = 0
+ self._total_loss = 0
+ if self._distilled:
+ with open(self._weights_path, "rb") as f:
+ distilled_weights = tvm.runtime.load_param_dict(f.read())
+ for sub_weights in weights:
+ sub_weights.update({k: v for k, v in distilled_weights.items()
if k in sub_weights})
+ self._logger.info("Update %d distilled weights",
len(distilled_weights))
+ return super()._reset(graphs, weights)
+
+ def build_model(self, teacher: Any, student: Any) -> Any:
+ """Build the model with teacher and student
+
+ Parameters
+ ----------
+ teacher: Any
+ The teacher model
+ student: Any
+ The student model
+
+ Returns
+ -------
+ model: Any
+ The built model.
+ """
+
+ raise NotImplementedError("build_model is not implemented in
BaseDistiller")
+
+ def learn(self, loss: Any):
+ """Learn after forward
+
+ Parameters
+ ----------
+ loss: Any
+ The loss after forward
+ """
+
+ if self.on_debug(3):
+ self._logger.debug("%sStart Learn", self.msg_mark())
+ self._total_loss += float(self._learn(loss))
+
+ def _learn(self, loss: Any):
+ """Learn after forward
+
+ Parameters
+ ----------
+ loss: Any
+ The loss after forward
+ """
+
+ raise NotImplementedError("_learn is not implemented in BaseDistiller")
+
+ def distill(self) -> Dict[str, Any]:
+ """Distill the knowledge
+
+ Returns
+ -------
+ weights: dict<str, Any>
+ The distilled weights.
+ """
+
+ weights = self._distill()
+ if self._current_iter >= self._max_iter or (
+ self._current_iter > 0 and self._current_iter % self._save_step == 0
+ ):
+ self._save_weights(weights)
+ if self._current_iter >= self._max_iter:
+ self._distilled = True
+ self._plan = {n: msc_utils.inspect_array(d, False) for n, d in
weights.items()}
+ self._logger.info(
+ "Distill[%d] loss(%d batch) %f", self._current_iter,
self._forward_cnt, self._total_loss
+ )
+ self._current_iter += 1
+ self._total_loss, self._forward_cnt = 0, 0
+ return weights
+
+ def _distill(self) -> Dict[str, Any]:
+ """Distill the knowledge
+
+ Returns
+ -------
+ weights: dict<str, Any>
+ The distilled weights.
+ """
+
+ raise NotImplementedError("_distill is not implemented in
BaseDistiller")
+
+ def _save_weights(self, weights: Dict[str, Any]):
+ """Save the distilled weights
+
+ Parameters
+ ----------
+ weights: dict<str, Any>
+ The distilled weights.
+ """
+
+ weights = {n: tvm.nd.array(msc_utils.cast_array(d)) for n, d in
weights.items()}
+ weights_path =
self._weights_folder.relpath("distill_{}.bin".format(self._current_iter))
+ with open(weights_path, "wb") as f_params:
+ f_params.write(tvm.runtime.save_param_dict(weights))
+ if self.on_debug(2, in_forward=False):
+ self._logger.debug("Save weights[%d] to %s", self._current_iter,
weights_path)
+
+ def _support_scope(self, scope: str) -> bool:
+ """Check if the scope si supported
+
+ Parameters
+ -------
+ scope: str
+ The scope mark, should be null or ToolScope
+
+ Returns
+ -------
+ vaild: bool
+ Whether to process the tensor.
+ """
+
+ return True
+
+ def _process_tensor(
+ self, tensor: Any, name: str, consumer: str, scope: str, strategys:
List[ToolStrategy]
+ ) -> Any:
+ """Process tensor
+
+ Parameters
+ -------
+ tensor: Any
+ Tensor in framework
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scope: str
+ The scope mark teacher| student| null.
+ strategys: list<ToolStrategy>
+ The strategys for the tensor.
+
+ Returns
+ -------
+ tensor: Any
+ The processed tensor.
+ """
+
+ if self._distilled:
+ return tensor
+ return self._distill_tensor(tensor, name, consumer, scope, strategys)
+
+ def _distill_tensor(
+ self, tensor: Any, name: str, consumer: str, scope: str, strategys:
List[ToolStrategy]
+ ) -> Any:
+ """Process tensor
+
+ Parameters
+ -------
+ tensor: Any
+ Tensor in framework
+ name: str
+ The name of the tensor.
+ consumer: str
+ The name of the consumer.
+ scope: str
+ The scope mark teacher| student| null.
+ strategys: list<ToolStrategy>
+ The strategys for the tensor.
+
+ Returns
+ -------
+ tensor: Any
+ The processed tensor.
+ """
+
+ if name not in self._plan:
+ self._plan[name] = {}
+ plan = {}
+ for strategy in strategys:
+ plan.update(strategy(self, tensor, name, consumer, scope))
+ self._plan[name][scope] = plan
+ return tensor
+
+ @property
+ def distilled(self):
+ return self._distilled
+
+ @classmethod
+ def tool_type(cls):
+ return ToolType.DISTILLER
+
+
+class DefaultDistiller(BaseDistiller):
+ @classmethod
+ def tool_style(cls):
+ return "default"
+
+
+msc_utils.register_tool_cls(DefaultDistiller)
diff --git a/python/tvm/contrib/msc/core/tools/distill/method.py
b/python/tvm/contrib/msc/core/tools/distill/method.py
new file mode 100644
index 0000000000..0f3fd0fe48
--- /dev/null
+++ b/python/tvm/contrib/msc/core/tools/distill/method.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.core.tools.distill.method"""
+
+from typing import List
+import numpy as np
+
+from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class DistillMethod(object):
+ """Default distill method"""
+
+ @classmethod
+ def loss_lp_norm(
+ cls,
+ distiller: BaseTool,
+ t_outputs: List[np.ndarray],
+ s_outputs: List[np.ndarray],
+ power: int = 2,
+ ):
+ """Calculate loss with mse
+
+ Parameters
+ ----------
+ distiller: BaseDistiller
+ The distiller
+ t_outputs: list<np.ndarray>
+ The teacher outputs.
+ s_outputs: list<np.ndarray>
+ The student outputs.
+ power: int
+ The power factor.
+
+ Returns
+ -------
+ loss: float
+ The loss.
+ """
+
+ loss = 0
+ for t_out, s_out in zip(t_outputs, s_outputs):
+ loss += np.mean(np.power(np.abs(t_out - s_out), power))
+ return loss
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.MSC
+
+ @classmethod
+ def tool_type(cls):
+ return ToolType.DISTILLER
+
+
+msc_utils.register_tool_method(DistillMethod)
diff --git a/python/tvm/contrib/msc/core/utils/expr.py
b/python/tvm/contrib/msc/core/utils/expr.py
index 8cebd1494f..9158381eb9 100644
--- a/python/tvm/contrib/msc/core/utils/expr.py
+++ b/python/tvm/contrib/msc/core/utils/expr.py
@@ -44,6 +44,26 @@ def get_expr_name(expr: relax.Expr) -> str:
return name
+def set_expr_name(expr: relax.Expr, name: str):
+ """Set the name for expr
+
+ Parameters
+ ----------
+ expr: Expr
+ The Expr of relax.
+ name: str
+ The name.
+
+ Returns
+ -------
+ expr: Expr
+ The expr with name.
+ """
+
+ expr.span = _ffi_api.SpanSetAttr(expr.span, "name", name)
+ return expr
+
+
def get_span_attrs(mod: tvm.IRModule) -> dict:
"""Extract the span attributes from relax.Function.
diff --git a/python/tvm/contrib/msc/core/utils/file.py
b/python/tvm/contrib/msc/core/utils/file.py
index 446efd4724..ada9745ff6 100644
--- a/python/tvm/contrib/msc/core/utils/file.py
+++ b/python/tvm/contrib/msc/core/utils/file.py
@@ -216,9 +216,14 @@ class MSCDirectory(object):
shutil.rmtree(f_path)
return f_path
- def listdir(self) -> List[str]:
+ def listdir(self, as_abs: bool = False) -> List[str]:
"""List contents in the dir.
+ Parameters
+ ----------
+ as_abs: bool
+ Whether to show abs path.
+
Returns
-------
names: list
@@ -227,6 +232,8 @@ class MSCDirectory(object):
if not os.path.isdir(self._path):
return []
+ if as_abs:
+ return [os.path.join(self._path, f) for f in
os.listdir(self._path)]
return os.listdir(self._path)
def destory(self):
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
index e5ebe50956..22f1821661 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
@@ -18,4 +18,5 @@
from .prune import *
from .quantize import *
+from .distill import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tensorflow/tools/distill/__init__.py
index ddfd41f3c8..1c89122c0a 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/__init__.py
@@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.tensorflow.tools.distill"""
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
diff --git
a/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py
b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py
new file mode 100644
index 0000000000..0385c6d941
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorflow.tools.distill.distiller"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.distill import BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorflowDistillerFactory(object):
+ """Distiller factory for tensorflow"""
+
+ def create(self, base_cls: BaseDistiller) -> BaseDistiller:
+ """Create adaptive distiller
+
+ Parameters
+ ----------
+ base_cls: BaseDistiller
+ The base distiller class
+
+ Returns
+ -------
+ distiller_cls: BaseDistiller
+ The distiller class.
+ """
+
+ class Distiller(base_cls):
+ """Adaptive distiller for tensorflow"""
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TENSORFLOW
+
+ return Distiller
+
+
+factory = TensorflowDistillerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC,
ToolType.DISTILLER, tool_style="all")
+for tool in tools.values():
+ msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
index c010a42004..7454dba712 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
@@ -18,4 +18,5 @@
from .prune import *
from .quantize import *
+from .distill import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tensorrt/tools/distill/__init__.py
index ddfd41f3c8..4d14e35c41 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/__init__.py
@@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.tensorrt.tools.distill"""
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
diff --git
a/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py
b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py
new file mode 100644
index 0000000000..bc9ead6dcc
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorrt.tools.distill.distiller"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.distill import BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorRTDistillerFactory(object):
+ """Distiller factory for tensorrt"""
+
+ def create(self, base_cls: BaseDistiller) -> BaseDistiller:
+ """Create adaptive distiller
+
+ Parameters
+ ----------
+ base_cls: BaseDistiller
+ The base distiller class
+
+ Returns
+ -------
+ distiller_cls: BaseDistiller
+ The distiller class.
+ """
+
+ class Distiller(base_cls):
+ """Adaptive distiller for tensorrt"""
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TENSORRT
+
+ return Distiller
+
+
+factory = TensorRTDistillerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC,
ToolType.DISTILLER, tool_style="all")
+for tool in tools.values():
+ msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/torch/tools/__init__.py
b/python/tvm/contrib/msc/framework/torch/tools/__init__.py
index f8fee73d69..f7dd3d489b 100644
--- a/python/tvm/contrib/msc/framework/torch/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/torch/tools/__init__.py
@@ -18,4 +18,5 @@
from .prune import *
from .quantize import *
+from .distill import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
b/python/tvm/contrib/msc/framework/torch/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/torch/tools/distill/__init__.py
index ddfd41f3c8..61ff8cc3ef 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/torch/tools/distill/__init__.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.torch.tools.distill"""
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
+from .method import *
diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py
b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py
new file mode 100644
index 0000000000..b2fa414aca
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py
@@ -0,0 +1,144 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.torch.tools.distill.distiller"""
+
+from typing import Any, Dict
+
+import torch
+from torch import optim
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.distill import BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchDistillerFactory(object):
+ """Distiller factory for torch"""
+
+ def create(self, base_cls: BaseDistiller) -> BaseDistiller:
+ """Create adaptive distiller
+
+ Parameters
+ ----------
+ base_cls: BaseDistiller
+ The base distiller class
+
+ Returns
+ -------
+ distiller_cls: BaseDistiller
+ The distiller class.
+ """
+
+ class Distiller(base_cls):
+ """Adaptive distiller for torch"""
+
+ def build_model(self, teacher: Any, student: Any) -> Any:
+ """Build the model with teacher and student
+
+ Parameters
+ -------
+ teacher: Any
+ The teacher model
+ student: Any
+ The student model
+
+ Returns
+ -------
+ model: Any
+ The built model.
+ """
+
+ optimizer = self._options.get("optimizer", "sgd")
+ opt_config = {"lr": 0.0001, "weight_decay": 1e-4}
+ opt_config.update(self._options.get("opt_config", {}))
+ self._logger.debug(
+ "%s build model with optimizer %s(%s)",
+ self.tool_type().upper(),
+ optimizer,
+ opt_config,
+ )
+ if optimizer == "sgd":
+ self._optimizer = optim.SGD(student.parameters(),
**opt_config)
+ elif optimizer == "adam":
+ self._optimizer = optim.Adam(student.parameters(),
**opt_config)
+ else:
+ raise NotImplementedError("optimizer {} is not
supported".format(optimizer))
+
+ # Get loss function
+ loss_strategy = self._strategys.get("loss.all")
+ assert loss_strategy, "Can not find loss.all in strategys"
+
+ def get_loss(teacher_outputs, student_outputs):
+ return loss_strategy(self, teacher_outputs,
student_outputs)
+
+ # Build model
+ class DistillModel(torch.nn.Module):
+ """Common distill model class"""
+
+ def __init__(self):
+ super(DistillModel, self).__init__()
+ self.teacher = teacher
+ self.student = student
+
+ def forward(self, *inputs):
+ with torch.no_grad():
+ teacher_outputs = self.teacher.forward(*inputs)
+ student_outputs = self.student.forward(*inputs)
+ return get_loss(teacher_outputs, student_outputs)
+
+ self._model = DistillModel()
+ return self._model
+
+ def _learn(self, loss: torch.Tensor):
+ """Learn after forward
+
+ Parameters
+ -------
+ loss: torch.Tensor
+ The loss after forward
+ """
+
+ loss.backward()
+ self._optimizer.step()
+ return loss
+
+ def _distill(self) -> Dict[str, Any]:
+ """Distill the knowledge
+
+ Returns
+ -------
+ weights: dict<str, Any>
+ The distilled weights.
+ """
+
+ state_dict = self._model.student.state_dict()
+ return {
+ n: state_dict.get(self.find_tensor(n).alias, d)
+ for n, d in self._weights.items()
+ }
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TORCH
+
+ return Distiller
+
+
+factory = TorchDistillerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC,
ToolType.DISTILLER, tool_style="all")
+for tool in tools.values():
+ msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/method.py
b/python/tvm/contrib/msc/framework/torch/tools/distill/method.py
new file mode 100644
index 0000000000..7de3fdbbac
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/tools/distill/method.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.torch.tools.distill.method"""
+
+from typing import List
+
+import torch
+from tvm.contrib.msc.core.tools.distill import DistillMethod, BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchDistillMethod(DistillMethod):
+ """Default quantize method for torch"""
+
+ @classmethod
+ def loss_kl_divergence(
+ cls,
+ distiller: BaseDistiller,
+ t_outputs: List[torch.Tensor],
+ s_outputs: List[torch.Tensor],
+ temperature: int = 5,
+ softmax_dim: int = -1,
+ ):
+ """Calculate loss with mse
+
+ Parameters
+ ----------
+ distiller: BaseDistiller
+ The distiller
+ t_outputs: list<torch.Tensor>
+ The teacher outputs.
+ s_outputs: list<torch.Tensor>
+ The student outputs.
+ temperature: int
+ The temperature factor.
+ softmax_dim: int
+ If >=0, use softmax_dim for softmax loss
+
+ Returns
+ -------
+ loss: float
+ The loss.
+ """
+
+ kd_loss, loss = torch.nn.KLDivLoss(), 0
+ if softmax_dim >= 0:
+ log_softmax = torch.nn.LogSoftmax(dim=softmax_dim)
+ softmax = torch.nn.Softmax(dim=softmax_dim)
+
+ def _distill_loss(t_out, s_out):
+ if softmax_dim >= 0:
+ return (
+ temperature
+ * temperature
+ * kd_loss(log_softmax(s_out / temperature), softmax(t_out
/ temperature))
+ )
+ return kd_loss(s_out / temperature, t_out / temperature)
+
+ for t_out, s_out in zip(t_outputs, s_outputs):
+ loss += _distill_loss(t_out, s_out)
+ return loss
+
+ @classmethod
+ def loss_lp_norm(
+ cls,
+ distiller: BaseDistiller,
+ t_outputs: List[torch.Tensor],
+ s_outputs: List[torch.Tensor],
+ power: int = 2,
+ ):
+ """Calculate loss with mse
+
+ Parameters
+ ----------
+ distiller: BaseDistiller
+ The distiller
+ t_outputs: list<torch.Tensor>
+ The teacher outputs.
+ s_outputs: list<torch.Tensor>
+ The student outputs.
+ power: int
+ The power factor.
+
+ Returns
+ -------
+ loss: float
+ The loss.
+ """
+
+ loss = 0
+ for t_out, s_out in zip(t_outputs, s_outputs):
+ loss += torch.pow((t_out - s_out).abs(), power).mean()
+ return loss
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TORCH
+
+
+msc_utils.register_tool_method(TorchDistillMethod)
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
index ddfd41f3c8..06b0f4a9d8 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
@@ -18,4 +18,5 @@
from .prune import *
from .quantize import *
+from .distill import *
from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
b/python/tvm/contrib/msc/framework/tvm/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tvm/tools/distill/__init__.py
index ddfd41f3c8..8d4b7dfb61 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tvm/tools/distill/__init__.py
@@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.tvm.tools.distill"""
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py
b/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py
new file mode 100644
index 0000000000..9cfc99dc1a
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tvm.tools.distill.distiller"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.distill import BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TVMDistillerFactory(object):
+ """Distiller factory for tvm"""
+
+ def create(self, base_cls: BaseDistiller) -> BaseDistiller:
+ """Create adaptive distiller
+
+ Parameters
+ ----------
+ base_cls: BaseDistiller
+ The base distiller class
+
+ Returns
+ -------
+ distiller_cls: BaseDistiller
+ The distiller class.
+ """
+
+ class Distiller(base_cls):
+ """Adaptive distiller for tvm"""
+
+ @classmethod
+ def framework(cls):
+ return MSCFramework.TVM
+
+ return Distiller
+
+
+factory = TVMDistillerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC,
ToolType.DISTILLER, tool_style="all")
+for tool in tools.values():
+ msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/pipeline/manager.py
b/python/tvm/contrib/msc/pipeline/manager.py
index 8a37ef951f..d19c7995ed 100644
--- a/python/tvm/contrib/msc/pipeline/manager.py
+++ b/python/tvm/contrib/msc/pipeline/manager.py
@@ -382,6 +382,10 @@ class BaseManager(object):
if _tool_enabled(ToolType.QUANTIZER):
self._apply_tool(ToolType.QUANTIZER, stage_config)
+ # run distill
+ if _tool_enabled(ToolType.DISTILLER):
+ self._apply_tool(ToolType.DISTILLER, stage_config)
+
# optimize and get the runner
msc_utils.time_stamp(MSCStage.OPTIMIZE)
return self._create_runner(
diff --git a/tests/python/contrib/test_msc/test_tools.py
b/tests/python/contrib/test_msc/test_tools.py
index 7e981d348b..cda6e231e8 100644
--- a/tests/python/contrib/test_msc/test_tools.py
+++ b/tests/python/contrib/test_msc/test_tools.py
@@ -68,7 +68,7 @@ def _get_config(
}
-def get_tool_config(tool_type):
+def get_tool_config(tool_type, use_distill=False):
"""Get config for the tool"""
config = {}
if tool_type == ToolType.PRUNER:
@@ -128,6 +128,17 @@ def get_tool_config(tool_type):
}
],
}
+ if use_distill:
+ distill_config = {
+ "plan_file": "msc_distiller.json",
+ "strategys": [
+ {
+ "method": "loss_lp_norm",
+ "op_types": ["loss"],
+ },
+ ],
+ }
+ return {tool_type: config, ToolType.DISTILLER: distill_config}
return {tool_type: config}
@@ -228,6 +239,16 @@ def test_tvm_tool(tool_type):
)
[email protected]("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER])
+def test_tvm_distill(tool_type):
+ """Test tools for tvm with distiller"""
+
+ tool_config = get_tool_config(tool_type, use_distill=True)
+ _test_from_torch(
+ MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM),
is_training=True
+ )
+
+
@requires_tensorrt
@pytest.mark.parametrize(
"tool_type",
@@ -253,5 +274,16 @@ def test_tensorrt_tool(tool_type):
)
+@requires_tensorrt
[email protected]("tool_type", [ToolType.PRUNER])
+def test_tensorrt_distill(tool_type):
+ """Test tools for tensorrt with distiller"""
+
+ tool_config = get_tool_config(tool_type, use_distill=True)
+ _test_from_torch(
+ MSCFramework.TENSORRT, tool_config,
get_model_info(MSCFramework.TENSORRT), is_training=False
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()