This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 225d067fb8 [Unity] Support Padding Reversal in Alter-Op pass (#15679) 225d067fb8 is described below commit 225d067fb85e37723bbb3a5ca51f872c0e0f6abf Author: rutkoor <120498024+rutk...@users.noreply.github.com> AuthorDate: Thu Sep 28 23:19:01 2023 +0530 [Unity] Support Padding Reversal in Alter-Op pass (#15679) * Support for padding Reversal in Alter-op pass * Removing lambda from remove_pad te.compute * Applying clang-format on .cc file * Removing extra line from manipulate.h file --- .../tvm/relax/transform/legalize_ops/manipulate.py | 10 ++- python/tvm/relax/transform/transform.py | 2 +- src/relax/transform/alter_op_impl.cc | 81 +++++++++++++++++++--- src/te/operation/compute_op.cc | 6 +- tests/python/relax/test_transform_alter_op_impl.py | 80 +++++++++++++++++---- 5 files changed, 153 insertions(+), 26 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 4e06a0df39..e56240dc0d 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -182,7 +182,15 @@ def _layout_transform(bb: BlockBuilder, call: Call) -> Expr: ) index_map: tvm.tir.IndexMap = call.attrs.index_map - pad_value = call.attrs.pad_value.value + pad_value = call.attrs.pad_value + if pad_value is not None: + pad_value = pad_value.value + else: + if "int" in call.args[0].struct_info.dtype: + pad_value = int(0) + else: + pad_value = float(0.0) + axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.axis_separators # Convert to list from array axis_separators = list(map(lambda x: x.value, axis_separators)) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 2a06d5098e..72a9966a4b 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1034,7 +1034,7 @@ def AlterOpImpl( l = [] for transform in transform_list: if isinstance(transform, Callable): - transform = IndexMap.from_func(transform) + transform = IndexMap.from_func_with_separators(transform)[0] l.append(transform) op_buffer_transforms[operator_name] = l diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 9813c4ed24..98d64dd7a8 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -30,7 +30,11 @@ #include <tvm/relax/attrs/manipulate.h> #include <tvm/relax/expr_functor.h> #include <tvm/relax/transform.h> +#include <tvm/te/operation.h> #include <tvm/tir/transform.h> +#include <tvm/topi/tags.h> + +#include "../../te/operation/create_primfunc.h" namespace tvm { namespace relax { @@ -162,8 +166,18 @@ class AlterOpImplMutator : public ExprMutator { return arr_tensor_sinfo; } + bool IsScalarConstant(const Expr& expr) { + if (expr->IsInstance<ConstantNode>() && expr.as<ConstantNode>()->is_scalar()) { + return true; + } + return false; + } + Expr TransformLayout(const Expr& expr, const IndexMap& index_map, - const Array<IntImm> axis_separators) { + const Array<IntImm>& axis_separators) { + if (IsScalarConstant(expr) || index_map.get() == nullptr) { + return expr; + } ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>(); // We want to avoid two layout_transform ops to share the same index map even if they are // identical. The scope of vars used in index map initial indices is local to the op. Not doing @@ -173,19 +187,70 @@ class AlterOpImplMutator : public ExprMutator { return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); } + /*! + * \brief Adds the \p remove_pad op to the module if it has not already been added before. + * \returns The global var associated with the remove_pad PrimFunc. + */ + GlobalVar GetOrCreateRemovePadOp(const Array<PrimExpr>& old_shape, const DataType& dtype) { + int t_shape = old_shape.size(); + if (remove_pad_map_.count(t_shape) != 0) { + return remove_pad_map_[t_shape]; + } + // Create dynamic shapes for input and output tensors + Array<PrimExpr> dyn_padded_shape, dyn_old_shape; + for (int i = 0; i < t_shape; i++) { + tir::Var var1("p" + std::to_string(i), old_shape[i].dtype()); + tir::Var var2("i" + std::to_string(i), old_shape[i].dtype()); + dyn_padded_shape.push_back(var1); + dyn_old_shape.push_back(var2); + } + + // Input tensor of remove_pad op + te::Tensor placeholder_tensor = te::placeholder(dyn_padded_shape, dtype, "input"); + // Output tensor of remove_pad op + te::Tensor output_tensor = te::compute( + dyn_old_shape, + [&placeholder_tensor](const Array<tir::Var>& indices) { + return placeholder_tensor(indices); + }, + "output", topi::kElementWise); + + String op_name = "remove_pad"; + // Create PrimFunc and add op_name to func.attrs + PrimFunc remove_pad_with_frozen_layout = + WithAttr(CreatePrimFunc({placeholder_tensor, output_tensor}), kOperatorName, op_name); + // Add PrimFunc to module + GlobalVar gv_remove_pad = builder_->AddFunction(remove_pad_with_frozen_layout, op_name); + // Mark the remove_pad PrimFunc as private by removing it from global scope + builder_->UpdateFunction(gv_remove_pad, + WithoutAttr(remove_pad_with_frozen_layout, "global_symbol")); + + remove_pad_map_[t_shape] = gv_remove_pad; + return gv_remove_pad; + } + Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map, const TensorStructInfo& old_tensor_sinfo, const Array<IntImm>& axis_separator) { + if (IsScalarConstant(expr) || index_map.get() == nullptr) { + return expr; + } Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo); Array<Range> initial_ranges = ConstructRangeFromShape(old_shape); arith::Analyzer analyzer; auto [inverse_index_map, padding_predicate] = index_map.NonSurjectiveInverse(initial_ranges, &analyzer); - ICHECK(tir::is_zero(padding_predicate)) - << "Only bijective transformations on input/output buffers are supported, but found " - "padding predicate " - << padding_predicate << " on initial range " << initial_ranges; - return TransformLayout(expr, inverse_index_map, axis_separator); + + if (tir::is_zero(padding_predicate)) { + return TransformLayout(expr, inverse_index_map, axis_separator); + } else { + auto padded_expr = + builder_->Normalize(TransformLayout(expr, inverse_index_map, axis_separator)); + const auto& tensor_sinfo = Downcast<TensorStructInfo>(padded_expr->struct_info_); + + GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_sinfo->dtype); + return Call(call_tir_op_, {gv_remove_pad, Tuple({padded_expr})}, {}, {old_tensor_sinfo}); + } } /*! @@ -223,8 +288,6 @@ class AlterOpImplMutator : public ExprMutator { axis_separator = axis_separators_value[index]; } auto transform = transforms[index++]; - ICHECK(IsTransformBijective(input, transform)) - << "Non bijective transforms on input and output buffers are not supported."; updated_inputs.push_back(TransformLayout(input, transform, axis_separator)); } return Tuple(updated_inputs); @@ -314,6 +377,8 @@ class AlterOpImplMutator : public ExprMutator { Map<PrimFunc, GlobalVar> cache_; /*! \brief Input IRModule */ const IRModule& mod_; + /*! \brief Map from shape_dim.size to the remove_pad GlobalVar */ + std::unordered_map<int, GlobalVar> remove_pad_map_; /*! \brief Map from kOperatorName attribute to the replacement PrimFunc */ const Map<String, PrimFunc>& op_impl_map_; /*! \brief Map from kOperatorName attribute to the layout transforms on i/o buffers */ diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 3ca40c9a6b..5797d2295b 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -98,7 +98,8 @@ Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name, std:: for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back(IterVar(Range(IntImm(shape[i]->dtype, 0), shape[i]), + Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } @@ -114,7 +115,8 @@ Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, std::string for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back(IterVar(Range(IntImm(shape[i]->dtype, 0), shape[i]), + Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } diff --git a/tests/python/relax/test_transform_alter_op_impl.py b/tests/python/relax/test_transform_alter_op_impl.py index 81bc480785..3cbba9a031 100644 --- a/tests/python/relax/test_transform_alter_op_impl.py +++ b/tests/python/relax/test_transform_alter_op_impl.py @@ -303,13 +303,13 @@ def test_multiple_outputs_with_axis_sep(): ) -def test_unsupported_implicit_padding(): +def test_supported_implicit_padding(): @I.ir_module - class InputModule: + class Before: @R.function def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"): with R.dataflow(): - lv = R.call_tir(InputModule.relu, (x,), out_sinfo=R.Tensor((14,), dtype="float32")) + lv = R.call_tir(Before.relu, (x,), out_sinfo=R.Tensor((14,), dtype="float32")) gv: R.Tensor((14,), dtype="float32") = lv R.output(gv) return gv @@ -324,7 +324,62 @@ def test_unsupported_implicit_padding(): T.writes(output[v_ax0]) output[v_ax0] = T.max(arg0[v_ax0], T.float32(0)) - before = InputModule + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((16,), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map(lambda i: (i % 16,)), + pad_value=None, + axis_separators=[], + ) + lv1 = R.call_tir( + Expected.relax_relu_replacement, + (lv,), + out_sinfo=R.Tensor((16,), dtype="float32"), + ) + lv2: R.Tensor((16,), dtype="float32") = R.layout_transform( + lv1, + index_map=T.index_map(lambda axis0: (axis0,)), + pad_value=None, + axis_separators=[], + ) + lv_1 = R.call_tir( + Expected.remove_pad, (lv2,), out_sinfo=R.Tensor((14,), dtype="float32") + ) + gv: R.Tensor((14,), dtype="float32") = lv_1 + R.output(gv) + return gv + + @T.prim_func(private=True) + def relax_relu_replacement( + arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32") + ): + T.func_attr({"operator_name": "relax.relu"}) + # with T.block("root"): + for ax0 in range(16): + with T.block("T_add"): + v_ax0 = T.axis.spatial(16, ax0) + T.reads(arg0[v_ax0]) + T.writes(output[v_ax0]) + output[v_ax0] = T.max(arg0[v_ax0], T.float32(0)) + + @T.prim_func(private=True) + def remove_pad(var_input: T.handle, var_output: T.handle): + T.func_attr({"operator_name": "remove_pad", "tir.noalias": T.bool(True)}) + p0 = T.int64() + input = T.match_buffer(var_input, (p0,)) + i0 = T.int64() + output = T.match_buffer(var_output, (i0,)) + # with T.block("root"): + for ax0 in range(i0): + with T.block("output"): + v_ax0 = T.axis.spatial(i0, ax0) + T.reads(input[v_ax0]) + T.writes(output[v_ax0]) + output[v_ax0] = input[v_ax0] @T.prim_func(private=True) def relu_pad(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): @@ -338,16 +393,13 @@ def test_unsupported_implicit_padding(): # introduces implicit padding for shape (14,) index_map = lambda i: (i % 16) operator_name = "relax.relu" - with pytest.raises( - tvm.TVMError, match="Non bijective transforms on input and output buffers are not supported" - ): - _ = relax.transform.AlterOpImpl( - {operator_name: relu_pad}, - { - operator_name: [index_map, index_map], - }, - {operator_name: None}, - )(before) + _check( + Before, + Expected, + operator_name="relax.relu", + replacement_primfunc=relu_pad, + layout_changes=[index_map, index_map], + ) def test_multiple_call_sites():