anirudh2290 commented on a change in pull request #11631: Fix batchnorm problem
with sparse matrices when fix_gamma=True
URL: https://github.com/apache/incubator-mxnet/pull/11631#discussion_r201799149
##########
File path: src/operator/nn/batch_norm.cc
##########
@@ -452,19 +453,66 @@ static inline bool BatchNormStorageType(const
nnvm::NodeAttrs &attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 5);
CHECK_EQ(out_attrs->size(), 3);
- return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
- in_attrs, out_attrs);
+ const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
+
+ if ((common::ContainsStorageType(*in_attrs, kRowSparseStorage) ||
+ common::ContainsStorageType(*in_attrs, kCSRStorage)) &&
+ param.fix_gamma) {
+ LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays.
Tracked at #11647";
+ }
+ for (int& v : *in_attrs)
+ if (v == - 1) v = kDefaultStorage;
+ bool dispatched = false;
+ if (!dispatched && (common::ContainsStorageType(*in_attrs,
kRowSparseStorage) ||
+ common::ContainsStorageType(*in_attrs, kCSRStorage))) {
+ dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+ }
+#if MXNET_USE_MKLDNN == 1
Review comment:
do we need this if else here. MKLDNNStorageType is handling both cases.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services