gemini-code-assist[bot] commented on code in PR #19539:
URL: https://github.com/apache/tvm/pull/19539#discussion_r3221709822


##########
src/runtime/contrib/vllm/attention_kernels.cu:
##########
@@ -756,10 +757,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef().def(
       "tvm.contrib.vllm.single_query_cached_kv_attention",
-      [](const DLTensor* query, const DLTensor* key_cache, const DLTensor* 
value_cache,
-         const DLTensor* block_tables, const DLTensor* context_lens, int 
block_size,
-         const DLTensor* max_context_len_tensor,  // TODO(masahi): pass integer
-         DLTensor* exp_sums, DLTensor* max_logits, DLTensor* tmp_out, 
DLTensor* out) {
+      [](Tensor query, Tensor key_cache, Tensor value_cache, Tensor 
block_tables,
+         Tensor context_lens, int block_size,
+         Tensor max_context_len_tensor,  // TODO(masahi): pass integer
+         Tensor exp_sums, Tensor max_logits, Tensor tmp_out, Tensor out) {

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The use of `Tensor` here is ambiguous and likely leads to a compilation 
error. In the `tvm::runtime` namespace, `Tensor` is an alias for 
`tvm::runtime::NDArray`, which does not have a `GetDLTensorPtr()` method. Based 
on the subsequent calls to `GetDLTensorPtr()` in the lambda body (e.g., line 
773), it appears that `tvm::ffi::Tensor` was intended. Explicitly using 
`ffi::Tensor` will resolve this ambiguity and ensure the code compiles 
correctly.
   
   ```
         [](ffi::Tensor query, ffi::Tensor key_cache, ffi::Tensor value_cache, 
ffi::Tensor block_tables,
            ffi::Tensor context_lens, int block_size,
            ffi::Tensor max_context_len_tensor,  // TODO(masahi): pass integer
            ffi::Tensor exp_sums, ffi::Tensor max_logits, ffi::Tensor tmp_out, 
ffi::Tensor out) {
   ```



##########
src/runtime/contrib/vllm/attention_kernels.cu:
##########
@@ -784,9 +791,27 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef()
       .def("tvm.contrib.vllm.single_query_cached_kv_attention_v1",
-           single_query_cached_kv_attention_v1)
+           [](Tensor query, Tensor key_cache, Tensor value_cache, Tensor 
block_tables,
+              Tensor context_lens, int block_size, Tensor 
max_context_len_tensor, Tensor out) {
+             single_query_cached_kv_attention_v1(
+                 query.GetDLTensorPtr(), key_cache.GetDLTensorPtr(), 
value_cache.GetDLTensorPtr(),
+                 block_tables.GetDLTensorPtr(), context_lens.GetDLTensorPtr(), 
block_size,
+                 max_context_len_tensor.GetDLTensorPtr(),
+                 const_cast<DLTensor*>(out.GetDLTensorPtr()));
+           })
       .def("tvm.contrib.vllm.single_query_cached_kv_attention_v2",
-           single_query_cached_kv_attention_v2);
+           [](Tensor query, Tensor key_cache, Tensor value_cache, Tensor 
block_tables,
+              Tensor context_lens, int block_size, Tensor 
max_context_len_tensor, Tensor exp_sums,
+              Tensor max_logits, Tensor tmp_out, Tensor out) {

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   Explicitly use `ffi::Tensor` to match the intended FFI type and ensure 
`GetDLTensorPtr()` is accessible for the output tensors.
   
   ```
              [](ffi::Tensor query, ffi::Tensor key_cache, ffi::Tensor 
value_cache, ffi::Tensor block_tables,
                 ffi::Tensor context_lens, int block_size, ffi::Tensor 
max_context_len_tensor, ffi::Tensor exp_sums,
                 ffi::Tensor max_logits, ffi::Tensor tmp_out, ffi::Tensor out) {
   ```



##########
src/runtime/contrib/vllm/attention_kernels.cu:
##########
@@ -784,9 +791,27 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef()
       .def("tvm.contrib.vllm.single_query_cached_kv_attention_v1",
-           single_query_cached_kv_attention_v1)
+           [](Tensor query, Tensor key_cache, Tensor value_cache, Tensor 
block_tables,
+              Tensor context_lens, int block_size, Tensor 
max_context_len_tensor, Tensor out) {

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   Similar to the previous registration, `Tensor` should be replaced with 
`ffi::Tensor` to ensure the `GetDLTensorPtr()` method is available and to avoid 
conflicts with `tvm::runtime::Tensor`.
   
   ```
              [](ffi::Tensor query, ffi::Tensor key_cache, ffi::Tensor 
value_cache, ffi::Tensor block_tables,
                 ffi::Tensor context_lens, int block_size, ffi::Tensor 
max_context_len_tensor, ffi::Tensor out) {
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to