cjolivier01 commented on a change in pull request #9770: eye operator, for 
default storage type
URL: https://github.com/apache/incubator-mxnet/pull/9770#discussion_r167722707
 
 

 ##########
 File path: src/operator/tensor/init_op.h
 ##########
 @@ -63,6 +63,86 @@ struct InitOpParam : public dmlc::Parameter<InitOpParam> {
   }
 };
 
+struct EyeParam : public dmlc::Parameter<EyeParam> {
+  nnvm::dim_t N;
+  nnvm::dim_t M;
+  nnvm::dim_t k;
+  std::string ctx;
+  int dtype;
+
+  DMLC_DECLARE_PARAMETER(EyeParam) {
+    DMLC_DECLARE_FIELD(N)
+    .describe("Number of rows in the output.");
+    DMLC_DECLARE_FIELD(M)
+    .set_default(0)
+    .describe("Number of columns in the output. If 0, defaults to N");
+    DMLC_DECLARE_FIELD(k)
+    .set_default(0)
+    .describe("Index of the diagonal. 0 (the default) refers to the main 
diagonal."
+              "A positive value refers to an upper diagonal."
+              "A negative value to a lower diagonal.");
+    DMLC_DECLARE_FIELD(ctx)
+    .set_default("")
+    .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
+              "Only used for imperative calls.");
+    DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("uint8", mshadow::kUint8)
+    .add_enum("int32", mshadow::kInt32)
+    .add_enum("int64", mshadow::kInt64)
+    .describe("Target data type.");
+  }
+};
+
+template<typename ParamType>
+inline bool InitEyeShape(const nnvm::NodeAttrs& attrs,
+                         std::vector<TShape> *in_attrs,
+                         std::vector<TShape> *out_attrs) {
+  const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 0U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, param.M > 0 ? 
param.M : param.N));
+  return true;
+}
+
+template<int req>
+struct eye_dns_fill {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data, const nnvm::dim_t 
num_cols,
+                                  const nnvm::dim_t k) {
+    if ((i % num_cols) == ((i / num_cols) + k)) {
 
 Review comment:
   Looks like two divides per value-fill. Would it be faster ( at least for 
CPU), to fill with 0's and then "walk" across an offset (using only add) to 
fill in the one's?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to