This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new a2a1b53402 [Unity] Split DecomposeOpsForTraining into two steps 
(#15954)
a2a1b53402 is described below

commit a2a1b534024e677a9407bec8d09d99f2237b2b0b
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Jan 16 08:08:40 2024 -0600

    [Unity] Split DecomposeOpsForTraining into two steps (#15954)
    
    * [Unity] Split DecomposeOpsForTraining into two steps
    
    Prior to this commit, the `DecomposeOpsForTraining` transform directly
    replaced `relax.nn.batch_norm` into more primitive relax operations.
    This required the decomposed form of `relax.nn.batch_norm` to be
    duplicated with `DecomposeOpsForInference`.  This commit refactors the
    pass to occur in two steps, first to apply training-specific
    mutations, and then to decompose.
    
    Having a clear `DecomposeOps` pass also has a clear single location
    for operator decomposition, which may be migrated into the operator
    definition in the future, similar to `FLegalize`.
    
    * Updated ApplyPassToFunction utility to use a regex
---
 include/tvm/ir/transform.h                         |  25 ++++
 src/ir/transform.cc                                |  31 ++++
 src/relax/transform/decompose_ops.cc               | 156 ++++++++++-----------
 tests/python/relax/test_transform_decompose_ops.py |  71 +++++-----
 4 files changed, 162 insertions(+), 121 deletions(-)

diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index ec151d9d75..adf3325250 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -525,6 +525,31 @@ TVM_DLL Pass CreateModulePass(
     const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& 
pass_func, int opt_level,
     String name, Array<runtime::String> required, bool traceable = false);
 
+/*
+ * \brief Utility to apply a pass to specific functions in an IRModule
+ *
+ * TVM uses IRModule to IRModule transformations at all stages of
+ * lowering.  These transformations may be useful when hand-writing an
+ * optimized model, or to perform optimizations on specific kernels
+ * within an IRModule.  This utility allows a pass to be applied to a
+ * specified function, without altering other functions in the module.
+ *
+ * \param pass The IRModule to IRModule pass to be applied.
+ *
+ * \param func_name_regex A regex used to select the functions to be
+ * updated.  The pass will be applied to all functions whose name
+ * matches the regex.
+ *
+ * \param error_if_no_function_matches_regex Specifies the behavior if
+ *     an IRModule does not contain any function matching the provided
+ *     regex.  If true, an error will be raised.  If false (default),
+ *     the IRModule will be returned unmodified.
+ *
+ * \return The modified IRModule to IRModule pass.
+ */
+TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex,
+                                 bool error_if_no_function_matches_regex = 
false);
+
 /*!
  * \brief A special trace pass that prints the header and IR to LOG(INFO).
  * \param header The header to be attached to the output.
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index f838120943..3bae6be9ba 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -31,6 +31,7 @@
 
 #include <chrono>
 #include <iomanip>
+#include <regex>
 #include <stack>
 #include <unordered_set>
 
@@ -531,6 +532,36 @@ Pass CreateModulePass(const 
runtime::TypedPackedFunc<IRModule(IRModule, PassCont
   return ModulePass(pass_func, pass_info);
 }
 
+Pass ApplyPassToFunction(Pass pass, String func_name_regex,
+                         bool error_if_no_function_matches_regex) {
+  auto pass_name =
+      static_cast<const std::stringstream&>(std::stringstream() << 
"ApplyPassTo" << func_name_regex)
+          .str();
+  std::regex regex(func_name_regex.operator std::string());
+
+  auto pass_func = [pass, regex](IRModule mod, PassContext) -> IRModule {
+    IRModule subset;
+
+    for (const auto& [gvar, func] : mod->functions) {
+      std::string name = gvar->name_hint;
+      if (std::regex_match(name, regex)) {
+        subset->Add(gvar, func);
+      }
+    }
+
+    if (subset->functions.size()) {
+      IRModule new_subset = pass(subset);
+      if (!new_subset.same_as(subset)) {
+        mod.CopyOnWrite()->Update(new_subset);
+      }
+    }
+
+    return mod;
+  };
+
+  return CreateModulePass(pass_func, 0, pass_name, {});
+}
+
 TVM_REGISTER_NODE_TYPE(PassInfoNode);
 
 TVM_REGISTER_GLOBAL("transform.PassInfo")
diff --git a/src/relax/transform/decompose_ops.cc 
b/src/relax/transform/decompose_ops.cc
index 899c80c1c4..1a4cd21625 100644
--- a/src/relax/transform/decompose_ops.cc
+++ b/src/relax/transform/decompose_ops.cc
@@ -48,7 +48,7 @@ Expr ExpandToMatchInput(Expr data, int ndim, Array<Integer> 
axes) {
   return expand_dims(data, expand_axes);
 }
 
-Tuple SimplifyBatchNormInference(const Call& call) {
+Tuple DecomposeBatchNorm(const Call& call) {
   auto attrs = call->attrs.as<BatchNormAttrs>();
   ICHECK_NOTNULL(attrs);
 
@@ -75,14 +75,18 @@ Tuple SimplifyBatchNormInference(const Call& call) {
   return Tuple({out, call->args[3], call->args[4]});
 }
 
-Tuple SimplifyBatchNormTraining(const Call& call) {
+Expr MutateBatchNormForTraining(Call call) {
   auto attrs = call->attrs.as<BatchNormAttrs>();
   ICHECK_NOTNULL(attrs);
 
+  ICHECK_EQ(call->args.size(), 5);
   Expr data = call->args[0];
-  TensorStructInfo sinfo = MatchTensorStructInfo(data);
   Expr gamma = call->args[1];
   Expr beta = call->args[2];
+  Expr moving_mean = call->args[3];
+  Expr moving_var = call->args[4];
+
+  TensorStructInfo sinfo = MatchTensorStructInfo(data);
 
   Array<Integer> reduce_axes;
   for (int i = 0; i < sinfo->ndim; ++i) {
@@ -92,35 +96,21 @@ Tuple SimplifyBatchNormTraining(const Call& call) {
   }
 
   Expr data_mean = mean(data, reduce_axes, false);
-  Expr data_mean_rs = ExpandToMatchInput(data_mean, sinfo->ndim, 
{attrs->axis});
   Expr data_var = variance(data, reduce_axes, false);
-  Expr data_var_rs = ExpandToMatchInput(data_var, sinfo->ndim, {attrs->axis});
-
-  // output = (x - mean) / sqrt(var + epsilon) * gamma + beta
-  Expr epsilon = MakeConstantScalar(attrs->epsilon, sinfo->dtype);
-  Expr sqrt_var = sqrt(add(data_var_rs, epsilon));
-  Expr out = divide(subtract(data, data_mean_rs), sqrt_var);
 
-  if (attrs->scale) {
-    out = multiply(out, ExpandToMatchInput(gamma, sinfo->ndim, {attrs->axis}));
-  }
-  if (attrs->center) {
-    out = add(out, ExpandToMatchInput(beta, sinfo->ndim, {attrs->axis}));
-  }
-
-  Expr moving_mean = call->args[3];
-  Expr moving_var = call->args[4];
   Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype);
   Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype);
 
-  return Tuple({
-      out,
-      add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)),
-      add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)),
-  });
+  Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean), 
multiply(momentum, data_mean));
+  Expr new_moving_var = add(multiply(one_minus_mom, moving_var), 
multiply(momentum, data_var));
+
+  call.CopyOnWrite()->args = {data, gamma, beta, data_mean, data_var};
+  // return call;
+
+  return relax::Tuple({TupleGetItem(call, 0), new_moving_mean, 
new_moving_var});
 }
 
-Expr SimplifyLayerNorm(const Call& call) {
+Expr DecomposeLayerNorm(const Call& call) {
   auto attrs = call->attrs.as<LayerNormAttrs>();
   ICHECK_NOTNULL(attrs);
 
@@ -172,92 +162,92 @@ Expr TensorToShape(const Call& call_node, const 
BlockBuilder& builder) {
   return ShapeExpr(shape_var);
 }
 
-class OpDecomposer : public ExprMutator {
- public:
-  constexpr static const char* kModeInference = "inference";
-  constexpr static const char* kModeTraining = "training";
+/*! \brief Update operators that have a training-specific form
+ *
+ * Some operators, such as relax.op.batch_norm, need additional
+ * processing when being run for training.  This mutator applies any mutations 
required
+ */
+class TrainingOperatorMutator : public ExprMutator {
+ private:
+  using ExprMutator::VisitExpr_;
 
-  explicit OpDecomposer(String mode) : ExprMutator(), mode_(mode) {
-    CHECK(mode == kModeInference || mode == kModeTraining)
-        << "The argument mode must be one of the following values: 
\"inference\", \"training\".";
+  Expr VisitExpr_(const CallNode* call_node) final {
+    Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
+    if (call->op == batch_norm_op_) {
+      return MutateBatchNormForTraining(call);
+    } else if (call->op == layer_norm_op_) {
+      // Here we only decompose LayerNorm in training because it is more 
efficient as a single op.
+      // In the future maybe we can also remove this decomposition during 
training.
+      return DecomposeLayerNorm(call);
+    } else {
+      return call;
+    }
   }
 
+  /* composite opeartor list */
+  const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
+  const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
+};
+
+class OpDecomposer : public ExprMutator {
  private:
   using ExprMutator::VisitExpr_;
 
   Expr VisitExpr_(const CallNode* call_node) final {
     Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
     if (call->op == batch_norm_op_) {
-      if (mode_ == kModeInference) {
-        return SimplifyBatchNormInference(call);
-      } else {
-        ICHECK_EQ(mode_, kModeTraining);
-        return SimplifyBatchNormTraining(call);
-      }
-    } else if (call->op == layer_norm_op_ && mode_ == kModeTraining) {
-      // Here we only decompose LayerNorm in training because it is more 
efficient as a single op.
-      // In the future maybe we can also remove this decomposition during 
training.
-      return SimplifyLayerNorm(call);
+      return DecomposeBatchNorm(call);
     } else if (call->op == tensor_to_shape_op_) {
       return TensorToShape(call, builder_);
     }
     return call;
   }
 
-  const String mode_;
-
   /* composite opeartor list */
   const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
-  const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
   const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
 };
 
-IRModule Decompose(IRModule mod, Optional<String> func_name, String mode) {
-  auto op_decomposer = OpDecomposer(mode);
-
-  IRModuleNode* new_module = mod.CopyOnWrite();
+namespace transform {
 
-  if (!func_name.defined()) {  // simplify all functions
-    Map<GlobalVar, BaseFunc> functions = mod->functions;
-    for (const auto& func_pr : functions) {
-      if (const auto* relax_f = func_pr.second.as<FunctionNode>()) {
-        Function f = 
Downcast<Function>(op_decomposer(GetRef<Function>(relax_f)));
-        new_module->Update(func_pr.first, f);
-      }
-    }
-  } else {  // simplify specified function
-    auto* func_ptr = mod->Lookup(func_name.value()).as<FunctionNode>();
-    CHECK(func_ptr) << func_name.value() << "is not a Relax Function";
-    auto gvar = mod->GetGlobalVar(func_name.value());
-    auto func = GetRef<Function>(func_ptr);
-    func = Downcast<Function>(op_decomposer(func));
-    new_module->Update(gvar, func);
-  }
+Pass MutateOpsForTraining() {
+  auto pass_func = [](Function func, IRModule, PassContext) -> Function {
+    TrainingOperatorMutator mutator;
+    return Downcast<Function>(mutator(func));
+  };
+  return CreateFunctionPass(/*pass_function=*/pass_func,
+                            /*opt_level=*/0,
+                            /*pass_name=*/"MutateOpsForTraining",
+                            /*required=*/{});
+}
 
-  return GetRef<IRModule>(new_module);
+Pass DecomposeOps() {
+  auto pass_func = [](Function func, IRModule, PassContext) -> Function {
+    OpDecomposer mutator;
+    return Downcast<Function>(mutator(func));
+  };
+  return CreateFunctionPass(/*pass_function=*/pass_func,
+                            /*opt_level=*/0,
+                            /*pass_name=*/"DecomposeOps",
+                            /*required=*/{});
 }
 
-namespace transform {
 Pass DecomposeOpsForInference(Optional<String> func_name) {
-  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
-                                                                            
PassContext pc) {
-    return Decompose(mod, func_name, OpDecomposer::kModeInference);
-  };
-  return CreateModulePass(/*pass_function=*/pass_func,
-                          /*opt_level=*/0,
-                          /*pass_name=*/"DecomposeOpsForInference",
-                          /*required=*/{});
+  if (func_name) {
+    return ApplyPassToFunction(DecomposeOps(), func_name.value());
+  } else {
+    return DecomposeOps();
+  }
 }
 
 Pass DecomposeOpsForTraining(Optional<String> func_name) {
-  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
-                                                                            
PassContext pc) {
-    return Decompose(mod, func_name, OpDecomposer::kModeTraining);
-  };
-  return CreateModulePass(/*pass_function=*/pass_func,
-                          /*opt_level=*/0,
-                          /*pass_name=*/"DecomposeOpsForTraining",
-                          /*required=*/{});
+  auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(), 
DecomposeOps()},
+                                                "DecomposeOpsForTraining");
+  if (func_name) {
+    return ApplyPassToFunction(module_pass, func_name.value());
+  } else {
+    return module_pass;
+  }
 }
 
 TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference")
diff --git a/tests/python/relax/test_transform_decompose_ops.py 
b/tests/python/relax/test_transform_decompose_ops.py
index 85657ab245..4e5bcb82e9 100644
--- a/tests/python/relax/test_transform_decompose_ops.py
+++ b/tests/python/relax/test_transform_decompose_ops.py
@@ -137,44 +137,39 @@ def test_batch_norm_training():
             R.Tensor((64,), dtype="float32"),
         ):
             with R.dataflow():
