anirudh2290 commented on a change in pull request #15298: Fix Cached_op with
static_shape=true
URL: https://github.com/apache/incubator-mxnet/pull/15298#discussion_r297298941
##########
File path: src/nnvm/legacy_op_util.cc
##########
@@ -110,47 +109,39 @@ class OperatorState {
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
- if (!fwd_init_) {
- CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
- CHECK_EQ(outputs.size(), out_data_.size());
- // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except
that the ones
- // referred by arg_data_ptr_ will be overriden
- for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] =
inputs[i];
- for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] =
inputs[i];
- for (size_t i = 0; i < aux_data_.size(); ++i) {
- aux_data_[i] = inputs[i + in_data_fwd_.size()];
- }
- for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
- fwd_init_ = true;
+ CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
+ CHECK_EQ(outputs.size(), out_data_.size());
+ // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except
that the ones
+ // referred by arg_data_ptr_ will be overriden
+ for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] =
inputs[i];
+ for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] =
inputs[i];
+ for (size_t i = 0; i < aux_data_.size(); ++i) {
+ aux_data_[i] = inputs[i + in_data_fwd_.size()];
}
+ for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
}
void Backward(const OpContext &ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
- if (!bwd_init_) {
- CHECK(fwd_init_);
- CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
- // override tblobs pointed by arg_data_ptr_ since they might not contain
- // initialized data during forward pass.
- for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
- *arg_data_ptr_[i] = inputs[i];
- }
- for (size_t i = 0; i < aux_data_.size(); ++i) {
- aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
- }
- CHECK_EQ(outputs.size(), in_grad_.size());
- for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
- bwd_init_ = true;
Review comment:
this caching was first removed in #14738 . I think this has certain
performance implications since we are not caching the TBlobs anymore. Is the
use case also similar, is this caused by split operator ?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services