tqchen commented on code in PR #18271:
URL: https://github.com/apache/tvm/pull/18271#discussion_r2324994636


##########
ffi/tests/python/test_load_inline.py:
##########
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import torch
+import tvm_ffi.cpp
+from tvm_ffi.module import Module
+
+
+def test_load_inline():
+    mod: Module = tvm_ffi.cpp.load_inline(
+        name="hello",
+        cpp_source=r"""
+            void AddOne(DLTensor* x, DLTensor* y) {
+              // implementation of a library function
+              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              DLDataType f32_dtype{kDLFloat, 32, 1};
+              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
+              for (int i = 0; i < x->shape[0]; ++i) {
+                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              }
+            }
+        """,
+        cuda_source=r"""
+            __global__ void AddOneKernel(float* x, float* y, int n) {
+              int idx = blockIdx.x * blockDim.x + threadIdx.x;
+              if (idx < n) {
+                y[idx] = x[idx] + 1;
+              }
+            }
+
+            void AddOneCUDA(DLTensor* x, DLTensor* y) {
+              // implementation of a library function
+              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              DLDataType f32_dtype{kDLFloat, 32, 1};
+              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
+
+              int64_t n = x->shape[0];
+              int64_t nthread_per_block = 256;
+              int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
+              // Obtain the current stream from the environment
+              // it will be set to torch.cuda.current_stream() when calling 
the function
+              // with torch.Tensors
+              cudaStream_t stream = static_cast<cudaStream_t>(
+                  TVMFFIEnvGetCurrentStream(x->device.device_type, 
x->device.device_id));
+              // launch the kernel
+              AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x->data),
+                                                                     
static_cast<float*>(y->data), n);
+            }
+        """,
+        cpp_functions={"add_one_cpu": "AddOne"},
+        cuda_functions={"add_one_cuda": "AddOneCUDA"},
+    )
+
+    x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
+    y = torch.empty_like(x)
+    mod.add_one_cpu(x, y)
+    torch.testing.assert_close(x + 1, y)
+
+    x_cuda = x.cuda()

Review Comment:
   not all ci env have cuda support, let us optionally skip that



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to