This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2bf3a0a428 [Unity][MSC][M3.1] Add distiller for distill model (#16264)
2bf3a0a428 is described below

commit 2bf3a0a4287069ac55ee3304c285b08592d3d1bc
Author: Archermmt <[email protected]>
AuthorDate: Tue Dec 26 18:53:26 2023 +0800

    [Unity][MSC][M3.1] Add distiller for distill model (#16264)
    
    * add distiller
    
    * remove useless
---
 python/tvm/contrib/msc/core/runtime/runner.py      |  32 ++-
 python/tvm/contrib/msc/core/tools/__init__.py      |   1 +
 .../tvm/tools => core/tools/distill}/__init__.py   |   7 +-
 .../contrib/msc/core/tools/distill/distiller.py    | 261 +++++++++++++++++++++
 .../tvm/contrib/msc/core/tools/distill/method.py   |  72 ++++++
 python/tvm/contrib/msc/core/utils/expr.py          |  20 ++
 python/tvm/contrib/msc/core/utils/file.py          |   9 +-
 .../msc/framework/tensorflow/tools/__init__.py     |   1 +
 .../tools => tensorflow/tools/distill}/__init__.py |   6 +-
 .../tensorflow/tools/distill/distiller.py          |  55 +++++
 .../msc/framework/tensorrt/tools/__init__.py       |   1 +
 .../tools => tensorrt/tools/distill}/__init__.py   |   6 +-
 .../framework/tensorrt/tools/distill/distiller.py  |  55 +++++
 .../contrib/msc/framework/torch/tools/__init__.py  |   1 +
 .../{tvm/tools => torch/tools/distill}/__init__.py |   7 +-
 .../msc/framework/torch/tools/distill/distiller.py | 144 ++++++++++++
 .../msc/framework/torch/tools/distill/method.py    | 116 +++++++++
 .../contrib/msc/framework/tvm/tools/__init__.py    |   1 +
 .../framework/tvm/tools/{ => distill}/__init__.py  |   6 +-
 .../msc/framework/tvm/tools/distill/distiller.py   |  55 +++++
 python/tvm/contrib/msc/pipeline/manager.py         |   4 +
 tests/python/contrib/test_msc/test_tools.py        |  34 ++-
 22 files changed, 868 insertions(+), 26 deletions(-)

diff --git a/python/tvm/contrib/msc/core/runtime/runner.py 
b/python/tvm/contrib/msc/core/runtime/runner.py
index 5228b06b10..4b84037994 100644
--- a/python/tvm/contrib/msc/core/runtime/runner.py
+++ b/python/tvm/contrib/msc/core/runtime/runner.py
@@ -26,7 +26,7 @@ import numpy as np
 import tvm
 from tvm.contrib.msc.core.ir import MSCGraph
 from tvm.contrib.msc.core.frontend import from_relax
-from tvm.contrib.msc.core.tools import BaseTool, ToolType, create_tool, 
remove_tools
+from tvm.contrib.msc.core.tools import BaseTool, ToolType, ToolScope, 
create_tool, remove_tools
 from tvm.contrib.msc.core.utils.namespace import MSCFramework
 from tvm.contrib.msc.core.utils.message import MSCStage
 from tvm.contrib.msc.core import utils as msc_utils
@@ -180,9 +180,24 @@ class BaseRunner(object):
 
         # Generate model
         if not self._model:
-            # Generate normal model
-            self._graphs, self._weights = self.reset_tools(cache_dir=cache_dir)
-            self._model = self._generate_model()
+            distiller = self.get_tool(ToolType.DISTILLER)
+            if distiller and not distiller.distilled:
+                build_root = self._generate_config["build_folder"]
+
+                def _build_scope_model(scope: str):
+                    self._update_codegen({"tools_scope": scope})
+                    self._generate_config["build_folder"] = 
build_root.create_dir(scope)
+                    return self._generate_model()
+
+                # Generate distill model
+                teacher_model = _build_scope_model(ToolScope.TEACHER)
+                self._graphs, self._weights = 
self.reset_tools(cache_dir=cache_dir)
+                student_model = _build_scope_model(ToolScope.STUDENT)
+                self._model = distiller.build_model(teacher_model, 
student_model)
+            else:
+                # Generate normal model
+                self._graphs, self._weights = 
self.reset_tools(cache_dir=cache_dir)
+                self._model = self._generate_model()
 
             # Log generate info
             generate_msg = "Generate model({})".format(self.framework)
