fhieber opened a new issue #11384: BucketingModule crash with parameters conditioned on bucket key URL: https://github.com/apache/incubator-mxnet/issues/11384 I am running into an issue with the BucketingModule when using a `sym_gen` function that uses a parameter variable only for certain bucket keys. Consider the following example code: ```python import mxnet as mx def sym_gen(bucket_key): data = mx.sym.Variable('data') weight = mx.sym.Variable('weight') if bucket_key <= 1: out = mx.sym.FullyConnected(data=data, weight=weight, no_bias=True, num_hidden=2, flatten=False) else: out = data return out, ['data'], [] default_bucket_key = 10 mod = mx.module.BucketingModule(sym_gen=sym_gen, default_bucket_key=default_bucket_key) data_shapes = [mx.io.DataDesc(name='data', shape=(2, default_bucket_key, 2))] mod.bind(data_shapes=data_shapes, for_training=False, grad_req="null") mod.init_params() print("module initialized with default bucket key") for bucket_key in range(1, default_bucket_key): print("BUCKET KEY", bucket_key) # batch with bucket_key==1 data_batch = mx.io.DataBatch(data=[mx.nd.ones((2, bucket_key, 2))], label=[], provide_data=[mx.io.DataDesc(name='data', shape=(2, bucket_key, 2))], bucket_key=bucket_key) mod.forward(data_batch=data_batch) print(mod.get_outputs()) ``` The above code crashes with the following output: ``` ['data'] ['data'] module initialized with default bucket key BUCKET KEY 1 ['data', 'weight'] libc++abi.dylib: terminating with uncaught exception of type std::out_of_range: unordered_map::at: key not found Process finished with exit code 134 (interrupted by signal 6: SIGABRT) ``` The crash happens in the `mod.forward()` call when trying to allocate/switch to a new bucket (1). If the above sym_gen code is changed to use the `weight` variable for the default bucket key (e.g. change `if bucket_key <= 1:` to `if bucket_key > 1:`), everything runs fine, as the default graph probably has allocated memory for the `weight` variable. My questions are as follows: - It took my a while to figure out the problem in my actual use case as the low-level error message is not really helpful. It'd be great if MXNet could guard against such parameter-related issues. - What is your general opinion about this kind of code? I can work around this issue by setting the default bucket key of the module to use 'all potential parameters' (in this case default_bucket_key=1), but in my use case that probably hurts memory sharing between buckets: usually one sets the default bucket key so that it corresopnds to the 'largest' computation graph (for example in terms of sequence length). In this particular example the questions is what is 'largest': largest sequence length, or largest w.r.t parameters/variables.
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
