This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new e6a654a  [Feature] Add a function to build c-dlpack for torch (#192)
e6a654a is described below

commit e6a654aaaad469ca455057821db01a995f312e2f
Author: Yaoyao Ding <[email protected]>
AuthorDate: Tue Oct 28 16:52:05 2025 -0400

    [Feature] Add a function to build c-dlpack for torch (#192)
    
    This PR adds a script
    `python/tvm_ffi/utils/_build_optional_c_dlpack.py`, which a standalone
    python module that is used to build a shared library as a patch to make
    old version of pytorch to support the dlpack protocol.
    
    Two kinds of usage:
    
    1. used as a standalone python module to build the shared library for
    each combination of `<python-version>` and `<pytorch-version>` and build
    a package to contain the built libraries.
    2. used as a jit module in tvm-ffi when there is not a prebuilt patch
    found.
    
    ## AOT Usage
    
    ```bash
    python python/tvm_ffi/utils/_build_optional_c_dlpack.py --build_dir 
<build-dir> [--build_with_cuda]
    ```
    
    Then there will be a shared library under the given `<build-dir>`
    (either `libtorch_c_dlpack_addon.so` or `libtorch_c_dlpack_addon.dll`,
    depends on the platform). The built shared library is specific to the
    current
    python version and pytorch version.
    
    ## JIT Usage
    
    We launched a python interpreter to run the build script. It will store
    the built shared library in tvm-ffi cache. See
    `python/tvm-ffi/_optional_torch_c_dlpack.py`.
    
    ---------
    
    Signed-off-by: Yaoyao Ding <[email protected]>
---
 .github/workflows/ci_test.yml                      |  17 +
 .github/workflows/utils/locate_vsdevcmd_bat.py     |  66 +++
 pyproject.toml                                     |   8 +-
 python/tvm_ffi/_optional_torch_c_dlpack.py         | 607 ++-------------------
 python/tvm_ffi/cpp/load_inline.py                  |   4 +-
 .../_build_optional_c_dlpack.py}                   | 419 +++++++++-----
 tests/python/test_dlpack_exchange_api.py           |   6 +
 tests/python/test_optional_torch_c_dlpack.py       |  79 +++
 8 files changed, 484 insertions(+), 722 deletions(-)

diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml
index 380e97e..4ac2a9c 100644
--- a/.github/workflows/ci_test.yml
+++ b/.github/workflows/ci_test.yml
@@ -125,6 +125,16 @@ jobs:
           cmake --build build_test --clean-first --config Debug --target 
tvm_ffi_tests &&
           ctest -V -C Debug --test-dir build_test --output-on-failure
 
+      - name: Locate and Set VsDevCmd Path [windows]
+        if: ${{ matrix.os == 'windows-latest' }}
+        shell: pwsh
+        run: |
+          # Captures the output (path) from your Python script
+          $vsPath = python .github/workflows/utils/locate_vsdevcmd_bat.py
+
+          # Sets an environment variable for all subsequent steps in this job
+          "VS_DEV_CMD_PATH=$vsPath" | Out-File -FilePath $env:GITHUB_ENV 
-Append
+
       # Run Python tests
       - name: Setup Python ${{ matrix.python_version }}
         uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4  # 
v6.7.0
@@ -137,7 +147,14 @@ jobs:
         run: |
           uv pip install --reinstall --verbose -e ".[test]"
       - name: Run python tests
+        if: ${{ matrix.os != 'windows-latest' }}
+        run: |
+          pytest -vvs tests/python
+      - name: Run python tests [windows]
+        if: ${{ matrix.os == 'windows-latest' }}
+        shell: cmd
         run: |
+          call "%VS_DEV_CMD_PATH%"
           pytest -vvs tests/python
 
       # Run Rust tests, must happen after installing the pip package.
diff --git a/.github/workflows/utils/locate_vsdevcmd_bat.py 
b/.github/workflows/utils/locate_vsdevcmd_bat.py
new file mode 100644
index 0000000..a1487b2
--- /dev/null
+++ b/.github/workflows/utils/locate_vsdevcmd_bat.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Locate the VsDevCmd.bat file for the current Visual Studio installation."""
+
+import os
+import subprocess
+from pathlib import Path
+
+
+def main() -> None:
+    """Locate the VsDevCmd.bat file for the current Visual Studio installation.
+
+    Raise exception if not found. If found, print the path to stdout.
+    """
+    # Path to vswhere.exe
+    vswhere_path = str(
+        Path(os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)"))
+        / "Microsoft Visual Studio"
+        / "Installer"
+        / "vswhere.exe"
+    )
+
+    if not Path(vswhere_path).exists():
+        raise FileNotFoundError("vswhere.exe not found.")
+
+    # Find the Visual Studio installation path
+    vs_install_path = subprocess.run(
+        [
+            vswhere_path,
+            "-latest",
+            "-prerelease",
+            "-products",
+            "*",
+            "-property",
+            "installationPath",
+        ],
+        capture_output=True,
+        text=True,
+        check=True,
+    ).stdout.strip()
+
+    if not vs_install_path:
+        raise FileNotFoundError("No Visual Studio installation found.")
+
+    # Construct the path to the VsDevCmd.bat file
+    vsdevcmd_path = str(Path(vs_install_path) / "Common7" / "Tools" / 
"VsDevCmd.bat")
+    print(vsdevcmd_path)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/pyproject.toml b/pyproject.toml
index 57de316..43b09ba 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,7 +43,13 @@ GitHub = "https://github.com/apache/tvm-ffi";
 torch = ["torch", "setuptools", "ninja"]
 cpp = ["ninja"]
 # note pytorch does not yet ship with 3.14t
-test = ["pytest", "numpy", "ninja", "torch; python_version < '3.14'"]
+test = [
+  "pytest",
+  "numpy",
+  "ninja",
+  "torch; python_version < '3.14'",
+  "setuptools",
+]
 
 [dependency-groups]
 dev = [
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 6bdff08..c491027 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -32,11 +32,14 @@ subsequent calls will be much faster.
 
 from __future__ import annotations
 
+import ctypes
+import os
+import subprocess
+import sys
 import warnings
+from pathlib import Path
 from typing import Any
 
-from . import libinfo
-
 
 def load_torch_c_dlpack_extension() -> Any:
     try:
@@ -51,587 +54,43 @@ def load_torch_c_dlpack_extension() -> Any:
         return None
 
     """Load the torch c dlpack extension."""
-    cpp_source = """
-#include <dlpack/dlpack.h>
-#include <ATen/DLConvertor.h>
-#include <ATen/Functions.h>
-
-#ifdef BUILD_WITH_CUDA
-#include <c10/cuda/CUDAStream.h>
-#endif
-
-using namespace std;
-namespace at {
-namespace {
-
-DLDataType getDLDataTypeForDLPackv1(const Tensor& t) {
-  DLDataType dtype;
-  dtype.lanes = 1;
-  dtype.bits = t.element_size() * 8;
-  switch (t.scalar_type()) {
-    case ScalarType::UInt1:
-    case ScalarType::UInt2:
-    case ScalarType::UInt3:
-    case ScalarType::UInt4:
-    case ScalarType::UInt5:
-    case ScalarType::UInt6:
-    case ScalarType::UInt7:
-    case ScalarType::Byte:
-    case ScalarType::UInt16:
-    case ScalarType::UInt32:
-    case ScalarType::UInt64:
-      dtype.code = DLDataTypeCode::kDLUInt;
-      break;
-#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
-    case ScalarType::Int1:
-    case ScalarType::Int2:
-    case ScalarType::Int3:
-    case ScalarType::Int4:
-    case ScalarType::Int5:
-    case ScalarType::Int6:
-    case ScalarType::Int7:
-    case ScalarType::Char:
-      dtype.code = DLDataTypeCode::kDLInt;
-      break;
-#endif
-    case ScalarType::Double:
-      dtype.code = DLDataTypeCode::kDLFloat;
-      break;
-    case ScalarType::Float:
-      dtype.code = DLDataTypeCode::kDLFloat;
-      break;
-    case ScalarType::Int:
-      dtype.code = DLDataTypeCode::kDLInt;
-      break;
-    case ScalarType::Long:
-      dtype.code = DLDataTypeCode::kDLInt;
-      break;
-    case ScalarType::Short:
-      dtype.code = DLDataTypeCode::kDLInt;
-      break;
-    case ScalarType::Half:
-      dtype.code = DLDataTypeCode::kDLFloat;
-      break;
-    case ScalarType::Bool:
-      dtype.code = DLDataTypeCode::kDLBool;
-      break;
-    case ScalarType::ComplexHalf:
-    case ScalarType::ComplexFloat:
-    case ScalarType::ComplexDouble:
-      dtype.code = DLDataTypeCode::kDLComplex;
-      break;
-    case ScalarType::BFloat16:
-      dtype.code = DLDataTypeCode::kDLBfloat;
-      break;
-     case ScalarType::Float8_e5m2:
-      dtype.code = DLDataTypeCode::kDLFloat8_e5m2;
-      break;
-    case ScalarType::Float8_e5m2fnuz:
-      dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz;
-      break;
-    case ScalarType::Float8_e4m3fn:
-      dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn;
-      break;
-    case ScalarType::Float8_e4m3fnuz:
-      dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz;
-      break;
-#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
-    case ScalarType::Float8_e8m0fnu:
-      dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu;
-      break;
-    case ScalarType::Float4_e2m1fn_x2:
-      dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
-      dtype.lanes = 2;
-      dtype.bits = 4;
-      break;
-#endif
-   default:
-      TORCH_CHECK(false, "Unsupported scalar type: ");
-  }
-  return dtype;
-}
-
-DLDevice torchDeviceToDLDeviceForDLPackv1(at::Device device) {
-  DLDevice ctx;
-
-  ctx.device_id = (device.is_cuda() || device.is_privateuseone())
-      ? static_cast<int32_t>(static_cast<unsigned char>(device.index()))
-      : 0;
-
-  switch (device.type()) {
-    case DeviceType::CPU:
-      ctx.device_type = DLDeviceType::kDLCPU;
-      break;
-    case DeviceType::CUDA:
-#ifdef USE_ROCM
-      ctx.device_type = DLDeviceType::kDLROCM;
-#else
-      ctx.device_type = DLDeviceType::kDLCUDA;
-#endif
-      break;
-    case DeviceType::OPENCL:
-      ctx.device_type = DLDeviceType::kDLOpenCL;
-      break;
-    case DeviceType::HIP:
-      ctx.device_type = DLDeviceType::kDLROCM;
-      break;
-    case DeviceType::XPU:
-      ctx.device_type = DLDeviceType::kDLOneAPI;
-      ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device);
-      break;
-    case DeviceType::MAIA:
-      ctx.device_type = DLDeviceType::kDLMAIA;
-      break;
-    case DeviceType::PrivateUse1:
-      ctx.device_type = DLDeviceType::kDLExtDev;
-      break;
-    case DeviceType::MPS:
-      ctx.device_type = DLDeviceType::kDLMetal;
-      break;
-    default:
-      TORCH_CHECK(false, "Cannot pack tensors on " + device.str());
-  }
-
-  return ctx;
-}
-
-template <class T>
-struct ATenDLMTensor {
-  Tensor handle;
-  T tensor{};
-};
-
-template <class T>
-void deleter(T* arg) {
-  delete static_cast<ATenDLMTensor<T>*>(arg->manager_ctx);
-}
-
-// Adds version information for DLManagedTensorVersioned.
-// This is a no-op for the other types.
-template <class T>
-void fillVersion(T* tensor) {}
-
-template <>
-void fillVersion<DLManagedTensorVersioned>(
-    DLManagedTensorVersioned* tensor) {
-  tensor->flags = 0;
-  tensor->version.major = DLPACK_MAJOR_VERSION;
-  tensor->version.minor = DLPACK_MINOR_VERSION;
-}
-
-// This function returns a shared_ptr to memory managed DLpack tensor
-// constructed out of ATen tensor
-template <class T>
-T* toDLPackImpl(const Tensor& src) {
-  ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
-  atDLMTensor->handle = src;
-  atDLMTensor->tensor.manager_ctx = atDLMTensor;
-  atDLMTensor->tensor.deleter = &deleter<T>;
-  atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
-  atDLMTensor->tensor.dl_tensor.device = 
torchDeviceToDLDeviceForDLPackv1(src.device());
-  atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
-  atDLMTensor->tensor.dl_tensor.dtype = getDLDataTypeForDLPackv1(src);
-  atDLMTensor->tensor.dl_tensor.shape = 
const_cast<int64_t*>(src.sizes().data());
-  atDLMTensor->tensor.dl_tensor.strides = 
const_cast<int64_t*>(src.strides().data());
-  atDLMTensor->tensor.dl_tensor.byte_offset = 0;
-  fillVersion(&atDLMTensor->tensor);
-  return &(atDLMTensor->tensor);
-}
-
-static Device getATenDeviceForDLPackv1(DLDeviceType type, c10::DeviceIndex 
index, void* data = nullptr) {
-  switch (type) {
-    case DLDeviceType::kDLCPU:
-      return at::Device(DeviceType::CPU);
-#ifndef USE_ROCM
-    // if we are compiled under HIP, we cannot do cuda
-    case DLDeviceType::kDLCUDA:
-      return at::Device(DeviceType::CUDA, index);
-#endif
-    case DLDeviceType::kDLOpenCL:
-      return at::Device(DeviceType::OPENCL, index);
-    case DLDeviceType::kDLROCM:
-#ifdef USE_ROCM
-      // this looks funny, we need to return CUDA here to masquerade
-      return at::Device(DeviceType::CUDA, index);
-#else
-      return at::Device(DeviceType::HIP, index);
-#endif
-    case DLDeviceType::kDLOneAPI:
-      TORCH_CHECK(data != nullptr, "Can't get ATen device for XPU without XPU 
data.");
-      return at::detail::getXPUHooks().getDeviceFromPtr(data);
-    case DLDeviceType::kDLMAIA:
-      return at::Device(DeviceType::MAIA, index);
-    case DLDeviceType::kDLExtDev:
-      return at::Device(DeviceType::PrivateUse1, index);
-    case DLDeviceType::kDLMetal:
-      return at::Device(DeviceType::MPS, index);
-    default:
-      TORCH_CHECK(
-          false, "Unsupported device_type: ", std::to_string(type));
-  }
-}
-
-ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) {
-  ScalarType stype = ScalarType::Undefined;
-#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
-  if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
-    TORCH_CHECK(
-        dtype.lanes == 1,
-        "ATen does not support lanes != 1 for dtype code", 
std::to_string(dtype.code));
-  }
-#endif
-  switch (dtype.code) {
-    case DLDataTypeCode::kDLUInt:
-      switch (dtype.bits) {
-        case 8:
-          stype = ScalarType::Byte;
-          break;
-        case 16:
-          stype = ScalarType::UInt16;
-          break;
-        case 32:
-          stype = ScalarType::UInt32;
-          break;
-        case 64:
-          stype = ScalarType::UInt64;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLInt:
-      switch (dtype.bits) {
-        case 8:
-          stype = ScalarType::Char;
-          break;
-        case 16:
-          stype = ScalarType::Short;
-          break;
-        case 32:
-          stype = ScalarType::Int;
-          break;
-        case 64:
-          stype = ScalarType::Long;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kInt bits ", std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLFloat:
-      switch (dtype.bits) {
-        case 16:
-          stype = ScalarType::Half;
-          break;
-        case 32:
-          stype = ScalarType::Float;
-          break;
-        case 64:
-          stype = ScalarType::Double;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLBfloat:
-      switch (dtype.bits) {
-        case 16:
-          stype = ScalarType::BFloat16;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLComplex:
-      switch (dtype.bits) {
-        case 32:
-          stype = ScalarType::ComplexHalf;
-          break;
-        case 64:
-          stype = ScalarType::ComplexFloat;
-          break;
-        case 128:
-          stype = ScalarType::ComplexDouble;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLBool:
-      switch (dtype.bits) {
-        case 8:
-          stype = ScalarType::Bool;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLFloat8_e5m2:
-      switch (dtype.bits) {
-        case 8:
-          stype = ScalarType::Float8_e5m2;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e5m2 bits ", 
std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLFloat8_e5m2fnuz:
-      switch (dtype.bits) {
-        case 8:
-          stype = ScalarType::Float8_e5m2fnuz;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e5m2fnuz bits ", 
std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLFloat8_e4m3fn:
-      switch (dtype.bits) {
-        case 8:
-          stype = ScalarType::Float8_e4m3fn;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e4m3fn bits ", 
std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLFloat8_e4m3fnuz:
-      switch (dtype.bits) {
-        case 8:
-          stype = ScalarType::Float8_e4m3fnuz;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e4m3fnuz bits ", 
std::to_string(dtype.bits));
-      }
-      break;
-#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
-    case DLDataTypeCode::kDLFloat8_e8m0fnu:
-      switch (dtype.bits) {
-        case 8:
-          stype = ScalarType::Float8_e8m0fnu;
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e8m0fnu bits ", 
std::to_string(dtype.bits));
-      }
-      break;
-    case DLDataTypeCode::kDLFloat4_e2m1fn:
-      switch (dtype.bits) {
-        case 4:
-          switch (dtype.lanes) {
-            case 2:
-              stype = ScalarType::Float4_e2m1fn_x2;
-              break;
-            default:
-              TORCH_CHECK(
-                false, "Unsupported kDLFloat4_e2m1fn lanes ", 
std::to_string(dtype.lanes));
-          }
-          break;
-        default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat4_e2m1fn bits ", 
std::to_string(dtype.bits));
-      }
-      break;
-#endif
-    default:
-      TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code));
-  }
-  return stype;
-}
-
-// This function constructs a Tensor from a memory managed DLPack which
-// may be represented as either: DLManagedTensor and DLManagedTensorVersioned.
-template <class T>
-at::Tensor fromDLPackImpl(T* src, std::function<void(void*)> deleter) {
-  if (!deleter) {
-    deleter = [src](void* self [[maybe_unused]]) {
-      if (src->deleter) {
-        src->deleter(src);
-      }
-    };
-  }
-
-  DLTensor& dl_tensor = src->dl_tensor;
-  Device device = getATenDeviceForDLPackv1(dl_tensor.device.device_type, 
dl_tensor.device.device_id, dl_tensor.data);
-  ScalarType stype = toScalarTypeForDLPackv1(dl_tensor.dtype);
-
-  if (!dl_tensor.strides) {
-    return at::from_blob(
-        dl_tensor.data,
-        IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
-        std::move(deleter),
-        at::device(device).dtype(stype),
-        {device});
-  }
-  return at::from_blob(
-      dl_tensor.data,
-      IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
-      IntArrayRef(dl_tensor.strides, dl_tensor.ndim),
-      deleter,
-      at::device(device).dtype(stype),
-      {device});
-}
-
-void toDLPackNonOwningImpl(const Tensor& tensor, DLTensor& out) {
-  // Fill in the pre-allocated DLTensor struct with direct pointers
-  // This is a non-owning conversion - the caller owns the tensor
-  // and must keep it alive for the duration of DLTensor usage
-  out.data = tensor.data_ptr();
-  out.device = torchDeviceToDLDeviceForDLPackv1(tensor.device());
-  out.ndim = static_cast<int32_t>(tensor.dim());
-  out.dtype = getDLDataTypeForDLPackv1(tensor);
-  // sizes() and strides() return pointers to TensorImpl's stable storage
-  // which remains valid as long as the tensor is alive
-  out.shape = const_cast<int64_t*>(tensor.sizes().data());
-  out.strides = const_cast<int64_t*>(tensor.strides().data());
-  out.byte_offset = 0;
-}
-
-} // namespace
-} // namespace at
-
-struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
-  TorchDLPackExchangeAPI() {
-    header.version.major = DLPACK_MAJOR_VERSION;
-    header.version.minor = DLPACK_MINOR_VERSION;
-    header.prev_api = nullptr;
-    managed_tensor_allocator = ManagedTensorAllocator;
-    managed_tensor_from_py_object_no_sync = ManagedTensorFromPyObjectNoSync;
-    managed_tensor_to_py_object_no_sync = ManagedTensorToPyObjectNoSync;
-    dltensor_from_py_object_no_sync = DLTensorFromPyObjectNoSync;
-    current_work_stream = CurrentWorkStream;
-  }
-
-  static const DLPackExchangeAPI* Global() {
-    static TorchDLPackExchangeAPI inst;
-    return &inst;
-  }
-
- private:
-  static int DLTensorFromPyObjectNoSync(void* py_obj, DLTensor* out) {
-    try {
-      // Use handle (non-owning) to avoid unnecessary refcount operations
-      py::handle handle(static_cast<PyObject*>(py_obj));
-      at::Tensor tensor = handle.cast<at::Tensor>();
-      at::toDLPackNonOwningImpl(tensor, *out);
-      return 0;
-    } catch (const std::exception& e) {
-      PyErr_SetString(PyExc_RuntimeError, e.what());
-      return -1;
-    }
-  }
-
-  static int ManagedTensorFromPyObjectNoSync(void* py_obj, 
DLManagedTensorVersioned** out) {
-    try {
-      py::handle handle(static_cast<PyObject*>(py_obj));
-      at::Tensor tensor = handle.cast<at::Tensor>();
-      *out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
-      return 0;
-    } catch (const std::exception& e) {
-      PyErr_SetString(PyExc_RuntimeError, e.what());
-      return -1;
-    }
-  }
-
-  static int ManagedTensorToPyObjectNoSync(DLManagedTensorVersioned* src, 
void** py_obj_out) {
-    try {
-      at::Tensor tensor = at::fromDLPackImpl<DLManagedTensorVersioned>(src, 
nullptr);
-      *py_obj_out = THPVariable_Wrap(tensor);
-      return 0;
-    } catch (const std::exception& e) {
-      PyErr_SetString(PyExc_RuntimeError, e.what());
-      return -1;
-    }
-  }
-
-  static int ManagedTensorAllocator(
-      DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx,
-      void (*SetError)(void* error_ctx, const char* kind, const char* message)
-  ) {
-    try {
-      at::IntArrayRef shape(prototype->shape, prototype->shape + 
prototype->ndim);
-      at::TensorOptions options = at::TensorOptions()
-        .dtype(at::toScalarType(prototype->dtype))
-        .device(at::getATenDeviceForDLPackv1(prototype->device.device_type, 
prototype->device.device_id));
-      at::Tensor tensor = at::empty(shape, options);
-      *out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
-      return 0;
-    } catch (const std::exception& e) {
-      SetError(error_ctx, "TorchDLPackManagedTensorAllocator", e.what());
-      return -1;
-    }
-  }
-
-  // Get current CUDA/ROCm work stream
-  static int CurrentWorkStream(
-      DLDeviceType device_type,
-      int32_t device_id,
-      void** out_stream) {
-    try {
-#ifdef BUILD_WITH_CUDA
-      if (device_type == kDLCUDA || device_type == kDLROCM) {
-        *out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
-        return 0;
-      }
-#endif
-      // For CPU and other devices, return NULL (no stream concept)
-      *out_stream = nullptr;
-      return 0;
-    } catch (const std::exception& e) {
-      PyErr_SetString(PyExc_RuntimeError, e.what());
-      return -1;
-    }
-  }
-};
-
-int64_t TorchDLPackExchangeAPIPtr() {
-  return reinterpret_cast<int64_t>(TorchDLPackExchangeAPI::Global());
-}
-    """
     try:
-        # optionally import torch
-        import torch  # noqa: PLC0415
-        from torch.utils import cpp_extension  # noqa: PLC0415
-
-        include_paths = libinfo.include_paths()
-        extra_cflags = ["-O3"]
-
-        if torch.cuda.is_available():
-            include_paths += cpp_extension.include_paths("cuda")
-            extra_cflags += ["-DBUILD_WITH_CUDA"]
-
-        mod = cpp_extension.load_inline(
-            name="c_dlpack",
-            cpp_sources=cpp_source,
-            functions=[
-                "TorchDLPackExchangeAPIPtr",
-            ],
-            extra_cflags=extra_cflags,
-            extra_include_paths=include_paths,
+        # todo: check whether a prebuilt package is installed, if so, use it.
+        ...
+
+        # check whether a JIT shared library is built in cache
+        cache_dir = Path(os.environ.get("TVM_FFI_CACHE_DIR", 
"~/.cache/tvm-ffi")).expanduser()
+        addon_build_dir = cache_dir / "torch_c_dlpack_addon"
+        lib_path = addon_build_dir / (
+            "libtorch_c_dlpack_addon" + (".dll" if sys.platform == "win32" 
else ".so")
         )
+        if not lib_path.exists():
+            build_script_path = Path(__file__).parent / "utils" / 
"_build_optional_c_dlpack.py"
+            args = [sys.executable, str(build_script_path), "--build_dir", 
str(addon_build_dir)]
+            if torch.cuda.is_available():
+                args.append("--build_with_cuda")
+            subprocess.run(
+                args,
+                check=True,
+            )
+            assert lib_path.exists(), "Failed to build torch c dlpack addon."
+
+        lib = ctypes.CDLL(str(lib_path))
+        func = lib.TorchDLPackExchangeAPIPtr
+        func.restype = ctypes.c_uint64
+        func.argtypes = []
+
         # Set the DLPackExchangeAPI pointer on the class
-        setattr(torch.Tensor, "__c_dlpack_exchange_api__", 
mod.TorchDLPackExchangeAPIPtr())
-        return mod
+        setattr(torch.Tensor, "__c_dlpack_exchange_api__", func())
     except ImportError:
         pass
     except Exception as e:
         warnings.warn(
-            f"Failed to load torch c dlpack extension: {e},EnvTensorAllocator 
will not be enabled."
+            f"Failed to load torch c dlpack extension, EnvTensorAllocator will 
not be enabled:\n  {e}"
         )
     return None
 
 
-# keep alive
-_mod = load_torch_c_dlpack_extension()
-
-
 def patch_torch_cuda_stream_protocol() -> Any:
     """Load the torch cuda stream protocol for older versions of torch."""
     try:
@@ -650,4 +109,6 @@ def patch_torch_cuda_stream_protocol() -> Any:
         pass
 
 
-patch_torch_cuda_stream_protocol()
+if os.environ.get("TVM_FFI_DISABLE_TORCH_C_DLPACK", "0") == "0":
+    load_torch_c_dlpack_extension()
+    patch_torch_cuda_stream_protocol()
diff --git a/python/tvm_ffi/cpp/load_inline.py 
b/python/tvm_ffi/cpp/load_inline.py
index 4d7087c..50cc7f0 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -318,7 +318,7 @@ def _generate_ninja_build(  # noqa: PLR0915
     return "\n".join(ninja)
 
 
-def _build_ninja(build_dir: str) -> None:
+def build_ninja(build_dir: str) -> None:
     """Build the module in the given build directory using ninja."""
     command = ["ninja", "-v"]
     num_workers = os.environ.get("MAX_JOBS", None)
@@ -561,7 +561,7 @@ def build_inline(  # noqa: PLR0915, PLR0912
             _maybe_write(str(build_dir / "cuda.cu"), cuda_source)
         _maybe_write(str(build_dir / "build.ninja"), ninja_source)
         # build the module
-        _build_ninja(str(build_dir))
+        build_ninja(str(build_dir))
         # Use appropriate extension based on platform
         ext = ".dll" if IS_WINDOWS else ".so"
         return str((build_dir / f"{name}{ext}").resolve())
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/utils/_build_optional_c_dlpack.py
similarity index 58%
copy from python/tvm_ffi/_optional_torch_c_dlpack.py
copy to python/tvm_ffi/utils/_build_optional_c_dlpack.py
index 6bdff08..166a69e 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_c_dlpack.py
@@ -14,47 +14,36 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Optional module to support faster DLPack conversion.
-
-This is an optional module to support faster DLPack conversion for torch.
-Some of the changes are merged but not yet released, so it is used
-as a stop gap to support faster DLPack conversion.
-
-This file contains source code from PyTorch:
-License: licenses/LICENSE.pytorch.txt
-
-This module only serves as temp measure and will
-likely be phased away and deleted after changes landed and released in pytorch.
-
-This module will load slowly at first time due to JITing,
-subsequent calls will be much faster.
-"""
+"""Build Torch C DLPack Addon."""
 
 from __future__ import annotations
 
-import warnings
-from typing import Any
+import argparse
+import os
+import shutil
+import sys
+import sysconfig
+from collections.abc import Sequence
+from pathlib import Path
 
-from . import libinfo
+import torch
+import torch.torch_version
+import torch.utils.cpp_extension
 
+# we need to set the following env to avoid tvm_ffi to build the torch 
c-dlpack addon during importing
+os.environ["TVM_FFI_DISABLE_TORCH_C_DLPACK"] = "1"
 
-def load_torch_c_dlpack_extension() -> Any:
-    try:
-        import torch  # noqa: PLC0415
+from tvm_ffi.cpp.load_inline import build_ninja
+from tvm_ffi.libinfo import find_dlpack_include_path
+from tvm_ffi.utils.lockfile import FileLock
 
-        if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
-            # skip loading the extension if the __c_dlpack_exchange_api__
-            # attribute is already set so we don't have to do it in
-            # newer version of PyTorch
-            return None
-    except ImportError:
-        return None
+IS_WINDOWS = sys.platform == "win32"
 
-    """Load the torch c dlpack extension."""
-    cpp_source = """
+cpp_source = """
 #include <dlpack/dlpack.h>
 #include <ATen/DLConvertor.h>
 #include <ATen/Functions.h>
+#include <torch/extension.h>
 
 #ifdef BUILD_WITH_CUDA
 #include <c10/cuda/CUDAStream.h>
@@ -123,7 +112,7 @@ DLDataType getDLDataTypeForDLPackv1(const Tensor& t) {
     case ScalarType::BFloat16:
       dtype.code = DLDataTypeCode::kDLBfloat;
       break;
-     case ScalarType::Float8_e5m2:
+    case ScalarType::Float8_e5m2:
       dtype.code = DLDataTypeCode::kDLFloat8_e5m2;
       break;
     case ScalarType::Float8_e5m2fnuz:
@@ -145,7 +134,7 @@ DLDataType getDLDataTypeForDLPackv1(const Tensor& t) {
       dtype.bits = 4;
       break;
 #endif
-   default:
+    default:
       TORCH_CHECK(false, "Unsupported scalar type: ");
   }
   return dtype;
@@ -155,8 +144,8 @@ DLDevice torchDeviceToDLDeviceForDLPackv1(at::Device 
device) {
   DLDevice ctx;
 
   ctx.device_id = (device.is_cuda() || device.is_privateuseone())
-      ? static_cast<int32_t>(static_cast<unsigned char>(device.index()))
-      : 0;
+                      ? static_cast<int32_t>(static_cast<unsigned 
char>(device.index()))
+                      : 0;
 
   switch (device.type()) {
     case DeviceType::CPU:
@@ -212,8 +201,7 @@ template <class T>
 void fillVersion(T* tensor) {}
 
 template <>
-void fillVersion<DLManagedTensorVersioned>(
-    DLManagedTensorVersioned* tensor) {
+void fillVersion<DLManagedTensorVersioned>(DLManagedTensorVersioned* tensor) {
   tensor->flags = 0;
   tensor->version.major = DLPACK_MAJOR_VERSION;
   tensor->version.minor = DLPACK_MINOR_VERSION;
@@ -238,7 +226,8 @@ T* toDLPackImpl(const Tensor& src) {
   return &(atDLMTensor->tensor);
 }
 
-static Device getATenDeviceForDLPackv1(DLDeviceType type, c10::DeviceIndex 
index, void* data = nullptr) {
+static Device getATenDeviceForDLPackv1(DLDeviceType type, c10::DeviceIndex 
index,
+                                       void* data = nullptr) {
   switch (type) {
     case DLDeviceType::kDLCPU:
       return at::Device(DeviceType::CPU);
@@ -266,8 +255,7 @@ static Device getATenDeviceForDLPackv1(DLDeviceType type, 
c10::DeviceIndex index
     case DLDeviceType::kDLMetal:
       return at::Device(DeviceType::MPS, index);
     default:
-      TORCH_CHECK(
-          false, "Unsupported device_type: ", std::to_string(type));
+      TORCH_CHECK(false, "Unsupported device_type: ", std::to_string(type));
   }
 }
 
@@ -275,9 +263,8 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
   ScalarType stype = ScalarType::Undefined;
 #if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
   if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
-    TORCH_CHECK(
-        dtype.lanes == 1,
-        "ATen does not support lanes != 1 for dtype code", 
std::to_string(dtype.code));
+    TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1 for dtype 
code",
+                std::to_string(dtype.code));
   }
 #endif
   switch (dtype.code) {
@@ -296,8 +283,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::UInt64;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kUInt bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLInt:
@@ -315,8 +301,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::Long;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kInt bits ", std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kInt bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLFloat:
@@ -331,8 +316,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::Double;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kFloat bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLBfloat:
@@ -341,8 +325,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::BFloat16;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kFloat bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLComplex:
@@ -357,8 +340,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::ComplexDouble;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kFloat bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLBool:
@@ -367,8 +349,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::Bool;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kDLBool bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLFloat8_e5m2:
@@ -377,8 +358,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::Float8_e5m2;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e5m2 bits ", 
std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kDLFloat8_e5m2 bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLFloat8_e5m2fnuz:
@@ -387,8 +367,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::Float8_e5m2fnuz;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e5m2fnuz bits ", 
std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kDLFloat8_e5m2fnuz bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLFloat8_e4m3fn:
@@ -397,8 +376,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::Float8_e4m3fn;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e4m3fn bits ", 
std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kDLFloat8_e4m3fn bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLFloat8_e4m3fnuz:
@@ -407,8 +385,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::Float8_e4m3fnuz;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e4m3fnuz bits ", 
std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kDLFloat8_e4m3fnuz bits ", 
std::to_string(dtype.bits));
       }
       break;
 #if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
@@ -418,8 +395,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
           stype = ScalarType::Float8_e8m0fnu;
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat8_e8m0fnu bits ", 
std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kDLFloat8_e8m0fnu bits ", 
std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLFloat4_e2m1fn:
@@ -430,13 +406,12 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& 
dtype) {
               stype = ScalarType::Float4_e2m1fn_x2;
               break;
             default:
-              TORCH_CHECK(
-                false, "Unsupported kDLFloat4_e2m1fn lanes ", 
std::to_string(dtype.lanes));
+              TORCH_CHECK(false, "Unsupported kDLFloat4_e2m1fn lanes ",
+                          std::to_string(dtype.lanes));
           }
           break;
         default:
-          TORCH_CHECK(
-              false, "Unsupported kDLFloat4_e2m1fn bits ", 
std::to_string(dtype.bits));
+          TORCH_CHECK(false, "Unsupported kDLFloat4_e2m1fn bits ", 
std::to_string(dtype.bits));
       }
       break;
 #endif
@@ -459,24 +434,17 @@ at::Tensor fromDLPackImpl(T* src, 
std::function<void(void*)> deleter) {
   }
 
   DLTensor& dl_tensor = src->dl_tensor;
-  Device device = getATenDeviceForDLPackv1(dl_tensor.device.device_type, 
dl_tensor.device.device_id, dl_tensor.data);
+  Device device = getATenDeviceForDLPackv1(dl_tensor.device.device_type, 
dl_tensor.device.device_id,
+                                           dl_tensor.data);
   ScalarType stype = toScalarTypeForDLPackv1(dl_tensor.dtype);
 
   if (!dl_tensor.strides) {
-    return at::from_blob(
-        dl_tensor.data,
-        IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
-        std::move(deleter),
-        at::device(device).dtype(stype),
-        {device});
+    return at::from_blob(dl_tensor.data, IntArrayRef(dl_tensor.shape, 
dl_tensor.ndim),
+                         std::move(deleter), at::device(device).dtype(stype), 
{device});
   }
-  return at::from_blob(
-      dl_tensor.data,
-      IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
-      IntArrayRef(dl_tensor.strides, dl_tensor.ndim),
-      deleter,
-      at::device(device).dtype(stype),
-      {device});
+  return at::from_blob(dl_tensor.data, IntArrayRef(dl_tensor.shape, 
dl_tensor.ndim),
+                       IntArrayRef(dl_tensor.strides, dl_tensor.ndim), deleter,
+                       at::device(device).dtype(stype), {device});
 }
 
 void toDLPackNonOwningImpl(const Tensor& tensor, DLTensor& out) {
@@ -494,8 +462,8 @@ void toDLPackNonOwningImpl(const Tensor& tensor, DLTensor& 
out) {
   out.byte_offset = 0;
 }
 
-} // namespace
-} // namespace at
+}  // namespace
+}  // namespace at
 
 struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
   TorchDLPackExchangeAPI() {
@@ -551,15 +519,17 @@ struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
     }
   }
 
