Kathryn-cat commented on code in PR #18681:
URL: https://github.com/apache/tvm/pull/18681#discussion_r2719738855


##########
python/tvm/contrib/nvcc.py:
##########
@@ -410,10 +510,81 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
             nvrtc.nvrtcDestroyProgram(prog)
             raise RuntimeError(f"Failed to get PTX: 
{nvrtc.nvrtcGetErrorString(result)}")
 
-    # Clean up
+    # Clean up NVRTC program
     nvrtc.nvrtcDestroyProgram(prog)
 
-    return bytearray(binary_buf)
+    # link stage for NVSHMEM
+    if use_nvshmem:
+        import ctypes  # pylint: disable=import-outside-toplevel
+
+        # cuLinkCreate requires a valid CUDA context.
+        (result,) = cu.cuInit(0)
+        if result != cu.CUresult.CUDA_SUCCESS:
+            raise RuntimeError(f"Failed to initialize CUDA: {result}")
+
+        # Check if there's already a CUDA context; create one if not
+        result, context = cu.cuCtxGetCurrent()
+        if result != cu.CUresult.CUDA_SUCCESS or context is None or 
int(context) == 0:
+            result, device = cu.cuDeviceGet(0)
+            if result != cu.CUresult.CUDA_SUCCESS:
+                raise RuntimeError(f"Failed to get CUDA device: {result}")
+            result, context = cu.cuCtxCreate(None, 0, device)
+            if result != cu.CUresult.CUDA_SUCCESS:
+                raise RuntimeError(f"Failed to create CUDA context: {result}")

Review Comment:
   addressed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to