ZhennanQin commented on a change in pull request #15298: Fix Cached_op with
static_shape=true
URL: https://github.com/apache/incubator-mxnet/pull/15298#discussion_r297417332
##########
File path: src/nnvm/legacy_op_util.cc
##########
@@ -110,47 +109,39 @@ class OperatorState {
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
- if (!fwd_init_) {
- CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
- CHECK_EQ(outputs.size(), out_data_.size());
- // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except
that the ones
- // referred by arg_data_ptr_ will be overriden
- for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] =
inputs[i];
- for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] =
inputs[i];
- for (size_t i = 0; i < aux_data_.size(); ++i) {
- aux_data_[i] = inputs[i + in_data_fwd_.size()];
- }
- for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
- fwd_init_ = true;
+ CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
+ CHECK_EQ(outputs.size(), out_data_.size());
+ // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except
that the ones
+ // referred by arg_data_ptr_ will be overriden
+ for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] =
inputs[i];
+ for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] =
inputs[i];
+ for (size_t i = 0; i < aux_data_.size(); ++i) {
+ aux_data_[i] = inputs[i + in_data_fwd_.size()];
}
+ for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
}
void Backward(const OpContext &ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
- if (!bwd_init_) {
- CHECK(fwd_init_);
- CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
- // override tblobs pointed by arg_data_ptr_ since they might not contain
- // initialized data during forward pass.
- for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
- *arg_data_ptr_[i] = inputs[i];
- }
- for (size_t i = 0; i < aux_data_.size(); ++i) {
- aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
- }
- CHECK_EQ(outputs.size(), in_grad_.size());
- for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
- bwd_init_ = true;
Review comment:
When using legacy ops in Cached_op, this caching is not correct, because
even static_alloc=true and static_shape=true, the input or output TBlobs may
changed if they are the input or output of Cached_op.
Thinking a small case that end-user only hybridize one legacy op, then its
input is the Cached_op's input, and also for output. Then end-user may pass
different NDArrays to this Cached_op, and this TBlobs cache isn't correct.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services