gemini-code-assist[bot] commented on code in PR #18638:
URL: https://github.com/apache/tvm/pull/18638#discussion_r2662298255
##########
src/relax/op/tensor/manipulate.cc:
##########
@@ -2580,14 +2580,43 @@ StructInfo InferStructInfoScatterElements(const Call&
call, const BlockBuilder&
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice);
}
-// TODO(relax-team): implement FRelaxInferLayout for scatter_elements
+InferLayoutOutput InferLayoutScatterElements(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* attrs = call->attrs.as<ScatterElementsAttrs>();
+ ICHECK(attrs) << "Invalid Call";
Review Comment:

The `ICHECK` message "Invalid Call" is a bit generic. A more specific
message would improve debuggability by indicating what was expected.
```suggestion
ICHECK(attrs) << "scatter_elements attributes not found";
```
##########
src/relax/op/tensor/manipulate.cc:
##########
@@ -2580,14 +2580,43 @@ StructInfo InferStructInfoScatterElements(const Call&
call, const BlockBuilder&
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice);
}
-// TODO(relax-team): implement FRelaxInferLayout for scatter_elements
+InferLayoutOutput InferLayoutScatterElements(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* attrs = call->attrs.as<ScatterElementsAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision data_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ LayoutDecision indices_layout = GetLayoutDecision(var_layout_map,
call->args[1]);
+ LayoutDecision updates_layout = GetLayoutDecision(var_layout_map,
call->args[2]);
+
+ LayoutDecision layout = data_layout;
+ if (NLayoutEqual()(indices_layout, updates_layout)) {
+ layout = indices_layout;
+ }
+
+ if (layout->layout.ndim() != layout->layout.ndim_primal()) {
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for
now";
+ int ndim = tensor_sinfo->ndim;
+ layout = LayoutDecision(InitialLayout(ndim));
+ }
Review Comment:

This block handles the fallback for sub-indexed layouts. Adding a comment to
explain this behavior would improve code clarity for future maintenance.
Additionally, the `ICHECK` message on line 2601 could be more specific to
aid debugging.
```suggestion
// scatter_elements currently does not support sub-indexed layouts.
// Fall back to the initial layout if the chosen layout is sub-indexed.
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
ICHECK(tensor_sinfo != nullptr)
<< "scatter_elements requires the first argument to be a Tensor";
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for
now";
int ndim = tensor_sinfo->ndim;
layout = LayoutDecision(InitialLayout(ndim));
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]