This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new acdc164d1f [Target] Support CUDA device function calls (#18055)
acdc164d1f is described below

commit acdc164d1f2e0fabfcee68f99f98f4268b4a1e68
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Jul 11 11:53:13 2025 +0800

    [Target] Support CUDA device function calls (#18055)
    
    [TIR][Target] Support device call compilation
    
    This PR introduces support for device call compilation in TVM by enhancing 
the BindTarget pass to properly handle functions called from both host and 
device contexts. The key improvement is the ability to automatically create 
host-specific duplicates of functions that are called from both host and device 
code, ensuring proper target binding for heterogeneous compilation.
    
    - **Function Classification**: Analyzes call patterns to identify functions 
called from host vs device contexts
    - **Smart Target Binding**: Automatically binds appropriate targets based 
on calling context:
      - Functions called only from host → host target
      - Functions called only from device → device target
      - Functions called from both → device target + host duplicate
    - **Call Site Updates**: Updates call sites in externally exposed functions 
to use appropriate duplicates
    
    - Improved device function extraction and kernel generation
    - Better handling of error propagation for different device types
    - Enhanced buffer declaration and parameter management
    
    - Support for `__device__` function calls in CUDA kernels
    - Proper function signature generation for device functions
    - Enhanced calling convention handling
    
    - Updated build pipeline to handle device call compilation
    - Improved target-specific compilation logic
    
    The following example demonstrates how the BindTarget pass handles 
functions called from both host and device contexts:
    
    ```python
    @I.ir_module
    class Module:
        @T.prim_func(private=True)
        def add(a: T.int32, b: T.int32) -> T.int32:
            return a + b
    
        @T.prim_func
        def main(
            A: T.Buffer((128, 128), "int32"),
            B: T.Buffer((128, 128), "int32"),
            C: T.Buffer((128, 128), "int32"),
        ):
            T.func_attr({"global_symbol": "main"})
            length: T.int32 = Module.add(64, 64)  # Call from host
            for bx in T.thread_binding(length, "blockIdx.x"):
                for tx in T.thread_binding(length, "threadIdx.x"):
                    C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Call from 
device
    ```
    
    After applying `BindTarget(cuda, host="llvm")`, the pass automatically:
    1. Creates a device version of `add` with CUDA target
    2. Creates a host duplicate `add_host` with LLVM target
    3. Updates the main function to call `add_host` from host context and `add` 
from device context
    
    This enables seamless compilation of mixed host/device code while 
maintaining proper target-specific optimizations and code generation.
    
    - **Automatic Target Binding**: No manual target annotation required for 
most use cases
    - **Heterogeneous Compilation**: Proper support for functions called from 
multiple contexts
    - **Code Reuse**: Shared functions can be called from both host and device 
without duplication
    - **Performance**: Maintains target-specific optimizations for each context
    - **Developer Experience**: Simplifies writing mixed host/device code
    
    The implementation is backward compatible and integrates seamlessly with 
existing TVM compilation pipelines.
---
 python/tvm/tir/build.py                            | 106 ++++--
 src/target/build_common.h                          |   4 +-
 src/target/opt/build_cuda_on.cc                    |   9 +-
 src/target/source/codegen_cuda.cc                  |  14 +-
 src/target/source/codegen_cuda.h                   |   3 +-
 src/tir/transforms/bind_target.cc                  | 377 +++++++++++++++++++++
 src/tir/transforms/primfunc_utils.cc               |  32 --
 src/tir/transforms/split_host_device.cc            |   4 -
 tests/python/codegen/test_target_codegen_cuda.py   |  62 ++++
 .../tir-transform/test_tir_transform_helpers.py    |  56 +++
 10 files changed, 597 insertions(+), 70 deletions(-)

diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py
index 1948ca193f..3eb3648533 100644
--- a/python/tvm/tir/build.py
+++ b/python/tvm/tir/build.py
@@ -17,8 +17,7 @@
 
 # pylint: disable=invalid-name
 """The build utils in python."""
-from typing import Union, Optional, Dict
-import enum
+from typing import Union, Optional, Dict, Tuple
 
 import tvm
 from tvm import ir
@@ -28,44 +27,95 @@ from tvm.ir.module import IRModule
 from tvm.target import Target
 
 
-def split_host_device_mods(mod):
+def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target, 
IRModule]]:
     """Split an IRModule into host and device modules.
 
+    This function takes an IRModule containing functions with different target 
attributes
+    and separates them into host (CPU) and device (GPU/accelerator) modules. 
Functions
+    are categorized based on their target attribute in func_attr.
+
     Parameters
     ----------
     mod : tvm.IRModule
-        The input module to split
+        The input module to split.
+        The module should contain functions with target attributes in their 
func_attr.
+        Functions with "cpu" in their target string are considered host 
functions,
+        while others are considered device functions.
 
     Returns
     -------
     host_mod : tvm.IRModule
-        The module containing host functions
+        The module containing host functions (CPU-targeted functions)
     device_mod_dict : Dict[Target, tvm.IRModule]
-        A dict mapping targets to device modules
+        A dict mapping targets to device modules. Each device module contains
+        functions targeting the same device (e.g., CUDA GPU, OpenCL, etc.)
+
+        Examples
+    --------
+    Given an IRModule with the following functions:
+
+    .. code-block:: python
+
+        @I.ir_module
+        class Module:
+            @T.prim_func(private=True)
+            def add(a: T.int32, b: T.int32) -> T.int32:
+                T.func_attr({"target": T.target({"arch": "sm_90", "keys": 
["cuda", "gpu"],
+                                                "kind": "cuda", 
"max_num_threads": 1024}))
+                return a + b
+
+            @T.prim_func(private=True)
+            def add_host(a: T.int32, b: T.int32) -> T.int32:
+                T.func_attr({"target": T.target({"keys": ["cpu"], "kind": 
"c"}))
+                return a + b
+
+            @T.prim_func
+            def main_kernel(A: T.handle, B: T.handle, C: T.handle, length: 
T.int32):
+                T.func_attr({"target": T.target({"arch": "sm_90", "keys": 
["cuda", "gpu"],
+                                                "kind": "cuda"}),
+                            "calling_conv": 2,  # kDeviceKernelLaunch for 
device kernels
+                            "tir.is_global_func": True})
+                # ... kernel implementation
+
+            @T.prim_func
+            def main(self_handle: T.handle, args: T.handle, num_args: T.int32, 
result: T.handle):
+                T.func_attr({"target": T.target({"keys": ["cpu"], "kind": 
"c"}),
+                            "calling_conv": 1,  # kCPackedFunc for entry 
functions
+                            "tir.is_entry_func": True})
+                # ... main function implementation
+
+    The function will return:
+    - host_mod: Contains `add_host` and `main` functions (CPU targets)
+    - device_mod_dict: Contains a CUDA module with `add` and `main_kernel` 
functions
+
+    Notes
+    -----
+    - Functions are categorized based on string matching of their target 
attribute
+    - Functions with "cpu" in the target string are considered host functions
+    - Device functions are grouped by their target to create separate modules
+    - The function uses string-based target matching due to target hash 
limitations
+    - All functions must have a `calling_conv` attribute in their func_attr:
+        - Private helper functions (private=True): use `calling_conv: 0` 
(kDefault, by default)
+        - Public entry functions: use `calling_conv: 1` (kCPackedFunc)
+        - Device kernel functions: use `calling_conv: 2` (kDeviceKernelLaunch)
     """
 
