This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6b9c277bf6 [Unity][Training] Simplify matmul patterns after gradient 
(#16082)
6b9c277bf6 is described below

commit 6b9c277bf6d5b76f4029748c23d06f6b7df7c8f2
Author: Yixin Dong <[email protected]>
AuthorDate: Wed Nov 8 20:18:25 2023 -0800

    [Unity][Training] Simplify matmul patterns after gradient (#16082)
    
    When processing matmul with transpose (permute_dims operator), the current 
gradient pass will introduce additional permute_dims ops. That's because the 
gradient visit all binding from last one to the first one, and matmul and 
permute_dims are regarded as two separate ops. E.g.
    
    Forward is:
    ```
    out = matmul(a, transpose(b))
    ```
    Then backward is:
    ```
    grad_a = matmul(grad_out, transpose(transpose(b)))
    grad_b = transpose(matmul(transpose(a), grad_out))
    ```
    This PR introduces a new pass, GradientSimplifier, and enhances the 
Gradient pass to simplify these patterns. The example above will be simplified 
to
    ```
    grad_a = matmul(grad_out, b)
    grad_b = matmul(transpose(grad_out), a)
    ```
---
 src/relax/transform/gradient.cc                    |   9 +-
 src/relax/transform/gradient_simplifier.cc         | 200 +++++++++++++++++++++
 src/relax/transform/gradient_simplifier.h          |  46 +++++
 tests/python/relax/test_transform_gradient.py      |  51 ++++++
 .../relax/test_transform_gradient_numeric.py       |  37 ++++
 5 files changed, 341 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 2bf858c98d..70e3e37876 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -36,6 +36,7 @@
 
 #include "../op/tensor/binary.h"
 #include "../op/tensor/create.h"
+#include "gradient_simplifier.h"
 #include "utils.h"
 
 namespace tvm {
@@ -648,12 +649,16 @@ class GradientMutator : private ExprMutator {
       new_func = Downcast<Function>(RemoveAllUnused(new_func));
     }
 
-    // Step 4.3 mark the transformed function as public
+    // Step 4.3 Simplify specific patterns generated by the gradient pass. 
Especially, simplify
+    // transpose + matmul patterns. For details see the document of 
SimplifyGradient
+    new_func = SimplifyGradient(new_func);
+
+    // Step 4.4 mark the transformed function as public
     // because the original function may be public, and have gsymbol attribute 
as func_name
     auto new_func_name = func_name + "_adjoint";
     auto new_func_with_gsymbol = WithAttr(new_func, tvm::attr::kGlobalSymbol, 
new_func_name);
 
-    // Step 4.4 Add the transformed function to IRModule
+    // Step 4.5 Add the transformed function to IRModule
     builder_->AddFunction(new_func_with_gsymbol, new_func_name);
     return builder_->GetContextIRModule();
   }
diff --git a/src/relax/transform/gradient_simplifier.cc 
b/src/relax/transform/gradient_simplifier.cc
new file mode 100644
index 0000000000..966e8b7ad6
--- /dev/null
+++ b/src/relax/transform/gradient_simplifier.cc
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/gradient_simplifier.cc
+ * \brief Simplify patterns generated by the gradient pass. Only used in 
gradient.cc.
+ * \sa tvm/relax/transform/gradient.cc
+ *
+ * We will simplify these patterns:
+ * (transpose means use permute_dims to transpose the last two dimensions)
+ * 1. Forward is: out = matmul(a, transpose(b))
+ *    Then backward is:
+ *        grad_a = matmul(grad_out, transpose(transpose(b)))
+ *        grad_b = transpose(matmul(transpose(a), grad_out))
+ *    We will simplify it to:
+ *        grad_a = matmul(grad_out, b)
+ *        grad_b = matmul(transpose(grad_out), a)
+ * 2. Forward is: out = matmul(transpose(a), b)
+ *    Then backward is:
+ *        grad_a = transpose(matmul(grad_out, transpose(b)))
+ *        grad_b = matmul(transpose(transpose(a)), grad_out)
+ *    We will simplify it to:
+ *        grad_a = matmul(b, transpose(grad_out))
+ *        grad_b = matmul(a, grad_out)
+ * 3. Forward is: out = matmul(transpose(a), transpose(b))
+ *    Then backward is:
+ *        grad_a = transpose(matmul(grad_out, transpose(transpose(b))))
+ *        grad_b = transpose(matmul(transpose(transpose(a)), grad_out))
+ *    We will simplify it to:
+ *        grad_a = matmul(transpose(b), transpose(grad_out))
+ *        grad_b = matmul(transpose(grad_out), transpose(a))
+ */
+
+#include "gradient_simplifier.h"
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/manipulate.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+
+#include "../op/tensor/linear_algebra.h"
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Simplify patterns generated by the gradient pass. Especially, 
simplify the matmul
+ * patterns.
+ */
+class GradientSimplifier : private ExprMutator {
+ public:
+  /*!
+   * \brief Collect all variables that needs to be checkpointed, and remove 
the start_checkpoint
+   * and the end_checkpoint bindings.
+   *
+   * \param func The original function
+   * \return The function with all start_checkpoint and end_checkpoint 
bindings removed, and a
+   * VarIdSet containing all checkpointed vars.
+   */
+  static Function Transform(const Function& func) {
+    return 
Downcast<Function>(RemoveAllUnused(GradientSimplifier().VisitExpr(func)));
+  }
+
+ private:
+  static bool IsTransposeOp(const CallNode* call_node) {
+    if (call_node->op != Op::Get("relax.permute_dims")) {
+      return false;
+    }
+    auto sinfo = MatchStructInfo<TensorStructInfo>(call_node->args[0]);
+    if (!sinfo) {
+      return false;
+    }
+    auto ndim = sinfo.value()->ndim;
+    if (ndim == kUnknownNDim || ndim == 1) {
+      return false;
+    }
+    if (!call_node->attrs.as<PermuteDimsAttrs>()->axes.defined()) {
+      return ndim == 2;
+    }
+    auto axes = call_node->attrs.as<PermuteDimsAttrs>()->axes.value();
+    ICHECK(static_cast<int>(axes.size()) == ndim);
+    for (int i = 0; i < ndim - 2; ++i) {
+      if (axes[i] != i) {
+        return false;
+      }
+    }
+    return axes[ndim - 2] == ndim - 1 && axes[ndim - 1] == ndim - 2;
+  }
+
+  // Return permute_dims(expr). Generate the axes needed.
+  static Expr GetTransposeOf(const Expr& expr) {
+    auto sinfo = MatchStructInfo<TensorStructInfo>(expr);
+    ICHECK(sinfo);
+    auto ndim = sinfo.value()->ndim;
+    if (ndim == 1) {
+      return expr;
+    }
+    auto axes = Array<Integer>();
+    for (int i = 0; i < ndim - 2; ++i) {
+      axes.push_back(i);
+    }
+    axes.push_back(ndim - 1);
+    axes.push_back(ndim - 2);
+    return permute_dims(expr, axes);
+  }
+
+  // If expr is already in the form of permute_dims in previous bindings, 
return the input of the
+  // permute_dims op
+  // Else, return permute_dims(expr)
+  Expr GetTransposeAccordingToCtx(const Expr& expr) {
+    if (!expr->IsInstance<VarNode>()) {
+      return GetTransposeOf(expr);
+    }
+    auto prev_expr = builder_->LookupBinding(Downcast<Var>(expr));
+    if (!prev_expr || !prev_expr->IsInstance<CallNode>()) {
+      return GetTransposeOf(expr);
+    }
+    auto prev_call_node = prev_expr.as<CallNode>();
+    if (!IsTransposeOp(prev_call_node)) {
+      return GetTransposeOf(expr);
+    }
+    return prev_call_node->args[0];
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) 
{
+    auto result = ExprMutator::VisitExpr(GetRef<Expr>(call_node));
+    auto new_call_node = result.as<CallNode>();
+    auto reemit_and_return = [&]() {
+      ReEmitBinding(binding, result);
+      return;
+    };
+
+    if (!IsTransposeOp(new_call_node)) {
+      return reemit_and_return();
+    }
+
+    auto arg = new_call_node->args[0];
+    if (!arg->IsInstance<VarNode>()) {
+      return reemit_and_return();
+    }
+
+    auto prev_expr = builder_->LookupBinding(Downcast<Var>(arg));
+    if (!prev_expr || !prev_expr->IsInstance<CallNode>()) {
+      return reemit_and_return();
+    }
+
+    auto prev_call_node = prev_expr.as<CallNode>();
+    if (IsTransposeOp(prev_call_node)) {
+      // rewrite rule #1: permute_dims(permute_dims(a)) -> a
+      if (prev_call_node->args[0]->IsInstance<VarNode>()) {
+        var_remap_[binding->var->vid] = Downcast<Var>(prev_call_node->args[0]);
+        return;
+      } else {
+        return reemit_and_return();
+      }
+    } else if (prev_call_node->op == Op::Get("relax.matmul")) {
+      // rewrite rule #2: permute_dims(matmul(a, b)) -> 
matmul(permute_dims(b), permute_dims(a))
+      // Should "a" or "b" already be in the form of "permute_dims", the 
redundant permute_dims
+      // operation should be eliminated
+
+      // Skip matmuls with 1-dim input because in these cases we cannot simply 
transpose the input
+      auto a_dim = 
MatchStructInfo<TensorStructInfo>(prev_call_node->args[0]).value()->ndim;
+      auto b_dim = 
MatchStructInfo<TensorStructInfo>(prev_call_node->args[1]).value()->ndim;
+      if (a_dim == 1 || b_dim == 1) {
+        return reemit_and_return();
+      }
+
+      auto a = GetTransposeAccordingToCtx(prev_call_node->args[0]);
+      auto b = GetTransposeAccordingToCtx(prev_call_node->args[1]);
+      result =
+          ExprMutator::VisitExpr(matmul(b, a, 
prev_call_node->attrs.as<MatmulAttrs>()->out_dtype));
+      ReEmitBinding(binding, result);
+      return;
+    } else {
+      return reemit_and_return();
+    }
+  }
+};
+
+Function SimplifyGradient(const Function& func) { return 
GradientSimplifier::Transform(func); }
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/gradient_simplifier.h 
b/src/relax/transform/gradient_simplifier.h
new file mode 100644
index 0000000000..0c6faa0ef6
--- /dev/null
+++ b/src/relax/transform/gradient_simplifier.h
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/gradient_simplifier.h
+ * \brief Simplify patterns generated by the gradient pass. Especially, 
simplify the matmul
+ * patterns.
+ * \sa tvm/relax/transform/gradient.cc
+ */
+
+#ifndef TVM_RELAX_TRANSFORM_GRADIENT_SIMPLIFIER_H_
+#define TVM_RELAX_TRANSFORM_GRADIENT_SIMPLIFIER_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Simplify patterns generated by the gradient pass. Especially, 
simplify the matmul
+ * patterns.
+ * \param func The function to be simplified.
+ * \return The simplified function.
+ */
+Function SimplifyGradient(const Function& func);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_TRANSFORM_GRADIENT_SIMPLIFIER_H_
diff --git a/tests/python/relax/test_transform_gradient.py 
b/tests/python/relax/test_transform_gradient.py
index c4e2f9d526..072edea5c4 100644
--- a/tests/python/relax/test_transform_gradient.py
+++ b/tests/python/relax/test_transform_gradient.py
@@ -922,6 +922,57 @@ def test_const():
     assert_structural_equal(After, Expected)
 
 
+def test_simplify_matmul_pattern():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                lv1 = R.permute_dims(x)
+                lv2 = R.permute_dims(y)
+                lv3 = R.matmul(lv1, lv2, out_dtype="float32")
+                gv = R.sum(lv3)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 
3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), 
R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), dtype="float32") = R.permute_dims(x, 
axes=None)
+                lv2: R.Tensor((3, 3), dtype="float32") = R.permute_dims(y, 
axes=None)
+                lv3: R.Tensor((3, 3), dtype="float32") = R.matmul(lv1, lv2, 
out_dtype="float32")
+                gv: R.Tensor((), dtype="float32") = R.sum(lv3, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
+                lv3_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+                lv: R.Tensor((3, 3), dtype="float32") = 
R.permute_dims(lv3_adjoint, axes=[1, 0])
+                lv1_1: R.Tensor((3, 3), dtype="float32") = R.permute_dims(x, 
axes=[1, 0])
+                y_adjoint: R.Tensor((3, 3), dtype="float32") = R.matmul(lv, 
lv1_1, out_dtype="void")
+                lv2_1: R.Tensor((3, 3), dtype="float32") = R.permute_dims(y, 
axes=[1, 0])
+                lv3_1: R.Tensor((3, 3), dtype="float32") = 
R.permute_dims(lv3_adjoint, axes=[1, 0])
+                x_adjoint: R.Tensor((3, 3), dtype="float32") = R.matmul(lv2_1, 
lv3_1, out_dtype="void")
+                x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint
+                y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint
+                R.output(gv, x_adjoint_out, y_adjoint_out)
+            return (gv, (x_adjoint_out, y_adjoint_out))
+        @R.function
+        def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), dtype="float32") = R.permute_dims(x, 
axes=None)
+                lv2: R.Tensor((3, 3), dtype="float32") = R.permute_dims(y, 
axes=None)
+                lv3: R.Tensor((3, 3), dtype="float32") = R.matmul(lv1, lv2, 
out_dtype="float32")
+                gv: R.Tensor((), dtype="float32") = R.sum(lv3, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
 def test_shape_expr():
     # fmt: off
     @I.ir_module
diff --git a/tests/python/relax/test_transform_gradient_numeric.py 
b/tests/python/relax/test_transform_gradient_numeric.py
index 38a63406e8..27c0ffb565 100644
--- a/tests/python/relax/test_transform_gradient_numeric.py
+++ b/tests/python/relax/test_transform_gradient_numeric.py
@@ -186,5 +186,42 @@ def test_complex(target, dev):
     check_numerical_grads(func, [i.numpy() for i in args], [i.numpy() for i in 
grad])
 
 
[email protected]_targets("llvm")
+def test_matmul(target, dev):
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                lv1 = R.matmul(x, y)
+                lv2 = R.permute_dims(x)
+                lv3 = R.matmul(lv2, y)
+                lv4 = R.permute_dims(y)
+                lv5 = R.matmul(x, lv4)
+                lv6 = R.permute_dims(x)
+                lv7 = R.permute_dims(y)
+                lv8 = R.matmul(lv6, lv7)
+                lv9 = lv1 + lv3 + lv5 + lv8
+                gv = R.sum(lv9)
+                R.output(gv)
+            return gv
+
+    After = relax.transform.Gradient("main")(Before)
+    args = []
+    for arg in After["main_adjoint"].params:
+        shape = [int(l) for l in arg.struct_info.shape]
+        args.append(rand("float32", *shape))
+
+    vm_before = _legalize_and_build(Before, target, dev)
+    vm_after = _legalize_and_build(After, target, dev)
+    _, grad = vm_after["main_adjoint"](*args)
+
+    def func(*inputs):
+        loss = vm_before["main"](*[tvm.nd.array(i) for i in inputs])
+        return loss.numpy()
+
+    check_numerical_grads(func, [i.numpy() for i in args], [i.numpy() for i in 
grad])
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to