This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8da3de1522 [Redo][Unity] Split DecomposeOpsForTraining into two steps
(#16465)
8da3de1522 is described below
commit 8da3de1522e2caa080c61a0dbe3faa9ab93bacf9
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Feb 6 08:02:40 2024 -0600
[Redo][Unity] Split DecomposeOpsForTraining into two steps (#16465)
* [Support] Add PackedFunc "tvm.support.regex_match"
This function should be used instead of `std::regex` within C++ call
sites, to avoid ABI incompatibilities with pytorch.
Currently, the pytorch wheels available through pip install use the
pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to user
the pre-C++11 ABI, this would cause breakages with dynamically-linked
LLVM environments.
Use of the `<regex>` header in TVM should be avoided, as its
implementation is not supported by gcc's dual ABI. This ABI
incompatibility results in runtime errors either when `std::regex` is
called from TVM, or when `std::regex` is called from pytorch,
depending on which library was loaded first. This restriction can be
removed when a version of pytorch compiled using `-DUSE_CXX11_ABI=1`
is available from PyPI.
[0] https://github.com/pytorch/pytorch/issues/51039
* [Redo][Unity] Split DecomposeOpsForTraining into two steps
This is a reapplication of https://github.com/apache/tvm/pull/15954,
after resolving the breakages that required reverting in
https://github.com/apache/tvm/pull/16442. The regex matching is now
implemented without the `#include <regex>` from the C++ stdlib, to
avoid ABI incompatibility with pytorch.
Prior to this commit, the `DecomposeOpsForTraining` transform directly
replaced `relax.nn.batch_norm` into more primitive relax operations.
This required the decomposed form of `relax.nn.batch_norm` to be
duplicated with `DecomposeOpsForInference`. This commit refactors the
pass to occur in two steps, first to apply training-specific
mutations, and then to decompose.
Having a clear `DecomposeOps` pass also has a clear single location
for operator decomposition, which may be migrated into the operator
definition in the future, similar to `FLegalize`.
---
include/tvm/ir/transform.h | 25 ++++
python/tvm/ir/transform.py | 43 ++++++
python/tvm/support.py | 45 +++++-
src/ir/transform.cc | 37 +++++
src/relax/transform/decompose_ops.cc | 156 ++++++++++-----------
tests/python/relax/test_transform_decompose_ops.py | 71 +++++-----
6 files changed, 255 insertions(+), 122 deletions(-)
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index ec151d9d75..adf3325250 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -525,6 +525,31 @@ TVM_DLL Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>&
pass_func, int opt_level,
String name, Array<runtime::String> required, bool traceable = false);
+/*
+ * \brief Utility to apply a pass to specific functions in an IRModule
+ *
+ * TVM uses IRModule to IRModule transformations at all stages of
+ * lowering. These transformations may be useful when hand-writing an
+ * optimized model, or to perform optimizations on specific kernels
+ * within an IRModule. This utility allows a pass to be applied to a
+ * specified function, without altering other functions in the module.
+ *
+ * \param pass The IRModule to IRModule pass to be applied.
+ *
+ * \param func_name_regex A regex used to select the functions to be
+ * updated. The pass will be applied to all functions whose name
+ * matches the regex.
+ *
+ * \param error_if_no_function_matches_regex Specifies the behavior if
+ * an IRModule does not contain any function matching the provided
+ * regex. If true, an error will be raised. If false (default),
+ * the IRModule will be returned unmodified.
+ *
+ * \return The modified IRModule to IRModule pass.
+ */
+TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex,
+ bool error_if_no_function_matches_regex =
false);
+
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index c8e9f5dc76..b2937acaa7 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -445,3 +445,46 @@ def PrintIR(header="", show_meta_data=False):
The pass
"""
return _ffi_transform_api.PrintIR(header, show_meta_data)
+
+
+def ApplyPassToFunction(
+ transform: Pass,
+ func_name_regex: str,
+ error_if_no_function_matches_regex: bool = False,
+) -> Pass:
+ """Utility to apply a pass to specific functions in an IRModule
+
+ TVM uses IRModule to IRModule transformations at all stages of
+ lowering. These transformations may be useful when hand-writing an
+ optimized model, or to perform optimizations on specific kernels
+ within an IRModule. This utility allows a pass to be applied to a
+ specified function, without altering other functions in the module.
+
+ Parameters
+ ----------
+ transform: Pass
+
+ The IRModule to IRModule pass to be applied.
+
+ func_name_regex: str
+
+ A regex used to select the functions to be updated. The pass
+ will be applied to all functions whose name matches the regex.
+
+ error_if_no_function_matches_regex: bool
+
+ Specifies the behavior if an IRModule does not contain any
+ function matching the provided regex. If true, an error will
+ be raised. If false (default), the IRModule will be returned
+ unmodified.
+
+ Returns
+ -------
+ new_transform: Pass
+
+ The modified IRModule to IRModule pass.
+
+ """
+ return _ffi_transform_api.ApplyPassToFunction(
+ transform, func_name_regex, error_if_no_function_matches_regex
+ )
diff --git a/python/tvm/support.py b/python/tvm/support.py
index ccd6f59e32..4fa95fac89 100644
--- a/python/tvm/support.py
+++ b/python/tvm/support.py
@@ -19,6 +19,7 @@ import json
import textwrap
import ctypes
import os
+import re
import sys
import tvm
@@ -26,6 +27,8 @@ import tvm._ffi
from .runtime.module import Module
from . import get_global_func
+tvm._ffi._init_api("support", __name__)
+
def libinfo():
"""Returns a dictionary containing compile-time info, including cmake
flags and git commit hash
@@ -87,4 +90,44 @@ class FrontendTestModule(Module):
self.add_function(key, value)
-tvm._ffi._init_api("support", __name__)
+@tvm._ffi.register_func("tvm.support.regex_match")
+def _regex_match(regex_pattern: str, match_against: str) -> bool:
+ """Check if a pattern matches a regular expression
+
+ This function should be used instead of `std::regex` within C++
+ call sites, to avoid ABI incompatibilities with pytorch.
+
+ Currently, the pytorch wheels available through pip install use
+ the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
+ user the pre-C++11 ABI, this would cause breakages with
+ dynamically-linked LLVM environments.
+
+ Use of the `<regex>` header in TVM should be avoided, as its
+ implementation is not supported by gcc's dual ABI. This ABI
+ incompatibility results in runtime errors either when `std::regex`
+ is called from TVM, or when `std::regex` is called from pytorch,
+ depending on which library was loaded first. This restriction can
+ be removed when a version of pytorch compiled using
+ `-DUSE_CXX11_ABI=1` is available from PyPI.
+
+ [0] https://github.com/pytorch/pytorch/issues/51039
+
+ Parameters
+ ----------
+ regex_pattern: str
+
+ The regular expression
+
+ match_against: str
+
+ The string against which to match the regular expression
+
+ Returns
+ -------
+ match_result: bool
+
+ True if `match_against` matches the pattern defined by
+ `regex_pattern`, and False otherwise.
+ """
+ match = re.match(regex_pattern, match_against)
+ return match is not None
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index f838120943..766bd28875 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -531,6 +531,43 @@ Pass CreateModulePass(const
runtime::TypedPackedFunc<IRModule(IRModule, PassCont
return ModulePass(pass_func, pass_info);
}
+Pass ApplyPassToFunction(Pass pass, String func_name_regex,
+ bool error_if_no_function_matches_regex) {
+ auto pass_name =
+ static_cast<const std::stringstream&>(std::stringstream() <<
"ApplyPassTo" << func_name_regex)
+ .str();
+
+ auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) ->
IRModule {
+ const auto* regex_match_func =
tvm::runtime::Registry::Get("tvm.support.regex_match");
+ CHECK(regex_match_func)
+ << "RuntimeError: "
+ << "The PackedFunc 'tvm.support.regex_match' has not been registered.
"
+ << "This can occur if the TVM Python library has not yet been
imported.";
+
+ IRModule subset;
+
+ for (const auto& [gvar, func] : mod->functions) {
+ std::string name = gvar->name_hint;
+ if ((*regex_match_func)(func_name_regex, name)) {
+ subset->Add(gvar, func);
+ }
+ }
+
+ if (subset->functions.size()) {
+ IRModule new_subset = pass(subset);
+ if (!new_subset.same_as(subset)) {
+ mod.CopyOnWrite()->Update(new_subset);
+ }
+ }
+
+ return mod;
+ };
+
+ return CreateModulePass(pass_func, 0, pass_name, {});
+}
+
+TVM_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction);
+
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("transform.PassInfo")
diff --git a/src/relax/transform/decompose_ops.cc
b/src/relax/transform/decompose_ops.cc
index 899c80c1c4..1a4cd21625 100644
--- a/src/relax/transform/decompose_ops.cc
+++ b/src/relax/transform/decompose_ops.cc
@@ -48,7 +48,7 @@ Expr ExpandToMatchInput(Expr data, int ndim, Array<Integer>
axes) {
return expand_dims(data, expand_axes);
}
-Tuple SimplifyBatchNormInference(const Call& call) {
+Tuple DecomposeBatchNorm(const Call& call) {
auto attrs = call->attrs.as<BatchNormAttrs>();
ICHECK_NOTNULL(attrs);
@@ -75,14 +75,18 @@ Tuple SimplifyBatchNormInference(const Call& call) {
return Tuple({out, call->args[3], call->args[4]});
}
-Tuple SimplifyBatchNormTraining(const Call& call) {
+Expr MutateBatchNormForTraining(Call call) {
auto attrs = call->attrs.as<BatchNormAttrs>();
ICHECK_NOTNULL(attrs);
+ ICHECK_EQ(call->args.size(), 5);
Expr data = call->args[0];
- TensorStructInfo sinfo = MatchTensorStructInfo(data);
Expr gamma = call->args[1];
Expr beta = call->args[2];
+ Expr moving_mean = call->args[3];
+ Expr moving_var = call->args[4];
+
+ TensorStructInfo sinfo = MatchTensorStructInfo(data);
Array<Integer> reduce_axes;
for (int i = 0; i < sinfo->ndim; ++i) {
@@ -92,35 +96,21 @@ Tuple SimplifyBatchNormTraining(const Call& call) {
}
Expr data_mean = mean(data, reduce_axes, false);
- Expr data_mean_rs = ExpandToMatchInput(data_mean, sinfo->ndim,
{attrs->axis});
Expr data_var = variance(data, reduce_axes, false);
- Expr data_var_rs = ExpandToMatchInput(data_var, sinfo->ndim, {attrs->axis});
-
- // output = (x - mean) / sqrt(var + epsilon) * gamma + beta
- Expr epsilon = MakeConstantScalar(attrs->epsilon, sinfo->dtype);
- Expr sqrt_var = sqrt(add(data_var_rs, epsilon));
- Expr out = divide(subtract(data, data_mean_rs), sqrt_var);
- if (attrs->scale) {
- out = multiply(out, ExpandToMatchInput(gamma, sinfo->ndim, {attrs->axis}));
- }
- if (attrs->center) {
- out = add(out, ExpandToMatchInput(beta, sinfo->ndim, {attrs->axis}));
- }
-
- Expr moving_mean = call->args[3];
- Expr moving_var = call->args[4];
Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype);
Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype);
- return Tuple({
- out,
- add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)),
- add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)),
- });
+ Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean),
multiply(momentum, data_mean));
+ Expr new_moving_var = add(multiply(one_minus_mom, moving_var),
multiply(momentum, data_var));
+
+ call.CopyOnWrite()->args = {data, gamma, beta, data_mean, data_var};
+ // return call;
+
+ return relax::Tuple({TupleGetItem(call, 0), new_moving_mean,
new_moving_var});
}
-Expr SimplifyLayerNorm(const Call& call) {
+Expr DecomposeLayerNorm(const Call& call) {
auto attrs = call->attrs.as<LayerNormAttrs>();
ICHECK_NOTNULL(attrs);
@@ -172,92 +162,92 @@ Expr TensorToShape(const Call& call_node, const
BlockBuilder& builder) {
return ShapeExpr(shape_var);
}
-class OpDecomposer : public ExprMutator {
- public:
- constexpr static const char* kModeInference = "inference";
- constexpr static const char* kModeTraining = "training";
+/*! \brief Update operators that have a training-specific form
+ *
+ * Some operators, such as relax.op.batch_norm, need additional
+ * processing when being run for training. This mutator applies any mutations
required
+ */
+class TrainingOperatorMutator : public ExprMutator {
+ private:
+ using ExprMutator::VisitExpr_;
- explicit OpDecomposer(String mode) : ExprMutator(), mode_(mode) {
- CHECK(mode == kModeInference || mode == kModeTraining)
- << "The argument mode must be one of the following values:
\"inference\", \"training\".";
+ Expr VisitExpr_(const CallNode* call_node) final {
+ Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
+ if (call->op == batch_norm_op_) {
+ return MutateBatchNormForTraining(call);
+ } else if (call->op == layer_norm_op_) {
+ // Here we only decompose LayerNorm in training because it is more
efficient as a single op.
+ // In the future maybe we can also remove this decomposition during
training.
+ return DecomposeLayerNorm(call);
+ } else {
+ return call;
+ }
}
+ /* composite opeartor list */
+ const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
+ const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
+};
+
+class OpDecomposer : public ExprMutator {
private:
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call_node) final {
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
if (call->op == batch_norm_op_) {
- if (mode_ == kModeInference) {
- return SimplifyBatchNormInference(call);
- } else {
- ICHECK_EQ(mode_, kModeTraining);
- return SimplifyBatchNormTraining(call);
- }
- } else if (call->op == layer_norm_op_ && mode_ == kModeTraining) {
- // Here we only decompose LayerNorm in training because it is more
efficient as a single op.
- // In the future maybe we can also remove this decomposition during
training.
- return SimplifyLayerNorm(call);
+ return DecomposeBatchNorm(call);
} else if (call->op == tensor_to_shape_op_) {
return TensorToShape(call, builder_);
}
return call;
}
- const String mode_;
-
/* composite opeartor list */
const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
- const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
};
-IRModule Decompose(IRModule mod, Optional<String> func_name, String mode) {
- auto op_decomposer = OpDecomposer(mode);
-
- IRModuleNode* new_module = mod.CopyOnWrite();
+namespace transform {
- if (!func_name.defined()) { // simplify all functions
- Map<GlobalVar, BaseFunc> functions = mod->functions;
- for (const auto& func_pr : functions) {
- if (const auto* relax_f = func_pr.second.as<FunctionNode>()) {
- Function f =
Downcast<Function>(op_decomposer(GetRef<Function>(relax_f)));
- new_module->Update(func_pr.first, f);
- }
- }
- } else { // simplify specified function
- auto* func_ptr = mod->Lookup(func_name.value()).as<FunctionNode>();
- CHECK(func_ptr) << func_name.value() << "is not a Relax Function";
- auto gvar = mod->GetGlobalVar(func_name.value());
- auto func = GetRef<Function>(func_ptr);
- func = Downcast<Function>(op_decomposer(func));
- new_module->Update(gvar, func);
- }
+Pass MutateOpsForTraining() {
+ auto pass_func = [](Function func, IRModule, PassContext) -> Function {
+ TrainingOperatorMutator mutator;
+ return Downcast<Function>(mutator(func));
+ };
+ return CreateFunctionPass(/*pass_function=*/pass_func,
+ /*opt_level=*/0,
+ /*pass_name=*/"MutateOpsForTraining",
+ /*required=*/{});
+}
- return GetRef<IRModule>(new_module);
+Pass DecomposeOps() {
+ auto pass_func = [](Function func, IRModule, PassContext) -> Function {
+ OpDecomposer mutator;
+ return Downcast<Function>(mutator(func));
+ };
+ return CreateFunctionPass(/*pass_function=*/pass_func,
+ /*opt_level=*/0,
+ /*pass_name=*/"DecomposeOps",
+ /*required=*/{});
}
-namespace transform {
Pass DecomposeOpsForInference(Optional<String> func_name) {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
-
PassContext pc) {
- return Decompose(mod, func_name, OpDecomposer::kModeInference);
- };
- return CreateModulePass(/*pass_function=*/pass_func,
- /*opt_level=*/0,
- /*pass_name=*/"DecomposeOpsForInference",
- /*required=*/{});
+ if (func_name) {
+ return ApplyPassToFunction(DecomposeOps(), func_name.value());
+ } else {
+ return DecomposeOps();
+ }
}
Pass DecomposeOpsForTraining(Optional<String> func_name) {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
-
PassContext pc) {
- return Decompose(mod, func_name, OpDecomposer::kModeTraining);
- };
- return CreateModulePass(/*pass_function=*/pass_func,
- /*opt_level=*/0,
- /*pass_name=*/"DecomposeOpsForTraining",
- /*required=*/{});
+ auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(),
DecomposeOps()},
+ "DecomposeOpsForTraining");
+ if (func_name) {
+ return ApplyPassToFunction(module_pass, func_name.value());
+ } else {
+ return module_pass;
+ }
}
TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference")
diff --git a/tests/python/relax/test_transform_decompose_ops.py
b/tests/python/relax/test_transform_decompose_ops.py
index 85657ab245..4e5bcb82e9 100644
--- a/tests/python/relax/test_transform_decompose_ops.py
+++ b/tests/python/relax/test_transform_decompose_ops.py
@@ -137,44 +137,39 @@ def test_batch_norm_training():
R.Tensor((64,), dtype="float32"),
):
with R.dataflow():
- lv: R.Tensor((64,), dtype="float32") = R.mean(x, axis=[0, 2,
3], keepdims=False)
- lv1: R.Tensor((1, 64, 1, 1), dtype="float32") =
R.expand_dims(lv, axis=[0, 2, 3])
- lv2: R.Tensor((1, 64, 112, 112), dtype="float32") =
R.subtract(x, lv1)
- lv3: R.Tensor((64,), dtype="float32") = R.variance(
- x, axis=[0, 2, 3], keepdims=False
- )
- lv4: R.Tensor((1, 64, 1, 1), dtype="float32") =
R.expand_dims(lv3, axis=[0, 2, 3])
- lv5: R.Tensor((1, 64, 1, 1), dtype="float32") = R.add(
- lv4, R.const(9.9999997473787516e-06, "float32")
- )
- lv6: R.Tensor((1, 64, 1, 1), dtype="float32") = R.sqrt(lv5)
- lv7: R.Tensor((1, 64, 112, 112), dtype="float32") =
R.divide(lv2, lv6)
- lv8: R.Tensor((1, 64, 1, 1), dtype="float32") =
R.expand_dims(gamma, axis=[0, 2, 3])
- lv9: R.Tensor((1, 64, 112, 112), dtype="float32") =
R.multiply(lv7, lv8)
- lv10: R.Tensor((1, 64, 1, 1), dtype="float32") =
R.expand_dims(beta, axis=[0, 2, 3])
- lv11: R.Tensor((1, 64, 112, 112), dtype="float32") =
R.add(lv9, lv10)
- lv12: R.Tensor((64,), dtype="float32") = R.multiply(
- R.const(0.89999997615814209, "float32"), moving_mean
- )
- lv13: R.Tensor((64,), dtype="float32") = R.multiply(
- R.const(0.10000000149011612, "float32"), lv
- )
- lv14: R.Tensor((64,), dtype="float32") = R.add(lv12, lv13)
- lv15: R.Tensor((64,), dtype="float32") = R.multiply(
- R.const(0.89999997615814209, "float32"), moving_var
- )
- lv16: R.Tensor((64,), dtype="float32") = R.multiply(
- R.const(0.10000000149011612, "float32"), lv3
- )
- lv17: R.Tensor((64,), dtype="float32") = R.add(lv15, lv16)
- bn: R.Tuple(
- R.Tensor((1, 64, 112, 112), dtype="float32"),
- R.Tensor((64,), dtype="float32"),
- R.Tensor((64,), dtype="float32"),
- ) = (lv11, lv14, lv17)
- gv0: R.Tensor((1, 64, 112, 112), dtype="float32") = bn[0]
- gv1: R.Tensor((64,), dtype="float32") = bn[1]
- gv2: R.Tensor((64,), dtype="float32") = bn[2]
+ # This portion is training-specific, computing the
+ # mean/variance of the dataset.
+ lv = R.mean(x, axis=[0, 2, 3], keepdims=False)
+ lv3 = R.variance(x, axis=[0, 2, 3], keepdims=False)
+
+ # This portion is identical to the batch_norm run during
inference
+ lv1 = R.expand_dims(lv, axis=[0, 2, 3])
+ lv2 = R.subtract(x, lv1)
+ lv4 = R.expand_dims(lv3, axis=[0, 2, 3])
+ lv5 = R.add(lv4, R.const(9.9999997473787516e-06, "float32"))
+ lv6 = R.sqrt(lv5)
+ lv7 = R.divide(lv2, lv6)
+ lv8 = R.expand_dims(gamma, axis=[0, 2, 3])
+ lv9 = R.multiply(lv7, lv8)
+ lv10 = R.expand_dims(beta, axis=[0, 2, 3])
+ lv11 = R.add(lv9, lv10)
+ inner_tuple = (lv11, lv, lv3)
+ # This is the result that would be returned from a
+ # batch_norm at inference.
+
+ # However, at training we need to update the moving
+ # mean/variance, and to return those updated values.
+ inner_res = inner_tuple[0]
+ lv12 = R.multiply(R.const(0.89999997615814209, "float32"),
moving_mean)
+ lv13 = R.multiply(R.const(0.10000000149011612, "float32"), lv)
+ lv14 = R.add(lv12, lv13)
+ lv15 = R.multiply(R.const(0.89999997615814209, "float32"),
moving_var)
+ lv16 = R.multiply(R.const(0.10000000149011612, "float32"), lv3)
+ lv17 = R.add(lv15, lv16)
+ bn = (inner_res, lv14, lv17)
+ gv0 = bn[0]
+ gv1 = bn[1]
+ gv2 = bn[2]
R.output(gv0, gv1, gv2)
return (gv0, gv1, gv2)