This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 225d067fb8 [Unity] Support Padding Reversal in Alter-Op pass (#15679)
225d067fb8 is described below
commit 225d067fb85e37723bbb3a5ca51f872c0e0f6abf
Author: rutkoor <[email protected]>
AuthorDate: Thu Sep 28 23:19:01 2023 +0530
[Unity] Support Padding Reversal in Alter-Op pass (#15679)
* Support for padding Reversal in Alter-op pass
* Removing lambda from remove_pad te.compute
* Applying clang-format on .cc file
* Removing extra line from manipulate.h file
---
.../tvm/relax/transform/legalize_ops/manipulate.py | 10 ++-
python/tvm/relax/transform/transform.py | 2 +-
src/relax/transform/alter_op_impl.cc | 81 +++++++++++++++++++---
src/te/operation/compute_op.cc | 6 +-
tests/python/relax/test_transform_alter_op_impl.py | 80 +++++++++++++++++----
5 files changed, 153 insertions(+), 26 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 4e06a0df39..e56240dc0d 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -182,7 +182,15 @@ def _layout_transform(bb: BlockBuilder, call: Call) ->
Expr:
)
index_map: tvm.tir.IndexMap = call.attrs.index_map
- pad_value = call.attrs.pad_value.value
+ pad_value = call.attrs.pad_value
+ if pad_value is not None:
+ pad_value = pad_value.value
+ else:
+ if "int" in call.args[0].struct_info.dtype:
+ pad_value = int(0)
+ else:
+ pad_value = float(0.0)
+
axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR =
call.attrs.axis_separators
# Convert to list from array
axis_separators = list(map(lambda x: x.value, axis_separators))
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 2a06d5098e..72a9966a4b 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1034,7 +1034,7 @@ def AlterOpImpl(
l = []
for transform in transform_list:
if isinstance(transform, Callable):
- transform = IndexMap.from_func(transform)
+ transform = IndexMap.from_func_with_separators(transform)[0]
l.append(transform)
op_buffer_transforms[operator_name] = l
diff --git a/src/relax/transform/alter_op_impl.cc
b/src/relax/transform/alter_op_impl.cc
index 9813c4ed24..98d64dd7a8 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -30,7 +30,11 @@
#include <tvm/relax/attrs/manipulate.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
+#include <tvm/te/operation.h>
#include <tvm/tir/transform.h>
+#include <tvm/topi/tags.h>
+
+#include "../../te/operation/create_primfunc.h"
namespace tvm {
namespace relax {
@@ -162,8 +166,18 @@ class AlterOpImplMutator : public ExprMutator {
return arr_tensor_sinfo;
}
+ bool IsScalarConstant(const Expr& expr) {
+ if (expr->IsInstance<ConstantNode>() &&
expr.as<ConstantNode>()->is_scalar()) {
+ return true;
+ }
+ return false;
+ }
+
Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
- const Array<IntImm> axis_separators) {
+ const Array<IntImm>& axis_separators) {
+ if (IsScalarConstant(expr) || index_map.get() == nullptr) {
+ return expr;
+ }
ObjectPtr<LayoutTransformAttrs> attrs =
make_object<LayoutTransformAttrs>();
// We want to avoid two layout_transform ops to share the same index map
even if they are
// identical. The scope of vars used in index map initial indices is local
to the op. Not doing
@@ -173,19 +187,70 @@ class AlterOpImplMutator : public ExprMutator {
return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
}
+ /*!
+ * \brief Adds the \p remove_pad op to the module if it has not already been
added before.
+ * \returns The global var associated with the remove_pad PrimFunc.
+ */
+ GlobalVar GetOrCreateRemovePadOp(const Array<PrimExpr>& old_shape, const
DataType& dtype) {
+ int t_shape = old_shape.size();
+ if (remove_pad_map_.count(t_shape) != 0) {
+ return remove_pad_map_[t_shape];
+ }
+ // Create dynamic shapes for input and output tensors
+ Array<PrimExpr> dyn_padded_shape, dyn_old_shape;
+ for (int i = 0; i < t_shape; i++) {
+ tir::Var var1("p" + std::to_string(i), old_shape[i].dtype());
+ tir::Var var2("i" + std::to_string(i), old_shape[i].dtype());
+ dyn_padded_shape.push_back(var1);
+ dyn_old_shape.push_back(var2);
+ }
+
+ // Input tensor of remove_pad op
+ te::Tensor placeholder_tensor = te::placeholder(dyn_padded_shape, dtype,
"input");
+ // Output tensor of remove_pad op
+ te::Tensor output_tensor = te::compute(
+ dyn_old_shape,
+ [&placeholder_tensor](const Array<tir::Var>& indices) {
+ return placeholder_tensor(indices);
+ },
+ "output", topi::kElementWise);
+
+ String op_name = "remove_pad";
+ // Create PrimFunc and add op_name to func.attrs
+ PrimFunc remove_pad_with_frozen_layout =
+ WithAttr(CreatePrimFunc({placeholder_tensor, output_tensor}),
kOperatorName, op_name);
+ // Add PrimFunc to module
+ GlobalVar gv_remove_pad =
builder_->AddFunction(remove_pad_with_frozen_layout, op_name);
+ // Mark the remove_pad PrimFunc as private by removing it from global scope
+ builder_->UpdateFunction(gv_remove_pad,
+ WithoutAttr(remove_pad_with_frozen_layout,
"global_symbol"));
+
+ remove_pad_map_[t_shape] = gv_remove_pad;
+ return gv_remove_pad;
+ }
+
Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
const TensorStructInfo& old_tensor_sinfo,
const Array<IntImm>& axis_separator) {
+ if (IsScalarConstant(expr) || index_map.get() == nullptr) {
+ return expr;
+ }
Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
arith::Analyzer analyzer;
auto [inverse_index_map, padding_predicate] =
index_map.NonSurjectiveInverse(initial_ranges, &analyzer);
- ICHECK(tir::is_zero(padding_predicate))
- << "Only bijective transformations on input/output buffers are
supported, but found "
- "padding predicate "
- << padding_predicate << " on initial range " << initial_ranges;
- return TransformLayout(expr, inverse_index_map, axis_separator);
+
+ if (tir::is_zero(padding_predicate)) {
+ return TransformLayout(expr, inverse_index_map, axis_separator);
+ } else {
+ auto padded_expr =
+ builder_->Normalize(TransformLayout(expr, inverse_index_map,
axis_separator));
+ const auto& tensor_sinfo =
Downcast<TensorStructInfo>(padded_expr->struct_info_);
+
+ GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape,
tensor_sinfo->dtype);
+ return Call(call_tir_op_, {gv_remove_pad, Tuple({padded_expr})}, {},
{old_tensor_sinfo});
+ }
}
/*!
@@ -223,8 +288,6 @@ class AlterOpImplMutator : public ExprMutator {
axis_separator = axis_separators_value[index];
}
auto transform = transforms[index++];
- ICHECK(IsTransformBijective(input, transform))
- << "Non bijective transforms on input and output buffers are not
supported.";
updated_inputs.push_back(TransformLayout(input, transform,
axis_separator));
}
return Tuple(updated_inputs);
@@ -314,6 +377,8 @@ class AlterOpImplMutator : public ExprMutator {
Map<PrimFunc, GlobalVar> cache_;
/*! \brief Input IRModule */
const IRModule& mod_;
+ /*! \brief Map from shape_dim.size to the remove_pad GlobalVar */
+ std::unordered_map<int, GlobalVar> remove_pad_map_;
/*! \brief Map from kOperatorName attribute to the replacement PrimFunc */
const Map<String, PrimFunc>& op_impl_map_;
/*! \brief Map from kOperatorName attribute to the layout transforms on i/o
buffers */
diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc
index 3ca40c9a6b..5797d2295b 100644
--- a/src/te/operation/compute_op.cc
+++ b/src/te/operation/compute_op.cc
@@ -98,7 +98,8 @@ Tensor compute(Array<PrimExpr> shape, FCompute fcompute,
std::string name, std::
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
- axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(),
shape[i].dtype()), kDataPar));
+ axis.emplace_back(IterVar(Range(IntImm(shape[i]->dtype, 0), shape[i]),
+ Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}
@@ -114,7 +115,8 @@ Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute
fcompute, std::string
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
- axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(),
shape[i].dtype()), kDataPar));
+ axis.emplace_back(IterVar(Range(IntImm(shape[i]->dtype, 0), shape[i]),
+ Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}
diff --git a/tests/python/relax/test_transform_alter_op_impl.py
b/tests/python/relax/test_transform_alter_op_impl.py
index 81bc480785..3cbba9a031 100644
--- a/tests/python/relax/test_transform_alter_op_impl.py
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -303,13 +303,13 @@ def test_multiple_outputs_with_axis_sep():
)
-def test_unsupported_implicit_padding():
+def test_supported_implicit_padding():
@I.ir_module
- class InputModule:
+ class Before:
@R.function
def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,),
dtype="float32"):
with R.dataflow():
- lv = R.call_tir(InputModule.relu, (x,),
out_sinfo=R.Tensor((14,), dtype="float32"))
+ lv = R.call_tir(Before.relu, (x,), out_sinfo=R.Tensor((14,),
dtype="float32"))
gv: R.Tensor((14,), dtype="float32") = lv
R.output(gv)
return gv
@@ -324,7 +324,62 @@ def test_unsupported_implicit_padding():
T.writes(output[v_ax0])
output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
- before = InputModule
+ @I.ir_module
+ class Expected:
+ @R.function
+ def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((16,), dtype="float32") = R.layout_transform(
+ x,
+ index_map=T.index_map(lambda i: (i % 16,)),
+ pad_value=None,
+ axis_separators=[],
+ )
+ lv1 = R.call_tir(
+ Expected.relax_relu_replacement,
+ (lv,),
+ out_sinfo=R.Tensor((16,), dtype="float32"),
+ )
+ lv2: R.Tensor((16,), dtype="float32") = R.layout_transform(
+ lv1,
+ index_map=T.index_map(lambda axis0: (axis0,)),
+ pad_value=None,
+ axis_separators=[],
+ )
+ lv_1 = R.call_tir(
+ Expected.remove_pad, (lv2,), out_sinfo=R.Tensor((14,),
dtype="float32")
+ )
+ gv: R.Tensor((14,), dtype="float32") = lv_1
+ R.output(gv)
+ return gv
+
+ @T.prim_func(private=True)
+ def relax_relu_replacement(
+ arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,),
"float32")
+ ):
+ T.func_attr({"operator_name": "relax.relu"})
+ # with T.block("root"):
+ for ax0 in range(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(16, ax0)
+ T.reads(arg0[v_ax0])
+ T.writes(output[v_ax0])
+ output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
+
+ @T.prim_func(private=True)
+ def remove_pad(var_input: T.handle, var_output: T.handle):
+ T.func_attr({"operator_name": "remove_pad", "tir.noalias":
T.bool(True)})
+ p0 = T.int64()
+ input = T.match_buffer(var_input, (p0,))
+ i0 = T.int64()
+ output = T.match_buffer(var_output, (i0,))
+ # with T.block("root"):
+ for ax0 in range(i0):
+ with T.block("output"):
+ v_ax0 = T.axis.spatial(i0, ax0)
+ T.reads(input[v_ax0])
+ T.writes(output[v_ax0])
+ output[v_ax0] = input[v_ax0]
@T.prim_func(private=True)
def relu_pad(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,),
"float32")):
@@ -338,16 +393,13 @@ def test_unsupported_implicit_padding():
# introduces implicit padding for shape (14,)
index_map = lambda i: (i % 16)
operator_name = "relax.relu"
- with pytest.raises(
- tvm.TVMError, match="Non bijective transforms on input and output
buffers are not supported"
- ):
- _ = relax.transform.AlterOpImpl(
- {operator_name: relu_pad},
- {
- operator_name: [index_map, index_map],
- },
- {operator_name: None},
- )(before)
+ _check(
+ Before,
+ Expected,
+ operator_name="relax.relu",
+ replacement_primfunc=relu_pad,
+ layout_changes=[index_map, index_map],
+ )
def test_multiple_call_sites():