-    class CallConv(enum.IntEnum):
-        """Enum representing different calling conventions.
-        Corresponds to the C++ tvm::ir::CallingConv enum.
-        """
-
-        kDefault = 0
-        kCPackedFunc = 1
-        kDeviceKernelLaunch = 2
-
-    host_mod = tvm.tir.transform.Filter(
-        lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
-        != int(CallConv.kDeviceKernelLaunch)
-    )(mod)
-    device_mod = tvm.tir.transform.Filter(
-        lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
-        == int(CallConv.kDeviceKernelLaunch)
-    )(mod)
-    device_mod_dict = {}
+    host_mod = tvm.tir.transform.Filter(lambda f: "cpu" in 
str(f.attrs.get("target", "cpu")))(mod)
+    device_mod = tvm.tir.transform.Filter(lambda f: "cpu" not in 
str(f.attrs.get("target", "cpu")))(
+        mod
+    )
+    # TODO(syfeng): Here we use str as key since target hash is not correct
+    target_str2target = {}
+    device_func_dict = {}
+    device_mod_dict: Dict[Target, IRModule] = {}
     for gv, func in device_mod.functions.items():
-        device_mod_dict.setdefault(func.attrs.get("target", None), 
dict()).update({gv: func})
-    for target, funcs in device_mod_dict.items():
-        device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs)
+        target = func.attrs.get("target", None)
+        target_str = str(target) if target is not None else ""
+        target_str2target[target_str] = target  # This might be overridden by 
the last one
+        device_func_dict.setdefault(target_str, dict()).update({gv: func})
+    for target_str in target_str2target.keys():
+        target = target_str2target[target_str]
+        device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], 
attrs=device_mod.attrs)
     return host_mod, device_mod_dict
 
 
