gemini-code-assist[bot] commented on code in PR #111:
URL: https://github.com/apache/tvm-ffi/pull/111#discussion_r2427358318


##########
tests/python/test_dlpack_exchange_api.py:
##########
@@ -0,0 +1,465 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file to
+# you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from __future__ import annotations
+
+import pytest
+
+try:
+    import torch  # type: ignore[no-redef]
+
+    # Import tvm_ffi to load the DLPack exchange API extension
+    # This sets torch.Tensor.__c_dlpack_exchange_api__
+    import tvm_ffi  # noqa: F401
+    from torch.utils import cpp_extension  # type: ignore
+    from tvm_ffi import libinfo
+except ImportError:
+    torch = None
+
+
[email protected](torch is None, reason="PyTorch not available")
+def test_torch_has_dlpack_exchange_api() -> None:
+    """Test that torch.Tensor has __c_dlpack_exchange_api__ attribute."""
+    assert torch is not None
+    assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__"), (
+        "torch.Tensor does not have __c_dlpack_exchange_api__"
+    )
+    api_ptr = torch.Tensor.__c_dlpack_exchange_api__
+    assert isinstance(api_ptr, int), "API pointer should be an integer"
+    assert api_ptr != 0, "API pointer should not be NULL"
+
+
[email protected](torch is None, reason="PyTorch not available")
+def test_dlpack_exchange_api_version() -> None:
+    assert torch is not None
+    assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__"), (
+        "torch.Tensor does not have __c_dlpack_exchange_api__"
+    )
+
+    api_ptr = torch.Tensor.__c_dlpack_exchange_api__
+
+    source = """
+    #include <torch/extension.h>
+    #include <dlpack/dlpack.h>
+
+    void test_api_version(int64_t api_ptr_int) {
+        DLPackExchangeAPI* api = 
reinterpret_cast<DLPackExchangeAPI*>(api_ptr_int);
+
+        TORCH_CHECK(api != nullptr, "API pointer is NULL");
+
+        TORCH_CHECK(api->header.version.major == DLPACK_MAJOR_VERSION,
+                    "Expected major version ", DLPACK_MAJOR_VERSION, ", got ", 
api->header.version.major);
+        TORCH_CHECK(api->header.version.minor == DLPACK_MINOR_VERSION,
+                    "Expected minor version ", DLPACK_MINOR_VERSION, ", got ", 
api->header.version.minor);
+    }
+    """
+
+    mod = cpp_extension.load_inline(
+        name="test_api_version",
+        cpp_sources=[source],
+        functions=["test_api_version"],
+        extra_include_paths=libinfo.include_paths(),
+    )
+
+    mod.test_api_version(api_ptr)
+
+
[email protected](torch is None, reason="PyTorch not available")
+def test_dlpack_exchange_api_function_pointers_not_null() -> None:
+    assert torch is not None
+
+    assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__"), (
+        "torch.Tensor does not have __c_dlpack_exchange_api__"
+    )
+
+    api_ptr = torch.Tensor.__c_dlpack_exchange_api__
+
+    source = """
+    #include <torch/extension.h>
+    #include <dlpack/dlpack.h>
+
+    void test_function_pointers_not_null(int64_t api_ptr_int) {
+        DLPackExchangeAPI* api = 
reinterpret_cast<DLPackExchangeAPI*>(api_ptr_int);
+
+        TORCH_CHECK(api != nullptr, "API pointer is NULL");
+
+        // Check that required function pointers are not NULL
+        TORCH_CHECK(api->managed_tensor_allocator != nullptr,
+                    "managed_tensor_allocator is NULL");
+        TORCH_CHECK(api->managed_tensor_from_py_object_no_sync != nullptr,
+                    "managed_tensor_from_py_object_no_sync is NULL");
+        TORCH_CHECK(api->managed_tensor_to_py_object_no_sync != nullptr,
+                    "managed_tensor_to_py_object_no_sync is NULL");
+        TORCH_CHECK(api->current_work_stream != nullptr,
+                    "current_work_stream is NULL");
+    }
+    """
+
+    mod = cpp_extension.load_inline(
+        name="test_function_pointers_not_null",
+        cpp_sources=[source],
+        functions=["test_function_pointers_not_null"],
+        extra_include_paths=libinfo.include_paths(),
+    )
+
+    mod.test_function_pointers_not_null(api_ptr)
+
+
[email protected](torch is None, reason="PyTorch not available")
+def test_managed_tensor_allocator() -> None:
+    assert torch is not None
+
+    assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__"), (
+        "torch.Tensor does not have __c_dlpack_exchange_api__"
+    )
+
+    api_ptr = torch.Tensor.__c_dlpack_exchange_api__
+
+    source = """
+    #include <torch/extension.h>
+    #include <dlpack/dlpack.h>
+
+    void test_allocator(int64_t api_ptr_int) {
+        DLPackExchangeAPI* api = 
reinterpret_cast<DLPackExchangeAPI*>(api_ptr_int);
+
+        // Create a prototype DLTensor
+        DLTensor prototype;
+        prototype.device.device_type = kDLCPU;
+        prototype.device.device_id = 0;
+        prototype.ndim = 3;
+
+        int64_t shape[3] = {3, 4, 5};
+        prototype.shape = shape;
+        prototype.strides = nullptr;
+
+        DLDataType dtype;
+        dtype.code = kDLFloat;
+        dtype.bits = 32;
+        dtype.lanes = 1;
+        prototype.dtype = dtype;
+
+        prototype.data = nullptr;
+        prototype.byte_offset = 0;
+
+        // Call allocator
+        DLManagedTensorVersioned* out_tensor = nullptr;
+        int result = api->managed_tensor_allocator(
+            &prototype,
+            &out_tensor,
+            nullptr,  // error_ctx
+            nullptr   // SetError
+        );
+
+        TORCH_CHECK(result == 0, "Allocator failed with code ", result);
+        TORCH_CHECK(out_tensor != nullptr, "Allocator returned NULL");
+
+        // Check shape
+        TORCH_CHECK(out_tensor->dl_tensor.ndim == 3, "Wrong ndim");
+        TORCH_CHECK(out_tensor->dl_tensor.shape[0] == 3, "Wrong shape[0]");
+        TORCH_CHECK(out_tensor->dl_tensor.shape[1] == 4, "Wrong shape[1]");
+        TORCH_CHECK(out_tensor->dl_tensor.shape[2] == 5, "Wrong shape[2]");
+
+        // Check dtype
+        TORCH_CHECK(out_tensor->dl_tensor.dtype.code == kDLFloat, "Wrong dtype 
code");
+        TORCH_CHECK(out_tensor->dl_tensor.dtype.bits == 32, "Wrong dtype 
bits");
+
+        // Check device
+        TORCH_CHECK(out_tensor->dl_tensor.device.device_type == kDLCPU, "Wrong 
device type");
+
+        // Call deleter to clean up
+        if (out_tensor->deleter) {
+            out_tensor->deleter(out_tensor);
+        }
+    }
+    """
+
+    mod = cpp_extension.load_inline(
+        name="test_allocator",
+        cpp_sources=[source],
+        functions=["test_allocator"],
+        extra_include_paths=libinfo.include_paths(),
+    )
+
+    mod.test_allocator(api_ptr)
+
+
[email protected](torch is None, reason="PyTorch not available")
+def test_managed_tensor_from_py_object() -> None:
+    assert torch is not None
+
+    assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__"), (
+        "torch.Tensor does not have __c_dlpack_exchange_api__"
+    )
+
+    api_ptr = torch.Tensor.__c_dlpack_exchange_api__
+
+    source = """
+    #include <torch/extension.h>
+    #include <dlpack/dlpack.h>
+
+    void test_from_py_object(at::Tensor tensor, int64_t api_ptr_int) {
+        DLPackExchangeAPI* api = 
reinterpret_cast<DLPackExchangeAPI*>(api_ptr_int);
+
+        TORCH_CHECK(api->managed_tensor_from_py_object_no_sync != nullptr,
+                    "managed_tensor_from_py_object_no_sync is NULL");
+
+        // Get PyObject* from at::Tensor
+        PyObject* py_obj = THPVariable_Wrap(tensor);
+        TORCH_CHECK(py_obj != nullptr, "Failed to wrap tensor to PyObject");
+
+        // Call from_py_object_no_sync
+        DLManagedTensorVersioned* out_tensor = nullptr;
+        int result = api->managed_tensor_from_py_object_no_sync(
+            py_obj,
+            &out_tensor
+        );
+
+        TORCH_CHECK(result == 0, "from_py_object_no_sync failed with code ", 
result);
+        TORCH_CHECK(out_tensor != nullptr, "from_py_object_no_sync returned 
NULL");
+
+        // Check version
+        TORCH_CHECK(out_tensor->version.major == DLPACK_MAJOR_VERSION,
+                    "Expected major version ", DLPACK_MAJOR_VERSION);
+        TORCH_CHECK(out_tensor->version.minor == DLPACK_MINOR_VERSION,
+                    "Expected minor version ", DLPACK_MINOR_VERSION);
+
+        // Check shape
+        TORCH_CHECK(out_tensor->dl_tensor.ndim == 3, "Wrong ndim");
+        TORCH_CHECK(out_tensor->dl_tensor.shape[0] == 2, "Wrong shape[0]");
+        TORCH_CHECK(out_tensor->dl_tensor.shape[1] == 3, "Wrong shape[1]");
+        TORCH_CHECK(out_tensor->dl_tensor.shape[2] == 4, "Wrong shape[2]");
+
+        // Check dtype (float32)
+        TORCH_CHECK(out_tensor->dl_tensor.dtype.code == kDLFloat, "Wrong dtype 
code");
+        TORCH_CHECK(out_tensor->dl_tensor.dtype.bits == 32, "Wrong dtype 
bits");
+
+        // Check data pointer is not NULL
+        TORCH_CHECK(out_tensor->dl_tensor.data != nullptr, "Data pointer is 
NULL");
+
+        // Call deleter to clean up
+        if (out_tensor->deleter) {
+            out_tensor->deleter(out_tensor);
+        }
+
+        // Decrement refcount of the wrapped PyObject
+        Py_DECREF(py_obj);
+    }
+    """
+
+    mod = cpp_extension.load_inline(
+        name="test_from_py_object",
+        cpp_sources=[source],
+        functions=["test_from_py_object"],
+        extra_include_paths=libinfo.include_paths(),
+    )
+
+    tensor = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)
+    mod.test_from_py_object(tensor, api_ptr)
+
+
[email protected](torch is None, reason="PyTorch not available")
+def test_managed_tensor_to_py_object() -> None:
+    assert torch is not None
+
+    assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__"), (
+        "torch.Tensor does not have __c_dlpack_exchange_api__"
+    )
+
+    api_ptr = torch.Tensor.__c_dlpack_exchange_api__
+
+    source = """
+    #include <torch/extension.h>
+    #include <dlpack/dlpack.h>
+
+    void test_to_py_object(at::Tensor tensor, int64_t api_ptr_int) {
+        DLPackExchangeAPI* api = 
reinterpret_cast<DLPackExchangeAPI*>(api_ptr_int);
+
+        TORCH_CHECK(api->managed_tensor_from_py_object_no_sync != nullptr);

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   For consistency and better error reporting, it's good practice to include a 
descriptive message in all `TORCH_CHECK` assertions. This check is missing a 
message, unlike other checks in this file.
   
   ```suggestion
           TORCH_CHECK(api->managed_tensor_from_py_object_no_sync != nullptr, 
"managed_tensor_from_py_object_no_sync is NULL");
   ```



##########
tests/python/test_dlpack_exchange_api.py:
##########
@@ -0,0 +1,465 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file to
+# you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from __future__ import annotations
+
+import pytest
+
+try:
+    import torch  # type: ignore[no-redef]
+
+    # Import tvm_ffi to load the DLPack exchange API extension
+    # This sets torch.Tensor.__c_dlpack_exchange_api__
+    import tvm_ffi  # noqa: F401
+    from torch.utils import cpp_extension  # type: ignore
+    from tvm_ffi import libinfo
+except ImportError:
+    torch = None
+
+
[email protected](torch is None, reason="PyTorch not available")
+def test_torch_has_dlpack_exchange_api() -> None:
+    """Test that torch.Tensor has __c_dlpack_exchange_api__ attribute."""
+    assert torch is not None
+    assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__"), (
+        "torch.Tensor does not have __c_dlpack_exchange_api__"
+    )
+    api_ptr = torch.Tensor.__c_dlpack_exchange_api__
+    assert isinstance(api_ptr, int), "API pointer should be an integer"
+    assert api_ptr != 0, "API pointer should not be NULL"
+
+
[email protected](torch is None, reason="PyTorch not available")
+def test_dlpack_exchange_api_version() -> None:
+    assert torch is not None
+    assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__"), (
+        "torch.Tensor does not have __c_dlpack_exchange_api__"
+    )
+
+    api_ptr = torch.Tensor.__c_dlpack_exchange_api__

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   There is a lot of repeated setup code across the tests in this file. This 
includes checking if `torch` is available, asserting the presence of 
`__c_dlpack_exchange_api__`, and retrieving the `api_ptr`. This boilerplate can 
be eliminated by using a `pytest` fixture. A fixture would centralize this 
setup logic, making the tests cleaner, more readable, and easier to maintain.
   
   Here's an example of how you could define a fixture and refactor this test:
   
   ```python
   @pytest.fixture(scope="module")
   def api_ptr() -> int:
       """Fixture to get the DLPack exchange API pointer from torch.Tensor."""
       if torch is None:
           pytest.skip("PyTorch not available")
   
       if not hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
           pytest.skip("torch.Tensor does not have __c_dlpack_exchange_api__")
   
       ptr = torch.Tensor.__c_dlpack_exchange_api__
       assert isinstance(ptr, int), "API pointer should be an integer"
       assert ptr != 0, "API pointer should not be NULL"
       return ptr
   
   
   # The @pytest.mark.skipif decorator is no longer needed on the test function
   # as the fixture handles skipping.
   def test_dlpack_exchange_api_version(api_ptr: int) -> None:
       # The test body can now start directly, assuming api_ptr is valid.
       source = """..."""
       # ... rest of the test
       mod.test_api_version(api_ptr)
   ```
   
   Applying this pattern to all tests in this file would significantly reduce 
code duplication.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to