This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 801e51b211 Fix curand. (#11901)
801e51b211 is described below

commit 801e51b21128f88a3e405c1022c41f8c25ac4118
Author: Xiyou Zhou <[email protected]>
AuthorDate: Mon Jun 27 12:56:17 2022 -0700

    Fix curand. (#11901)
---
 src/runtime/contrib/curand/curand.cc | 26 +++++++++++++++++++++-----
 1 file changed, 21 insertions(+), 5 deletions(-)

diff --git a/src/runtime/contrib/curand/curand.cc 
b/src/runtime/contrib/curand/curand.cc
index 23282304f7..50600d913f 100644
--- a/src/runtime/contrib/curand/curand.cc
+++ b/src/runtime/contrib/curand/curand.cc
@@ -79,18 +79,34 @@ void RandomFill(DLTensor* tensor) {
   static DeviceAPI* cuda_api = GetCUDADeviceAPI();
   CHECK(tensor->device.device_type == DLDeviceType::kDLCUDA)
       << "ValueError: cuRAND only works on CUDA devices";
+  int64_t tensor_size = GetTensorSize(tensor);
+  int64_t actual_size = tensor_size % 2 == 0 ? tensor_size : tensor_size + 1;
   if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 
16) {
-    int64_t tensor_size = GetTensorSize(tensor);
-    void* data = cuda_api->AllocWorkspace(tensor->device, tensor_size * 
sizeof(float));
+    // curand only works for size % 2 = 0
+    void* data = cuda_api->AllocWorkspace(tensor->device, actual_size * 
sizeof(float));
     {
       DeferredFunc defer([data, tensor]() { 
cuda_api->FreeWorkspace(tensor->device, data); });
-      CURandGenerator().Generate32bit(data, GetTensorSize(tensor));
+      CURandGenerator().Generate32bit(data, actual_size);
       ConvertFp32toFp16(/*src=*/data, /*dst=*/tensor->data, 
/*num=*/tensor_size);
     }
   } else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && 
tensor->dtype.bits == 32) {
-    CURandGenerator().Generate32bit(tensor->data, GetTensorSize(tensor));
+    if (tensor_size % 2 == 1) {
+      void* data = cuda_api->AllocWorkspace(tensor->device, actual_size * 
sizeof(float));
+      DeferredFunc defer([data, tensor]() { 
cuda_api->FreeWorkspace(tensor->device, data); });
+      CURandGenerator().Generate32bit(data, actual_size);
+      cudaMemcpy(tensor->data, data, tensor_size * sizeof(float), 
cudaMemcpyDeviceToDevice);
+    } else {
+      CURandGenerator().Generate32bit(tensor->data, actual_size);
+    }
   } else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && 
tensor->dtype.bits == 64) {
-    CURandGenerator().Generate64bit(tensor->data, GetTensorSize(tensor));
+    if (tensor_size % 2 == 1) {
+      void* data = cuda_api->AllocWorkspace(tensor->device, actual_size * 
sizeof(double));
+      DeferredFunc defer([data, tensor]() { 
cuda_api->FreeWorkspace(tensor->device, data); });
+      CURandGenerator().Generate64bit(data, actual_size);
+      cudaMemcpy(tensor->data, data, tensor_size * sizeof(double), 
cudaMemcpyDeviceToDevice);
+    } else {
+      CURandGenerator().Generate64bit(tensor->data, actual_size);
+    }
   } else {
     LOG(FATAL) << "ValueError: Unsupported dtype: " << tensor->dtype;
   }

Reply via email to