bgawrych commented on code in PR #21004:
URL: https://github.com/apache/incubator-mxnet/pull/21004#discussion_r861500749
##########
src/operator/tensor/broadcast_reduce_op.h:
##########
@@ -1354,6 +1354,94 @@ struct direct_copy {
}
};
+template <typename IType, typename OType>
+void BroadcastCPU(const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs,
+ const mxnet::TShape& src_shape,
+ const mxnet::TShape& dst_shape,
+ ShapeAndStride aux_data) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ using namespace mxnet_op;
+ constexpr size_t ELEMENTS_THRESHOLD = 256;
+ Stream<cpu>* s = ctx.get_stream<cpu>();
+
+ std::vector<size_t> elements_to_copy(aux_data.num_broadcast_axes);
+ std::vector<size_t> preaxis_dims(aux_data.num_broadcast_axes);
+ for (int ax = 0; ax < aux_data.num_broadcast_axes; ax++) {
+ index_t axis = aux_data.axes[ax];
+
+ elements_to_copy[ax] = 1;
+ for (int i = axis + 1; i < dst_shape.ndim(); i++) {
+ elements_to_copy[ax] *= dst_shape[i];
+ }
+ preaxis_dims[ax] = 1;
+ for (int i = axis - 1; i >= 0; i--) {
+ preaxis_dims[ax] *= src_shape[i];
Review Comment:
Because dst_shape is already broadcasted shape.
e.g. shapes: (6, 1, 4, 1, 2) broadcasted to (6, 5, 4, 3, 2) - first
braodcasting axis elements_to_copy is equal to 2 and in for the second one it's
equal to 24 (4*3*2)
in preaxis_dims I am calculating how many different dimensions are in src
tensor: for the first broadcasted axis it's 24 (6*1*4) and for second axis it's
6
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]