@@ -162,7 +212,7 @@ def build(
     # Step 3: Bind the target to the input module
     mod = tvm.tir.transform.BindTarget(target_to_bind)(mod)
 
-    # Step 4: Apply the tir  pipeline
+    # Step 4: Apply the tir pipeline
     if pipeline is not None:
         # custom pipeline
         if isinstance(pipeline, str):
diff --git a/src/target/build_common.h b/src/target/build_common.h
index fda7e2e67c..9e52f6f8ff 100644
--- a/src/target/build_common.h
+++ b/src/target/build_common.h
@@ -66,7 +66,9 @@ inline std::unordered_map<std::string, runtime::FunctionInfo> 
ExtractFuncInfo(co
       }
     }
     auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    fmap[static_cast<std::string>(global_symbol.value())] = info;
+    if (global_symbol) {
+      fmap[static_cast<std::string>(global_symbol.value())] = info;
+    }
   }
   return fmap;
 }
diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc
index 068f6c2f71..bdac0b8fb7 100644
--- a/src/target/opt/build_cuda_on.cc
+++ b/src/target/opt/build_cuda_on.cc
@@ -134,9 +134,12 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
   for (auto [gvar, base_func] : mod->functions) {
     ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only 
take PrimFunc";
     auto prim_func = Downcast<PrimFunc>(base_func);
-    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
-    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
-        << "CodeGenCUDA: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch";
+    auto calling_conv =
+        prim_func->GetAttr<Integer>(tvm::attr::kCallingConv, 
Integer(tvm::CallingConv::kDefault));
+    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch ||
+           calling_conv == CallingConv::kDefault)
+        << "CodeGenCUDA: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch or "
+           "CallingConv::kDefault";
     functions.Set(gvar, prim_func);
   }
 
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 21fbc20f47..35e3d3cb8d 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -140,7 +140,19 @@ void CodeGenCUDA::Init(bool output_ssa) {
   ICHECK_EQ(vid_global_barrier_state_, 
runtime::symbol::tvm_global_barrier_state);
 }
 
-void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" 
__global__ "; }
+void CodeGenCUDA::PrintFunctionSignature(const String& function_name, const 
PrimFunc& func,
+                                         std::ostream& os) {
+  auto calling_conv =
+      func->GetAttr<Integer>(tvm::attr::kCallingConv, 
Integer(tvm::CallingConv::kDefault));
+  if (calling_conv == CallingConv::kDeviceKernelLaunch) {
+    os << "extern \"C\" __global__ ";
+  } else if (calling_conv == CallingConv::kDefault) {
+    os << "extern \"C\" __device__ ";
+  } else {
+    LOG(FATAL) << "Unsupported calling convention for cuda codegen: " << 
calling_conv;
+  }
+  CodeGenC::PrintFunctionSignature(function_name, func, os);
+}
 
 class ThreadIdxExtractor : public tir::StmtVisitor {
  private:
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index bfd5794f81..6441f87909 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -46,7 +46,8 @@ class CodeGenCUDA final : public CodeGenC {
             enable_fp4_ || need_math_constants_h_ || need_mma_h_);
   }
   // override behavior
-  void PrintFuncPrefix(std::ostream& os) final;
+  void PrintFunctionSignature(const String& function_name, const PrimFunc& 
func,
+                              std::ostream& os) final;
   void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;  // 
NOLINT(*)
   void VisitStmt_(const ForNode* op) final;
   void PrintStorageSync(const CallNode* op) final;
