gyshi commented on a change in pull request #16257: [Numpy] add numpy op
bitwise_xor, hsplit, moveaxis, rot90
URL: https://github.com/apache/incubator-mxnet/pull/16257#discussion_r329309685
##########
File path: src/operator/numpy/np_matrix_op.cc
##########
@@ -612,5 +614,215 @@ NNVM_REGISTER_OP(_backward_npi_flip)
})
.set_attr<FCompute>("FCompute<cpu>", NumpyFlipForward<cpu>);
+bool NumpyMoveaxisShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_attrs,
+ mxnet::ShapeVector *out_attrs) {
+ const NumpyMoveaxisParam& param =
nnvm::get<NumpyMoveaxisParam>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ mxnet::TShape& shp = (*in_attrs)[0];
+ CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
+ CHECK_EQ(param.source.ndim(), param.destination.ndim())
+ << "source and destination not equal.";
+ mxnet::TShape ret(shp.ndim(), -1);
+ mxnet::TShape axes(shp.ndim(), -1);
+ std::vector<bool> state_axes(shp.ndim(), false);
+ mxnet::TShape real_src(param.source.ndim(), -1);
+ mxnet::TShape real_des(param.destination.ndim(), -1);
+ for (int i = 0; i < param.source.ndim(); ++i) {
+ if (param.source[i] >= 0) {
+ CHECK_LT(static_cast<size_t>(param.source[i]), shp.ndim());
+ real_src[i] = param.source[i];
+ } else {
+ CHECK_LT(param.source[i] + shp.ndim(), shp.ndim());
+ real_src[i] = param.source[i] + shp.ndim();
+ }
+ if (param.destination[i] >= 0) {
+ CHECK_LT(static_cast<size_t>(param.destination[i]), shp.ndim());
+ real_des[i] = param.destination[i];
+ } else {
+ CHECK_LT(param.destination[i] + shp.ndim(), shp.ndim());
+ real_des[i] = param.destination[i] + shp.ndim();
+ }
+ }
+ if (shp.ndim() > 1) {
+ for (int i = 0; i < param.source.ndim() - 1; ++i) {
+ for (int j = i + 1; j < param.source.ndim(); ++j) {
+ CHECK_NE(real_src[i], real_src[j])
+ << "repeated axis in `source` argument";
+ CHECK_NE(real_des[i], real_des[j])
+ << "repeated axis in `destination` argument";
+ }
+ }
+ }
+ for (int i = 0; i < param.source.ndim(); ++i) {
+ axes[real_des[i]] = real_src[i];
+ state_axes[real_src[i]] = true;
+ }
+ for (int i = 0; i < axes.ndim(); ++i) {
+ if (axes[i] < 0) {
+ for (int j = 0; j < axes.ndim(); ++j) {
+ if (state_axes[j] == false) {
+ axes[i] = j;
+ state_axes[j] = true;
+ break;
+ }
+ }
+ }
+ }
Review comment:
you can see numpy moveaxis, you must give origin corresponding order
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services