@@ -422,6 +437,15 @@ class BaseRunner(object):
                     self.run(inputs, ret_type="native")
                 quantizer.calibrate()
             plan = quantizer.finalize()
+        elif tool_type == ToolType.DISTILLER:
+            distiller = self.get_tool(ToolType.DISTILLER)
+            while not distiller.distilled:
+                assert data_loader, "data_loader should be given to plan prune"
+                for inputs in data_loader():
+                    loss = self.run(inputs, ret_type="native")
+                    distiller.learn(loss)
+                distiller.distill()
+            plan = distiller.finalize()
         else:
             plan = self.get_tool(tool_type).finalize()
         assert plan, "Failed to create plan for {}".format(tool_type)
diff --git a/python/tvm/contrib/msc/core/tools/__init__.py 
b/python/tvm/contrib/msc/core/tools/__init__.py
index e97771cf6c..3c563fbea0 100644
--- a/python/tvm/contrib/msc/core/tools/__init__.py
+++ b/python/tvm/contrib/msc/core/tools/__init__.py
@@ -20,4 +20,5 @@ from .tool import *
 from .execute import *
 from .prune import *
 from .quantize import *
+from .distill import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py 
b/python/tvm/contrib/msc/core/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/core/tools/distill/__init__.py
index ddfd41f3c8..8714eae4e4 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/core/tools/distill/__init__.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.core.tools.distill"""
 
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
+from .method import *
diff --git a/python/tvm/contrib/msc/core/tools/distill/distiller.py 
b/python/tvm/contrib/msc/core/tools/distill/distiller.py
new file mode 100644
index 0000000000..f5c2ca2f88
--- /dev/null
+++ b/python/tvm/contrib/msc/core/tools/distill/distiller.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.tools.distill.distiller"""
+
+import os
+from typing import List, Any, Dict, Tuple
+
+import tvm
+from tvm.contrib.msc.core.ir import MSCGraph
+from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool, ToolStrategy
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class BaseDistiller(BaseTool):
+    """Base distiller for all"""
+
+    def setup(self) -> dict:
+        """Setup the tool
+
+        Returns
+        -------
+        info: dict
+            The setup info.
+        """
+
+        self._max_iter = self._options.get("max_iter", 5)
+        self._save_step = self._options.get("save_step", 50)
+        self._weights_folder = 
msc_utils.get_weights_dir().create_dir("Distill")
+        self._weights_path = 
self._weights_folder.relpath("distill_{}.bin".format(self._max_iter))
+        self._distilled = os.path.isfile(self._weights_path)
+        return super().setup()
+
+    def _reset(
+        self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]]
+    ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]:
+        """Reset the tool
+
+        Parameters
+        ----------
+        graphs: list<MSCgraph>
+            The msc graphs.
+        weights: list<dict<str, tvm.nd.array>>
+            The weights
+
+        Returns
+        -------
+        graphs: list<MSCgraph>
+            The msc graphs.
+        weights: list<dict<str, tvm.nd.array>>
+            The weights
+        """
+
+        self._current_iter = 0
+        self._total_loss = 0
+        if self._distilled:
+            with open(self._weights_path, "rb") as f:
+                distilled_weights = tvm.runtime.load_param_dict(f.read())
+            for sub_weights in weights:
+                sub_weights.update({k: v for k, v in distilled_weights.items() 
if k in sub_weights})
+            self._logger.info("Update %d distilled weights", 
len(distilled_weights))
+        return super()._reset(graphs, weights)
+
+    def build_model(self, teacher: Any, student: Any) -> Any:
+        """Build the model with teacher and student
+
+        Parameters
+        ----------
+        teacher: Any
+            The teacher model
+        student: Any
+            The student model
+
+        Returns
+        -------
+        model: Any
+            The built model.
+        """
+
+        raise NotImplementedError("build_model is not implemented in 
BaseDistiller")
+
+    def learn(self, loss: Any):
+        """Learn after forward
+
+        Parameters
+        ----------
+        loss: Any
+            The loss after forward
+        """
+
+        if self.on_debug(3):
+            self._logger.debug("%sStart Learn", self.msg_mark())
+        self._total_loss += float(self._learn(loss))
+
+    def _learn(self, loss: Any):
+        """Learn after forward
+
+        Parameters
+        ----------
+        loss: Any
+            The loss after forward
+        """
+
+        raise NotImplementedError("_learn is not implemented in BaseDistiller")
+
+    def distill(self) -> Dict[str, Any]:
+        """Distill the knowledge
+
+        Returns
+        -------
+        weights: dict<str, Any>
+            The distilled weights.
+        """
+
+        weights = self._distill()
+        if self._current_iter >= self._max_iter or (
+            self._current_iter > 0 and self._current_iter % self._save_step == 0
+        ):
+            self._save_weights(weights)
+        if self._current_iter >= self._max_iter:
+            self._distilled = True
+            self._plan = {n: msc_utils.inspect_array(d, False) for n, d in 
weights.items()}
+        self._logger.info(
+            "Distill[%d] loss(%d batch) %f", self._current_iter, 
self._forward_cnt, self._total_loss
+        )
+        self._current_iter += 1
+        self._total_loss, self._forward_cnt = 0, 0
+        return weights
+
+    def _distill(self) -> Dict[str, Any]:
+        """Distill the knowledge
+
+        Returns
+        -------
+        weights: dict<str, Any>
+            The distilled weights.
+        """
+
+        raise NotImplementedError("_distill is not implemented in 
BaseDistiller")
+
+    def _save_weights(self, weights: Dict[str, Any]):
+        """Save the distilled weights
+
+        Parameters
+        ----------
+        weights: dict<str, Any>
+            The distilled weights.
+        """
+
+        weights = {n: tvm.nd.array(msc_utils.cast_array(d)) for n, d in 
weights.items()}
+        weights_path = 
self._weights_folder.relpath("distill_{}.bin".format(self._current_iter))
+        with open(weights_path, "wb") as f_params:
+            f_params.write(tvm.runtime.save_param_dict(weights))
+        if self.on_debug(2, in_forward=False):
+            self._logger.debug("Save weights[%d] to %s", self._current_iter, 
weights_path)
+
+    def _support_scope(self, scope: str) -> bool:
+        """Check if the scope si supported
+
+        Parameters
+        -------
+        scope: str
+            The scope mark, should be null or ToolScope
+
+        Returns
+        -------
+        vaild: bool
+            Whether to process the tensor.
+        """
+
+        return True
+
+    def _process_tensor(
+        self, tensor: Any, name: str, consumer: str, scope: str, strategys: 
List[ToolStrategy]
+    ) -> Any:
+        """Process tensor
+
+        Parameters
+        -------
+        tensor: Any
+            Tensor in framework
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scope: str
+            The scope mark teacher| student| null.
+        strategys: list<ToolStrategy>
+            The strategys for the tensor.
+
+        Returns
+        -------
+        tensor: Any
+            The processed tensor.
+        """
+
+        if self._distilled:
+            return tensor
+        return self._distill_tensor(tensor, name, consumer, scope, strategys)
+
+    def _distill_tensor(
+        self, tensor: Any, name: str, consumer: str, scope: str, strategys: 
List[ToolStrategy]
+    ) -> Any:
+        """Process tensor
+
+        Parameters
+        -------
+        tensor: Any
+            Tensor in framework
+        name: str
+            The name of the tensor.
+        consumer: str
+            The name of the consumer.
+        scope: str
+            The scope mark teacher| student| null.
+        strategys: list<ToolStrategy>
+            The strategys for the tensor.
+
+        Returns
+        -------
+        tensor: Any
+            The processed tensor.
+        """
+
+        if name not in self._plan:
+            self._plan[name] = {}
+        plan = {}
+        for strategy in strategys:
+            plan.update(strategy(self, tensor, name, consumer, scope))
+        self._plan[name][scope] = plan
+        return tensor
+
+    @property
+    def distilled(self):
+        return self._distilled
+
+    @classmethod
+    def tool_type(cls):
+        return ToolType.DISTILLER
+
+
+class DefaultDistiller(BaseDistiller):
+    @classmethod
+    def tool_style(cls):
+        return "default"
+
+
+msc_utils.register_tool_cls(DefaultDistiller)
diff --git a/python/tvm/contrib/msc/core/tools/distill/method.py 
b/python/tvm/contrib/msc/core/tools/distill/method.py
new file mode 100644
index 0000000000..0f3fd0fe48
--- /dev/null
+++ b/python/tvm/contrib/msc/core/tools/distill/method.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.core.tools.distill.method"""
+
+from typing import List
+import numpy as np
+
+from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class DistillMethod(object):
+    """Default distill method"""
+
+    @classmethod
+    def loss_lp_norm(
+        cls,
+        distiller: BaseTool,
+        t_outputs: List[np.ndarray],
+        s_outputs: List[np.ndarray],
+        power: int = 2,
+    ):
+        """Calculate loss with mse
+
+        Parameters
+        ----------
+        distiller: BaseDistiller
+            The distiller
+        t_outputs: list<np.ndarray>
+            The teacher outputs.
+        s_outputs: list<np.ndarray>
+            The student outputs.
+        power: int
+            The power factor.
+
+        Returns
+        -------
+        loss: float
+            The loss.
+        """
+
+        loss = 0
+        for t_out, s_out in zip(t_outputs, s_outputs):
+            loss += np.mean(np.power(np.abs(t_out - s_out), power))
+        return loss
+
+    @classmethod
+    def framework(cls):
+        return MSCFramework.MSC
+
+    @classmethod
+    def tool_type(cls):
+        return ToolType.DISTILLER
+
+
+msc_utils.register_tool_method(DistillMethod)
diff --git a/python/tvm/contrib/msc/core/utils/expr.py 
b/python/tvm/contrib/msc/core/utils/expr.py
index 8cebd1494f..9158381eb9 100644
--- a/python/tvm/contrib/msc/core/utils/expr.py
+++ b/python/tvm/contrib/msc/core/utils/expr.py
@@ -44,6 +44,26 @@ def get_expr_name(expr: relax.Expr) -> str:
     return name
 
 
+def set_expr_name(expr: relax.Expr, name: str):
+    """Set the name for expr
+
+    Parameters
+    ----------
+    expr: Expr
+        The Expr of relax.
+    name: str
+        The name.
+
+    Returns
+    -------
+    expr: Expr
+        The expr with name.
+    """
+
+    expr.span = _ffi_api.SpanSetAttr(expr.span, "name", name)
+    return expr
+
+
 def get_span_attrs(mod: tvm.IRModule) -> dict:
     """Extract the span attributes from relax.Function.
 
diff --git a/python/tvm/contrib/msc/core/utils/file.py 
b/python/tvm/contrib/msc/core/utils/file.py
index 446efd4724..ada9745ff6 100644
--- a/python/tvm/contrib/msc/core/utils/file.py
+++ b/python/tvm/contrib/msc/core/utils/file.py
@@ -216,9 +216,14 @@ class MSCDirectory(object):
             shutil.rmtree(f_path)
         return f_path
 
-    def listdir(self) -> List[str]:
+    def listdir(self, as_abs: bool = False) -> List[str]:
         """List contents in the dir.
 
+        Parameters
+        ----------
+        as_abs: bool
+            Whether to show abs path.
+
         Returns
         -------
         names: list
@@ -227,6 +232,8 @@ class MSCDirectory(object):
 
         if not os.path.isdir(self._path):
             return []
+        if as_abs:
+            return [os.path.join(self._path, f) for f in 
os.listdir(self._path)]
         return os.listdir(self._path)
 
     def destory(self):
diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
index e5ebe50956..22f1821661 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py
@@ -18,4 +18,5 @@
 
 from .prune import *
 from .quantize import *
+from .distill import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tensorflow/tools/distill/__init__.py
index ddfd41f3c8..1c89122c0a 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/__init__.py
@@ -14,8 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.tensorflow.tools.distill"""
 
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
diff --git 
a/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py 
b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py
new file mode 100644
index 0000000000..0385c6d941
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorflow.tools.distill.distiller"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.distill import BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorflowDistillerFactory(object):
+    """Distiller factory for tensorflow"""
+
+    def create(self, base_cls: BaseDistiller) -> BaseDistiller:
+        """Create adaptive distiller
+
+        Parameters
+        ----------
+        base_cls: BaseDistiller
+            The base distiller class
+
+        Returns
+        -------
+        distiller_cls: BaseDistiller
+            The distiller class.
+        """
+
+        class Distiller(base_cls):
+            """Adaptive distiller for tensorflow"""
+
+            @classmethod
+            def framework(cls):
+                return MSCFramework.TENSORFLOW
+
+        return Distiller
+
+
+factory = TensorflowDistillerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, 
ToolType.DISTILLER, tool_style="all")
+for tool in tools.values():
+    msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
index c010a42004..7454dba712 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py
@@ -18,4 +18,5 @@
 
 from .prune import *
 from .quantize import *
+from .distill import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tensorrt/tools/distill/__init__.py
index ddfd41f3c8..4d14e35c41 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/__init__.py
@@ -14,8 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.tensorrt.tools.distill"""
 
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
diff --git 
a/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py 
b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py
new file mode 100644
index 0000000000..bc9ead6dcc
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tensorrt.tools.distill.distiller"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.distill import BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TensorRTDistillerFactory(object):
+    """Distiller factory for tensorrt"""
+
+    def create(self, base_cls: BaseDistiller) -> BaseDistiller:
+        """Create adaptive distiller
+
+        Parameters
+        ----------
+        base_cls: BaseDistiller
+            The base distiller class
+
+        Returns
+        -------
+        distiller_cls: BaseDistiller
+            The distiller class.
+        """
+
+        class Distiller(base_cls):
+            """Adaptive distiller for tensorrt"""
+
+            @classmethod
+            def framework(cls):
+                return MSCFramework.TENSORRT
+
+        return Distiller
+
+
+factory = TensorRTDistillerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, 
ToolType.DISTILLER, tool_style="all")
+for tool in tools.values():
+    msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/torch/tools/__init__.py 
b/python/tvm/contrib/msc/framework/torch/tools/__init__.py
index f8fee73d69..f7dd3d489b 100644
--- a/python/tvm/contrib/msc/framework/torch/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/torch/tools/__init__.py
@@ -18,4 +18,5 @@
 
 from .prune import *
 from .quantize import *
+from .distill import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py 
b/python/tvm/contrib/msc/framework/torch/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/torch/tools/distill/__init__.py
index ddfd41f3c8..61ff8cc3ef 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/torch/tools/distill/__init__.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.torch.tools.distill"""
 
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
+from .method import *
diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py 
b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py
new file mode 100644
index 0000000000..b2fa414aca
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py
@@ -0,0 +1,144 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.torch.tools.distill.distiller"""
+
+from typing import Any, Dict
+
+import torch
+from torch import optim
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.distill import BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchDistillerFactory(object):
+    """Distiller factory for torch"""
+
+    def create(self, base_cls: BaseDistiller) -> BaseDistiller:
+        """Create adaptive distiller
+
+        Parameters
+        ----------
+        base_cls: BaseDistiller
+            The base distiller class
+
+        Returns
+        -------
+        distiller_cls: BaseDistiller
+            The distiller class.
+        """
+
+        class Distiller(base_cls):
+            """Adaptive distiller for torch"""
+
+            def build_model(self, teacher: Any, student: Any) -> Any:
+                """Build the model with teacher and student
+
+                Parameters
+                -------
+                teacher: Any
+                    The teacher model
+                student: Any
+                    The student model
+
+                Returns
+                -------
+                model: Any
+                    The built model.
+                """
+
+                optimizer = self._options.get("optimizer", "sgd")
+                opt_config = {"lr": 0.0001, "weight_decay": 1e-4}
+                opt_config.update(self._options.get("opt_config", {}))
+                self._logger.debug(
+                    "%s build model with optimizer %s(%s)",
+                    self.tool_type().upper(),
+                    optimizer,
+                    opt_config,
+                )
+                if optimizer == "sgd":
+                    self._optimizer = optim.SGD(student.parameters(), 
**opt_config)
+                elif optimizer == "adam":
+                    self._optimizer = optim.Adam(student.parameters(), 
**opt_config)
+                else:
+                    raise NotImplementedError("optimizer {} is not 
supported".format(optimizer))
+
+                # Get loss function
+                loss_strategy = self._strategys.get("loss.all")
+                assert loss_strategy, "Can not find loss.all in strategys"
+
+                def get_loss(teacher_outputs, student_outputs):
+                    return loss_strategy(self, teacher_outputs, 
student_outputs)
+
+                # Build model
+                class DistillModel(torch.nn.Module):
+                    """Common distill model class"""
+
+                    def __init__(self):
+                        super(DistillModel, self).__init__()
+                        self.teacher = teacher
+                        self.student = student
+
+                    def forward(self, *inputs):
+                        with torch.no_grad():
+                            teacher_outputs = self.teacher.forward(*inputs)
+                        student_outputs = self.student.forward(*inputs)
+                        return get_loss(teacher_outputs, student_outputs)
+
+                self._model = DistillModel()
+                return self._model
+
+            def _learn(self, loss: torch.Tensor):
+                """Learn after forward
+
+                Parameters
+                -------
+                loss: torch.Tensor
+                    The loss after forward
+                """
+
+                loss.backward()
+                self._optimizer.step()
+                return loss
+
+            def _distill(self) -> Dict[str, Any]:
+                """Distill the knowledge
+
+                Returns
+                -------
+                weights: dict<str, Any>
+                    The distilled weights.
+                """
+
+                state_dict = self._model.student.state_dict()
+                return {
+                    n: state_dict.get(self.find_tensor(n).alias, d)
+                    for n, d in self._weights.items()
+                }
+
+            @classmethod
+            def framework(cls):
+                return MSCFramework.TORCH
+
+        return Distiller
+
+
+factory = TorchDistillerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, 
ToolType.DISTILLER, tool_style="all")
+for tool in tools.values():
+    msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/method.py 
b/python/tvm/contrib/msc/framework/torch/tools/distill/method.py
new file mode 100644
index 0000000000..7de3fdbbac
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/torch/tools/distill/method.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.framework.torch.tools.distill.method"""
+
+from typing import List
+
+import torch
+from tvm.contrib.msc.core.tools.distill import DistillMethod, BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TorchDistillMethod(DistillMethod):
+    """Default quantize method for torch"""
+
+    @classmethod
+    def loss_kl_divergence(
+        cls,
+        distiller: BaseDistiller,
+        t_outputs: List[torch.Tensor],
+        s_outputs: List[torch.Tensor],
+        temperature: int = 5,
+        softmax_dim: int = -1,
+    ):
+        """Calculate loss with mse
+
+        Parameters
+        ----------
+        distiller: BaseDistiller
+            The distiller
+        t_outputs: list<torch.Tensor>
+            The teacher outputs.
+        s_outputs: list<torch.Tensor>
+            The student outputs.
+        temperature: int
+            The temperature factor.
+        softmax_dim: int
+            If >=0, use softmax_dim for softmax loss
+
+        Returns
+        -------
+        loss: float
+            The loss.
+        """
+
+        kd_loss, loss = torch.nn.KLDivLoss(), 0
+        if softmax_dim >= 0:
+            log_softmax = torch.nn.LogSoftmax(dim=softmax_dim)
+            softmax = torch.nn.Softmax(dim=softmax_dim)
+
+        def _distill_loss(t_out, s_out):
+            if softmax_dim >= 0:
+                return (
+                    temperature
+                    * temperature
+                    * kd_loss(log_softmax(s_out / temperature), softmax(t_out 
/ temperature))
+                )
+            return kd_loss(s_out / temperature, t_out / temperature)
+
+        for t_out, s_out in zip(t_outputs, s_outputs):
+            loss += _distill_loss(t_out, s_out)
+        return loss
+
+    @classmethod
+    def loss_lp_norm(
+        cls,
+        distiller: BaseDistiller,
+        t_outputs: List[torch.Tensor],
+        s_outputs: List[torch.Tensor],
+        power: int = 2,
+    ):
+        """Calculate loss with mse
+
+        Parameters
+        ----------
+        distiller: BaseDistiller
+            The distiller
+        t_outputs: list<torch.Tensor>
+            The teacher outputs.
+        s_outputs: list<torch.Tensor>
+            The student outputs.
+        power: int
+            The power factor.
+
+        Returns
+        -------
+        loss: float
+            The loss.
+        """
+
+        loss = 0
+        for t_out, s_out in zip(t_outputs, s_outputs):
+            loss += torch.pow((t_out - s_out).abs(), power).mean()
+        return loss
+
+    @classmethod
+    def framework(cls):
+        return MSCFramework.TORCH
+
+
+msc_utils.register_tool_method(TorchDistillMethod)
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
index ddfd41f3c8..06b0f4a9d8 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
@@ -18,4 +18,5 @@
 
 from .prune import *
 from .quantize import *
+from .distill import *
 from .track import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py 
b/python/tvm/contrib/msc/framework/tvm/tools/distill/__init__.py
similarity index 87%
copy from python/tvm/contrib/msc/framework/tvm/tools/__init__.py
copy to python/tvm/contrib/msc/framework/tvm/tools/distill/__init__.py
index ddfd41f3c8..8d4b7dfb61 100644
--- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py
+++ b/python/tvm/contrib/msc/framework/tvm/tools/distill/__init__.py
@@ -14,8 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""tvm.contrib.msc.framework.tvm.tools"""
+"""tvm.contrib.msc.framework.tvm.tools.distill"""
 
-from .prune import *
-from .quantize import *
-from .track import *
+from .distiller import *
diff --git a/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py 
b/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py
new file mode 100644
index 0000000000..9cfc99dc1a
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.tvm.tools.distill.distiller"""
+
+from tvm.contrib.msc.core.tools.tool import ToolType
+from tvm.contrib.msc.core.tools.distill import BaseDistiller
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+
+
+class TVMDistillerFactory(object):
+    """Distiller factory for tvm"""
+
+    def create(self, base_cls: BaseDistiller) -> BaseDistiller:
+        """Create adaptive distiller
+
+        Parameters
+        ----------
+        base_cls: BaseDistiller
+            The base distiller class
+
+        Returns
+        -------
+        distiller_cls: BaseDistiller
+            The distiller class.
+        """
+
+        class Distiller(base_cls):
+            """Adaptive distiller for tvm"""
+
+            @classmethod
+            def framework(cls):
+                return MSCFramework.TVM
+
+        return Distiller
+
+
+factory = TVMDistillerFactory()
+tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, 
ToolType.DISTILLER, tool_style="all")
+for tool in tools.values():
+    msc_utils.register_tool_cls(factory.create(tool))
diff --git a/python/tvm/contrib/msc/pipeline/manager.py 
b/python/tvm/contrib/msc/pipeline/manager.py
index 8a37ef951f..d19c7995ed 100644
--- a/python/tvm/contrib/msc/pipeline/manager.py
+++ b/python/tvm/contrib/msc/pipeline/manager.py
@@ -382,6 +382,10 @@ class BaseManager(object):
         if _tool_enabled(ToolType.QUANTIZER):
             self._apply_tool(ToolType.QUANTIZER, stage_config)
 