diff --git a/src/tir/transforms/bind_target.cc 
b/src/tir/transforms/bind_target.cc
new file mode 100644
index 0000000000..7cb010aa9b
--- /dev/null
+++ b/src/tir/transforms/bind_target.cc
@@ -0,0 +1,377 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bind_target.cc
+ * \brief Pass to bind target to primfunc for heterogeneous compilation.
+ *
+ * This pass analyzes function call patterns in an IRModule and binds 
appropriate
+ * targets (host/device) to each PrimFunc based on where they are called from.
+ *
+ * The pass handles the following scenarios:
+ * 1. Functions called from host code (CPU)
+ * 2. Functions called from device code (GPU/accelerator)
+ * 3. Functions called from both host and device
+ * 4. Externally exposed functions (entry points)
+ *
+ * For functions called from both host and device, the pass creates duplicates
+ * with appropriate targets and updates call sites accordingly.
+ */
+
+#include <tvm/ir/global_var_supply.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <unordered_set>
+
+#include "tvm/ir/attrs.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Visitor class to classify function calls as host or device calls.
+ *
+ * This visitor traverses the IRModule to identify which functions are called
+ * from host code vs device code. It tracks GPU scopes (thread binding loops
+ * and thread extent attributes) to determine the calling context.
+ */
+class FunctionClassifierVisitor : public StmtExprVisitor {
+ public:
+  /*!
+   * \brief Analyze function call patterns in the IRModule.
+   * \param mod The IRModule to analyze
+   * \return A tuple containing:
+   *         - Set of GlobalVarNodes called from host code
+   *         - Set of GlobalVarNodes called from device code
+   * \note A single function can be called by both host and device contexts.
+   */
+  static std::tuple<std::unordered_set<const GlobalVarNode*>,
+                    std::unordered_set<const GlobalVarNode*>>
+  GetFunctionCallers(const IRModule& mod) {
+    FunctionClassifierVisitor visitor;
+
+    // Only analyze externally exposed functions as potential callers
+    // since they represent the entry points where host/device calls originate
+    for (const auto& [gvar, func] : mod->functions) {
+      bool is_externally_exposed = 
func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+      const auto* prim_func = func.as<PrimFuncNode>();
+
+      if (is_externally_exposed && prim_func != nullptr) {
+        visitor.VisitStmt(prim_func->body);
+      }
+    }
+
+    return std::make_tuple(visitor.host_called_global_vars_, 
visitor.device_called_global_vars_);
+  }
+
+ private:
+  using StmtExprVisitor::VisitStmt_;
+
+  void VisitExpr_(const CallNode* op) final {
+    const auto* global_var = op->op.as<GlobalVarNode>();
+    if (global_var != nullptr) {
+      // Classify the call based on current scope
+      if (is_under_gpu_scope_) {
+        device_called_global_vars_.insert(global_var);
+      } else {
+        host_called_global_vars_.insert(global_var);
+      }
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const ForNode* op) final {
+    if (op->kind == ForKind::kThreadBinding) {
+      // Enter GPU scope for thread binding loops
+      bool last_is_under_gpu_scope = is_under_gpu_scope_;
+      is_under_gpu_scope_ = true;
+      StmtExprVisitor::VisitStmt_(op);
+      is_under_gpu_scope_ = last_is_under_gpu_scope;
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::thread_extent || op->attr_key == 
attr::virtual_thread) {
+      // Enter GPU scope for thread extent and virtual thread attributes
+      bool last_is_under_gpu_scope = is_under_gpu_scope_;
+      is_under_gpu_scope_ = true;
+      StmtExprVisitor::VisitStmt_(op);
+      is_under_gpu_scope_ = last_is_under_gpu_scope;
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+
+ private:
+  /*! \brief Whether the current statement is under a GPU scope */
+  bool is_under_gpu_scope_ = false;
+  /*! \brief Set of functions called from host code */
+  std::unordered_set<const GlobalVarNode*> host_called_global_vars_;
+  /*! \brief Set of functions called from device code */
+  std::unordered_set<const GlobalVarNode*> device_called_global_vars_;
+};
+
+/*!
+ * \brief Mutator class to substitute function calls in host contexts.
+ *
+ * This mutator replaces calls to functions that have been duplicated for
+ * host/device contexts. It only performs substitutions when not under
+ * GPU scope to ensure device calls remain unchanged.
+ */
+class CallSubstitutor : public StmtExprMutator {
+ public:
+  /*!
+   * \brief Constructor with function replacement mapping.
+   * \param replacements Map from original GlobalVar to host-specific GlobalVar
+   */
+  explicit CallSubstitutor(const Map<GlobalVar, GlobalVar>& replacements)
+      : replacements_(replacements) {}
+
+  /*!
+   * \brief Substitute function calls in a PrimFunc.
+   * \param func The PrimFunc to process
+   * \return The modified PrimFunc with updated calls
+   */
+  PrimFunc Substitute(PrimFunc func) {
+    auto f = func.CopyOnWrite();
+    auto body = VisitStmt(f->body);
+
+    // Only update if the body actually changed
+    if (body.same_as(func->body)) {
+      return func;
+    }
+
+    f->body = std::move(body);
+    return func;
+  }
+
+ private:
+  using StmtExprMutator::VisitStmt_;
+
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
+
+    // Only substitute calls when not under GPU scope
+    if (!is_under_gpu_scope_) {
+      if (auto old_gvar = call->op.as<GlobalVar>()) {
+        if (auto new_gvar = replacements_.Get(old_gvar.value())) {
+          call.CopyOnWrite()->op = new_gvar.value();
+        }
+      }
+    }
+    return call;
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    if (op->kind == ForKind::kThreadBinding) {
+      // Enter GPU scope for thread binding loops
+      bool last_is_under_gpu_scope = is_under_gpu_scope_;
+      is_under_gpu_scope_ = true;
+      auto stmt = StmtExprMutator::VisitStmt_(op);
+      is_under_gpu_scope_ = last_is_under_gpu_scope;
+      return stmt;
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::thread_extent || op->attr_key == 
attr::virtual_thread) {
+      // Enter GPU scope for thread extent and virtual thread attributes
+      bool last_is_under_gpu_scope = is_under_gpu_scope_;
+      is_under_gpu_scope_ = true;
+      auto stmt = StmtExprMutator::VisitStmt_(op);
+      is_under_gpu_scope_ = last_is_under_gpu_scope;
+      return stmt;
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+ private:
+  /*! \brief Whether the current statement is under a GPU scope */
+  bool is_under_gpu_scope_ = false;
+  /*! \brief Mapping from original functions to host-specific duplicates */
+  Map<GlobalVar, GlobalVar> replacements_;
+};
+
+/*!
+ * \brief Bind appropriate targets to functions in an IRModule.
+ *
+ * This function analyzes the call patterns in the module and binds appropriate
+ * targets to each PrimFunc based on where they are called from. The binding
+ * follows these rules:
+ *
+ * 1. Externally exposed functions (with global symbol) get the full target
+ * 2. Functions called only from host get the host target
+ * 3. Functions called only from device get the device target
+ * 4. Functions called from both contexts get the device target, and a 
duplicate
+ *    is created with the host target for host callers
+ *
+ * \param mod The IRModule to process
+ * \param target The target to bind (should include both host and device)
+ * \return The modified IRModule with targets bound to functions
+ */
+IRModule BindTarget(IRModule mod, const Target& target) {
+  // Extract host and device targets
+  auto target_host = Downcast<Target>(target->host.value_or(Target("llvm")));
+  auto target_without_host = target.WithoutHost();
+
+  auto mod_copy_on_write = mod.CopyOnWrite();
+  auto new_mod = GetRef<IRModule>(mod_copy_on_write);
+
+  // Step 1: Analyze function call patterns
+  auto [host_called_global_vars, device_called_global_vars] =
+      FunctionClassifierVisitor::GetFunctionCallers(mod);
+
+  // Step 2: Bind target to functions with the following rules:
+  //  1. If the function has a target, and the target has a host, and the 
function does not have a
+  //     host, then add the host to the function target
+  //  2. If the function is marked as host function, bind the host target to 
the function
+  //  3. If the function is externally exposed (with global symbol), bind the 
full target
+  //  4. If the function is not externally exposed:
+  //    2.1 If the function is called by both host and device, bind the device 
target to the current
+  //        function and duplicate the function with the host target.
+  //    2.2 If the function is called by host only, bind the host target to 
the current function
+  //    2.3 If the function is called by device only, bind the device target 
to the current function
+  //    2.4 If the function is not called by any host or device, skip binding
+
+  // Track duplicated functions for call replacement
+  Map<GlobalVar, GlobalVar> host_function_replacements;
+  GlobalVarSupply gvar_supply(new_mod);
+
+  for (auto [gvar, func] : mod->functions) {
+    const auto* prim_func_node = func.as<PrimFuncNode>();
+    if (prim_func_node == nullptr) {
+      // Skip non-PrimFunc entries
+      continue;
+    }
+    auto prim_func = GetRef<PrimFunc>(prim_func_node);
+
+    bool is_externally_exposed = 
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+
+    if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
+      // Rule 1: If the function has a target, and the target has a host, and 
the function does not
+      // have a host, then add the host to the function target
+      auto func_target_host = func_target.value()->GetHost();
+      auto target_host = target->GetHost();
+
+      if (target_host && !func_target_host && is_externally_exposed) {
+        auto new_target = Target::WithHost(func_target.value(), 
target_host.value());
+        new_mod->Update(gvar, WithAttr(std::move(prim_func), 
tvm::attr::kTarget, new_target));
+      }
+      continue;
+    }
+
+    if (prim_func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
+      // Rule 2: If the function is marked as host function, bind the host 
target to the function
+      prim_func = WithAttr(std::move(prim_func), tvm::attr::kTarget,
+                           Target::WithHost(target_host, target_host));
+      new_mod->Update(gvar, WithoutAttr(std::move(prim_func), 
tvm::tir::attr::kIsHostFunc));
+      continue;
+    }
+
+    if (is_externally_exposed) {
+      // Rule 3: Externally exposed functions get the full target
+      new_mod->Update(gvar, WithAttr(std::move(prim_func), tvm::attr::kTarget, 
target));
+    } else {
+      const auto* gvar_node = gvar.get();
+      bool called_by_host = host_called_global_vars.count(gvar_node);
+      bool called_by_device = device_called_global_vars.count(gvar_node);
+
+      if (called_by_host && called_by_device) {
+        // Rule 4.1: Called by both host and device
+        // Bind device target to current function
+        PrimFunc host_func = RenewDefs(prim_func);
+        new_mod->Update(gvar,
+                        WithAttr(std::move(prim_func), tvm::attr::kTarget, 
target_without_host));
+
+        // Create duplicate with host target for host callers
+        host_func = WithAttr(std::move(host_func), tvm::attr::kTarget, 
target_host);
+        String host_func_name = gvar->name_hint + "_host";
+        GlobalVar host_gvar = gvar_supply->FreshGlobal(host_func_name, false);
+
+        new_mod->Add(host_gvar, host_func);
+        host_function_replacements.Set(gvar, host_gvar);
+
+      } else if (called_by_host) {
+        // Rule 4.2: Called by host only
+        new_mod->Update(gvar, WithAttr(std::move(prim_func), 
tvm::attr::kTarget, target_host));
+      } else if (called_by_device) {
+        // Rule 4.3: Called by device only
+        new_mod->Update(gvar,
+                        WithAttr(std::move(prim_func), tvm::attr::kTarget, 
target_without_host));
+      } else {
+        // Rule 4.4: Not called by any context
+        // NOTE: To keep the current behavior, we bind the target to the full 
target, but it needs
+        // further check
+        new_mod->Update(gvar,
+                        WithAttr(std::move(prim_func), tvm::attr::kTarget, 
target_without_host));
+      }
+    }
+  }
+
+  // Step 3: Update call sites in externally exposed functions
+  if (!host_function_replacements.empty()) {
+    CallSubstitutor substitutor(host_function_replacements);
+
+    for (auto [gvar, func] : mod->functions) {
+      const auto* prim_func = func.as<PrimFuncNode>();
+      if (prim_func == nullptr) {
+        continue;
+      }
+
+      bool is_externally_exposed = 
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
+      if (is_externally_exposed) {
+        // Update calls in externally exposed functions to use host duplicates
+        PrimFunc new_func = substitutor.Substitute(Downcast<PrimFunc>(func));
+        new_mod->Update(gvar, new_func);
+      }
+    }
+  }
+
+  return new_mod;
+}
+
+namespace transform {
+
+/*!
+ * \brief Create a pass that binds targets to functions in an IRModule.
+ *
+ * This pass analyzes the call patterns in the module and binds appropriate
+ * targets (host/device) to each PrimFunc based on where they are called from.
+ *
+ * \param target The target to bind (should include both host and device)
+ * \return A transform pass that performs target binding
+ */
+transform::Pass BindTarget(Target target) {
+  auto fpass = [target](IRModule mod, transform::PassContext ctx) {
+    return tvm::tir::BindTarget(mod, target);
+  };
+  return tir::transform::CreateModulePass(fpass, 0, "tir.BindTarget", {});
+}
+
+TVM_FFI_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget);
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/transforms/primfunc_utils.cc 
b/src/tir/transforms/primfunc_utils.cc
index ade1aea7c9..00751e3a9a 100644
--- a/src/tir/transforms/primfunc_utils.cc
+++ b/src/tir/transforms/primfunc_utils.cc
@@ -27,37 +27,6 @@
 namespace tvm {
 namespace tir {
 namespace transform {
-transform::Pass BindTarget(Target target) {
-  Target without_host = target.WithoutHost();
-  Target target_host = Downcast<Target>(target->host.value_or(Target("llvm")));
-
-  auto fpass = [target, target_host, without_host](tir::PrimFunc func, 
IRModule m,
-                                                   transform::PassContext ctx) 
{
-    bool is_externally_exposed = 
func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
-
-    if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
-      auto func_target_host = func_target.value()->GetHost();
-      auto target_host = target->GetHost();
-
-      if (target_host && !func_target_host && is_externally_exposed) {
-        auto new_target = Target::WithHost(func_target.value(), 
target_host.value());
-        func = WithAttr(std::move(func), tvm::attr::kTarget, new_target);
-      }
-    } else if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
-      func =
-          WithAttr(std::move(func), tvm::attr::kTarget, 
Target::WithHost(target_host, target_host));
-    } else if (is_externally_exposed) {
-      func = WithAttr(std::move(func), tvm::attr::kTarget, target);
-    } else {
-      func = WithAttr(std::move(func), tvm::attr::kTarget, without_host);
-    }
-
-    func = WithoutAttr(std::move(func), tvm::tir::attr::kIsHostFunc);
-
-    return func;
-  };
-  return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
-}
 
 transform::Pass AnnotateEntryFunc() {
   auto fpass = [](IRModule mod, transform::PassContext ctx) -> IRModule {
@@ -109,7 +78,6 @@ transform::Pass Filter(ffi::TypedFunction<bool(PrimFunc)> 
fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {});
 }
 
-TVM_FFI_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget);
 
TVM_FFI_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc);
 TVM_FFI_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter);
 
diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index 13bf018689..160d80b4da 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -32,11 +32,7 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
-#include <unordered_map>
-
-#include "../../runtime/thread_storage_scope.h"
 #include "../analysis/var_use_def_analysis.h"
-#include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
diff --git a/tests/python/codegen/test_target_codegen_cuda.py 
b/tests/python/codegen/test_target_codegen_cuda.py
index 063ed0469b..2d00618eb0 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -23,6 +23,7 @@ import tvm
 import tvm.testing
 from tvm import te, topi
 from tvm.contrib.nvcc import have_bf16, have_fp16, have_int8
+from tvm.script import ir as I
 from tvm.script import tir as T
 
 
@@ -777,5 +778,66 @@ extern "C" __global__ void __launch_bounds__(128) 
main_kernel(float* __restrict_
     )
 
 
[email protected]_cuda
+def test_cuda_device_func_call():
+    @I.ir_module
+    class Module:
+        @T.prim_func(private=True)
+        def add(a: T.float32, b: T.float32) -> T.float32:
+            return a + b
+
+        @T.prim_func
+        def main(
+            A: T.Buffer((1024, 1024), "float32"),
+            B: T.Buffer((1024, 1024), "float32"),
+            C: T.Buffer((1024, 1024), "float32"),
+        ):
+            for bx in T.thread_binding(1024, "blockIdx.x"):
+                for tx in T.thread_binding(1024, "threadIdx.x"):
+                    C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])
+
+    lib = tvm.compile(Module, target="cuda")
+    cuda_code = lib.mod.imported_modules[0].get_source()
+    assert 'extern "C" __device__ float add(float a, float b) {\n  return (a + 
b);\n}' in cuda_code
+
+
[email protected]_cuda
+def test_device_host_call_same_func():
+    @I.ir_module
+    class Module:
+        @T.prim_func(private=True)
+        def add(a: T.int32, b: T.int32) -> T.int32:
+            return a + b
+
+        @T.prim_func
+        def main(
+            A: T.Buffer((128, 128), "int32"),
+            B: T.Buffer((128, 128), "int32"),
+            C: T.Buffer((128, 128), "int32"),
+        ):
+            length: T.int32 = Module.add(64, 64)  # Call from host
+            for bx in T.thread_binding(length, "blockIdx.x"):
+                for tx in T.thread_binding(length, "threadIdx.x"):
+                    C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Call from 
device
+
+    # If we set host to llvm, it will raise an error of
+    # "the tir.ret should be transformed to return zero before the llvm code 
generation."
+    # Need to revisit this.
+    target = tvm.target.Target("cuda", host="c")
+    lib = tvm.compile(Module, target=target)
+    cuda_code = lib.mod.imported_modules[0].get_source()
+    assert 'extern "C" __device__ int add(int a, int b) {\n  return (a + 
b);\n}' in cuda_code
+
+    # Run a simple test
+    dev = tvm.cuda(0)
+    a_np = np.random.randint(0, 10, (128, 128), dtype="int32")
+    b_np = np.random.randint(0, 10, (128, 128), dtype="int32")
+    a_tvm = tvm.nd.array(a_np, device=dev)
+    b_tvm = tvm.nd.array(b_np, device=dev)
+    c_tvm = tvm.nd.empty((128, 128), dtype="int32", device=dev)
+    lib["main"](a_tvm, b_tvm, c_tvm)
+    tvm.testing.assert_allclose(c_tvm.numpy(), a_np + b_np)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/tir-transform/test_tir_transform_helpers.py 
b/tests/python/tir-transform/test_tir_transform_helpers.py
index d2ea82a140..0bbd0e7160 100644
--- a/tests/python/tir-transform/test_tir_transform_helpers.py
+++ b/tests/python/tir-transform/test_tir_transform_helpers.py
@@ -209,6 +209,62 @@ class 
TestBindTargetMultipleFunctions(tvm.testing.CompareBeforeAfter):
         return mod
 
 
