This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 801e51b211 Fix curand. (#11901)
801e51b211 is described below
commit 801e51b21128f88a3e405c1022c41f8c25ac4118
Author: Xiyou Zhou <[email protected]>
AuthorDate: Mon Jun 27 12:56:17 2022 -0700
Fix curand. (#11901)
---
src/runtime/contrib/curand/curand.cc | 26 +++++++++++++++++++++-----
1 file changed, 21 insertions(+), 5 deletions(-)
diff --git a/src/runtime/contrib/curand/curand.cc
b/src/runtime/contrib/curand/curand.cc
index 23282304f7..50600d913f 100644
--- a/src/runtime/contrib/curand/curand.cc
+++ b/src/runtime/contrib/curand/curand.cc
@@ -79,18 +79,34 @@ void RandomFill(DLTensor* tensor) {
static DeviceAPI* cuda_api = GetCUDADeviceAPI();
CHECK(tensor->device.device_type == DLDeviceType::kDLCUDA)
<< "ValueError: cuRAND only works on CUDA devices";
+ int64_t tensor_size = GetTensorSize(tensor);
+ int64_t actual_size = tensor_size % 2 == 0 ? tensor_size : tensor_size + 1;
if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits ==
16) {
- int64_t tensor_size = GetTensorSize(tensor);
- void* data = cuda_api->AllocWorkspace(tensor->device, tensor_size *
sizeof(float));
+ // curand only works for size % 2 = 0
+ void* data = cuda_api->AllocWorkspace(tensor->device, actual_size *
sizeof(float));
{
DeferredFunc defer([data, tensor]() {
cuda_api->FreeWorkspace(tensor->device, data); });
- CURandGenerator().Generate32bit(data, GetTensorSize(tensor));
+ CURandGenerator().Generate32bit(data, actual_size);
ConvertFp32toFp16(/*src=*/data, /*dst=*/tensor->data,
/*num=*/tensor_size);
}
} else if (tensor->dtype.code == DLDataTypeCode::kDLFloat &&
tensor->dtype.bits == 32) {
- CURandGenerator().Generate32bit(tensor->data, GetTensorSize(tensor));
+ if (tensor_size % 2 == 1) {
+ void* data = cuda_api->AllocWorkspace(tensor->device, actual_size *
sizeof(float));
+ DeferredFunc defer([data, tensor]() {
cuda_api->FreeWorkspace(tensor->device, data); });
+ CURandGenerator().Generate32bit(data, actual_size);
+ cudaMemcpy(tensor->data, data, tensor_size * sizeof(float),
cudaMemcpyDeviceToDevice);
+ } else {
+ CURandGenerator().Generate32bit(tensor->data, actual_size);
+ }
} else if (tensor->dtype.code == DLDataTypeCode::kDLFloat &&
tensor->dtype.bits == 64) {
- CURandGenerator().Generate64bit(tensor->data, GetTensorSize(tensor));
+ if (tensor_size % 2 == 1) {
+ void* data = cuda_api->AllocWorkspace(tensor->device, actual_size *
sizeof(double));
+ DeferredFunc defer([data, tensor]() {
cuda_api->FreeWorkspace(tensor->device, data); });
+ CURandGenerator().Generate64bit(data, actual_size);
+ cudaMemcpy(tensor->data, data, tensor_size * sizeof(double),
cudaMemcpyDeviceToDevice);
+ } else {
+ CURandGenerator().Generate64bit(tensor->data, actual_size);
+ }
} else {
LOG(FATAL) << "ValueError: Unsupported dtype: " << tensor->dtype;
}