This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 5ad8941252 [Unity] Add pass to allocate big workspace and pass it to
all functions that need temp storage (#14802)
5ad8941252 is described below
commit 5ad8941252fd0c3e7596863f594b0bb7b1bc5243
Author: masahi <[email protected]>
AuthorDate: Wed May 10 09:49:43 2023 +0900
[Unity] Add pass to allocate big workspace and pass it to all functions
that need temp storage (#14802)
* Add workspace allocation and rewriting pass for CUTLASS
* fix when workspace is not needed
* wip
* rename to Allocateworkspace
* minor
* minor
* fixed test
* add test
* add doc
* black
* zeros -> alloc_tensor for workspace
---
include/tvm/relax/expr.h | 2 +
python/tvm/contrib/cutlass/attention_operation.py | 9 +-
python/tvm/contrib/cutlass/build.py | 19 +-
python/tvm/contrib/cutlass/gen_tensor_op.py | 9 +-
python/tvm/relax/backend/contrib/cutlass.py | 56 +++++-
python/tvm/relax/transform/transform.py | 15 ++
src/relax/backend/vm/codegen_vm.cc | 6 +-
src/relax/ir/block_builder.cc | 2 +-
src/relax/op/op_common.h | 3 +
src/relax/transform/allocate_workspace.cc | 199 +++++++++++++++++++++
tests/python/relax/test_codegen_cutlass.py | 38 ++--
.../relax/test_transform_allocate_workspace.py | 132 ++++++++++++++
12 files changed, 448 insertions(+), 42 deletions(-)
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 0788193ee7..f090610019 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -983,6 +983,8 @@ constexpr const char* kCodegen = "Codegen";
constexpr const char* kComposite = "Composite";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass.
*/
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
+/*! \brief The required workspace for an external function. */
+constexpr const char* kWorkspaceSize = "WorkspaceSize";
} // namespace attr
/*! \brief The extern function, which can represent packed function. */
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index 57c9ef4f91..c728f7fe4b 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -98,10 +98,7 @@ def instantiate_attention_template(attrs):
p.output_ptr = reinterpret_cast<T *>(out0->data);
p.output_accum_ptr = nullptr;
if (Attention::kNeedsOutputAccumulatorBuffer) {
- cudaMalloc(
- &p.output_accum_ptr,
- ${output_size} * sizeof(Attention::output_accum_t)
- );
+ p.output_accum_ptr = static_cast<float*>(${workspace}->data);
}
p.num_heads = ${num_heads}; // N
@@ -131,10 +128,6 @@ def instantiate_attention_template(attrs):
CHECK(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
-
- if (Attention::kNeedsOutputAccumulatorBuffer) {
- cudaFree(p.output_accum_ptr);
- }
"""
template = substitute_template(
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 389dbf3e5c..519754d407 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -791,31 +791,38 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
arg["arg0_dtype"] = signature["arg0_dtype"]
arg["arg1_shape"] = q_shape = signature["arg1_shape"]
- if "arg2_shape" not in signature:
+ if "arg3_shape" not in signature:
+ # arg0: qkv, arg1: shape, arg2: workspace
arg["arg2_shape"] = k_shape = signature["arg1_shape"]
arg["arg3_shape"] = v_shape = signature["arg1_shape"]
else:
- assert "arg3_shape" in signature
+ # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4:
workspace
arg["arg2_shape"] = k_shape = signature["arg2_shape"]
arg["arg3_shape"] = v_shape = signature["arg3_shape"]
- if "arg4_dtype" in signature:
+ if "arg5_dtype" in signature:
+ # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4:
bias, arg5: workspace
arg["bias_dtype"] = signature["arg4_dtype"]
- if "arg4_shape" in signature:
+ if "arg5_shape" in signature:
arg["bias_shape"] = signature["arg4_shape"]
+
qkv_layout = "qkv_stacked"
else:
+ # arg0: q, arg1: k, arg2: v, arg3: bias, arg4: workspace
arg["arg0_shape"] = q_shape = signature["arg0_shape"]
arg["arg1_shape"] = k_shape = signature["arg1_shape"]
arg["arg2_shape"] = v_shape = signature["arg2_shape"]
arg["arg0_dtype"] = signature["arg0_dtype"]
arg["arg1_dtype"] = signature["arg1_dtype"]
arg["arg2_dtype"] = signature["arg2_dtype"]
- if "arg3_dtype" in signature:
+
+ if "arg4_dtype" in signature:
arg["bias_dtype"] = signature["arg3_dtype"]
- if "arg3_shape" in signature:
+ if "arg4_shape" in signature:
arg["bias_shape"] = signature["arg3_shape"]
+
qkv_layout = "default"
+
out_shape = signature["ret_shape"]
out_dtype = signature["ret_dtype"]
num_batches, num_queries, num_heads, head_dim = q_shape
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 5e5ac621ef..f94d7ef467 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -727,15 +727,20 @@ def instantiate_template(func_name, annotations,
func_args):
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False
attrs["qkv_layout"] = annotations["qkv_layout"]
+
+ for arg in func_args:
+ if "workspace" in arg:
+ attrs["workspace"] = arg
+
if attrs["qkv_layout"] == "default":
attrs["query"] = func_args[0]
attrs["key"] = func_args[1]
attrs["value"] = func_args[2]
- if len(func_args) > 3:
+ if len(func_args) > 4: # +1 for workspace, the last arg
attrs["bias"] = func_args[3]
elif attrs["qkv_layout"] == "qkv_stacked":
attrs["qkv"] = func_args[0]
- if len(func_args) > 4:
+ if len(func_args) > 5: # +1 for workspace, the last arg
attrs["bias"] = func_args[4]
else:
raise NotImplementedError()
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 36f43c6c21..d5940ac5e4 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -16,11 +16,13 @@
# under the License.
"""Pattern table for CUTLASS backend"""
-
+import operator
from typing import Mapping, Sequence
+from functools import reduce
+import tvm
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import DataflowVar, Var, transform, Call
+from tvm.relax import DataflowVar, Var, transform, Call, PyExprMutator,
expr_functor, Function
from tvm.relax.transform import PatternCheckContext
from tvm.relax.dpl import rewrite_call
@@ -373,6 +375,46 @@ register_patterns(
_REWRITE_PATTERNS = [*attention_rewrite_patterns()]
+@expr_functor.mutator
+class WorkspaceAnnotator(PyExprMutator):
+ """Annotate a workspace requirement for each CUTLASS-offloaded function."""
+
+ def __init__(self, mod):
+ super().__init__(mod)
+
+ def visit_function_(self, f):
+ if f.attrs is None or "Composite" not in f.attrs:
+ body = super().visit_expr(f.body)
+ new_f = Function(f.params, body, f.ret_struct_info, f.attrs,
f.span)
+
+ if f.attrs and "global_symbol" in f.attrs and "cutlass" in
f.attrs["global_symbol"]:
+ composite_func = body.blocks[0].bindings[0].value
+ if "WorkspaceSize" in composite_func.attrs:
+ return new_f.with_attr("WorkspaceSize",
composite_func.attrs["WorkspaceSize"])
+
+ return new_f
+
+ if "attention" in f.attrs["Composite"]:
+ # Workspace is needed only for larger head sizes, but for
simplicity we always allocate.
+ out_dtype = f.ret_struct_info.dtype
+ out_size_1d = reduce(operator.mul, f.ret_struct_info.shape, 1)
+ # This needs to be in sync with the actual value that the kernel
expects.
+ workspace_size_bytes = out_size_1d * {"float16": 2, "float32":
4}[out_dtype]
+ return f.with_attr("WorkspaceSize", workspace_size_bytes)
+
+ return f
+
+
[email protected]_pass(opt_level=0)
+def annotate_workspace(mod, _):
+ """Pass to annotate a workspace requirement for each CUTLASS-offloaded
function."""
+ annotator = WorkspaceAnnotator(mod)
+ for name, f in mod.functions.items():
+ new_f = annotator.visit_expr(f)
+ mod.update_func(name, new_f)
+ return mod
+
+
def partition_for_cutlass(mod, annotate_codegen=True):
"""
Partition the input module into CUTLASS-supported subgraphs.
@@ -396,6 +438,12 @@ def partition_for_cutlass(mod, annotate_codegen=True):
for pattern, rewriter in _REWRITE_PATTERNS:
mod["main"] = rewrite_call(pattern, rewriter, mod["main"])
patterns = get_patterns_with_prefix("cutlass")
- return transform.FuseOpsByPattern(
- patterns, bind_constants=False, annotate_codegen=annotate_codegen
+ return tvm.transform.Sequential(
+ [
+ transform.FuseOpsByPattern(
+ patterns, bind_constants=False,
annotate_codegen=annotate_codegen
+ ),
+ annotate_workspace,
+ transform.AllocateWorkspace(),
+ ]
)(mod)
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index b0d2710a99..508e8bccba 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1016,6 +1016,21 @@ def RewriteCUDAGraph() -> tvm.ir.transform.Pass:
return _ffi_api.RewriteCUDAGraph() # type: ignore
+def AllocateWorkspace() -> tvm.ir.transform.Pass:
+ """Allocate a workspace, represented by a tensor of size big enough for
all external
+ functions that require a temporary storage, and append it to the arguments
of external
+ functions.
+
+ An external function can specify its workspace requirement by the
kWorkspaceSize attribute.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ The registered pass for allocating workspace.
+ """
+ return _ffi_api.AllocateWorkspace() # type: ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index 09f21cf751..c44300907f 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -70,11 +70,11 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const
Expr&)> {
IRModule res_mod = IRModule(Map<GlobalVar, BaseFunc>());
CodeGenVM codegen(builder, mod);
// Remove relax function and turn into TIR func.
- for (auto& p : mod->functions) {
- if (auto* func = p.second.as<FunctionNode>()) {
+ for (const auto& [gvar, f] : mod->functions) {
+ if (auto* func = f.as<FunctionNode>()) {
codegen.Codegen(GetRef<Function>(func));
} else {
- res_mod->Add(p.first, p.second);
+ res_mod->Add(gvar, f);
}
}
return res_mod;
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index fe9e9bf8a5..5f9ce63c97 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -82,7 +82,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
while (context_mod_->ContainGlobalVar(func_name)) {
func_name = GetUniqueName(func_name_hint);
}
- GlobalVar gvar = GlobalVar(func_name);
+ GlobalVar gvar(func_name);
StructInfo finfo;
if (func->struct_info_.defined()) {
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 8f5d1fbaa1..f7cff638cd 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -346,6 +346,9 @@ inline Optional<ShapeExpr>
CheckNdimPerLayoutAndGetShape(const Call& call, const
return NullOpt;
}
+Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm
dtype);
+Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm
dtype);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/transform/allocate_workspace.cc
b/src/relax/transform/allocate_workspace.cc
new file mode 100644
index 0000000000..b20f982efb
--- /dev/null
+++ b/src/relax/transform/allocate_workspace.cc
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/allocate_workspace.cc
+ * \brief Allocate a workspace and append it to the arguments of external
functions, to
+ * satisfy their temporary storage requirement.
+ */
+
+#include <tvm/ir/name_supply.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+
+#include "../op/op_common.h"
+
+namespace tvm {
+namespace relax {
+
+class ExternFunctionRewriter : ExprMutator {
+ public:
+ using ExprMutator::VisitExpr_;
+
+ ExternFunctionRewriter(IRModule mod, size_t max_workspace_size)
+ : ExprMutator(mod), name_sup_(""),
max_workspace_size_(max_workspace_size) {}
+
+ std::unordered_map<const GlobalVarNode*, Function> Run() {
+ std::unordered_map<const GlobalVarNode*, Function> ret;
+ for (const auto& [gvar, f] : builder_->GetContextIRModule()->functions) {
+ if (f->GetAttr<Integer>(attr::kWorkspaceSize)) {
+ ret[gvar.get()] = Downcast<Function>(VisitExpr(f));
+ }
+ }
+ return ret;
+ }
+
+ Expr VisitExpr_(const FunctionNode* func_node) override {
+ if (!func_node->GetAttr<String>(attr::kCodegen) &&
+ !func_node->GetAttr<String>(attr::kComposite)) {
+ return ExprMutator::VisitExpr_(func_node);
+ }
+ if (auto workspace = func_node->GetAttr<Integer>(attr::kWorkspaceSize)) {
+ // Append the workspace parameter to this function.
+ Array<Var> new_params = func_node->params;
+
+ auto sinfo = TensorStructInfo(ShapeExpr({Integer(max_workspace_size_)}),
DataType::UInt(8));
+ Var workspace_param(name_sup_->FreshName("workspace"), sinfo);
+
+ if (func_node->GetAttr<String>(attr::kCodegen)) {
+ workspace_var_param_ = workspace_param;
+ }
+
+ new_params.push_back(workspace_param);
+ return Function(new_params, VisitExpr(func_node->body),
func_node->ret_struct_info,
+ func_node->attrs);
+ }
+ return ExprMutator::VisitExpr_(func_node);
+ }
+
+ Expr VisitExpr_(const CallNode* call_node) override {
+ auto new_op = VisitExpr(call_node->op);
+ if (auto var = new_op.as<Var>()) {
+ if (auto callee = builder_->LookupBinding(var.value());
+ callee && callee->IsInstance<FunctionNode>() &&
+
Downcast<Function>(callee.value())->GetAttr<String>(attr::kComposite)) {
+ // Append the workspace argument to this call. The callee should have
been updated to accept
+ // a workspace as the last parameter.
+ auto new_args = call_node->args;
+ ICHECK(workspace_var_param_.defined());
+ new_args.push_back(workspace_var_param_);
+ return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args,
call_node->span);
+ }
+ }
+ return ExprMutator::VisitExpr_(call_node);
+ }
+
+ private:
+ NameSupply name_sup_;
+ /*! \brief A variable that represents the workspace parameter passed from
main. */
+ Var workspace_var_param_;
+ size_t max_workspace_size_ = 0;
+};
+
+class WorkspaceProvider : ExprMutator {
+ public:
+ explicit WorkspaceProvider(IRModule mod) : ExprMutator(mod), mod_(mod) {}
+ using ExprMutator::VisitBindingBlock_;
+ using ExprMutator::VisitExpr_;
+
+ IRModule Run() {
+ for (const auto& [gvar, f] : mod_->functions) {
+ if (auto workspace = f->GetAttr<Integer>(relax::attr::kWorkspaceSize)) {
+ max_workspace_size_ = std::max<size_t>(max_workspace_size_,
workspace.value()->value);
+ }
+ }
+
+ if (max_workspace_size_ == 0) {
+ return mod_;
+ }
+
+ auto new_funcs = relax::ExternFunctionRewriter(mod_,
max_workspace_size_).Run();
+
+ for (const auto& [gvar, f] : new_funcs) {
+ auto new_gvar = builder_->AddFunction(f, gvar->name_hint);
+ // This is only required since the well-formed check requires
kGlobalSymbol to be the same
+ // as the actual name of the global variable.
+ builder_->UpdateFunction(new_gvar,
+ WithAttr(f, tvm::attr::kGlobalSymbol,
new_gvar->name_hint));
+ gvar_map_[gvar] = new_gvar;
+ builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
+ }
+
+ auto gvar = mod_->GetGlobalVar("main");
+ auto func = Downcast<Function>(mod_->Lookup(gvar));
+ auto new_func =
+ Function(func->params, VisitExpr(func->body), func->ret_struct_info,
func->attrs);
+ builder_->UpdateFunction(gvar, new_func);
+ return builder_->GetContextIRModule();
+ }
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final {
+ builder_->BeginDataflowBlock();
+ if (!workspace_var_main_.defined()) {
+ auto shape = ShapeExpr({Integer(max_workspace_size_)});
+ auto ty = DataTypeImm(DataType::UInt(8));
+ auto storage = MakeVMAllocStorage(shape, PrimValue::Int64(0), ty);
+ auto workspace = MakeVMAllocTensor(storage, PrimValue::Int64(0), shape,
ty);
+ workspace_var_main_ = builder_->Emit(workspace, "workspace_main");
+ }
+ for (const auto& binding : block_node->bindings) {
+ this->VisitBinding(binding);
+ }
+ return builder_->EndBlock();
+ }
+
+ Expr VisitExpr_(const GlobalVarNode* gvar_node) override {
+ if (gvar_map_.count(gvar_node)) {
+ return gvar_map_[gvar_node];
+ }
+ return ExprMutator::VisitExpr_(gvar_node);
+ }
+
+ Expr VisitExpr_(const CallNode* call_node) override {
+ auto new_op = VisitExpr(call_node->op);
+
+ if (auto gv = new_op.as<GlobalVar>()) {
+ auto callee = builder_->GetContextIRModule()->Lookup(gv.value());
+ if (callee->HasNonzeroAttr(attr::kWorkspaceSize)) {
+ auto new_args = call_node->args;
+ ICHECK(workspace_var_main_.defined());
+ new_args.push_back(workspace_var_main_);
+ return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args,
call_node->span);
+ }
+ }
+
+ return ExprMutator::VisitExpr_(call_node);
+ }
+
+ private:
+ IRModule mod_;
+ /*! \brief A variable that represents the workspace created at the beginning
of main. */
+ Var workspace_var_main_;
+ size_t max_workspace_size_ = 0;
+ /*! \brief A map from old global variables representing a function with
workspace requirement to
+ * the new ones that are transformed to take an additional workspace
parameter. This is only
+ * needed since the struct info of the global variables changes between
transformation. */
+ std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
+};
+
+} // namespace relax
+
+namespace transform {
+
+Pass AllocateWorkspace() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule m, PassContext pc) { return
relax::WorkspaceProvider(m).Run(); };
+
+ return CreateModulePass(pass_func, 0, "AllocateWorkspace", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AllocateWorkspace").set_body_typed(AllocateWorkspace);
+
+} // namespace transform
+} // namespace tvm
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 45d66b3704..7a831c094e 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -83,9 +83,9 @@ cutlass_enabled = pytest.mark.skipif(
pytestmark = [cutlass_enabled]
-def build_and_run(mod, inputs_np, target, legalize=False):
+def build_and_run(mod, inputs_np, target, legalize=True):
if legalize:
- mod = relax.transform.LegalizeOps()(mod)
+ mod = relax.transform.LegalizeOps()(mod) # For cpu reference, nop for
cutlass.
dev = tvm.device(target, 0)
ex = relax.build(mod, target)
@@ -95,11 +95,13 @@ def build_and_run(mod, inputs_np, target, legalize=False):
return f(*inputs).numpy()
-def get_result_with_relax_cutlass_offload(mod, *args,
assert_all_bindings_fused=True):
+def get_result_with_relax_cutlass_offload(
+ mod, *args, assert_all_bindings_fused=True, num_final_bindings=1
+):
mod = partition_for_cutlass(mod)
if assert_all_bindings_fused:
- assert len(mod["main"].body.blocks[0].bindings) == 1
+ assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80,
"find_first_valid": True}})
mod = codegen_pass(mod)
@@ -116,7 +118,7 @@ def test_kernel_sharing():
out = get_result_with_relax_cutlass_offload(
Conv2dx2, data_np, weight1_np, weight2_np,
assert_all_bindings_fused=False
)
- ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm",
legalize=True)
+ ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm")
np.testing.assert_equal(out, ref)
@@ -243,7 +245,7 @@ def test_conv2d_offload(data_shape, weight_shape, dtype,
epilogue, residual_bloc
)
out = get_result_with_relax_cutlass_offload(mod, *args)
- ref = build_and_run(mod, args, "llvm", legalize=True)
+ ref = build_and_run(mod, args, "llvm")
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
@@ -369,7 +371,7 @@ def test_matmul_offload(
residual_activation=residual_activation,
)
out = get_result_with_relax_cutlass_offload(mod, *args)
- ref = build_and_run(mod, args, "llvm", legalize=True)
+ ref = build_and_run(mod, args, "llvm")
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -616,7 +618,7 @@ def test_attention_offload(attention_size, attention_dtype):
)
mod = get_relax_attention_module(q, k, v)
- out = get_result_with_relax_cutlass_offload(mod, q, k, v)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -645,7 +647,7 @@ def test_attention_bias_offload(attention_bias_size):
)
mod = get_relax_attention_module(q, k, v, bias)
- out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias,
num_final_bindings=3)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -674,9 +676,9 @@ def test_attention_scale_offload(attention_scale_size,
attention_scale):
mod = get_relax_attention_module(q, k, v, bias, attention_scale)
if bias is None:
- out = get_result_with_relax_cutlass_offload(mod, q, k, v)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
else:
- out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias,
num_final_bindings=3)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -777,9 +779,9 @@ def
test_stacked_attention_split_offload(stacked_attention_size):
)
if bias is None:
- out = get_result_with_relax_cutlass_offload(mod, qkv)
+ out = get_result_with_relax_cutlass_offload(mod, qkv,
num_final_bindings=3)
else:
- out = get_result_with_relax_cutlass_offload(mod, qkv, bias)
+ out = get_result_with_relax_cutlass_offload(mod, qkv, bias,
num_final_bindings=3)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -795,9 +797,9 @@ def
test_stacked_attention_strided_slice_offload(stacked_attention_size):
qkv, b, s, n, h, h_v, "strided_slice", bias, scale,
single_shape=single_shape
)
if bias is None:
- out = get_result_with_relax_cutlass_offload(mod, qkv)
+ out = get_result_with_relax_cutlass_offload(mod, qkv,
num_final_bindings=3)
else:
- out = get_result_with_relax_cutlass_offload(mod, qkv, bias)
+ out = get_result_with_relax_cutlass_offload(mod, qkv, bias,
num_final_bindings=3)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -966,8 +968,8 @@ def test_attention_rewrite_offload(attention_rewrite_size):
expected_out = build_and_run(expected_mod, [q, k, v], "cuda")
tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5,
atol=1e-5)
else:
- original_out = build_and_run(original_mod, [q, k, v, bias], "cuda")
- expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda")
+ original_out = build_and_run(original_mod, [q, k, v, bias], "cuda",
legalize=False)
+ expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda",
legalize=False)
tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5,
atol=1e-5)
@@ -1043,7 +1045,7 @@ def test_layer_norm(data_shape, dtype, axes):
gamma = np.random.randn(data_shape[-1]).astype(dtype)
beta = np.random.randn(data_shape[-1]).astype(dtype)
out = build_and_run(mod, [inp, gamma, beta], "cuda")
- ref = build_and_run(Module, [inp, gamma, beta], "llvm", legalize=True)
+ ref = build_and_run(Module, [inp, gamma, beta], "llvm")
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
diff --git a/tests/python/relax/test_transform_allocate_workspace.py
b/tests/python/relax/test_transform_allocate_workspace.py
new file mode 100644
index 0000000000..7ffbd01b05
--- /dev/null
+++ b/tests/python/relax/test_transform_allocate_workspace.py
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+
+
[email protected]_module
+class Module:
+ @R.function
+ def fused_relax_nn_attention_cutlass(
+ q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+ R.func_attr(
+ {
+ "Codegen": "cutlass",
+ "WorkspaceSize": 65536,
+ "global_symbol": "fused_relax_nn_attention_cutlass",
+ }
+ )
+
+ @R.function
+ def gv(
+ q_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ k_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ v_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+ R.func_attr({"Composite": "cutlass.attention", "Primitive": 1,
"WorkspaceSize": 65536})
+ with R.dataflow():
+ gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") =
R.nn.attention(
+ q_1, k_1, v_1, scale=None
+ )
+ R.output(gv_2)
+ return gv_2
+
+ gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v)
+ return gv1
+
+ @R.function
+ def main(
+ q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+ cls = Module
+ with R.dataflow():
+ gv: R.Tensor((32, 8, 16, 8), dtype="float16") =
cls.fused_relax_nn_attention_cutlass(
+ q, k, v
+ )
+ R.output(gv)
+ return gv
+
+
[email protected]_module
+class Expected:
+ @R.function
+ def fused_relax_nn_attention_cutlass1(
+ q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ workspace: R.Tensor((65536,), dtype="uint8"),
+ ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+ R.func_attr(
+ {
+ "Codegen": "cutlass",
+ "WorkspaceSize": 65536,
+ "global_symbol": "fused_relax_nn_attention_cutlass1",
+ }
+ )
+
+ @R.function
+ def gv(
+ q_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ k_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ v_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ workspace_1: R.Tensor((65536,), dtype="uint8"),
+ ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+ R.func_attr({"Composite": "cutlass.attention", "Primitive": 1,
"WorkspaceSize": 65536})
+ with R.dataflow():
+ gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") =
R.nn.attention(
+ q_1, k_1, v_1, scale=None
+ )
+ R.output(gv_2)
+ return gv_2
+
+ gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v, workspace)
+ return gv1
+
+ @R.function
+ def main(
+ q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+ ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+ cls = Expected
+ with R.dataflow():
+ lv: R.Object = R.vm.alloc_storage(R.shape([65536]),
R.prim_value(0), R.dtype("uint8"))
+ workspace_main: R.Tensor((65536,), dtype="uint8") =
R.vm.alloc_tensor(
+ lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+ )
+ gv: R.Tensor((32, 8, 16, 8), dtype="float16") =
cls.fused_relax_nn_attention_cutlass1(
+ q, k, v, workspace_main
+ )
+ R.output(gv)
+ return gv
+
+
+def test_single_attention():
+ rewritten = relax.transform.AllocateWorkspace()(Module)
+ tvm.ir.assert_structural_equal(rewritten, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()