+class TestBindTargetWithDeviceHostCallSameFunc(tvm.testing.CompareBeforeAfter):
+    """BindTarget should bind the device target to the function if it is 
called from device"""
+
+    transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", 
host="llvm -opt-level=0"))
+
+    def before(self):
+        @I.ir_module
+        class Module:
+            @T.prim_func(private=True)
+            def add(a: T.int32, b: T.int32) -> T.int32:
+                return a + b
+
+            @T.prim_func
+            def main(
+                A: T.Buffer((128, 128), "int32"),
+                B: T.Buffer((128, 128), "int32"),
+                C: T.Buffer((128, 128), "int32"),
+            ):
+                T.func_attr({"global_symbol": "main"})
+                length: T.int32 = Module.add(64, 64)  # Call from host
+                for bx in T.thread_binding(length, "blockIdx.x"):
+                    for tx in T.thread_binding(length, "threadIdx.x"):
+                        C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Call 
from device
+
+        return Module
+
+    def expected(self):
+        @I.ir_module
+        class Module:
+            @T.prim_func(private=True)
+            def add(a: T.int32, b: T.int32) -> T.int32:
+                T.func_attr({"target": T.target("cuda")})
+                return a + b
+
+            @T.prim_func(private=True)
+            def add_host(a: T.int32, b: T.int32) -> T.int32:
+                T.func_attr({"target": T.target("llvm -opt-level=0")})
+                return a + b
+
+            @T.prim_func
+            def main(
+                A: T.Buffer((128, 128), "int32"),
+                B: T.Buffer((128, 128), "int32"),
+                C: T.Buffer((128, 128), "int32"),
+            ):
+                T.func_attr(
+                    {"global_symbol": "main", "target": T.target("cuda", 
host="llvm -opt-level=0")}
+                )
+                length: T.int32 = Module.add_host(64, 64)  # Call from host
+                for bx in T.thread_binding(length, "blockIdx.x"):
+                    for tx in T.thread_binding(length, "threadIdx.x"):
+                        C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Call 
from device
+
+        return Module
+
+
 def test_filter_primfunc():
     mod = MockModule
     assert mod


Reply via email to