+        # run distill
+        if _tool_enabled(ToolType.DISTILLER):
+            self._apply_tool(ToolType.DISTILLER, stage_config)
+
         # optimize and get the runner
         msc_utils.time_stamp(MSCStage.OPTIMIZE)
         return self._create_runner(
diff --git a/tests/python/contrib/test_msc/test_tools.py 
b/tests/python/contrib/test_msc/test_tools.py
index 7e981d348b..cda6e231e8 100644
--- a/tests/python/contrib/test_msc/test_tools.py
+++ b/tests/python/contrib/test_msc/test_tools.py
@@ -68,7 +68,7 @@ def _get_config(
     }
 
 
-def get_tool_config(tool_type):
+def get_tool_config(tool_type, use_distill=False):
     """Get config for the tool"""
     config = {}
     if tool_type == ToolType.PRUNER:
@@ -128,6 +128,17 @@ def get_tool_config(tool_type):
                 }
             ],
         }
+    if use_distill:
+        distill_config = {
+            "plan_file": "msc_distiller.json",
+            "strategys": [
+                {
+                    "method": "loss_lp_norm",
+                    "op_types": ["loss"],
+                },
+            ],
+        }
+        return {tool_type: config, ToolType.DISTILLER: distill_config}
     return {tool_type: config}
 
 
@@ -228,6 +239,16 @@ def test_tvm_tool(tool_type):
     )
 
 
[email protected]("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER])
+def test_tvm_distill(tool_type):
+    """Test tools for tvm with distiller"""
+
+    tool_config = get_tool_config(tool_type, use_distill=True)
+    _test_from_torch(
+        MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM), 
is_training=True
+    )
+
+
 @requires_tensorrt
 @pytest.mark.parametrize(
     "tool_type",
@@ -253,5 +274,16 @@ def test_tensorrt_tool(tool_type):
     )
 
 
+@requires_tensorrt
[email protected]("tool_type", [ToolType.PRUNER])
+def test_tensorrt_distill(tool_type):
+    """Test tools for tensorrt with distiller"""
+
+    tool_config = get_tool_config(tool_type, use_distill=True)
+    _test_from_torch(
+        MSCFramework.TENSORRT, tool_config, 
get_model_info(MSCFramework.TENSORRT), is_training=False
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to