trevor-m commented on a change in pull request #5961:
URL: https://github.com/apache/incubator-tvm/pull/5961#discussion_r448619297



##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1269,6 +1269,93 @@ RELAY_REGISTER_OP("repeat")
     .set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
     .set_attr<TOpPattern>("TOpPattern", kBroadcast);
 
+// meshgrid operator
+TVM_REGISTER_NODE_TYPE(MeshgridAttrs);
+
+bool MeshgridRel(const Array<Type>& types, int num_inputs, const Attrs& 
raw_attrs,
+                 const TypeReporter& reporter) {
+  // types: [data, result]
+  CHECK_EQ(types.size(), 2);
+  const MeshgridAttrs* attrs = raw_attrs.as<MeshgridAttrs>();
+  const auto* tensor_tuple = types[0].as<TupleTypeNode>();
+  if (tensor_tuple == nullptr) {
+    throw Error(
+        ErrorBuilder() << "meshgrid requires a tuple of tensors as the first 
argument, found "
+                       << PrettyPrint(types[0]));
+  } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
+    return false;
+  }
+  const int data_length = static_cast<int>(tensor_tuple->fields.size());
+
+  // Get first dtype.
+  const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
+  const DataType dtype = first->dtype;
+
+  // Get size of output grid.
+  std::vector<IndexExpr> grid_shape;
+  grid_shape.reserve(data_length);
+  for (const Type& ele : tensor_tuple->fields) {
+    if (ele.as<IncompleteTypeNode>()) {
+      return false;
+    }
+    const auto& e = Downcast<TensorType>(ele);
+    int e_ndim = static_cast<int>(e->shape.size());
+    const DataType& e_dtype = e->dtype;
+    if (e_dtype != dtype) {
+      throw Error("relay.meshgrid requires all tensors have the same dtype");
+    }
+    if (e_ndim == 0) {
+      grid_shape.emplace_back(1);
+    } else if (e_ndim == 1) {
+      grid_shape.emplace_back(e->shape[0]);
+    } else {
+      throw Error("relay.meshgrid requires all tensors be either scalars or 
1-D vectors.");
+    }
+  }
+
+  // "xy" mode swaps first two dimensions
+  if (attrs->indexing == "xy" && grid_shape.size() >= 2) {
+    std::swap(grid_shape[0], grid_shape[1]);
+  }
+
+  // There is one output grid for each input, all with same shape.
+  std::vector<Type> grids;
+  grids.reserve(data_length);
+  for (int i = 0; i < data_length; i++) {
+    grids.emplace_back(TensorType(grid_shape, dtype));
+  }
+  reporter->Assign(types[1], TupleType(Array<Type>(grids)));
+  return true;
+}
+
+Array<te::Tensor> MeshgridCompute(const Attrs& attrs, const Array<te::Tensor>& 
inputs,
+                                  const Type& out_type) {
+  const MeshgridAttrs* param = attrs.as<MeshgridAttrs>();
+  CHECK(param != nullptr);
+  return {topi::meshgrid(inputs, param->indexing)};
+}
+
+Expr MakeMeshgrid(Expr data, String indexing) {
+  auto attrs = make_object<MeshgridAttrs>();
+  attrs->indexing = std::move(indexing);
+  static const Op& op = Op::Get("meshgrid");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.meshgrid").set_body_typed(MakeMeshgrid);
+
+RELAY_REGISTER_OP("meshgrid")
+    .describe(R"code(Create coordinate matrices from coordinate vectors.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<MeshgridAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input list of tensors.")
+    .set_support_level(3)
+    .add_type_rel("Meshgrid", MeshgridRel)
+    .set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);

Review comment:
       @masahi Thanks for the comment, I'm not very familar with `TOpPattern` 
and how it is used. I picked injective since both concat (Tuple input) and 
split (Tuple output) both used that. If anyone has any suggestions I am happy 
to change it.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to