-  static int ManagedTensorAllocator(
-      DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx,
-      void (*SetError)(void* error_ctx, const char* kind, const char* message)
-  ) {
+  static int ManagedTensorAllocator(DLTensor* prototype, 
DLManagedTensorVersioned** out,
+                                    void* error_ctx,
+                                    void (*SetError)(void* error_ctx, const 
char* kind,
+                                                     const char* message)) {
     try {
       at::IntArrayRef shape(prototype->shape, prototype->shape + 
prototype->ndim);
-      at::TensorOptions options = at::TensorOptions()
-        .dtype(at::toScalarType(prototype->dtype))
-        .device(at::getATenDeviceForDLPackv1(prototype->device.device_type, 
prototype->device.device_id));
+      at::TensorOptions options =
+          at::TensorOptions()
+              .dtype(at::toScalarType(prototype->dtype))
+              
.device(at::getATenDeviceForDLPackv1(prototype->device.device_type,
+                                                   
prototype->device.device_id));
       at::Tensor tensor = at::empty(shape, options);
       *out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
       return 0;
@@ -570,10 +540,7 @@ struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
   }
 
   // Get current CUDA/ROCm work stream
-  static int CurrentWorkStream(
-      DLDeviceType device_type,
-      int32_t device_id,
-      void** out_stream) {
+  static int CurrentWorkStream(DLDeviceType device_type, int32_t device_id, 
void** out_stream) {
     try {
 #ifdef BUILD_WITH_CUDA
       if (device_type == kDLCUDA || device_type == kDLROCM) {
@@ -591,63 +558,223 @@ struct TorchDLPackExchangeAPI : public DLPackExchangeAPI 
{
   }
 };
 
-int64_t TorchDLPackExchangeAPIPtr() {
+// defien a cross-platgorm macro to export the symbol
+#ifdef _WIN32
+#define DLL_EXPORT __declspec(dllexport)
+#else
+#define DLL_EXPORT __attribute__((visibility("default")))
+#endif
+
+extern "C" DLL_EXPORT int64_t TorchDLPackExchangeAPIPtr() {
   return reinterpret_cast<int64_t>(TorchDLPackExchangeAPI::Global());
 }
-    """
-    try:
-        # optionally import torch
-        import torch  # noqa: PLC0415
-        from torch.utils import cpp_extension  # noqa: PLC0415
-
-        include_paths = libinfo.include_paths()
-        extra_cflags = ["-O3"]
-
-        if torch.cuda.is_available():
-            include_paths += cpp_extension.include_paths("cuda")
-            extra_cflags += ["-DBUILD_WITH_CUDA"]
-
-        mod = cpp_extension.load_inline(
-            name="c_dlpack",
-            cpp_sources=cpp_source,
-            functions=[
-                "TorchDLPackExchangeAPIPtr",
-            ],
-            extra_cflags=extra_cflags,
-            extra_include_paths=include_paths,
-        )
-        # Set the DLPackExchangeAPI pointer on the class
-        setattr(torch.Tensor, "__c_dlpack_exchange_api__", 
mod.TorchDLPackExchangeAPIPtr())
-        return mod
-    except ImportError:
-        pass
-    except Exception as e:
-        warnings.warn(
-            f"Failed to load torch c dlpack extension: {e},EnvTensorAllocator 
will not be enabled."
-        )
-    return None
-
-
-# keep alive
-_mod = load_torch_c_dlpack_extension()
-
+"""
 
-def patch_torch_cuda_stream_protocol() -> Any:
-    """Load the torch cuda stream protocol for older versions of torch."""
-    try:
-        import torch  # noqa: PLC0415
 
-        if not torch.cuda.is_available():
+def _generate_ninja_build(
+    build_dir: Path,
+    libname: str,
+    source_path: Path,
+    extra_cflags: Sequence[str],
+    extra_ldflags: Sequence[str],
+    extra_include_paths: Sequence[str],
+) -> None:
+    """Generate the content of build.ninja for building the module."""
+    if IS_WINDOWS:
+        default_cflags = [
+            "/std:c++17",
+            "/MD",
+            "/wd4819",
+            "/wd4251",
+            "/wd4244",
+            "/wd4267",
+            "/wd4275",
+            "/wd4018",
+            "/wd4190",
+            "/wd4624",
+            "/wd4067",
+            "/wd4068",
+            "/EHsc",
+        ]
+        default_ldflags = ["/DLL"]
+    else:
+        default_cflags = ["-std=c++17", "-fPIC", "-O2"]
+        default_ldflags = ["-shared", "-Wl,-rpath,$ORIGIN", 
"-Wl,--no-as-needed"]
+
+    cflags = default_cflags + [flag.strip() for flag in extra_cflags]
+    ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags]
+    include_paths = [find_dlpack_include_path()] + [
+        str(Path(path).resolve()) for path in extra_include_paths
+    ]
+
+    # append include paths
+    for path in include_paths:
+        cflags.append("-I{}".format(path.replace(":", "$:")))
+
+    # flags
+    ninja = []
+    ninja.append("ninja_required_version = 1.3")
+    ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS 
else "c++")))
+    ninja.append("cflags = {}".format(" ".join(cflags)))
+    ninja.append("ldflags = {}".format(" ".join(ldflags)))
+
+    # rules
+    ninja.append("")
+    ninja.append("rule compile")
+    if IS_WINDOWS:
+        ninja.append("  command = $cxx /showIncludes $cflags -c $in /Fo$out")
+        ninja.append("  deps = msvc")
+    else:
+        ninja.append("  depfile = $out.d")
+        ninja.append("  deps = gcc")
+        ninja.append("  command = $cxx -MMD -MF $out.d $cflags -c $in -o $out")
+    ninja.append("")
+
+    ninja.append("rule link")
+    if IS_WINDOWS:
+        ninja.append("  command = $cxx $in /link $ldflags /out:$out")
+    else:
+        ninja.append("  command = $cxx $in $ldflags -o $out")
+    ninja.append("")
+
+    # build targets
+    obj_name = "main.obj" if IS_WINDOWS else "main.o"
+    ninja.append(
+        "build {}: compile {}".format(obj_name, 
str(source_path.resolve()).replace(":", "$:"))
+    )
+
+    # Use appropriate extension based on platform
+    ninja.append(f"build {libname}: link {obj_name}")
+    ninja.append("")
+
+    # default target
+    ninja.append(f"default {libname}")
+    ninja.append("")
+
+    with open(build_dir / "build.ninja", "w") as f:  # noqa: PTH123
+        f.write("\n".join(ninja))
+
+
+def get_torch_include_paths(build_with_cuda: bool) -> Sequence[str]:
+    """Get the include paths for building with torch."""
+    if torch.__version__ >= torch.torch_version.TorchVersion("2.6.0"):
+        return torch.utils.cpp_extension.include_paths(
+            device_type="cuda" if build_with_cuda else "cpu"
+        )
+    else:
+        return torch.utils.cpp_extension.include_paths(cuda=build_with_cuda)
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+    "--build_dir",
+    type=str,
+    default=str(Path("~/.cache/tvm-ffi/torch_c_dlpack_addon").expanduser()),
+    help="Directory to store the built extension library.",
+)
+parser.add_argument(
+    "--build_with_cuda",
+    action="store_true",
+    default=torch.cuda.is_available(),
+    help="Build with CUDA support.",
+)
+
+
+def main() -> None:  # noqa: PLR0912, PLR0915
+    """Build the torch c dlpack extension."""
+    args = parser.parse_args()
+    build_dir = Path(args.build_dir)
+
+    if not build_dir.exists():
+        build_dir.mkdir(parents=True, exist_ok=True)
+
+    name = "libtorch_c_dlpack_addon"
+    suffix = ".dll" if IS_WINDOWS else ".so"
+    libname = name + suffix
+    tmp_libname = name + ".tmp" + suffix
+
+    with FileLock(str(build_dir / "build.lock")):
+        if (build_dir / libname).exists():
+            # already built
             return
-        if not hasattr(torch.cuda.Stream, "__cuda_stream__"):
 
-            def __torch_cuda_stream__(self: torch.cuda.Stream) -> tuple[int, 
torch.cuda.Stream]:
-                """Return the version number and the cuda stream."""
-                return (0, self.cuda_stream)
+        # write the source
+        source_path = build_dir / "addon.cc"
+        with open(source_path, "w") as f:  # noqa: PTH123
+            f.write(cpp_source)
+
+        # resolve configs
+        include_paths = []
+        ldflags = []
+        cflags = []
+        include_paths.append(sysconfig.get_paths()["include"])
+
+        if args.build_with_cuda:
+            cflags.append("-DBUILD_WITH_CUDA")
+        include_paths.extend(get_torch_include_paths(args.build_with_cuda))
+
+        # use CXX11 ABI
+        if torch.compiled_with_cxx11_abi():
+            cflags.append("-D_GLIBCXX_USE_CXX11_ABI=1")
+        else:
+            cflags.append("-D_GLIBCXX_USE_CXX11_ABI=0")
+
+        for lib_dir in torch.utils.cpp_extension.library_paths():
+            if IS_WINDOWS:
+                ldflags.append(f"/LIBPATH:{lib_dir}")
+            else:
+                ldflags.append(f"-L{lib_dir}")
+
+        # Add all required PyTorch libraries
+        if IS_WINDOWS:
+            # On Windows, use .lib format for linking
+            ldflags.extend(["c10.lib", "torch.lib", "torch_cpu.lib", 
"torch_python.lib"])
+        else:
+            # On Unix/macOS, use -l format for linking
+            ldflags.extend(["-lc10", "-ltorch", "-ltorch_cpu", 
"-ltorch_python"])
+
+        # Add Python library linking
+        if IS_WINDOWS:
+            python_lib = f"python{sys.version_info.major}.lib"
+            python_libdir_list = [
+                sysconfig.get_config_var("LIBDIR"),
+                sysconfig.get_path("include"),
+            ]
+            if (
+                sysconfig.get_path("include") is not None
+                and (Path(sysconfig.get_path("include")).parent / 
"libs").exists()
+            ):
+                python_libdir_list.append(
+                    str((Path(sysconfig.get_path("include")).parent / 
"libs").resolve())
+                )
+            for python_libdir in python_libdir_list:
+                if python_libdir and (Path(python_libdir) / 
python_lib).exists():
+                    ldflags.append(f"/LIBPATH:{python_libdir.replace(':', 
'$:')}")
+                    ldflags.append(python_lib)
+                    break
+        else:
+            python_libdir = sysconfig.get_config_var("LIBDIR")
+            if python_libdir:
+                ldflags.append(f"-L{python_libdir}")
+                py_version = f"python{sysconfig.get_python_version()}"
+                ldflags.append(f"-l{py_version}")
+
+        # generate ninja build file
+        _generate_ninja_build(
+            build_dir=build_dir,
+            libname=tmp_libname,
+            source_path=source_path,
+            extra_cflags=cflags,
+            extra_ldflags=ldflags,
+            extra_include_paths=include_paths,
+        )
+
+        # build the shared library
+        build_ninja(build_dir=str(build_dir))
 
