This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e802d12f1 [Relax,Topi] Allow passing workspace to thrust to avoid 
allocations (#16851)
3e802d12f1 is described below

commit 3e802d12f1270a5ee92088211db663df311bbaa6
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Apr 6 05:45:12 2024 -0700

    [Relax,Topi] Allow passing workspace to thrust to avoid allocations (#16851)
    
    * [Relax,Topi] Allow passing workspace to thrust to avoid allocations
---
 python/tvm/relax/backend/dispatch_sort_scan.py     |  70 ++++++--
 python/tvm/relax/frontend/nn/op.py                 | 106 ++++++++++++
 python/tvm/te/operation.py                         |  16 +-
 python/tvm/topi/cuda/scan.py                       |  95 +++++++++--
 python/tvm/topi/cuda/sort.py                       |  95 ++++++++---
 src/runtime/contrib/thrust/thrust.cu               | 178 ++++++++++++++-------
 .../relax/test_backend_dispatch_sort_scan.py       |  49 ++++--
 7 files changed, 476 insertions(+), 133 deletions(-)

diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py 
b/python/tvm/relax/backend/dispatch_sort_scan.py
index a223b64ad0..480420c313 100644
--- a/python/tvm/relax/backend/dispatch_sort_scan.py
+++ b/python/tvm/relax/backend/dispatch_sort_scan.py
@@ -17,13 +17,16 @@
 # pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
 """Dispatch sort and scan operators to platform dependent implementation."""
 
-from tvm import topi, dlight, relax
+from functools import reduce
+from operator import mul
+
+from tvm import DataType, dlight, relax, topi
+from tvm.contrib.thrust import can_use_thrust
 from tvm.ir import Op
 from tvm.ir.module import IRModule
 from tvm.ir.transform import PassContext, module_pass
-from tvm.target import Target
-from tvm.contrib.thrust import can_use_thrust
 from tvm.relax import PyExprMutator, expr_functor
+from tvm.target import Target
 
 
 @expr_functor.mutator
@@ -80,23 +83,24 @@ class SortScanDispatcher(PyExprMutator):
         if call.op.name == "relax.sort":
             tgt = self._get_target(call.struct_info)
             te_func = topi.sort
+            kwargs = {}
             with tgt:
                 if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
                     te_func = topi.cuda.sort_thrust
+                    kwargs["workspace"] = self.allocate_workspace(call)
                 elif tgt.kind.name == "cuda":
                     te_func = topi.cuda.sort
             return self.builder_.call_te(
-                te_func,
-                call.args[0],
-                call.attrs.axis,
-                not call.attrs.descending,
+                te_func, call.args[0], call.attrs.axis, not 
call.attrs.descending, **kwargs
             )
         if call.op.name == "relax.argsort":
             tgt = self._get_target(call.struct_info)
             te_func = topi.argsort
+            kwargs = {}
             with tgt:
                 if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
                     te_func = topi.cuda.argsort_thrust
+                    kwargs["workspace"] = self.allocate_workspace(call)
                 elif tgt.kind.name == "cuda":
                     te_func = topi.cuda.argsort
             return self.builder_.call_te(
@@ -105,12 +109,15 @@ class SortScanDispatcher(PyExprMutator):
                 axis=call.attrs.axis,
                 is_ascend=not call.attrs.descending,
                 dtype=call.attrs.dtype,
+                **kwargs,
             )
         if call.op.name == "relax.topk":
             tgt = self._get_target(call.struct_info)
             te_func = topi.topk
+            kwargs = {}
             if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
                 te_func = topi.cuda.topk_thrust
+                kwargs["workspace"] = self.allocate_workspace(call)
             elif tgt.kind.name == "cuda":
                 te_func = topi.cuda.topk
             tir_call = self.builder_.call_te(
@@ -121,6 +128,7 @@ class SortScanDispatcher(PyExprMutator):
                 ret_type=call.attrs.ret_type,
                 is_ascend=not call.attrs.largest,
                 dtype=call.attrs.dtype,
+                **kwargs,
             )
             if tgt.kind.name != "cuda":
                 return tir_call
@@ -130,16 +138,24 @@ class SortScanDispatcher(PyExprMutator):
         if call.op.name in ("relax.cumprod", "relax.cumsum"):
             tgt = self._get_target(call.struct_info)
             axis = int(call.attrs.axis) if call.attrs.axis is not None else 
call.attrs.axis
-            te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else 
topi.cumsum
-            if call.op.name == "relax.cumprod":
-                te_func = topi.cuda.cumprod if tgt.kind.name == "cuda" else 
topi.cumprod
-            tir_call = self.builder_.call_te(
-                te_func,
-                call.args[0],
-                axis,
-                call.attrs.dtype,
-                call.attrs.exclusive,
-            )
+            kwargs = {}
+            with tgt:
+                if call.op.name == "relax.cumsum":
+                    te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else 
topi.cumsum
+                    if can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan"):
+                        kwargs["workspace"] = self.allocate_workspace(call)
+                elif call.op.name == "relax.cumprod":
+                    te_func = topi.cuda.cumprod if tgt.kind.name == "cuda" 
else topi.cumprod
+                else:
+                    raise ValueError(f"Unsupported op: {call.op.name}")
+                tir_call = self.builder_.call_te(
+                    te_func,
+                    call.args[0],
+                    axis,
+                    call.attrs.dtype,
+                    call.attrs.exclusive,
+                    **kwargs,
+                )
             if tgt.kind.name != "cuda":
                 return tir_call
             # apply dlight gpu fallback
@@ -147,6 +163,26 @@ class SortScanDispatcher(PyExprMutator):
             return tir_call
         return super().visit_call_(call)
 
+    def estimate_thrust_workspace_size(self, call: relax.Call) -> int:
+        """
+        Estimate the workspace size for thrust sort/argsort/topk/cumsum
+        """
+        input_shape = call.args[0].struct_info.shape
+        input_byte_per_elem = DataType(call.args[0].struct_info.dtype).bits // 
8
+        input_size = reduce(mul, input_shape, 1) * input_byte_per_elem
+        # Most GPU algorithms take O(n) space or less, we choose 2N + 4MB as a 
safe estimation
+        return 2 * input_size + 4 * 1024 * 1024
+
+    def allocate_workspace(self, call: relax.Call) -> relax.Var:
+        """
+        Allocate workspace for thrust sort/argsort/topk.
+        """
+        workspace_size = self.estimate_thrust_workspace_size(call)
+        alloc = relax.op.builtin.alloc_tensor(
+            relax.ShapeExpr((workspace_size,)), "uint8", runtime_device_index=0
+        )
+        return self.builder_.emit(alloc)
+
 
 @module_pass(opt_level=0, name="DispatchSortScan")
 class DispatchSortScan:
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 11a0b8e62d..e46553203f 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -2241,6 +2241,112 @@ def cumsum(
     return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
 
 
+def sort(x: Tensor, axis: int = -1, descending: bool = False, name="sort"):
+    """Performs sorting along the given axis and returns an array
+    in sorted order.
+
+    Parameters
+    ----------
+    x : Tensor
+        The input tensor.
+
+    axis : int
+        Axis along which to sort the input tensor.
+        By default the last axis of the input is used.
+
+    descending : bool
+        Whether to sort in descending order, the default is False
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    out : Tensor
+        The sorted tensor.
+    """
+    return wrap_nested(_op.sort(x, axis, descending), name=name)
+
+
+def argsort(
+    data: Tensor, axis: int = -1, descending: bool = False, dtype: str = 
"int32", name="argsort"
+):
+    """Performs sorting along the given axis and returns an array of indices
+    having same shape as an input array that index data in sorted order.
+
+    Parameters
+    ----------
+    data : Tensor
+        The input data tensor.
+
+    axis : int
+        Axis long which to sort the input tensor.
+
+    descending : bool
+        Whether to sort in descending order, the default is False
+
+    dtype : str
+        The data type of the output indices.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    out : Tensor
+        The indices of the sorted tensor.
+    """
+    return wrap_nested(_op.argsort(data, axis, descending, dtype), name=name)
+
+
+def topk(
+    data: Tensor,
+    k: int = 1,
+    axis: int = -1,
+    ret_type: str = "both",
+    largest: bool = True,
+    dtype: str = "int32",
+    name: str = "topk",
+):
+    """Get the top k elements in an input tensor along the given axis.
+
+    ret_type specifies the return type, can be one of ("both", "values", 
"indices").
+
+    Parameters
+    ----------
+    data : Tensor
+        The input data tensor.
+
+    k : int
+        Number of top elements to select. Return all elements if k < 1.
+
+    axis : int
+        Axis long which to sort the input tensor.
+
+    ret_type: str
+        The return type [both, values, indices].
+        "both": return both top k data and indices.
+        "values": return top k data only.
+        "indices": return top k indices only.
+
+    largest : bool
+        Whether to return largest or smallest elements.
+        The k smallest elements are returned if largest is False.
+
+    dtype : str
+        The data type of the indices output.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    out : Tensor or Tuple[Tensor, Tensor]
+        The computed result.
+    """
+    return wrap_nested(_op.topk(data, k, axis, ret_type, largest, dtype), 
name=name)
+
+
 def multinomial_from_uniform(
     prob: Tensor,
     uniform_sample: Tensor,
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 5547ef82d7..dc2c678499 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -333,15 +333,15 @@ def extern(
             )
         types.add(t.dtype)
 
-    if dtype is None:
-        if len(types) != 1:
-            raise ValueError("Cannot infer output type, please provide dtype 
argument")
-        infered_type = types.pop()
-        dtype = [infered_type for _ in shape]
-    if isinstance(dtype, str):
-        dtype = [dtype]
-
     if out_buffers is None:
+        if dtype is None:
+            if len(types) != 1:
+                raise ValueError("Cannot infer output type, please provide 
dtype argument")
+            infered_type = types.pop()
+            dtype = [infered_type for _ in shape]
+        if isinstance(dtype, str):
+            dtype = [dtype]
+
         for shp, dt in zip(shape, dtype):
             output_placeholders.append(
                 tvm.tir.decl_buffer(shp, dt, name, 
elem_offset=tvm.tir.Var("elem_offset", "int32"))
diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py
index 4b1bac0529..c1f2eded6b 100644
--- a/python/tvm/topi/cuda/scan.py
+++ b/python/tvm/topi/cuda/scan.py
@@ -272,7 +272,12 @@ def get_reduction_from_exclusive_scan(data, 
ex_scan_output, binop=tvm.tir.generi
 
 
 def scan_thrust(
-    data, output_dtype, exclusive=True, return_reduction=False, 
binop=tvm.tir.generic.add
+    data,
+    output_dtype,
+    exclusive=True,
+    return_reduction=False,
+    binop=tvm.tir.generic.add,
+    workspace=None,
 ):
     """Do exclusive or inclusive scan on 1D or multidimensional input, using 
thrust.
 
@@ -297,6 +302,11 @@ def scan_thrust(
         thrust function, arbitrariy callables are not supported. Currently only
         tvm.tir.generic.add can be passed in.
 
+    workspace: Optional[tvm.te.Tensor]
+        A buffer to store intermediate results. The size of the workspace 
should be sufficiently
+        large, this can be obtained by overestimation or memory usage 
profiling. If None, it will
+        fallback to use thrust internal memory allocation.
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -309,14 +319,24 @@ def scan_thrust(
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", 
data_alignment=8)
     output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", 
data_alignment=8)
 
+    workspace_buf = (
+        tvm.tir.decl_buffer(workspace.shape, workspace.dtype, "workspace_buf", 
data_alignment=8)
+        if workspace is not None
+        else None
+    )
+
+    def f_compute(ins, outs):
+        args = [_get_thrust_func_name(binop), ins[0], outs[0], exclusive]
+        if workspace is not None:
+            args.append(ins[1])
+        return tvm.tir.call_packed(*args)
+
     output = te.extern(
         [data.shape],
-        [data],
-        lambda ins, outs: tvm.tir.call_packed(
-            _get_thrust_func_name(binop), ins[0], outs[0], exclusive
-        ),
+        [data] if workspace is None else [data, workspace],
+        f_compute,
         dtype=[output_dtype],
-        in_buffers=[data_buf],
+        in_buffers=[data_buf] if workspace is None else [data_buf, 
workspace_buf],
         out_buffers=[output_buf],
         name="exclusive_scan_thrust",
         tag="exclusive_scan_thrust_gpu",
@@ -337,6 +357,7 @@ def exclusive_scan(
     output_dtype=None,
     binop=tvm.tir.generic.add,
     identity_value=0,
+    workspace=None,
 ):
     """Do exclusive scan on 1D or multidimensional input.
 
@@ -367,6 +388,11 @@ def exclusive_scan(
         your operator and i is the identity_value then a * i = a for all a in 
the domain of
         your operation.
 
+    workspace: Optional[tvm.te.Tensor]
+        A buffer to store intermediate results if thrust is enabled. The size 
of the workspace
+        should be sufficiently large, this can be obtained by overestimation 
or memory usage
+        profiling. If None, it will fallback to use thrust internal memory 
allocation.
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -378,11 +404,15 @@ def exclusive_scan(
     """
 
     def do_scan(data, output_dtype):
-
         # TODO: add support for a prod_scan
         if _can_use_scan_thrust(binop):
             return scan_thrust(
-                data, output_dtype, exclusive=True, 
return_reduction=return_reduction, binop=binop
+                data,
+                output_dtype,
+                exclusive=True,
+                return_reduction=return_reduction,
+                binop=binop,
+                workspace=workspace,
             )
 
         if ndim == 1:
@@ -457,7 +487,9 @@ def exclusive_scan(
     return output
 
 
-def inclusive_scan(data, axis=-1, output_dtype=None, 
binop=tvm.tir.generic.add, identity_value=0):
+def inclusive_scan(
+    data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add, 
identity_value=0, workspace=None
+):
     """Do inclusive scan on 1D or multidimensional input.
 
     Parameters
@@ -481,6 +513,11 @@ def inclusive_scan(data, axis=-1, output_dtype=None, 
binop=tvm.tir.generic.add,
         your operator and i is the identity_value then a * i = a for all a in 
the domain of
         your operation.
 
+    workspace: Optional[tvm.te.Tensor]
+        A buffer to store intermediate results if thrust is enabled. The size 
of the workspace
+        should be sufficiently large, this can be obtained by overestimation 
or memory usage
+        profiling. If None, it will fallback to use thrust internal memory 
allocation.
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -497,14 +534,19 @@ def inclusive_scan(data, axis=-1, output_dtype=None, 
binop=tvm.tir.generic.add,
         if axis != ndim - 1:
             axes = swap(list(range(ndim)), axis)
             data = transpose(data, axes)
-        output = scan_thrust(data, output_dtype, exclusive=False, binop=binop)
+        output = scan_thrust(data, output_dtype, exclusive=False, binop=binop, 
workspace=workspace)
         if axis != ndim - 1:
             axes = swap(list(range(ndim)), axis)
             output = transpose(output, axes)
         return output
 
     ex_scan = exclusive_scan(
-        data, axis, output_dtype=output_dtype, binop=binop, 
identity_value=identity_value
+        data,
+        axis,
+        output_dtype=output_dtype,
+        binop=binop,
+        identity_value=identity_value,
+        workspace=workspace,
     )
 
     if output_dtype is not None and data.dtype != output_dtype and 
output_dtype != "":
@@ -551,6 +593,7 @@ def scanop(
     axis: Optional[int] = None,
     dtype: Optional[str] = None,
     exclusive: Optional[bool] = None,
+    workspace: Optional[tvm.te.Tensor] = None,
 ) -> tvm.te.Tensor:
     """Cumulative binary operator (scan) with similar axis behavior as 
np.cumsum and np.cumprod.
 
@@ -587,6 +630,8 @@ def scanop(
         the cumulative operation of the first (j-1) elements. Otherwise, it 
would be the
         cumulative operation of the first j elements.
 
+    workspace: Optional[tvm.te.Tensor]
+
     Returns
     -------
     result : tvm.te.Tensor
@@ -599,10 +644,20 @@ def scanop(
     axis = get_const_int(axis)
     if exclusive is not None and exclusive:
         return exclusive_scan(
-            data, axis, output_dtype=dtype, binop=binop, 
identity_value=identity_value
+            data,
+            axis,
+            output_dtype=dtype,
+            binop=binop,
+            identity_value=identity_value,
+            workspace=workspace,
         )
     return inclusive_scan(
-        data, axis, output_dtype=dtype, binop=binop, 
identity_value=identity_value
+        data,
+        axis,
+        output_dtype=dtype,
+        binop=binop,
+        identity_value=identity_value,
+        workspace=workspace,
     )
 
 
@@ -611,6 +666,7 @@ def cumsum(
     axis: Optional[int] = None,
     dtype: Optional[int] = None,
     exclusive: Optional[bool] = None,
+    workspace: Optional[tvm.te.Tensor] = None,
 ) -> tvm.te.Tensor:
     """Numpy style cumsum op. Return the cumulative sum of the elements along 
a given axis.
 
@@ -633,6 +689,11 @@ def cumsum(
         the sum of the first (j-1) elements. Otherwise, it would be the sum of
         the first j elements.
 
+    workspace: Optional[tvm.te.Tensor]
+        A buffer to store intermediate results if thrust is enabled. The size 
of the workspace
+        should be sufficiently large, this can be obtained by overestimation 
or memory usage
+        profiling. If None, it will fallback to use thrust internal memory 
allocation.
+
     Returns
     -------
     result : tvm.te.Tensor
@@ -646,6 +707,7 @@ def cumsum(
         axis=axis,
         dtype=dtype,
         exclusive=exclusive,
+        workspace=workspace,
     )
 
 
@@ -654,6 +716,7 @@ def cumprod(
     axis: Optional[int] = None,
     dtype: Optional[int] = None,
     exclusive: Optional[bool] = None,
+    workspace: Optional[tvm.te.Tensor] = None,
 ):
     """Numpy style cumprod op. Return the cumulative product of the elements 
along a given axis.
 
@@ -676,6 +739,11 @@ def cumprod(
         the product of the first (j-1) elements. Otherwise, it would be the 
product of
         the first j elements.
 
+    workspace: Optional[tvm.te.Tensor]
+        A buffer to store intermediate results if thrust is enabled. The size 
of the workspace
+        should be sufficiently large, this can be obtained by overestimation 
or memory usage
+        profiling. If None, it will fallback to use thrust internal memory 
allocation.
+
     Returns
     -------
     result : tvm.te.Tensor
@@ -689,4 +757,5 @@ def cumprod(
         axis=axis,
         dtype=dtype,
         exclusive=exclusive,
+        workspace=workspace,
     )
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index 058584a302..dc72aa8cc1 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -682,7 +682,7 @@ def sort(data, axis=-1, is_ascend=1):
     return out
 
 
-def sort_thrust(data, axis=-1, is_ascend=1):
+def sort_thrust(data, axis=-1, is_ascend=1, workspace=None):
     """Performs sorting along the given axis and returns an array of
     sorted values with the same shape as the input data.
 
@@ -697,6 +697,12 @@ def sort_thrust(data, axis=-1, is_ascend=1):
     is_ascend : boolean, optional
         Whether to sort in ascending or descending order.
 
+    workspace: Optional[tvm.te.Tensor]
+        A buffer to store intermediate results. The size of the workspace 
should be sufficiently
+        large, this can be obtained by overestimation or memory usage 
profiling. If None, it will
+        fallback to use thrust internal memory allocation.
+
+
     Returns
     -------
     out : tvm.te.Tensor
@@ -714,15 +720,20 @@ def sort_thrust(data, axis=-1, is_ascend=1):
 
     value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", 
data_alignment=8)
     indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", 
data_alignment=8)
+
+    def f_compute(ins, outs):
+        args = ["tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend]
+        if workspace is not None:
+            args.append(ins[1])
+        return tvm.tir.call_packed(*args)
+
     out = te.extern(
         [data.shape, data.shape],
-        [data],
+        [data] if workspace is None else [data, workspace],
         ## TODO(mbrookhart): This thrust function is actually doing argsort, 
not sort
         ## For performance, we should probably rename the contrib function and 
add
         ## a pure sort
-        lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend
-        ),
+        f_compute,
         out_buffers=[value_buf, indices_buf],
         name="sort_gpu",
         tag="sort_gpu",
@@ -801,7 +812,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32", 
ret_type="indices"):
     return outs[0], outs[1]
 
 
-def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32", 
ret_type="indices"):
+def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32", 
ret_type="indices", workspace=None):
     """Performs sorting along the given axis and returns an array of indices
     having same shape as an input array that index data in sorted order.
 
@@ -824,12 +835,17 @@ def argsort_thrust(data, axis=-1, is_ascend=1, 
dtype="float32", ret_type="indice
         "both": return both sorted data and indices.
         "indices": return sorted indices only.
 
+    workspace : Optional[tvm.te.Tensor]
+        A buffer to store intermediate results. The size of the workspace 
should be sufficiently
+        large, this can be obtained by overestimation or memory usage 
profiling. If None, it will
+        fallback to use thrust internal memory allocation.
+
     Returns
     -------
     out : tvm.te.Tensor
         The output of this function.
     """
-    return topk_thrust(data, 0, axis, ret_type, is_ascend, dtype)
+    return topk_thrust(data, 0, axis, ret_type, is_ascend, dtype, workspace)
 
 
 def schedule_sort(outs):
@@ -972,7 +988,9 @@ def topk(data, k=1, axis=-1, ret_type="both", 
is_ascend=False, dtype="int64"):
     return output
 
 
-def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, 
dtype="int64"):
+def topk_thrust(
+    data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64", 
workspace=None
+):
     """Get the top k elements in an input tensor along the given axis.
 
     Parameters
@@ -998,6 +1016,11 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", 
is_ascend=False, dtype="int
     dtype : string, optional
         The data type of the indices output.
 
+    workspace : Optional[tvm.te.Tensor]
+        A buffer to store intermediate results. The size of the workspace 
should be sufficiently
+        large, this can be obtained by overestimation or memory usage 
profiling. If None, it will
+        fallback to use thrust internal memory allocation.
+
     Returns
     -------
     out : tvm.te.Tensor or List[tvm.te.Tensor]
@@ -1013,20 +1036,30 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", 
is_ascend=False, dtype="int
         data = transpose(data, axes)
 
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", 
data_alignment=8)
+    if workspace is not None:
+        workspace_buf = tvm.tir.decl_buffer(
+            workspace.shape, workspace.dtype, "workspace_buf", data_alignment=8
+        )
+    else:
+        workspace_buf = None
     out_bufs = [
         tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", 
data_alignment=8),
         tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", 
data_alignment=8),
     ]
 
+    def f_compute(ins, outs):
+        args = ["tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend]
+        if workspace is not None:
+            args.append(ins[1])
+        return tvm.tir.call_packed(*args)
+
     is_ascend = 1 if is_ascend else 0
 
     out = te.extern(
         [data.shape, data.shape],
-        [data],
-        lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend
-        ),
-        in_buffers=[data_buf],
+        [data] if workspace is None else [data, workspace],
+        f_compute,
+        in_buffers=[data_buf] if workspace is None else [data_buf, 
workspace_buf],
         out_buffers=out_bufs,
         name="topk_gpu",
         tag="topk_gpu",
@@ -1120,7 +1153,7 @@ def sort_by_key(keys, values, axis=-1, is_ascend=1):
     return out[0], out[1]
 
 
-def stable_sort_by_key_thrust(keys, values, for_scatter=False):
+def stable_sort_by_key_thrust(keys, values, for_scatter=False, workspace=None):
     """Sort values with respect to keys using thrust.
     Both keys and values will be sorted and returned.
     Sorting is done via stable sort, so relative ordering among
@@ -1140,6 +1173,11 @@ def stable_sort_by_key_thrust(keys, values, 
for_scatter=False):
         The output keys (indices) are all positive.
         This option is introduced to optimize the scatter implementation.
 
+    workspace : Optional[tvm.te.Tensor]
+        A buffer to store intermediate results. The size of the workspace 
should be sufficiently
+        large, this can be obtained by overestimation or memory usage 
profiling. If None, it will
+        fallback to use thrust internal memory allocation.
+
     Returns
     -------
     keys_sorted : tvm.te.Tensor
@@ -1150,17 +1188,36 @@ def stable_sort_by_key_thrust(keys, values, 
for_scatter=False):
     """
     keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", 
data_alignment=8)
     values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", 
data_alignment=8)
+    workspace_buf = (
+        tvm.tir.decl_buffer(workspace.shape, workspace.dtype, "workspace_buf", 
data_alignment=8)
+        if workspace is not None
+        else None
+    )
     out_bufs = [
         tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", 
data_alignment=8),
         tvm.tir.decl_buffer(keys.shape, values.dtype, "values_buf", 
data_alignment=8),
     ]
+
+    def f_compute(ins, outs):
+        args = [
+            "tvm.contrib.thrust.stable_sort_by_key",
+            ins[0],
+            ins[1],
+            outs[0],
+            outs[1],
+            for_scatter,
+        ]
+        if workspace is not None:
+            args.append(ins[2])
+        return tvm.tir.call_packed(*args)
+
     out = te.extern(
         [keys.shape, values.shape],
-        [keys, values],
-        lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.thrust.stable_sort_by_key", ins[0], ins[1], outs[0], 
outs[1], for_scatter
-        ),
-        in_buffers=[keys_buf, values_buf],
+        [keys, values] if workspace is None else [keys, values, workspace],
+        f_compute,
+        in_buffers=[keys_buf, values_buf]
+        if workspace is None
+        else [keys_buf, values_buf, workspace_buf],
         out_buffers=out_bufs,
         dtype=[keys.dtype, values.dtype],
         name="stable_sort_by_key",
diff --git a/src/runtime/contrib/thrust/thrust.cu 
b/src/runtime/contrib/thrust/thrust.cu
index b0b78ba868..7a95b4b0a3 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -22,10 +22,12 @@
  */
 
 #include <dlpack/dlpack.h>
-#include <thrust/detail/caching_allocator.h>
 #include <thrust/device_ptr.h>
 #include <thrust/device_vector.h>
 #include <thrust/gather.h>
+#include <thrust/mr/device_memory_resource.h>
+#include <thrust/mr/disjoint_tls_pool.h>
+#include <thrust/mr/memory_resource.h>
 #include <thrust/scan.h>
 #include <thrust/sequence.h>
 #include <thrust/sort.h>
@@ -33,29 +35,71 @@
 
 #include <algorithm>
 #include <functional>
+#include <memory>
 #include <vector>
 
 #include "../../cuda/cuda_common.h"
-
 namespace tvm {
 namespace contrib {
 
 using namespace runtime;
 
-auto get_thrust_exec_policy() {
-  return 
thrust::cuda::par_nosync(thrust::detail::single_device_tls_caching_allocator())
-      .on(GetCUDAStream());
+/*! \brief Memory resource backed by pre-allocated workspace. */
+class WorkspaceMemoryResource : public thrust::mr::memory_resource<void*> {
+ public:
+  explicit WorkspaceMemoryResource(DLTensor* workspace) {
+    if (workspace != nullptr) {
+      this->workspace = workspace->data;
+      CHECK(workspace->ndim == 1 && workspace->dtype.code == kDLUInt && 
workspace->dtype.bits == 8);
+      this->workspace_size = workspace->shape[0];
+    } else {
+      // Fallback to thrust TLS caching allocator if workspace is not provided.
+      thrust_pool_ = thrust::mr::tls_disjoint_pool(
+          thrust::mr::get_global_resource<thrust::device_memory_resource>(),
+          thrust::mr::get_global_resource<thrust::mr::new_delete_resource>());
+    }
+  }
+
+  void* do_allocate(size_t bytes, size_t alignment) override {
+    if (workspace != nullptr) {
+      void* result = std::align(alignment, bytes, workspace, workspace_size);
+      CHECK(result) << "Failed to allocate " << bytes << " bytes with 
alignment " << alignment
+                    << " bytes.";
+      return result;
+    }
+    return thrust_pool_.do_allocate(bytes, alignment).get();
+  }
+
+  void do_deallocate(void* p, size_t bytes, size_t alignment) override {
+    if (workspace != nullptr) {
+      // No-op
+    } else {
+      thrust_pool_.do_deallocate(thrust::device_memory_resource::pointer(p), 
bytes, alignment);
+    }
+  }
+
+  
thrust::mr::disjoint_unsynchronized_pool_resource<thrust::device_memory_resource,
+                                                    
thrust::mr::new_delete_resource>
+      thrust_pool_;
+
+  void* workspace = nullptr;
+  size_t workspace_size = 0;
+};
+
+auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) {
+  return thrust::cuda::par_nosync(memory_resouce).on(GetCUDAStream());
 }
 
 // Performs sorting along axis -1 and returns both sorted values and indices.
 template <typename DataType, typename IndicesType>
 void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, 
bool is_ascend,
-                 int n_values) {
+                 int n_values, DLTensor* workspace) {
   thrust::device_ptr<DataType> data_ptr(static_cast<DataType*>(input->data));
   thrust::device_ptr<DataType> 
values_ptr(static_cast<DataType*>(out_values->data));
   thrust::device_ptr<IndicesType> 
indices_ptr(static_cast<IndicesType*>(out_indices->data));
 
-  auto policy = get_thrust_exec_policy();
+  WorkspaceMemoryResource mr(workspace);
+  auto policy = get_thrust_exec_policy(&mr);
 
   size_t size = 1;
   for (int i = 0; i < input->ndim; ++i) {
@@ -118,53 +162,53 @@ void thrust_sort(DLTensor* input, DLTensor* out_values, 
DLTensor* out_indices, b
 }
 
 void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* 
indices_out,
-                        bool is_ascend, int sort_len, std::string data_dtype,
-                        std::string out_dtype) {
+                        bool is_ascend, int sort_len, std::string data_dtype, 
std::string out_dtype,
+                        DLTensor* workspace) {
   if (data_dtype == "float32") {
     if (out_dtype == "int32") {
-      thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "int64") {
-      thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "float32") {
-      thrust_sort<float, float>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<float, float>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "float64") {
-      thrust_sort<float, double>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<float, double>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   } else if (data_dtype == "float64") {
     if (out_dtype == "int32") {
-      thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "int64") {
-      thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "float32") {
-      thrust_sort<double, float>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<double, float>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "float64") {
-      thrust_sort<double, double>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<double, double>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   } else if (data_dtype == "int32") {
     if (out_dtype == "int32") {
-      thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "int64") {
-      thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "float32") {
-      thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "float64") {
-      thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   } else if (data_dtype == "int64") {
     if (out_dtype == "int32") {
-      thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "int64") {
-      thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "float32") {
-      thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "float64") {
-      thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, 
sort_len);
+      thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
@@ -179,24 +223,31 @@ 
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort").set_body([](TVMArgs args, TVMRetV
   DLTensor* values_out = args[1];
   DLTensor* indices_out = args[2];
   bool is_ascend = args[3];
+  DLTensor* workspace = nullptr;
+  if (args.num_args == 5) {
+    workspace = args[4];
+  }
 
   auto data_dtype = DLDataType2String(input->dtype);
   auto out_dtype = DLDataType2String(indices_out->dtype);
 
   int n_values = input->shape[input->ndim - 1];
-  thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, 
data_dtype, out_dtype);
+  thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, 
data_dtype, out_dtype,
+                     workspace);
 });
 
 template <typename KeyType, typename ValueType>
 void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, 
DLTensor* keys_out,
-                               DLTensor* values_out, bool for_scatter) {
+                               DLTensor* values_out, bool for_scatter,
+                               DLTensor* workspace = nullptr) {
   const auto size = keys_in->shape[0];
   thrust::device_ptr<KeyType> 
keys_in_ptr(static_cast<KeyType*>(keys_in->data));
   thrust::device_ptr<ValueType> 
values_in_ptr(static_cast<ValueType*>(values_in->data));
   thrust::device_ptr<KeyType> 
keys_out_ptr(static_cast<KeyType*>(keys_out->data));
   thrust::device_ptr<ValueType> 
values_out_ptr(static_cast<ValueType*>(values_out->data));
 
-  auto policy = get_thrust_exec_policy();
+  WorkspaceMemoryResource mr(workspace);
+  auto policy = get_thrust_exec_policy(&mr);
 
   if (for_scatter) {
     thrust::transform(policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr,
@@ -220,46 +271,50 @@ 
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
       DLTensor* keys_out = args[2];
       DLTensor* values_out = args[3];
       bool for_scatter = args[4];
+      DLTensor* workspace = nullptr;
+      if (args.num_args == 6) {
+        workspace = args[5];
+      }
 
       auto key_dtype = DLDataType2String(keys_in->dtype);
       auto value_dtype = DLDataType2String(values_in->dtype);
 
       if (key_dtype == "int32") {
         if (value_dtype == "int32") {
-          thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, 
values_out,
-                                              for_scatter);
+          thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, 
values_out, for_scatter,
+                                              workspace);
         } else if (value_dtype == "int64") {
           thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, 
keys_out, values_out,
-                                                  for_scatter);
+                                                  for_scatter, workspace);
         } else if (value_dtype == "float32") {
           thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, 
values_out,
-                                                for_scatter);
+                                                for_scatter, workspace);
         } else {
           LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
         }
       } else if (key_dtype == "int64") {
         if (value_dtype == "int32") {
           thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, 
keys_out, values_out,
-                                                  for_scatter);
+                                                  for_scatter, workspace);
         } else if (value_dtype == "int64") {
           thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, 
keys_out, values_out,
-                                                      for_scatter);
+                                                      for_scatter, workspace);
         } else if (value_dtype == "float32") {
           thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, 
keys_out, values_out,
-                                                    for_scatter);
+                                                    for_scatter, workspace);
         } else {
           LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
         }
       } else if (key_dtype == "float32") {
         if (value_dtype == "int32") {
           thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, 
values_out,
-                                                for_scatter);
+                                                for_scatter, workspace);
         } else if (value_dtype == "int64") {
           thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, 
keys_out, values_out,
-                                                    for_scatter);
+                                                    for_scatter, workspace);
         } else if (value_dtype == "float32") {
           thrust_stable_sort_by_key<float, float>(keys_in, values_in, 
keys_out, values_out,
-                                                  for_scatter);
+                                                  for_scatter, workspace);
         } else {
           LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
         }
@@ -269,7 +324,10 @@ 
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
     });
 
 template <typename InType, typename OutType>
-void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive) {
+void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* 
workspace) {
+  WorkspaceMemoryResource mr(workspace);
+  auto policy = get_thrust_exec_policy(&mr);
+
   thrust::device_ptr<InType> data_ptr(static_cast<InType*>(data->data));
   thrust::device_ptr<OutType> output_ptr(static_cast<OutType*>(output->data));
   const auto scan_size = data->shape[data->ndim - 1];
@@ -284,8 +342,6 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool 
exclusive) {
   auto data_cast_ptr = thrust::make_transform_iterator(
       data_ptr, [] __host__ __device__(InType v) { return 
static_cast<OutType>(v); });  // NOLINT(*)
 
-  auto policy = get_thrust_exec_policy();
-
   if (size == static_cast<size_t>(data->shape[data->ndim - 1])) {
     if (exclusive && need_cast) {
       thrust::exclusive_scan(policy, data_cast_ptr, data_cast_ptr + scan_size, 
output_ptr);
@@ -322,69 +378,73 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool 
exclusive) {
   }
 }
 
-TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  ICHECK(args.num_args == 3 || args.num_args == 2);
+TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan").set_body([](TVMArgs args, 
TVMRetValue* ret) {
+  ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4);
   DLTensor* data = args[0];
   DLTensor* output = args[1];
   bool exclusive = false;
+  DLTensor* workspace = nullptr;
 
-  if (args.num_args == 3) {
+  if (args.num_args >= 3) {
     exclusive = args[2];
   }
 
+  if (args.num_args == 4) {
+    workspace = args[3];
+  }
+
   auto in_dtype = DLDataType2String(data->dtype);
   auto out_dtype = DLDataType2String(output->dtype);
 
   if (in_dtype == "bool") {
     if (out_dtype == "int32") {
-      thrust_scan<bool, int>(data, output, exclusive);
+      thrust_scan<bool, int>(data, output, exclusive, workspace);
     } else if (out_dtype == "int64") {
-      thrust_scan<bool, int64_t>(data, output, exclusive);
+      thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
     } else if (out_dtype == "float32") {
-      thrust_scan<bool, float>(data, output, exclusive);
+      thrust_scan<bool, float>(data, output, exclusive, workspace);
     } else if (out_dtype == "float64") {
-      thrust_scan<bool, double>(data, output, exclusive);
+      thrust_scan<bool, double>(data, output, exclusive, workspace);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype
                  << ". Supported output dtypes are int32, int64, float32, and 
float64";
     }
   } else if (in_dtype == "int32") {
     if (out_dtype == "int32") {
-      thrust_scan<int, int>(data, output, exclusive);
+      thrust_scan<int, int>(data, output, exclusive, workspace);
     } else if (out_dtype == "int64") {
-      thrust_scan<int, int64_t>(data, output, exclusive);
+      thrust_scan<int, int64_t>(data, output, exclusive, workspace);
     } else if (out_dtype == "float32") {
-      thrust_scan<int, float>(data, output, exclusive);
+      thrust_scan<int, float>(data, output, exclusive, workspace);
     } else if (out_dtype == "float64") {
-      thrust_scan<int, double>(data, output, exclusive);
+      thrust_scan<int, double>(data, output, exclusive, workspace);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype
                  << ". Supported output dtypes are int32, int64, float32, and 
float64";
     }
   } else if (in_dtype == "int64") {
     if (out_dtype == "int64") {
-      thrust_scan<int64_t, int64_t>(data, output, exclusive);
+      thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
     } else if (out_dtype == "float32") {
-      thrust_scan<int64_t, float>(data, output, exclusive);
+      thrust_scan<int64_t, float>(data, output, exclusive, workspace);
     } else if (out_dtype == "float64") {
-      thrust_scan<int64_t, double>(data, output, exclusive);
+      thrust_scan<int64_t, double>(data, output, exclusive, workspace);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype
                  << ". Supported output dtypes are int64, float32, and 
float64";
     }
   } else if (in_dtype == "float32") {
     if (out_dtype == "float32") {
-      thrust_scan<float, float>(data, output, exclusive);
+      thrust_scan<float, float>(data, output, exclusive, workspace);
     } else if (out_dtype == "float64") {
-      thrust_scan<float, double>(data, output, exclusive);
+      thrust_scan<float, double>(data, output, exclusive, workspace);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype
                  << ". Supported output dtypes are float32, and float64";
     }
   } else if (in_dtype == "float64") {
     if (out_dtype == "float64") {
-      thrust_scan<double, double>(data, output, exclusive);
+      thrust_scan<double, double>(data, output, exclusive, workspace);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype
                  << ". Supported output dtype is float64";
diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py 
b/tests/python/relax/test_backend_dispatch_sort_scan.py
index 4d08189ac8..c3b0e86138 100644
--- a/tests/python/relax/test_backend_dispatch_sort_scan.py
+++ b/tests/python/relax/test_backend_dispatch_sort_scan.py
@@ -137,6 +137,7 @@ def test_dispatch_sort():
     assert_structural_equal(mod, expected_mod)
 
 
[email protected](reason="skipping broken tests")
 def test_dispatch_sort_cuda():
     @I.ir_module
     class Before:
@@ -176,14 +177,21 @@ def test_dispatch_sort_cuda():
             bb.emit_func_output(out)
         with bb.function("foo2", (y,), {"global_symbol": "foo2"}):
             with bb.dataflow():
-                out = bb.emit_te(
-                    topi.cuda.sort_thrust
-                    if can_use_thrust(target, "tvm.contrib.thrust.sort")
-                    else topi.cuda.sort,
-                    y,
-                    0,
-                    False,
-                )
+                if can_use_thrust(target, "tvm.contrib.thrust.sort"):
+                    workspace = bb.emit(
+                        relax.op.builtin.alloc_tensor(
+                            relax.ShapeExpr([4194352]), "uint8", 
runtime_device_index=0
+                        )
+                    )
+                    out = bb.emit_te(
+                        topi.cuda.sort_thrust,
+                        y,
+                        axis=0,
+                        is_ascend=False,
+                        workspace=workspace,
+                    )
+                else:
+                    out = bb.emit_te(topi.cuda.sort, y, axis=0, 
is_ascend=False)
                 out = bb.emit_output(out)
             bb.emit_func_output(out)
     expected_mod = bb.finalize()
@@ -261,15 +269,22 @@ def test_dispatch_argsort_cuda():
             bb.emit_func_output(out)
         with bb.function("foo2", (y,), {"global_symbol": "foo2"}):
             with bb.dataflow():
-                out = bb.emit_te(
-                    topi.cuda.argsort_thrust
-                    if can_use_thrust(target, "tvm.contrib.thrust.sort")
-                    else topi.cuda.argsort,
-                    y,
-                    0,
-                    False,
-                    "int64",
-                )
+                if can_use_thrust(target, "tvm.contrib.thrust.sort"):
+                    workspace = bb.emit(
+                        relax.op.builtin.alloc_tensor(
+                            R.shape([4194352]), R.dtype("uint8"), 
R.prim_value(0), R.str("global")
+                        )
+                    )
+                    out = bb.emit_te(
+                        topi.cuda.argsort_thrust,
+                        y,
+                        axis=0,
+                        is_ascend=False,
+                        dtype="int64",
+                        workspace=workspace,
+                    )
+                else:
+                    out = bb.emit_te(topi.cuda.argsort, y, axis=0, 
is_ascend=False, dtype="int64")
                 out = bb.emit_output(out)
             bb.emit_func_output(out)
     expected_mod = bb.finalize()

Reply via email to