-                lv: R.Tensor((64,), dtype="float32") = R.mean(x, axis=[0, 2, 
3], keepdims=False)
-                lv1: R.Tensor((1, 64, 1, 1), dtype="float32") = 
R.expand_dims(lv, axis=[0, 2, 3])
-                lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = 
R.subtract(x, lv1)
-                lv3: R.Tensor((64,), dtype="float32") = R.variance(
-                    x, axis=[0, 2, 3], keepdims=False
-                )
-                lv4: R.Tensor((1, 64, 1, 1), dtype="float32") = 
R.expand_dims(lv3, axis=[0, 2, 3])
-                lv5: R.Tensor((1, 64, 1, 1), dtype="float32") = R.add(
-                    lv4, R.const(9.9999997473787516e-06, "float32")
-                )
-                lv6: R.Tensor((1, 64, 1, 1), dtype="float32") = R.sqrt(lv5)
-                lv7: R.Tensor((1, 64, 112, 112), dtype="float32") = 
R.divide(lv2, lv6)
-                lv8: R.Tensor((1, 64, 1, 1), dtype="float32") = 
R.expand_dims(gamma, axis=[0, 2, 3])
-                lv9: R.Tensor((1, 64, 112, 112), dtype="float32") = 
R.multiply(lv7, lv8)
-                lv10: R.Tensor((1, 64, 1, 1), dtype="float32") = 
R.expand_dims(beta, axis=[0, 2, 3])
-                lv11: R.Tensor((1, 64, 112, 112), dtype="float32") = 
R.add(lv9, lv10)
-                lv12: R.Tensor((64,), dtype="float32") = R.multiply(
-                    R.const(0.89999997615814209, "float32"), moving_mean
-                )
-                lv13: R.Tensor((64,), dtype="float32") = R.multiply(
-                    R.const(0.10000000149011612, "float32"), lv
-                )
-                lv14: R.Tensor((64,), dtype="float32") = R.add(lv12, lv13)
-                lv15: R.Tensor((64,), dtype="float32") = R.multiply(
-                    R.const(0.89999997615814209, "float32"), moving_var
-                )
-                lv16: R.Tensor((64,), dtype="float32") = R.multiply(
-                    R.const(0.10000000149011612, "float32"), lv3
-                )
-                lv17: R.Tensor((64,), dtype="float32") = R.add(lv15, lv16)
-                bn: R.Tuple(
-                    R.Tensor((1, 64, 112, 112), dtype="float32"),
-                    R.Tensor((64,), dtype="float32"),
-                    R.Tensor((64,), dtype="float32"),
-                ) = (lv11, lv14, lv17)
-                gv0: R.Tensor((1, 64, 112, 112), dtype="float32") = bn[0]
-                gv1: R.Tensor((64,), dtype="float32") = bn[1]
-                gv2: R.Tensor((64,), dtype="float32") = bn[2]
+                # This portion is training-specific, computing the
+                # mean/variance of the dataset.
+                lv = R.mean(x, axis=[0, 2, 3], keepdims=False)
+                lv3 = R.variance(x, axis=[0, 2, 3], keepdims=False)
+
+                # This portion is identical to the batch_norm run during 
inference
+                lv1 = R.expand_dims(lv, axis=[0, 2, 3])
+                lv2 = R.subtract(x, lv1)
+                lv4 = R.expand_dims(lv3, axis=[0, 2, 3])
+                lv5 = R.add(lv4, R.const(9.9999997473787516e-06, "float32"))
+                lv6 = R.sqrt(lv5)
+                lv7 = R.divide(lv2, lv6)
+                lv8 = R.expand_dims(gamma, axis=[0, 2, 3])
+                lv9 = R.multiply(lv7, lv8)
+                lv10 = R.expand_dims(beta, axis=[0, 2, 3])
+                lv11 = R.add(lv9, lv10)
+                inner_tuple = (lv11, lv, lv3)
+                # This is the result that would be returned from a
+                # batch_norm at inference.
+
+                # However, at training we need to update the moving
+                # mean/variance, and to return those updated values.
+                inner_res = inner_tuple[0]
+                lv12 = R.multiply(R.const(0.89999997615814209, "float32"), 
moving_mean)
+                lv13 = R.multiply(R.const(0.10000000149011612, "float32"), lv)
+                lv14 = R.add(lv12, lv13)
+                lv15 = R.multiply(R.const(0.89999997615814209, "float32"), 
moving_var)
+                lv16 = R.multiply(R.const(0.10000000149011612, "float32"), lv3)
+                lv17 = R.add(lv15, lv16)
+                bn = (inner_res, lv14, lv17)
+                gv0 = bn[0]
+                gv1 = bn[1]
+                gv2 = bn[2]
                 R.output(gv0, gv1, gv2)
             return (gv0, gv1, gv2)
 

Reply via email to