-            setattr(torch.cuda.Stream, "__cuda_stream__", 
__torch_cuda_stream__)
-    except ImportError:
-        pass
+        # rename the tmp file to final libname
+        shutil.move(str(build_dir / tmp_libname), str(build_dir / libname))
 
 
-patch_torch_cuda_stream_protocol()
+if __name__ == "__main__":
+    main()
diff --git a/tests/python/test_dlpack_exchange_api.py 
b/tests/python/test_dlpack_exchange_api.py
index d5be763..11f93ae 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -18,6 +18,8 @@
 
 from __future__ import annotations
 
+import sys
+
 import pytest
 
 try:
@@ -37,6 +39,10 @@ _has_dlpack_api = torch is not None and 
hasattr(torch.Tensor, "__c_dlpack_exchan
 
 @pytest.mark.skipif(not _has_dlpack_api, reason="PyTorch DLPack Exchange API 
not available")
 def test_dlpack_exchange_api() -> None:
+    # xfail the test on windows platform, it seems to be a bug in torch 
extension building on windows
+    if sys.platform.startswith("win"):
+        pytest.xfail("DLPack Exchange API test is known to fail on Windows 
platform")
+
     assert torch is not None
 
     assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__")
diff --git a/tests/python/test_optional_torch_c_dlpack.py 
b/tests/python/test_optional_torch_c_dlpack.py
new file mode 100644
index 0000000..33254e3
--- /dev/null
+++ b/tests/python/test_optional_torch_c_dlpack.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ctypes
+import subprocess
+import sys
+from pathlib import Path
+
+import pytest
+
+try:
+    import torch
+except ImportError:
+    torch = None
+
+
+import tvm_ffi
+
+IS_WINDOWS = sys.platform.startswith("win")
+
+
[email protected](torch is None, reason="torch is not installed")
+def test_build_torch_c_dlpack_extension() -> None:
+    build_script = Path(tvm_ffi.__file__).parent / "utils" / 
"_build_optional_c_dlpack.py"
+    subprocess.run(
+        [sys.executable, str(build_script), "--build_dir", 
"./build_test_dir"], check=True
+    )
+
+    lib_path = str(
+        Path(
+            "./build_test_dir/libtorch_c_dlpack_addon.{}".format("dll" if 
IS_WINDOWS else "so")
+        ).resolve()
+    )
+    assert Path(lib_path).exists()
+
+    lib = ctypes.CDLL(lib_path)
+    func = lib.TorchDLPackExchangeAPIPtr
+    func.restype = ctypes.c_int64
+    ptr = func()
+    assert ptr != 0
+
+
[email protected](torch is None, reason="torch is not installed")
+def test_parallel_build() -> None:
+    build_script = Path(tvm_ffi.__file__).parent / "utils" / 
"_build_optional_c_dlpack.py"
+    num_processes = 4
+    build_dir = "./build_test_dir_parallel"
+    processes = []
+    for i in range(num_processes):
+        p = subprocess.Popen([sys.executable, str(build_script), 
"--build_dir", build_dir])
+        processes.append((p, build_dir))
+
+    for p, build_dir in processes:
+        p.wait()
+        assert p.returncode == 0
+    lib_path = str(
+        Path(
+            "{}/libtorch_c_dlpack_addon.{}".format(build_dir, "dll" if 
IS_WINDOWS else "so")
+        ).resolve()
+    )
+    assert Path(lib_path).exists()
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])

Reply via email to