gemini-code-assist[bot] commented on code in PR #18675:
URL: https://github.com/apache/tvm/pull/18675#discussion_r2706643382
##########
src/relax/transform/convert_layout.cc:
##########
@@ -201,15 +202,21 @@ class LayoutConvertMutator : public ExprMutator {
ffi::Optional<InferLayoutOutput> GetInferLayoutInfo(
const CallNode* call_node,
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
- const VarLayoutMap& var_layout_map) {
+ const LayoutCb& layout_cb, const VarLayoutMap& var_layout_map) {
const OpNode* op_node = call_node->op.as<OpNode>();
if (op_node == nullptr) return std::nullopt;
Op op = Downcast<Op>(ffi::GetRef<Op>(op_node));
const auto attr_map =
Op::GetAttrMap<FRelaxInferLayout>("FRelaxInferLayout");
if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) {
// If the op has FRelaxInferLayout, and all the input tensors have known
ndim
FRelaxInferLayout f = attr_map[op];
- return f(ffi::GetRef<Call>(call_node), desired_layouts, var_layout_map);
+ if (layout_cb != nullptr) {
+ ffi::Map<ffi::String, ffi::Array<ffi::String>> custom_layouts;
+ custom_layouts = layout_cb(ffi::GetRef<Call>(call_node));
+ return f(ffi::GetRef<Call>(call_node), custom_layouts, var_layout_map);
+ } else {
+ return f(ffi::GetRef<Call>(call_node), desired_layouts,
var_layout_map);
+ }
Review Comment:

To improve readability and avoid redundant calls to
`ffi::GetRef<Call>(call_node)`, you could store the result in a local variable.
This also makes the code slightly more efficient.
```suggestion
FRelaxInferLayout f = attr_map[op];
auto call = ffi::GetRef<Call>(call_node);
if (layout_cb) {
auto custom_layouts = layout_cb(call);
return f(call, custom_layouts, var_layout_map);
} else {
return f(call, desired_layouts, var_layout_map);
}
```
##########
python/tvm/relax/transform/transform.py:
##########
@@ -1377,13 +1380,16 @@ def ConvertLayout(desired_layouts: Dict[str,
List[str]]) -> tvm.ir.transform.Pas
of the desired feature map, weight and output. For example, if we want
to convert the
layout of conv2d from NCHW to NHWC, we can set the desired layout of
conv2d to be
``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``.
+ layout_cb : Callable
+ A user defined call back function that can dynamically handle operator
layouts
+ based on Call description. desigred_layouts will be ignored if
layout_cb is defined.
Review Comment:

There's a small typo in the docstring. `desigred_layouts` should be
`desired_layouts`.
```suggestion
based on Call description. desired_layouts will be ignored if
layout_cb is defined.
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]