yongwww commented on code in PR #14493:
URL: https://github.com/apache/tvm/pull/14493#discussion_r1159067973


##########
src/relax/op/tensor/manipulate.cc:
##########
@@ -1359,5 +1359,119 @@ TVM_REGISTER_OP("relax.cumsum")
     .add_argument("data", "Tensor", "The input tensor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCumsum);
 
+/* relax.scatter_elements */
+TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
+
+Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String 
reduction) {
+  auto attrs = make_object<ScatterElementsAttrs>();
+  attrs->axis = std::move(axis);
+  attrs->reduction = std::move(reduction);
+  static const Op& op = Op::Get("relax.scatter_elements");
+  return Call(op, {data, indices, updates}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.scatter_elements").set_body_typed(scatter_elements);
+
+StructInfo InferStructInfoScatterElements(const Call& call, const 
BlockBuilder& ctx) {
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* indices_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const auto* updates_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+
+  auto diag_def = [&](const TensorStructInfoNode* sinfo, String name, String 
type_key) {
+    if (sinfo == nullptr) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "ScatterEelemtns requires the input " << name

Review Comment:
   typo: ScatterElements



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to