Kacper-Pietkun commented on code in PR #21132:
URL: https://github.com/apache/incubator-mxnet/pull/21132#discussion_r954894121
##########
src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:
##########
@@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- mxnet::TShape new_lshape, new_rshape, new_oshape;
- int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
- inputs[1].shape(),
- outputs[0].shape(),
- &new_lshape,
- &new_rshape,
- &new_oshape);
- std::vector<NDArray> new_inputs;
- std::vector<NDArray> new_outputs;
- if (ndim_diff) {
- new_inputs = {inputs[0].Reshape(new_lshape),
inputs[1].Reshape(new_rshape)};
- new_outputs = {outputs[0].Reshape(new_oshape)};
- } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
- // BinaryBroadcastShapeCompact function doesn't reshape tensors of size
(1,1,...,1)
- // into shape (1). It is mandatory for oneDNN primitive to have this
reshape done.
- mxnet::TShape one_shape = mxnet::TShape(1, 1);
- new_inputs = {inputs[0].Reshape(one_shape),
inputs[1].Reshape(one_shape)};
- new_outputs = {outputs[0].Reshape(one_shape)};
+ // We can use more efficient sum kernel when there is no broadcast - when
shapes are the same
+ const bool same_shape = (inputs[0].shape() == inputs[1].shape());
+
+ if (same_shape && alg == dnnl::algorithm::binary_add) {
+ DNNLSumFwd& fwd = DNNLSumFwd::GetCached(inputs, outputs);
+ fwd.Execute(ctx, inputs, req, outputs);
} else {
- new_inputs = {inputs[0], inputs[1]};
- new_outputs = {outputs[0]};
- }
+ mxnet::TShape new_lshape, new_rshape, new_oshape;
+ int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
+ inputs[1].shape(),
+ outputs[0].shape(),
+ &new_lshape,
+ &new_rshape,
+ &new_oshape);
+ std::vector<NDArray> new_inputs;
+ std::vector<NDArray> new_outputs;
+ if (ndim_diff) {
+ new_inputs = {inputs[0].Reshape(new_lshape),
inputs[1].Reshape(new_rshape)};
+ new_outputs = {outputs[0].Reshape(new_oshape)};
+ } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1)
{
+ // BinaryBroadcastShapeCompact function doesn't reshape tensors of size
(1,1,...,1)
+ // into shape (1). It is mandatory for oneDNN primitive to have this
reshape done.
Review Comment:
BinaryBroadcastShapeCompact does not support reshape of tensors of size
(1,1,...,1), but we need to reshape such tensors into shape (1) for oneDNN to
work properly.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]