samskalicky commented on a change in pull request #12430: [MXNET-882] Support
for N-d arrays added to diag op.
URL: https://github.com/apache/incubator-mxnet/pull/12430#discussion_r216770574
##########
File path: src/operator/tensor/diag_op-inl.h
##########
@@ -133,13 +167,94 @@ struct diag_gen {
auto j = unravel(i, oshape);
if (j[1] == (j[0] + k)) {
auto l = j[0] < j[1] ? j[0] : j[1];
- KERNEL_ASSIGN(out[i], req, a[l]);
- } else {
+ if (back) {
+ KERNEL_ASSIGN(out[l], req, a[i]);
+ } else {
+ KERNEL_ASSIGN(out[i], req, a[l]);
+ }
+ } else if (!back) {
KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
}
}
};
+template<typename xpu, bool back>
+void DiagOpProcess(const TBlob& in_data,
+ const TBlob& out_data,
+ const TShape& ishape,
+ const TShape& oshape,
+ int dsize,
+ const DiagParam& param,
+ mxnet_op::Stream<xpu> *s,
+ const std::vector<OpReqType>& req) {
+ using namespace mxnet_op;
+ using namespace mshadow;
+ if (ishape.ndim() > 1) {
+ // input : (leading + i, body + i, trailing)
+ int x1 = CheckAxis(param.axis1.value(), ishape.ndim());
+ int x2 = CheckAxis(param.axis2.value(), ishape.ndim());
+
+ int idim = ishape.ndim(), odim = oshape.ndim();
+
+ int minx = x1, maxx = x2;
+ if (minx > maxx)
+ std::swap(minx, maxx);
+
+ int oleading = 1, obody = 1, otrailing = 1;
+ for (int i = 0; i < minx; i ++)
+ oleading *= ishape[i];
Review comment:
This code is counting the number of elements, so its possible the total
number of elements in an array could be greater than 2,147,483,647 (2^31-1).
This would be an 8GB array though. Seems like a special case. If you have a
data set so large so that the data operated on by a single operator is 8GB
you're probably going to batch it up into something smaller.
This would be a different feature request. And we would probably want to
review this separately from the feature in this PR (extending 2D to N-D)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services