This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d8c973e674 [RELAX][LAYOUT] Support for dynamic layout specification 
(#18675)
d8c973e674 is described below

commit d8c973e6747f98ef59277fcd84a9eb0671da819f
Author: Siva <[email protected]>
AuthorDate: Tue Jan 20 19:03:15 2026 +0530

    [RELAX][LAYOUT] Support for dynamic layout specification (#18675)
    
    This allows user defined callback to specify layouts dynamically based
    on call description.
    Helpful to alter layouts based on the operator shapes or attributes.
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 include/tvm/relax/transform.h                      |  5 +-
 python/tvm/relax/transform/transform.py            | 10 ++-
 src/relax/transform/convert_layout.cc              | 28 +++++--
 .../python/relax/test_transform_convert_layout.py  | 95 +++++++++++++++++++++-
 4 files changed, 124 insertions(+), 14 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 786dfdcdf9..0e660292c4 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -41,6 +41,7 @@ using PassContext = tvm::transform::PassContext;
 using Function = tvm::relax::Function;
 using DataflowBlock = tvm::relax::DataflowBlock;
 using tvm::transform::CreateModulePass;
+using LayoutCb = ffi::TypedFunction<ffi::Map<ffi::String, 
ffi::Array<ffi::String>>(Call)>;
 
 /*!
  * \brief Create a function pass.
@@ -606,10 +607,12 @@ TVM_DLL Pass AlterOpImpl(
 /*!
  * \brief Layout conversion pass.
  * \param desired_layouts The desired layouts for some operators.
+ * \param layout_cb custom call back to define layouts dynamically.
  * \return The Pass.
  * \note Operates only on dataflow blocks. ConvertToDataflow may need to be 
called first.
  */
-TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> 
desired_layouts);
+TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> 
desired_layouts,
+                           LayoutCb layout_cb);
 
 /*!
  * \brief A pass that converts consecutive dataflow operations
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 46efc17e3d..bfd7dbf87d 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1367,7 +1367,10 @@ def AlterOpImpl(
     )  # type: ignore
 
 
-def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> 
tvm.ir.transform.Pass:
+def ConvertLayout(
+    desired_layouts: Dict[str, List[str]],
+    layout_cb: Callable = None,
+) -> tvm.ir.transform.Pass:
     """Automatic layout conversion pass.
 
     Parameters
@@ -1377,13 +1380,16 @@ def ConvertLayout(desired_layouts: Dict[str, 
List[str]]) -> tvm.ir.transform.Pas
         of the desired feature map, weight and output. For example, if we want 
to convert the
         layout of conv2d from NCHW to NHWC, we can set the desired layout of 
conv2d to be
         ``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``.
+    layout_cb : Callable
+        A user defined call back function that can dynamically handle operator 
layouts
+        based on Call description. desired_layouts will be ignored if 
layout_cb is defined.
 
     Returns
     -------
     ret : tvm.transform.Pass
         The registered pass for layout conversion.
     """
-    return _ffi_api.ConvertLayout(desired_layouts)  # type: ignore
+    return _ffi_api.ConvertLayout(desired_layouts, layout_cb)  # type: ignore
 
 
 def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> 
tvm.ir.transform.Pass:
diff --git a/src/relax/transform/convert_layout.cc 
b/src/relax/transform/convert_layout.cc
index c543799e3b..27684313de 100644
--- a/src/relax/transform/convert_layout.cc
+++ b/src/relax/transform/convert_layout.cc
@@ -38,6 +38,7 @@ namespace relax {
 
 using tir::IndexMap;
 using tir::Layout;
+using LayoutCb = tvm::relax::transform::LayoutCb;
 
 /*!
  * \brief Main logic to convert the layout of conv2d. Other ops
@@ -79,8 +80,8 @@ using tir::Layout;
 class LayoutConvertMutator : public ExprMutator {
  public:
   explicit LayoutConvertMutator(
-      const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts)
-      : desired_layouts_(desired_layouts) {}
+      const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts, 
LayoutCb layout_cb)
+      : desired_layouts_(desired_layouts), layout_cb_(layout_cb) {}
 
  private:
   ffi::Array<Integer> LayoutToIntegers(const Layout& layout) {
@@ -201,7 +202,7 @@ class LayoutConvertMutator : public ExprMutator {
   ffi::Optional<InferLayoutOutput> GetInferLayoutInfo(
       const CallNode* call_node,
       const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
-      const VarLayoutMap& var_layout_map) {
+      const LayoutCb& layout_cb, const VarLayoutMap& var_layout_map) {
     const OpNode* op_node = call_node->op.as<OpNode>();
     if (op_node == nullptr) return std::nullopt;
     Op op = Downcast<Op>(ffi::GetRef<Op>(op_node));
@@ -209,7 +210,13 @@ class LayoutConvertMutator : public ExprMutator {
     if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) {
       // If the op has FRelaxInferLayout, and all the input tensors have known 
ndim
       FRelaxInferLayout f = attr_map[op];
-      return f(ffi::GetRef<Call>(call_node), desired_layouts, var_layout_map);
+      auto call = ffi::GetRef<Call>(call_node);
+      if (layout_cb != nullptr) {
+        auto custom_layouts = layout_cb(call);
+        return f(call, custom_layouts, var_layout_map);
+      } else {
+        return f(call, desired_layouts, var_layout_map);
+      }
     } else {
       // Otherwise, we use the default policy.
       return std::nullopt;
@@ -218,7 +225,7 @@ class LayoutConvertMutator : public ExprMutator {
 
   void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) 
final {
     ffi::Optional<InferLayoutOutput> res =
-        GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_);
+        GetInferLayoutInfo(call_node, desired_layouts_, layout_cb_, 
var_layout_map_);
     ObjectPtr<CallNode> new_call = ffi::make_object<CallNode>(*call_node);
     new_call->struct_info_ = std::nullopt;
     if (!res.defined() ||
@@ -335,20 +342,23 @@ class LayoutConvertMutator : public ExprMutator {
 
   std::unordered_map<Var, NLayout> var_layout_map_;
   ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts_;
+  LayoutCb layout_cb_;
 };  // namespace relax
 
 DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block,
-                                ffi::Map<ffi::String, ffi::Array<ffi::String>> 
desired_layouts) {
-  LayoutConvertMutator mutator(desired_layouts);
+                                ffi::Map<ffi::String, ffi::Array<ffi::String>> 
desired_layouts,
+                                LayoutCb layout_cb) {
+  LayoutConvertMutator mutator(desired_layouts, layout_cb);
   return Downcast<DataflowBlock>(mutator.VisitBindingBlock(df_block));
 }
 
 namespace transform {
 
-Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> 
desired_layouts) {
+Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> 
desired_layouts,
+                   LayoutCb layout_cb) {
   ffi::TypedFunction<DataflowBlock(DataflowBlock, IRModule, PassContext)> 
pass_func =
       [=](DataflowBlock df_block, IRModule m, PassContext pc) {
-        return Downcast<DataflowBlock>(ConvertLayoutPass(df_block, 
desired_layouts));
+        return Downcast<DataflowBlock>(ConvertLayoutPass(df_block, 
desired_layouts, layout_cb));
       };
   return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {});
 }
diff --git a/tests/python/relax/test_transform_convert_layout.py 
b/tests/python/relax/test_transform_convert_layout.py
index 84fa9e70c7..fe412fd93b 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -21,10 +21,10 @@ from tvm.relax.transform import ConvertLayout, Normalize
 from tvm.script.parser import ir as I, relax as R, tir as T
 
 
-def verify(input, expected, extra_ops={}):
+def verify(input, expected, extra_ops={}, cb=None):
     desired_layouts = {"relax.nn.conv2d": ["NHWC", "OHWI"]}
     desired_layouts.update(extra_ops)
-    mod = ConvertLayout(desired_layouts)(input)
+    mod = ConvertLayout(desired_layouts, cb)(input)
     mod = Normalize()(mod)
     tvm.ir.assert_structural_equal(mod, expected)
 
@@ -5487,5 +5487,96 @@ def test_conv2d_gather_elements():
     verify(Input, Expected)
 
 
+def test_layout_cb():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 4, 28, 28), "float32"),
+            w: R.Tensor((4, 4, 3, 3), "float32"),
+            bias: R.Tensor((2, 4, 26, 26), "float32"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
+                gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2)
+                gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, 
out_dtype="float32")
+                R.output(gv4)
+            return gv4
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 4, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 4, 3, 3), dtype="float32"),
+            bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
+        ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") = 
R.layout_transform(
+                    x,
+                    index_map=T.index_map(
+                        lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), 
index_dtype="int32"
+                    ),
+                )
+                lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = 
R.layout_transform(
+                    w,
+                    index_map=T.index_map(
+                        lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), 
index_dtype="int32"
+                    ),
+                )
+                gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW4c",
+                    kernel_layout="OIHW4o",
+                    out_layout="NCHW4c",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = 
R.layout_transform(
+                    bias,
+                    index_map=T.index_map(
+                        lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), 
index_dtype="int32"
+                    ),
+                )
+                gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv, 
lv2)
+                gv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = 
R.nn.relu(gv2)
+                lv3: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = 
R.layout_transform(
+                    w,
+                    index_map=T.index_map(
+                        lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), 
index_dtype="int32"
+                    ),
+                )
+                lv4: R.Tensor((2, 1, 24, 24, 4), dtype="float32") = 
R.nn.conv2d(
+                    gv3,
+                    lv3,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW4c",
+                    kernel_layout="OIHW4o",
+                    out_layout="NCHW4c",
+                    out_dtype="float32",
+                )
+                gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = 
R.layout_transform(
+                    lv4,
+                    index_map=T.index_map(
+                        lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), 
index_dtype="int32"
+                    ),
+                )
+                R.output(gv4)
+            return gv4
+
+    def layout_cb(call: tvm.relax.Call):
+        return {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}
+
+    verify(Input, Expected, cb=layout_cb)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to