eric-haibin-lin commented on a change in pull request #17270: [WIP] Dynamic custom operator GPU support URL: https://github.com/apache/incubator-mxnet/pull/17270#discussion_r366082428
########## File path: src/c_api/c_api.cc ########## @@ -720,8 +751,11 @@ int MXLoadLib(const char *path) { gradOp.set_attr<bool>("TIsLayerOpBackward", true, plevel); gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstateful_backward, plevel); + gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Review comment: What is the target supported contexts for this feature? Do we target just cpu and gpu, or we want to support other hardware backends, too? Currently the dispatch logic is inside FCompute, which is a bit different from existing mxnet users' experience. Usually the FCompute only declares the computation, and leave the dispatch logic to MXNet executor. And it's unclear how it supports the case where the same op is extended by a library for Intel CPUs and NVIDIA GPUs - they may hard-code the dispatch logic to only care about their own hardware. How do we handle such conflicts? Furthermore, currently the infer_shape/infer_dtype is not context-aware, i.e. CPU and GPU infers the same dtype. However, it may not be true (e.g. cpu supports fp32 and bfloat16, and gpu supports fp32 and fp16). How do we handle these attribute conflict? I had a short discussion with @yzhliu and we saw two potential fixes: 1. make infer_shape/infer_dtype context aware. This way we can have different infer_dtype function for cpu & gpu. MXNet needs to dispatch to the function based on the current context. For example, `op.set_attr<FInferType>("FInferType<cpu>", my_infer_type_function)` for cpu specific type inference, and `op.set_attr<FInferType>("FInferType<gpu>", my_infer_type_function_gpu)`for gpu. 2. Another way is to register ops with different names (e.g. 'cpu_gemm' and 'gpu_gemm'). This way they can have different infer_attr functions. But we don't want users to modify their model definition in the training script to these names. To mitigate that we can have an API to allow user to provide a mapping (e.g. {'gemm' -> 'cpu_gemm'}) for mxnet to map an op to another op registered in the backend. Finally, is there a plan to support dynamic custom context? :P @samskalicky ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services