ZiyueHuang commented on pull request #18622:
URL: https://github.com/apache/incubator-mxnet/pull/18622#issuecomment-678613328


   After this PR, the training of Electra Model in gluon-nlp will raise error 
below
   ```
   mxnet.base.MXNetError: Traceback (most recent call last):
     File "/home/ubuntu/mxnet/src/common/cuda/rtc.cc", line 163
   MXNetError: Check failed: compileResult == NVRTC_SUCCESS (6 vs. 0) : NVRTC 
Compilation failed.
   The generated code was stored in mxnet_rtc_debug_code.log
   binary_scalar_kernel_kernel.cu(1118): error: more than one instance of 
overloaded function "isnan" matches the argument list:
               function "isnan(float)"
               function "isnan(double)"
               function "isnan(long double)"
               argument types are: (const InputType0)
             detected during instantiation of "type_util::mixed_type<DType, 
DType2, void>::type op::min(DType, DType2) [with DType=InputType0, 
DType2=InputType0]"
   (2285): here
   ```
   
   Setting `MXNET_RTC_VERBOSE =1` will get the code for `binary_scalar_kernel`
   ```
   using InputType0 = float32;
   using OutputType0 = float32;
   const bool aligned = true;
   const int nvec = 4;
   const OpReqType req = OpReqType::kWriteTo;
   #define OP op::div
   
   
   struct binary_scalar_kernel_params {
     const void *inputs[2];
     void *outputs[1];
     double scalar;
   };
   
   __launch_bounds__(kRTCMaxThreadsPerBlock)
   __global__ void binary_scalar_kernel(const binary_scalar_kernel_params 
params,
                                        const index_t lead_dim,
                                        const index_t other_dim,
                                        const index_t N,
                                        const index_t num_aligned_elements) {
     using namespace vector;
     VectorizedLoader<InputType0, nvec, aligned> loader(
       reinterpret_cast<const InputType0*>(params.inputs[0]), N);
     VectorizedStorer<OutputType0, nvec, aligned> storer(
       reinterpret_cast<OutputType0*>(params.outputs[0]), N);
   
     using IType = AccType<InputType0>;
     using OType = AccType<OutputType0>;
   
     const index_t M = num_aligned_elements;
   
     for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
          tid < M;
          tid += gridDim.x * blockDim.x) {
       loader.load(tid, N);
       if (req == OpReqType::kAddTo) {
         storer.load(tid, N);
       }
   #pragma unroll
       for (int i = 0; i < nvec; ++i) {
         const auto input = IType::from(loader.separate()[i]);
         // enables returning different type
         const auto temp = OP(input,
                              static_cast<typename 
type_util::mixed_type<typename IType::type,
                                                                         
typename OType::type>::type>
                                (params.scalar));
   
         if (req == OpReqType::kAddTo) {
           // temp2 may have a wider type than either temp
           // or OType
           const auto temp2 = op::add(temp, OType::from(storer.separate()[i]));
           storer.separate()[i] = OType::to(temp2);
         } else {
           storer.separate()[i] = OType::to(temp);
         }
       }
       storer.store(tid, N);
     }
   }
   
   ```
   
   @ptrendx Could you please take a look?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to