This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d8c973e674 [RELAX][LAYOUT] Support for dynamic layout specification
(#18675)
d8c973e674 is described below
commit d8c973e6747f98ef59277fcd84a9eb0671da819f
Author: Siva <[email protected]>
AuthorDate: Tue Jan 20 19:03:15 2026 +0530
[RELAX][LAYOUT] Support for dynamic layout specification (#18675)
This allows user defined callback to specify layouts dynamically based
on call description.
Helpful to alter layouts based on the operator shapes or attributes.
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
include/tvm/relax/transform.h | 5 +-
python/tvm/relax/transform/transform.py | 10 ++-
src/relax/transform/convert_layout.cc | 28 +++++--
.../python/relax/test_transform_convert_layout.py | 95 +++++++++++++++++++++-
4 files changed, 124 insertions(+), 14 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 786dfdcdf9..0e660292c4 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -41,6 +41,7 @@ using PassContext = tvm::transform::PassContext;
using Function = tvm::relax::Function;
using DataflowBlock = tvm::relax::DataflowBlock;
using tvm::transform::CreateModulePass;
+using LayoutCb = ffi::TypedFunction<ffi::Map<ffi::String,
ffi::Array<ffi::String>>(Call)>;
/*!
* \brief Create a function pass.
@@ -606,10 +607,12 @@ TVM_DLL Pass AlterOpImpl(
/*!
* \brief Layout conversion pass.
* \param desired_layouts The desired layouts for some operators.
+ * \param layout_cb custom call back to define layouts dynamically.
* \return The Pass.
* \note Operates only on dataflow blocks. ConvertToDataflow may need to be
called first.
*/
-TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>>
desired_layouts);
+TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>>
desired_layouts,
+ LayoutCb layout_cb);
/*!
* \brief A pass that converts consecutive dataflow operations
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 46efc17e3d..bfd7dbf87d 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1367,7 +1367,10 @@ def AlterOpImpl(
) # type: ignore
-def ConvertLayout(desired_layouts: Dict[str, List[str]]) ->
tvm.ir.transform.Pass:
+def ConvertLayout(
+ desired_layouts: Dict[str, List[str]],
+ layout_cb: Callable = None,
+) -> tvm.ir.transform.Pass:
"""Automatic layout conversion pass.
Parameters
@@ -1377,13 +1380,16 @@ def ConvertLayout(desired_layouts: Dict[str,
List[str]]) -> tvm.ir.transform.Pas
of the desired feature map, weight and output. For example, if we want
to convert the
layout of conv2d from NCHW to NHWC, we can set the desired layout of
conv2d to be
``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``.
+ layout_cb : Callable
+ A user defined call back function that can dynamically handle operator
layouts
+ based on Call description. desired_layouts will be ignored if
layout_cb is defined.
Returns
-------
ret : tvm.transform.Pass
The registered pass for layout conversion.
"""
- return _ffi_api.ConvertLayout(desired_layouts) # type: ignore
+ return _ffi_api.ConvertLayout(desired_layouts, layout_cb) # type: ignore
def DeadCodeElimination(entry_functions: Optional[List[str]] = None) ->
tvm.ir.transform.Pass:
diff --git a/src/relax/transform/convert_layout.cc
b/src/relax/transform/convert_layout.cc
index c543799e3b..27684313de 100644
--- a/src/relax/transform/convert_layout.cc
+++ b/src/relax/transform/convert_layout.cc
@@ -38,6 +38,7 @@ namespace relax {
using tir::IndexMap;
using tir::Layout;
+using LayoutCb = tvm::relax::transform::LayoutCb;
/*!
* \brief Main logic to convert the layout of conv2d. Other ops
@@ -79,8 +80,8 @@ using tir::Layout;
class LayoutConvertMutator : public ExprMutator {
public:
explicit LayoutConvertMutator(
- const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts)
- : desired_layouts_(desired_layouts) {}
+ const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
LayoutCb layout_cb)
+ : desired_layouts_(desired_layouts), layout_cb_(layout_cb) {}
private:
ffi::Array<Integer> LayoutToIntegers(const Layout& layout) {
@@ -201,7 +202,7 @@ class LayoutConvertMutator : public ExprMutator {
ffi::Optional<InferLayoutOutput> GetInferLayoutInfo(
const CallNode* call_node,
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
- const VarLayoutMap& var_layout_map) {
+ const LayoutCb& layout_cb, const VarLayoutMap& var_layout_map) {
const OpNode* op_node = call_node->op.as<OpNode>();
if (op_node == nullptr) return std::nullopt;
Op op = Downcast<Op>(ffi::GetRef<Op>(op_node));
@@ -209,7 +210,13 @@ class LayoutConvertMutator : public ExprMutator {
if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) {
// If the op has FRelaxInferLayout, and all the input tensors have known
ndim
FRelaxInferLayout f = attr_map[op];
- return f(ffi::GetRef<Call>(call_node), desired_layouts, var_layout_map);
+ auto call = ffi::GetRef<Call>(call_node);
+ if (layout_cb != nullptr) {
+ auto custom_layouts = layout_cb(call);
+ return f(call, custom_layouts, var_layout_map);
+ } else {
+ return f(call, desired_layouts, var_layout_map);
+ }
} else {
// Otherwise, we use the default policy.
return std::nullopt;
@@ -218,7 +225,7 @@ class LayoutConvertMutator : public ExprMutator {
void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node)
final {
ffi::Optional<InferLayoutOutput> res =
- GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_);
+ GetInferLayoutInfo(call_node, desired_layouts_, layout_cb_,
var_layout_map_);
ObjectPtr<CallNode> new_call = ffi::make_object<CallNode>(*call_node);
new_call->struct_info_ = std::nullopt;
if (!res.defined() ||
@@ -335,20 +342,23 @@ class LayoutConvertMutator : public ExprMutator {
std::unordered_map<Var, NLayout> var_layout_map_;
ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts_;
+ LayoutCb layout_cb_;
}; // namespace relax
DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block,
- ffi::Map<ffi::String, ffi::Array<ffi::String>>
desired_layouts) {
- LayoutConvertMutator mutator(desired_layouts);
+ ffi::Map<ffi::String, ffi::Array<ffi::String>>
desired_layouts,
+ LayoutCb layout_cb) {
+ LayoutConvertMutator mutator(desired_layouts, layout_cb);
return Downcast<DataflowBlock>(mutator.VisitBindingBlock(df_block));
}
namespace transform {
-Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>>
desired_layouts) {
+Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>>
desired_layouts,
+ LayoutCb layout_cb) {
ffi::TypedFunction<DataflowBlock(DataflowBlock, IRModule, PassContext)>
pass_func =
[=](DataflowBlock df_block, IRModule m, PassContext pc) {
- return Downcast<DataflowBlock>(ConvertLayoutPass(df_block,
desired_layouts));
+ return Downcast<DataflowBlock>(ConvertLayoutPass(df_block,
desired_layouts, layout_cb));
};
return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {});
}
diff --git a/tests/python/relax/test_transform_convert_layout.py
b/tests/python/relax/test_transform_convert_layout.py
index 84fa9e70c7..fe412fd93b 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -21,10 +21,10 @@ from tvm.relax.transform import ConvertLayout, Normalize
from tvm.script.parser import ir as I, relax as R, tir as T
-def verify(input, expected, extra_ops={}):
+def verify(input, expected, extra_ops={}, cb=None):
desired_layouts = {"relax.nn.conv2d": ["NHWC", "OHWI"]}
desired_layouts.update(extra_ops)
- mod = ConvertLayout(desired_layouts)(input)
+ mod = ConvertLayout(desired_layouts, cb)(input)
mod = Normalize()(mod)
tvm.ir.assert_structural_equal(mod, expected)
@@ -5487,5 +5487,96 @@ def test_conv2d_gather_elements():
verify(Input, Expected)
+def test_layout_cb():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 4, 28, 28), "float32"),
+ w: R.Tensor((4, 4, 3, 3), "float32"),
+ bias: R.Tensor((2, 4, 26, 26), "float32"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
+ gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2)
+ gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w,
out_dtype="float32")
+ R.output(gv4)
+ return gv4
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 4, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 4, 3, 3), dtype="float32"),
+ bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
+ ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") =
R.layout_transform(
+ x,
+ index_map=T.index_map(
+ lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4),
index_dtype="int32"
+ ),
+ )
+ lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") =
R.layout_transform(
+ w,
+ index_map=T.index_map(
+ lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4),
index_dtype="int32"
+ ),
+ )
+ gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW4c",
+ kernel_layout="OIHW4o",
+ out_layout="NCHW4c",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") =
R.layout_transform(
+ bias,
+ index_map=T.index_map(
+ lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4),
index_dtype="int32"
+ ),
+ )
+ gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv,
lv2)
+ gv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") =
R.nn.relu(gv2)
+ lv3: R.Tensor((1, 4, 3, 3, 4), dtype="float32") =
R.layout_transform(
+ w,
+ index_map=T.index_map(
+ lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4),
index_dtype="int32"
+ ),
+ )
+ lv4: R.Tensor((2, 1, 24, 24, 4), dtype="float32") =
R.nn.conv2d(
+ gv3,
+ lv3,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW4c",
+ kernel_layout="OIHW4o",
+ out_layout="NCHW4c",
+ out_dtype="float32",
+ )
+ gv4: R.Tensor((2, 4, 24, 24), dtype="float32") =
R.layout_transform(
+ lv4,
+ index_map=T.index_map(
+ lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3),
index_dtype="int32"
+ ),
+ )
+ R.output(gv4)
+ return gv4
+
+ def layout_cb(call: tvm.relax.Call):
+ return {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}
+
+ verify(Input, Expected, cb=layout_cb)
+
+
if __name__ == "__main__":
tvm.testing.main()