masahi commented on a change in pull request #5961:
URL: https://github.com/apache/incubator-tvm/pull/5961#discussion_r448621954



##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1269,6 +1269,93 @@ RELAY_REGISTER_OP("repeat")
     .set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
     .set_attr<TOpPattern>("TOpPattern", kBroadcast);
 
+// meshgrid operator
+TVM_REGISTER_NODE_TYPE(MeshgridAttrs);
+
+bool MeshgridRel(const Array<Type>& types, int num_inputs, const Attrs& 
raw_attrs,
+                 const TypeReporter& reporter) {
+  // types: [data, result]
+  CHECK_EQ(types.size(), 2);
+  const MeshgridAttrs* attrs = raw_attrs.as<MeshgridAttrs>();
+  const auto* tensor_tuple = types[0].as<TupleTypeNode>();
+  if (tensor_tuple == nullptr) {
+    throw Error(
+        ErrorBuilder() << "meshgrid requires a tuple of tensors as the first 
argument, found "
+                       << PrettyPrint(types[0]));
+  } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
+    return false;
+  }
+  const int data_length = static_cast<int>(tensor_tuple->fields.size());
+
+  // Get first dtype.
+  const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
+  const DataType dtype = first->dtype;
+
+  // Get size of output grid.
+  std::vector<IndexExpr> grid_shape;
+  grid_shape.reserve(data_length);
+  for (const Type& ele : tensor_tuple->fields) {
+    if (ele.as<IncompleteTypeNode>()) {
+      return false;
+    }
+    const auto& e = Downcast<TensorType>(ele);
+    int e_ndim = static_cast<int>(e->shape.size());
+    const DataType& e_dtype = e->dtype;
+    if (e_dtype != dtype) {
+      throw Error("relay.meshgrid requires all tensors have the same dtype");
+    }
+    if (e_ndim == 0) {
+      grid_shape.emplace_back(1);
+    } else if (e_ndim == 1) {
+      grid_shape.emplace_back(e->shape[0]);
+    } else {
+      throw Error("relay.meshgrid requires all tensors be either scalars or 
1-D vectors.");
+    }
+  }
+
+  // "xy" mode swaps first two dimensions
+  if (attrs->indexing == "xy" && grid_shape.size() >= 2) {
+    std::swap(grid_shape[0], grid_shape[1]);
+  }
+
+  // There is one output grid for each input, all with same shape.
+  std::vector<Type> grids;
+  grids.reserve(data_length);
+  for (int i = 0; i < data_length; i++) {
+    grids.emplace_back(TensorType(grid_shape, dtype));
+  }
+  reporter->Assign(types[1], TupleType(Array<Type>(grids)));
+  return true;
+}
+
+Array<te::Tensor> MeshgridCompute(const Attrs& attrs, const Array<te::Tensor>& 
inputs,
+                                  const Type& out_type) {
+  const MeshgridAttrs* param = attrs.as<MeshgridAttrs>();
+  CHECK(param != nullptr);
+  return {topi::meshgrid(inputs, param->indexing)};
+}
+
+Expr MakeMeshgrid(Expr data, String indexing) {
+  auto attrs = make_object<MeshgridAttrs>();
+  attrs->indexing = std::move(indexing);
+  static const Op& op = Op::Get("meshgrid");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.meshgrid").set_body_typed(MakeMeshgrid);
+
+RELAY_REGISTER_OP("meshgrid")
+    .describe(R"code(Create coordinate matrices from coordinate vectors.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<MeshgridAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input list of tensors.")
+    .set_support_level(3)
+    .add_type_rel("Meshgrid", MeshgridRel)
+    .set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);

Review comment:
       hmm you are right about split op. I think if all output tensors in the 
tuple are used by the following op, there is no problem.
   
    But if some output tensors are used by one op and others used by another 
op, and these two ops end up in a different fusion groups, the computation of 
split or meshgrid may be duplicated .




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to