arcadiaphy edited a comment on issue #14522: mx.nd.Custom conflicts with memory management URL: https://github.com/apache/incubator-mxnet/issues/14522#issuecomment-477870929 @anirudh2290 @wkcn @YutingZhang Finally figure out the reason: Normally, when a exception is thrown in spawn thread, it should be caught in `std::exception_ptr` and re-thrown in main thread to ensure proper except handling and avoid `std::terminate`. This mechanism is introduced in #9681 to handle exceptions in operator. But there are still two problems in the except handling of custom op: 1. The exception thrown in custom thread cannot be caught in main thread, causing program crash. 2. When exception happens in custom op, it will be caught and re-thrown in sync function `WaitForVar`. But the `WaitForVar` is deadlocked because the pushed operation `CustomOperation` is skipped running https://github.com/apache/incubator-mxnet/blob/master/src/operator/custom/custom-inl.h#L128, making the write dependency of waiting ndarray not completed forever. By adding `c.wait_to_read()`, the exception is forced to be re-thrown, but the program still crashes due to Problem 1. ``` import mxnet as mx class MyMulMax(mx.operator.CustomOp): def __init__(self): super(MyMulMax, self).__init__() def forward(self, is_train, req, in_data, out_data, aux): a, b = in_data[0:2] c = mx.nd.batch_dot(a, b) # force re-throw exception, program still crash due to uncaught exception # c.wait_to_read() d = mx.nd.max(c, axis=-1, keepdims=True) self.assign(out_data[0], req[0], d) def backward(self, req, out_grad, in_data, out_data, in_grad, aux): self.assign(in_grad[0], req[0], 0) self.assign(in_grad[1], req[1], 0) @mx.operator.register("MyMulMax") class MyMulMaxProp(mx.operator.CustomOpProp): def __init__(self): super(MyMulMaxProp, self).__init__() def list_arguments(self): return ['a', 'b'] def list_outputs(self): return ['d'] def infer_shape(self, in_shape): return in_shape, [list(in_shape[0][:-1] + [1])] def create_operator(self, ctx, shapes, dtypes): return MyMulMax() def custom(n): with mx.Context(mx.gpu(0)): a = mx.nd.random.uniform(shape=(n, 6000, 1)) b = mx.nd.random.uniform(shape=(n, 1, 7000)) d = mx.nd.Custom(a, b, op_type="MyMulMax") d.wait_to_read() def direct(n): with mx.Context(mx.gpu(0)): a = mx.nd.random.uniform(shape=(n, 6000, 1)) b = mx.nd.random.uniform(shape=(n, 1, 7000)) c = mx.nd.batch_dot(a, b) d = mx.nd.max(c, axis=-1, keepdims=True) # deadlock due to skipped CustomOperation function d.wait_to_read() if __name__ == "__main__": custom(100) ```
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
