junrushao commented on code in PR #15680:
URL: https://github.com/apache/tvm/pull/15680#discussion_r1319180740


##########
src/relax/op/ccl/ccl.cc:
##########
@@ -70,5 +70,55 @@ TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.ccl.scatter_from_worker0 */
+TVM_REGISTER_NODE_TYPE(ScatterFromWorker0Attrs);
+
+Expr scatter_from_worker0(Expr data, int num_workers) {
+  ObjectPtr<ScatterFromWorker0Attrs> attrs = 
make_object<ScatterFromWorker0Attrs>();
+  attrs->num_workers = std::move(num_workers);
+  static const Op& op = Op::Get("relax.ccl.scatter_from_worker0");
+
+  return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0);
+
+StructInfo InferStructInfoScatterFromWorker0(const Call& call, const 
BlockBuilder& ctx) {
+  TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  DataType output_dtype = input_sinfo->dtype;
+
+  const auto* attrs = call->attrs.as<ScatterFromWorker0Attrs>();
+  int num_workers = attrs->num_workers;
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  const auto* input_shape = input_sinfo->shape.as<ShapeExprNode>();
+  input_sinfo->shape.as<VarNode>();
+  CHECK(input_shape != nullptr)
+      << "input tensor of scatter_from_worker0 should have defined ShapeExpr 
as shape";

Review Comment:
   Use `TensorStructInfoNode::GetShape`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to