This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6b9c277bf6 [Unity][Training] Simplify matmul patterns after gradient
(#16082)
6b9c277bf6 is described below
commit 6b9c277bf6d5b76f4029748c23d06f6b7df7c8f2
Author: Yixin Dong <[email protected]>
AuthorDate: Wed Nov 8 20:18:25 2023 -0800
[Unity][Training] Simplify matmul patterns after gradient (#16082)
When processing matmul with transpose (permute_dims operator), the current
gradient pass will introduce additional permute_dims ops. That's because the
gradient visit all binding from last one to the first one, and matmul and
permute_dims are regarded as two separate ops. E.g.
Forward is:
```
out = matmul(a, transpose(b))
```
Then backward is:
```
grad_a = matmul(grad_out, transpose(transpose(b)))
grad_b = transpose(matmul(transpose(a), grad_out))
```
This PR introduces a new pass, GradientSimplifier, and enhances the
Gradient pass to simplify these patterns. The example above will be simplified
to
```
grad_a = matmul(grad_out, b)
grad_b = matmul(transpose(grad_out), a)
```
---
src/relax/transform/gradient.cc | 9 +-
src/relax/transform/gradient_simplifier.cc | 200 +++++++++++++++++++++
src/relax/transform/gradient_simplifier.h | 46 +++++
tests/python/relax/test_transform_gradient.py | 51 ++++++
.../relax/test_transform_gradient_numeric.py | 37 ++++
5 files changed, 341 insertions(+), 2 deletions(-)
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 2bf858c98d..70e3e37876 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -36,6 +36,7 @@
#include "../op/tensor/binary.h"
#include "../op/tensor/create.h"
+#include "gradient_simplifier.h"
#include "utils.h"
namespace tvm {
@@ -648,12 +649,16 @@ class GradientMutator : private ExprMutator {
new_func = Downcast<Function>(RemoveAllUnused(new_func));
}
- // Step 4.3 mark the transformed function as public
+ // Step 4.3 Simplify specific patterns generated by the gradient pass.
Especially, simplify
+ // transpose + matmul patterns. For details see the document of
SimplifyGradient
+ new_func = SimplifyGradient(new_func);
+
+ // Step 4.4 mark the transformed function as public
// because the original function may be public, and have gsymbol attribute
as func_name
auto new_func_name = func_name + "_adjoint";
auto new_func_with_gsymbol = WithAttr(new_func, tvm::attr::kGlobalSymbol,
new_func_name);
- // Step 4.4 Add the transformed function to IRModule
+ // Step 4.5 Add the transformed function to IRModule
builder_->AddFunction(new_func_with_gsymbol, new_func_name);
return builder_->GetContextIRModule();
}
diff --git a/src/relax/transform/gradient_simplifier.cc
b/src/relax/transform/gradient_simplifier.cc
new file mode 100644
index 0000000000..966e8b7ad6
--- /dev/null
+++ b/src/relax/transform/gradient_simplifier.cc
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/gradient_simplifier.cc
+ * \brief Simplify patterns generated by the gradient pass. Only used in
gradient.cc.
+ * \sa tvm/relax/transform/gradient.cc
+ *
+ * We will simplify these patterns:
+ * (transpose means use permute_dims to transpose the last two dimensions)
+ * 1. Forward is: out = matmul(a, transpose(b))
+ * Then backward is:
+ * grad_a = matmul(grad_out, transpose(transpose(b)))
+ * grad_b = transpose(matmul(transpose(a), grad_out))
+ * We will simplify it to:
+ * grad_a = matmul(grad_out, b)
+ * grad_b = matmul(transpose(grad_out), a)
+ * 2. Forward is: out = matmul(transpose(a), b)
+ * Then backward is:
+ * grad_a = transpose(matmul(grad_out, transpose(b)))
+ * grad_b = matmul(transpose(transpose(a)), grad_out)
+ * We will simplify it to:
+ * grad_a = matmul(b, transpose(grad_out))
+ * grad_b = matmul(a, grad_out)
+ * 3. Forward is: out = matmul(transpose(a), transpose(b))
+ * Then backward is:
+ * grad_a = transpose(matmul(grad_out, transpose(transpose(b))))
+ * grad_b = transpose(matmul(transpose(transpose(a)), grad_out))
+ * We will simplify it to:
+ * grad_a = matmul(transpose(b), transpose(grad_out))
+ * grad_b = matmul(transpose(grad_out), transpose(a))
+ */
+
+#include "gradient_simplifier.h"
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/manipulate.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+
+#include "../op/tensor/linear_algebra.h"
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Simplify patterns generated by the gradient pass. Especially,
simplify the matmul
+ * patterns.
+ */
+class GradientSimplifier : private ExprMutator {
+ public:
+ /*!
+ * \brief Collect all variables that needs to be checkpointed, and remove
the start_checkpoint
+ * and the end_checkpoint bindings.
+ *
+ * \param func The original function
+ * \return The function with all start_checkpoint and end_checkpoint
bindings removed, and a
+ * VarIdSet containing all checkpointed vars.
+ */
+ static Function Transform(const Function& func) {
+ return
Downcast<Function>(RemoveAllUnused(GradientSimplifier().VisitExpr(func)));
+ }
+
+ private:
+ static bool IsTransposeOp(const CallNode* call_node) {
+ if (call_node->op != Op::Get("relax.permute_dims")) {
+ return false;
+ }
+ auto sinfo = MatchStructInfo<TensorStructInfo>(call_node->args[0]);
+ if (!sinfo) {
+ return false;
+ }
+ auto ndim = sinfo.value()->ndim;
+ if (ndim == kUnknownNDim || ndim == 1) {
+ return false;
+ }
+ if (!call_node->attrs.as<PermuteDimsAttrs>()->axes.defined()) {
+ return ndim == 2;
+ }
+ auto axes = call_node->attrs.as<PermuteDimsAttrs>()->axes.value();
+ ICHECK(static_cast<int>(axes.size()) == ndim);
+ for (int i = 0; i < ndim - 2; ++i) {
+ if (axes[i] != i) {
+ return false;
+ }
+ }
+ return axes[ndim - 2] == ndim - 1 && axes[ndim - 1] == ndim - 2;
+ }
+
+ // Return permute_dims(expr). Generate the axes needed.
+ static Expr GetTransposeOf(const Expr& expr) {
+ auto sinfo = MatchStructInfo<TensorStructInfo>(expr);
+ ICHECK(sinfo);
+ auto ndim = sinfo.value()->ndim;
+ if (ndim == 1) {
+ return expr;
+ }
+ auto axes = Array<Integer>();
+ for (int i = 0; i < ndim - 2; ++i) {
+ axes.push_back(i);
+ }
+ axes.push_back(ndim - 1);
+ axes.push_back(ndim - 2);
+ return permute_dims(expr, axes);
+ }
+
+ // If expr is already in the form of permute_dims in previous bindings,
return the input of the
+ // permute_dims op
+ // Else, return permute_dims(expr)
+ Expr GetTransposeAccordingToCtx(const Expr& expr) {
+ if (!expr->IsInstance<VarNode>()) {
+ return GetTransposeOf(expr);
+ }
+ auto prev_expr = builder_->LookupBinding(Downcast<Var>(expr));
+ if (!prev_expr || !prev_expr->IsInstance<CallNode>()) {
+ return GetTransposeOf(expr);
+ }
+ auto prev_call_node = prev_expr.as<CallNode>();
+ if (!IsTransposeOp(prev_call_node)) {
+ return GetTransposeOf(expr);
+ }
+ return prev_call_node->args[0];
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node)
{
+ auto result = ExprMutator::VisitExpr(GetRef<Expr>(call_node));
+ auto new_call_node = result.as<CallNode>();
+ auto reemit_and_return = [&]() {
+ ReEmitBinding(binding, result);
+ return;
+ };
+
+ if (!IsTransposeOp(new_call_node)) {
+ return reemit_and_return();
+ }
+
+ auto arg = new_call_node->args[0];
+ if (!arg->IsInstance<VarNode>()) {
+ return reemit_and_return();
+ }
+
+ auto prev_expr = builder_->LookupBinding(Downcast<Var>(arg));
+ if (!prev_expr || !prev_expr->IsInstance<CallNode>()) {
+ return reemit_and_return();
+ }
+
+ auto prev_call_node = prev_expr.as<CallNode>();
+ if (IsTransposeOp(prev_call_node)) {
+ // rewrite rule #1: permute_dims(permute_dims(a)) -> a
+ if (prev_call_node->args[0]->IsInstance<VarNode>()) {
+ var_remap_[binding->var->vid] = Downcast<Var>(prev_call_node->args[0]);
+ return;
+ } else {
+ return reemit_and_return();
+ }
+ } else if (prev_call_node->op == Op::Get("relax.matmul")) {
+ // rewrite rule #2: permute_dims(matmul(a, b)) ->
matmul(permute_dims(b), permute_dims(a))
+ // Should "a" or "b" already be in the form of "permute_dims", the
redundant permute_dims
+ // operation should be eliminated
+
+ // Skip matmuls with 1-dim input because in these cases we cannot simply
transpose the input
+ auto a_dim =
MatchStructInfo<TensorStructInfo>(prev_call_node->args[0]).value()->ndim;
+ auto b_dim =
MatchStructInfo<TensorStructInfo>(prev_call_node->args[1]).value()->ndim;
+ if (a_dim == 1 || b_dim == 1) {
+ return reemit_and_return();
+ }
+
+ auto a = GetTransposeAccordingToCtx(prev_call_node->args[0]);
+ auto b = GetTransposeAccordingToCtx(prev_call_node->args[1]);
+ result =
+ ExprMutator::VisitExpr(matmul(b, a,
prev_call_node->attrs.as<MatmulAttrs>()->out_dtype));
+ ReEmitBinding(binding, result);
+ return;
+ } else {
+ return reemit_and_return();
+ }
+ }
+};
+
+Function SimplifyGradient(const Function& func) { return
GradientSimplifier::Transform(func); }
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/gradient_simplifier.h
b/src/relax/transform/gradient_simplifier.h
new file mode 100644
index 0000000000..0c6faa0ef6
--- /dev/null
+++ b/src/relax/transform/gradient_simplifier.h
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/gradient_simplifier.h
+ * \brief Simplify patterns generated by the gradient pass. Especially,
simplify the matmul
+ * patterns.
+ * \sa tvm/relax/transform/gradient.cc
+ */
+
+#ifndef TVM_RELAX_TRANSFORM_GRADIENT_SIMPLIFIER_H_
+#define TVM_RELAX_TRANSFORM_GRADIENT_SIMPLIFIER_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Simplify patterns generated by the gradient pass. Especially,
simplify the matmul
+ * patterns.
+ * \param func The function to be simplified.
+ * \return The simplified function.
+ */
+Function SimplifyGradient(const Function& func);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_TRANSFORM_GRADIENT_SIMPLIFIER_H_
diff --git a/tests/python/relax/test_transform_gradient.py
b/tests/python/relax/test_transform_gradient.py
index c4e2f9d526..072edea5c4 100644
--- a/tests/python/relax/test_transform_gradient.py
+++ b/tests/python/relax/test_transform_gradient.py
@@ -922,6 +922,57 @@ def test_const():
assert_structural_equal(After, Expected)
+def test_simplify_matmul_pattern():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")):
+ with R.dataflow():
+ lv1 = R.permute_dims(x)
+ lv2 = R.permute_dims(y)
+ lv3 = R.matmul(lv1, lv2, out_dtype="float32")
+ gv = R.sum(lv3)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3,
3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"),
R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), dtype="float32") = R.permute_dims(x,
axes=None)
+ lv2: R.Tensor((3, 3), dtype="float32") = R.permute_dims(y,
axes=None)
+ lv3: R.Tensor((3, 3), dtype="float32") = R.matmul(lv1, lv2,
out_dtype="float32")
+ gv: R.Tensor((), dtype="float32") = R.sum(lv3, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
+ lv3_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+ lv: R.Tensor((3, 3), dtype="float32") =
R.permute_dims(lv3_adjoint, axes=[1, 0])
+ lv1_1: R.Tensor((3, 3), dtype="float32") = R.permute_dims(x,
axes=[1, 0])
+ y_adjoint: R.Tensor((3, 3), dtype="float32") = R.matmul(lv,
lv1_1, out_dtype="void")
+ lv2_1: R.Tensor((3, 3), dtype="float32") = R.permute_dims(y,
axes=[1, 0])
+ lv3_1: R.Tensor((3, 3), dtype="float32") =
R.permute_dims(lv3_adjoint, axes=[1, 0])
+ x_adjoint: R.Tensor((3, 3), dtype="float32") = R.matmul(lv2_1,
lv3_1, out_dtype="void")
+ x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint
+ y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint
+ R.output(gv, x_adjoint_out, y_adjoint_out)
+ return (gv, (x_adjoint_out, y_adjoint_out))
+ @R.function
+ def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), dtype="float32") = R.permute_dims(x,
axes=None)
+ lv2: R.Tensor((3, 3), dtype="float32") = R.permute_dims(y,
axes=None)
+ lv3: R.Tensor((3, 3), dtype="float32") = R.matmul(lv1, lv2,
out_dtype="float32")
+ gv: R.Tensor((), dtype="float32") = R.sum(lv3, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
def test_shape_expr():
# fmt: off
@I.ir_module
diff --git a/tests/python/relax/test_transform_gradient_numeric.py
b/tests/python/relax/test_transform_gradient_numeric.py
index 38a63406e8..27c0ffb565 100644
--- a/tests/python/relax/test_transform_gradient_numeric.py
+++ b/tests/python/relax/test_transform_gradient_numeric.py
@@ -186,5 +186,42 @@ def test_complex(target, dev):
check_numerical_grads(func, [i.numpy() for i in args], [i.numpy() for i in
grad])
[email protected]_targets("llvm")
+def test_matmul(target, dev):
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")):
+ with R.dataflow():
+ lv1 = R.matmul(x, y)
+ lv2 = R.permute_dims(x)
+ lv3 = R.matmul(lv2, y)
+ lv4 = R.permute_dims(y)
+ lv5 = R.matmul(x, lv4)
+ lv6 = R.permute_dims(x)
+ lv7 = R.permute_dims(y)
+ lv8 = R.matmul(lv6, lv7)
+ lv9 = lv1 + lv3 + lv5 + lv8
+ gv = R.sum(lv9)
+ R.output(gv)
+ return gv
+
+ After = relax.transform.Gradient("main")(Before)
+ args = []
+ for arg in After["main_adjoint"].params:
+ shape = [int(l) for l in arg.struct_info.shape]
+ args.append(rand("float32", *shape))
+
+ vm_before = _legalize_and_build(Before, target, dev)
+ vm_after = _legalize_and_build(After, target, dev)
+ _, grad = vm_after["main_adjoint"](*args)
+
+ def func(*inputs):
+ loss = vm_before["main"](*[tvm.nd.array(i) for i in inputs])
+ return loss.numpy()
+
+ check_numerical_grads(func, [i.numpy() for i in args], [i.numpy() for i in
grad])
+
+
if __name__ == "__main__":
tvm.testing.main()