reminisce commented on a change in pull request #14409: [Numpy] Change
semantics of ndim for operators in `src/operator/contrib`
URL: https://github.com/apache/incubator-mxnet/pull/14409#discussion_r265673494
##########
File path: src/operator/contrib/transformer-inl.h
##########
@@ -41,7 +41,9 @@ static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- double sqrt_dim =
std::sqrt(static_cast<double>(inputs[0].shape_[inputs[0].ndim() - 1]));
+ int input_ndim = inputs[0].ndim();
+ int last_idx = (input_ndim == 0) ? (0) : (input_ndim - 1);
+ double sqrt_dim = std::sqrt(static_cast<double>(inputs[0].shape_[last_idx]));
Review comment:
If `input_ndim = 0`, `inputs[0]` is a scalar, and `inputs[0].shape_` will
have nothing. I think instead of making these changes, we should just add one
check: `CHECK_GT(inputs[0].ndim(), 1)`?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services