This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 521ab47edf [MSC] Reconstruct tensorrt module (#17344)
521ab47edf is described below
commit 521ab47edf1a2b25b6614d64df5d9f6133dfa329
Author: Archermmt <[email protected]>
AuthorDate: Sun Sep 8 18:40:49 2024 +0800
[MSC] Reconstruct tensorrt module (#17344)
* reconstruct tensorrt
* format fix
---
python/tvm/contrib/msc/core/frontend/translate.py | 2 +-
.../msc/framework/tensorrt/frontend/translate.py | 5 +-
.../msc/framework/tensorrt/transform/pattern.py | 31 +-
.../msc/framework/tensorrt/transform/transform.py | 13 +-
src/contrib/msc/core/transform/rewrite_utils.cc | 58 ++
src/contrib/msc/core/transform/rewrite_utils.h | 72 +++
src/contrib/msc/core/utils.cc | 19 +-
src/contrib/msc/core/utils.h | 4 +-
.../msc/framework/tensorrt/tensorrt_opcode.cc | 6 +-
.../msc/framework/tensorrt/transform_tensorrt.cc | 668 +++++++++++++--------
.../contrib/test_msc/test_translate_tensorrt.py | 47 +-
11 files changed, 642 insertions(+), 283 deletions(-)
diff --git a/python/tvm/contrib/msc/core/frontend/translate.py
b/python/tvm/contrib/msc/core/frontend/translate.py
index 63b4424524..cea021ade3 100644
--- a/python/tvm/contrib/msc/core/frontend/translate.py
+++ b/python/tvm/contrib/msc/core/frontend/translate.py
@@ -330,7 +330,7 @@ def byoc_partition(
msc_mod = _partition_mod(mod)
func_names = [var.name_hint for var, func in msc_mod.functions.items() if
_is_target_func(func)]
- if not trans_config.get("allow_incomplete", False):
+ if trans_config.get("as_complete", True):
assert len(func_names) == 1, "More than 1 target func is found: " +
str(msc_mod)
BYOCChecker().check(func_names, msc_mod[entry])
diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
index 8758fdb630..4a02b02728 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
@@ -49,7 +49,10 @@ def transform_for_tensorrt(
return tvm.transform.Sequential(
[
msc_transform.SetExprName(),
- trt_transform.TransformTensorRT(trans_config.get("version")),
+ trt_transform.TransformTensorRT(
+ version=trans_config.get("version"),
+ linear_to_conv=trans_config.get("linear_to_conv", False),
+ ),
relax.transform.FoldConstant(),
]
)(mod)
diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
index 8eea3f7081..17aee690e3 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
@@ -136,12 +136,22 @@ def _check_expr(expr: relax.Expr, dtypes: Tuple[str] =
None) -> bool:
return True
if isinstance(expr, relax.Tuple):
return all(_check_expr(field) for field in expr.fields)
- if any(i < 0 for i in expr.struct_info.shape.values):
- return False
- dtypes = dtypes or ("float32", "float16")
- if expr.struct_info.dtype not in dtypes:
- return False
- return True
+ dtypes = dtypes or ("float32", "float16", "int64", "int32", "bool")
+
+ def _check(sinfo):
+ if not sinfo.shape or sinfo.dtype not in dtypes:
+ return False
+ unknown_dim = 0
+ for s in sinfo.shape.values:
+ if isinstance(s, (tvm.tir.Var, tvm.tir.Any)):
+ unknown_dim += 1
+ elif isinstance(s, tvm.tir.IntImm) and s < 0:
+ unknown_dim += 1
+ return unknown_dim <= 1
+
+ if isinstance(expr.struct_info, relax.TupleStructInfo):
+ return all(_check(s) for s in expr.struct_info.fields)
+ return _check(expr.struct_info)
def _basic_check(context: PatternCheckContext) -> bool:
@@ -216,8 +226,7 @@ def _reshape_check(context: PatternCheckContext) -> bool:
Whether the pattern is correct.
"""
- dtypes = ("float32", "float16", "int32")
- if any(not _check_expr(context.annotated_expr[key], dtypes) for key in
["input_0", "out"]):
+ if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0",
"out"]):
return False
return True
@@ -323,16 +332,18 @@ def get_patterns(target) -> List[Pattern]:
"nn.avg_pool2d": ["input"],
"nn.conv2d": ["input", "constant"],
"nn.max_pool2d": ["input"],
+ "astype": ["input"],
"concat": ["input"],
"clip": ["input", "input", "input"],
"image.resize2d": ["input", "input"],
"matmul": ["input", "input"],
"permute_dims": ["input"],
- "strided_slice": ["input"],
+ "strided_slice": ["input", "input", "input", "input", "input"],
+ "topk": ["input"],
}
activation_ops = ["nn.relu", "nn.softmax", "sigmoid", "tanh"]
reduce_ops = ["max", "min", "mean", "sum"]
- unary_ops = ["cos", "exp", "negative", "round", "sin", "square", "sqrt",
"tan"]
+ unary_ops = ["cos", "erf", "exp", "negative", "round", "sin", "square",
"sqrt", "tan"]
elemwise_ops = [
"add",
"divide",
diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py
b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py
index d6f15c43da..cf4d4b9f33 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py
@@ -25,18 +25,25 @@ from tvm.contrib.msc.core.utils import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils
-def TransformTensorRT(version: List[int] = None) -> tvm.ir.transform.Pass:
+def TransformTensorRT(
+ version: List[int] = None, linear_to_conv: bool = False
+) -> tvm.ir.transform.Pass:
"""Transform the Function to fit TensorRT.
Parameters
----------
version: list<int>
The tensorrt version.
+ linear_to_conv: bool
+ Whether to cast linear to conv2d
Returns
-------
ret: tvm.ir.transform.Pass
"""
- version = version or msc_utils.get_version(MSCFramework.TENSORRT)
- return relax_api.TransformTensorRT(version) # type: ignore
+ config = {
+ "version": version or msc_utils.get_version(MSCFramework.TENSORRT),
+ "linear_to_conv": linear_to_conv,
+ }
+ return relax_api.TransformTensorRT(msc_utils.dump_dict(config)) # type:
ignore
diff --git a/src/contrib/msc/core/transform/rewrite_utils.cc
b/src/contrib/msc/core/transform/rewrite_utils.cc
new file mode 100644
index 0000000000..20e4821e6f
--- /dev/null
+++ b/src/contrib/msc/core/transform/rewrite_utils.cc
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/rewrite_utils.cc
+ */
+#include "rewrite_utils.h"
+
+#include <set>
+#include <string>
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+Var RewriteUtils::ReEmit(BlockBuilder builder, const String& name, const Expr&
expr) {
+ expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name);
+ return builder->Emit(expr, name);
+}
+
+Var RewriteUtils::MakeCall(BlockBuilder builder, const String& name, Expr op,
Array<Expr> args,
+ Attrs attrs) {
+ const auto& call = Call(op, args, attrs);
+ return ReEmit(builder, name, call);
+}
+
+Expr RewriteUtils::MakeConstant(BlockBuilder builder, const String& name,
double value,
+ const DataType& dtype, size_t ndim) {
+ const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value));
+ Span span = SpanUtils::CreateWithAttr(msc_attr::kName, name);
+ const auto& constant = Constant(data, NullOpt, span);
+ if (ndim == 0) {
+ return constant;
+ }
+ static const Op& reshape_op = Op::Get("relax.reshape");
+ Array<PrimExpr> exp_shape(ndim, Integer(1));
+ return MakeCall(builder, name + "_exp", reshape_op, {constant,
ShapeExpr(exp_shape)});
+}
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
diff --git a/src/contrib/msc/core/transform/rewrite_utils.h
b/src/contrib/msc/core/transform/rewrite_utils.h
new file mode 100644
index 0000000000..2693a6ccd2
--- /dev/null
+++ b/src/contrib/msc/core/transform/rewrite_utils.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/rewrite_utils.h
+ * \brief Common utilities for rewrite.
+ */
+#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_
+#define TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_
+
+#include <tvm/ir/source_map.h>
+#include <tvm/relax/expr.h>
+
+#include <vector>
+
+#include "../../../../relax/transform/utils.h"
+#include "../../../../support/scalars.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using Expr = tvm::RelayExpr;
+using namespace tvm::relax;
+
+/*!
+ * \brief Utils for Layout.
+ */
+class RewriteUtils {
+ public:
+ /*!
+ * \brief Emit call with span name.
+ * \return The emitted var.
+ */
+ TVM_DLL static Var ReEmit(BlockBuilder builder, const String& name, const
Expr& expr);
+
+ /*!
+ * \brief Make and emit a call binding with span.
+ * \return The emitted var.
+ */
+ TVM_DLL static Var MakeCall(BlockBuilder builder, const String& name, Expr
op, Array<Expr> args,
+ Attrs attrs = Attrs());
+
+ /*!
+ * \brief Make and emit a (shaped)constant with span.
+ * \return The constant/reshape.
+ */
+ TVM_DLL static Expr MakeConstant(BlockBuilder builder, const String& name,
double value,
+ const DataType& dtype, size_t ndim = 0);
+};
+
+} // namespace msc
+} // namespace contrib
+} // namespace tvm
+#endif // TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
index c6e74d4284..1e846b0b3a 100644
--- a/src/contrib/msc/core/utils.cc
+++ b/src/contrib/msc/core/utils.cc
@@ -507,12 +507,25 @@ const String ExprUtils::GetSpanName(const Expr& expr,
const String& suffix) {
return name;
}
-const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr) {
- const auto& shape_opt =
Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->GetShape();
- ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr;
+const Array<PrimExpr> ExprUtils::GetShape(const relax::TensorStructInfo&
sinfo, bool as_int) {
+ const auto& shape_opt = sinfo->GetShape();
+ if (!shape_opt.defined()) {
+ return Array<PrimExpr>();
+ }
+ if (as_int) {
+ Array<PrimExpr> shape;
+ for (const auto& s : shape_opt.value()) {
+ shape.push_back(s->IsInstance<IntImmNode>() ? s : Integer(-1));
+ }
+ return shape;
+ }
return shape_opt.value();
}
+const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr, bool as_int) {
+ return
GetShape(Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr)), as_int);
+}
+
const DataType ExprUtils::GetDataType(const Expr& expr) {
return Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->dtype;
}
diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h
index d7758cc23d..7fb9c87a99 100644
--- a/src/contrib/msc/core/utils.h
+++ b/src/contrib/msc/core/utils.h
@@ -398,7 +398,9 @@ class ExprUtils {
* \brief Get shape of expr.
* \return The shape.
*/
- TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr);
+ TVM_DLL static const Array<PrimExpr> GetShape(const relax::TensorStructInfo&
sinfo,
+ bool as_int = true);
+ TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr, bool as_int
= true);
/*!
* \brief Get dtype of expr.
diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc
b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc
index a080fdd778..d90cdc35d1 100644
--- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc
+++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc
@@ -92,6 +92,8 @@ const String TensorRTOpCode::DType(const DataType& dtype) {
dtype_enum = "DataType::kINT8";
} else if (dtype_name == "int32") {
dtype_enum = "DataType::kINT32";
+ } else if (dtype_name == "int64") {
+ dtype_enum = "DataType::kINT32";
} else if (dtype_name == "float16") {
dtype_enum = "DataType::kHALF";
} else if (dtype_name == "float32") {
@@ -267,7 +269,7 @@ class TensorRTAstypeCodeGen : public TensorRTOpCode {
void CodeGenBuild() final {
stack_.op_call()
.op_input_arg()
- .func_call("setOutput", NullOpt, DocUtils::ToPtr(IdxNode()))
+ .func_call("setOutputType", NullOpt, DocUtils::ToPtr(IdxNode()))
.call_arg(0)
.op_dtype_arg(node()->OutputAt(0)->dtype);
}
@@ -661,7 +663,7 @@ class TensorRTTopkCodeGen : public TensorRTOpCode {
protected:
void CodeGenBuild() final {
- const String& symbol = node()->GetTypeAttr<bool>("is_asend") ? "MIN" :
"MAX";
+ const String& symbol = node()->GetTypeAttr<bool>("largest") ? "MAX" :
"MIN";
stack_.op_call()
.op_input_arg()
.call_arg("TopKOperation::k" + symbol)
diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
index 3f85309cd8..542e15d06c 100644
--- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
+++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
@@ -22,83 +22,101 @@
* \brief Pass for transform the function to tensorrt.
*/
+#include <tvm/relax/attrs/sorting.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include "../../../../relax/transform/utils.h"
#include "../../../../support/scalars.h"
+#include "../../core/transform/rewrite_utils.h"
#include "../../core/utils.h"
namespace tvm {
namespace relax {
using namespace tvm::contrib::msc;
-const Array<PrimExpr> GetShape(const Expr& var) {
- const auto& shape_opt =
Downcast<TensorStructInfo>(GetStructInfo(var))->GetShape();
- ICHECK(shape_opt.defined()) << "Shape is not defined for " << var;
- return shape_opt.value();
-}
-
-Var EmitCall(BlockBuilder builder, const Expr& expr, const Span& src_span,
const String& suffix) {
- const auto& name = SpanUtils::GetAttr(src_span, msc_attr::kName) + "_" +
suffix;
- expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name);
- return builder->Emit(expr, name);
-}
-
-Var MakeCall(BlockBuilder builder, const Span& src_span, const String& suffix,
Expr op,
- Array<Expr> args, Attrs attrs = Attrs()) {
- const auto& call = Call(op, args, attrs);
- return EmitCall(builder, call, src_span, suffix);
-}
+struct TensorRTTransConfig {
+ // Whether to cast linear to conv
+ bool linear_to_conv{true};
+ std::vector<size_t> version{0, 0, 0};
+
+ void Load(dmlc::JSONReader* reader) {
+ std::string key;
+ reader->BeginObject();
+ while (reader->NextObjectItem(&key)) {
+ if (key == "linear_to_conv") {
+ reader->Read(&linear_to_conv);
+ } else if (key == "version") {
+ reader->Read(&version);
+ } else {
+ LOG(FATAL) << "Do not support key " << key;
+ }
+ }
+ }
+};
-Expr MakeConstant(double value, const DataType& dtype, const String& name) {
- const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value));
- const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, name);
- return Constant(data, NullOpt, span);
+const TensorRTTransConfig ParseConfig(const String& config_str) {
+ TensorRTTransConfig config;
+ if (config_str.size() > 0) {
+ std::istringstream is(config_str);
+ dmlc::JSONReader reader(&is);
+ reader.Read(&config);
+ }
+ return config;
}
using FRewriteTensorRT =
runtime::TypedPackedFunc<Expr(BlockBuilder builder, const Var& var, const
Call& src_call,
- const Map<Expr, Call>& new_calls, const
Array<Integer>& version)>;
+ const Map<Expr, Call>& new_calls, const
String& config)>;
+
+const Array<PrimExpr> BroadcastShape(const Array<PrimExpr>& src_shape,
+ const Array<PrimExpr>& out_shape) {
+ size_t diff = out_shape.size() - src_shape.size();
+ Array<PrimExpr> leading_shape, tailing_shape;
+ for (size_t i = 0; i < diff; i++) {
+ leading_shape.push_back(Integer(1));
+ }
+ for (const auto& s : src_shape) {
+ tailing_shape.push_back(s);
+ leading_shape.push_back(s);
+ }
+ for (size_t i = 0; i < diff; i++) {
+ tailing_shape.push_back(Integer(1));
+ }
+ if (ArrayUtils::Broadcastable(tailing_shape, out_shape)) {
+ return tailing_shape;
+ }
+ ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape))
+ << "Only support elemwise ops with leading or tailing expand";
+ return leading_shape;
+}
Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call&
src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& shape_a = GetShape(call->args[0]);
- const auto& shape_b = GetShape(call->args[1]);
+ const auto& shape_a = ExprUtils::GetShape(call->args[0]);
+ const auto& shape_b = ExprUtils::GetShape(call->args[1]);
+ const auto& shape_out = ExprUtils::GetShape(var);
static const Op& reshape_op = Op::Get("relax.reshape");
if (shape_a.size() > shape_b.size()) {
- Array<PrimExpr> exp_shape(shape_a.size(), Integer(1));
- if (shape_b.size() == 1) {
- exp_shape.Set(shape_a.size() - 1, shape_b[0]);
- } else if (shape_b.size() == 0) {
- LOG_DEBUG << "Expand scalar argument to " << exp_shape;
- } else {
- LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_b;
- }
- const auto& expand_b = MakeCall(builder, call->span, "expand_b",
reshape_op,
- {call->args[1], ShapeExpr(exp_shape)});
+ const auto& exp_shape = BroadcastShape(shape_b, shape_out);
+ const auto& expand_b =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"expand_b"), reshape_op,
+ {call->args[1], ShapeExpr(exp_shape)});
return Call(call->op, {call->args[0], expand_b}, call->attrs,
call->sinfo_args, call->span);
- }
- if (shape_a.size() < shape_b.size()) {
- Array<PrimExpr> exp_shape(shape_b.size(), Integer(1));
- if (shape_a.size() == 1) {
- exp_shape.Set(shape_b.size() - 1, shape_a[0]);
- } else if (shape_a.size() == 0) {
- LOG_DEBUG << "Expand scalar argument to " << exp_shape;
- } else {
- LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_a;
- }
- const auto& expand_a = MakeCall(builder, call->span, "expand_a",
reshape_op,
- {call->args[0], ShapeExpr(exp_shape)});
+ } else if (shape_a.size() < shape_b.size()) {
+ const auto& exp_shape = BroadcastShape(shape_a, shape_out);
+ const auto& expand_a =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"expand_a"), reshape_op,
+ {call->args[0], ShapeExpr(exp_shape)});
return Call(call->op, {expand_a, call->args[1]}, call->attrs,
call->sinfo_args, call->span);
}
return call;
}
Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
if (new_calls.count(call->args[0]) &&
new_calls[call->args[0]]->op == Op::Get("relax.nn.conv1d")) {
@@ -110,19 +128,20 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var,
const Call& src_call,
if (conv2d->op != Op::Get("relax.nn.conv2d")) {
return call;
}
- const auto& input_shape = GetShape(call->args[0]);
- const auto& bias_shape = GetShape(call->args[1]);
+ const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+ const auto& bias_shape = ExprUtils::GetShape(call->args[1]);
const auto* conv_attrs = conv2d->attrs.as<Conv2DAttrs>();
if (conv_attrs->data_layout == "NCHW") {
// expand bias reshape
Array<PrimExpr> exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1),
bias_shape[2]};
static const Op& reshape_op = Op::Get("relax.reshape");
- const auto& exp_bias = MakeCall(builder, call->span, "exp_bias",
reshape_op,
- {call->args[1],
ShapeExpr(exp_bias_shape)});
+ const auto& exp_bias =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"exp_bias"), reshape_op,
+ {call->args[1], ShapeExpr(exp_bias_shape)});
// redirect to conv2d
static const Op& add_op = Op::Get("relax.add");
- const auto& exp_add =
- MakeCall(builder, call->span, "exp_add", add_op, {reshape->args[0],
exp_bias});
+ const auto& exp_add = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "exp_add"),
+ add_op, {reshape->args[0],
exp_bias});
// reduce output
return Call(reshape_op, {exp_add, ShapeExpr(input_shape)}, Attrs(),
call->sinfo_args,
call->span);
@@ -130,48 +149,50 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var,
const Call& src_call,
LOG_FATAL << "Unexpected data layout " << conv_attrs->data_layout;
}
}
- return RewriteElemwise(builder, var, call, new_calls, version);
+ return RewriteElemwise(builder, var, call, new_calls, config);
}
Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call&
src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& out_dtype =
Downcast<TensorStructInfo>(GetStructInfo(var))->dtype;
+ const auto& out_dtype = ExprUtils::GetDataType(var);
const auto* src_attrs = src_call->attrs.as<ArgmaxArgminAttrs>();
- Expr raw_var;
- if (src_attrs->keepdims) {
- raw_var = EmitCall(builder, call, call->span, "raw");
- } else {
- auto new_attrs = make_object<ArgmaxArgminAttrs>();
- new_attrs->axis = src_attrs->axis;
- new_attrs->keepdims = true;
- raw_var =
- MakeCall(builder, call->span, "keepdims", call->op, {call->args[0]},
Attrs(new_attrs));
+ ICHECK(out_dtype == DataType::Int(32) || out_dtype == DataType::Int(64))
+ << "Unexpected out dtype " << out_dtype;
+ static const Op& topk_op = Op::Get("relax.topk");
+ auto topk_attrs = make_object<TopKAttrs>();
+ topk_attrs->k = 1;
+ if (src_attrs->axis.defined()) {
+ topk_attrs->axis = src_attrs->axis.value()->value;
}
- static const Op& astype_op = Op::Get("relax.astype");
- auto cast_to_attrs = make_object<AstypeAttrs>();
- cast_to_attrs->dtype = DataType::Int(32);
- Expr res = MakeCall(builder, call->span, "cast_to", astype_op, {raw_var},
Attrs(cast_to_attrs));
- // reshape back
- if (!src_attrs->keepdims) {
- const auto& output_shape = GetShape(var);
- static const Op& reshape_op = Op::Get("relax.reshape");
- res = MakeCall(builder, call->span, "reshape", reshape_op, {res,
ShapeExpr(output_shape)});
+ topk_attrs->largest = call->op == Op::Get("relax.argmax");
+ topk_attrs->ret_type = "both";
+ topk_attrs->dtype = out_dtype;
+ // change to topk
+ const auto& topk = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "topk"), topk_op,
+ {call->args[0]},
Attrs(topk_attrs));
+ const auto& get_name = ExprUtils::GetSpanName(call, ".1");
+ const auto& get_item =
+ TupleGetItem(topk, 1, SpanUtils::CreateWithAttr(msc_attr::kName,
get_name));
+ if (src_attrs->keepdims) {
+ return get_item;
}
- auto cast_from_attrs = make_object<AstypeAttrs>();
- cast_from_attrs->dtype = out_dtype;
- return Call(astype_op, {res}, Attrs(cast_from_attrs), call->sinfo_args,
call->span);
+ const auto& get_item_var = builder->Emit(get_item, get_name);
+ static const Op& reshape_op = Op::Get("relax.reshape");
+ const auto& output_shape = ExprUtils::GetShape(var);
+ return Call(reshape_op, {get_item_var, ShapeExpr(output_shape)}, Attrs(),
call->sinfo_args,
+ call->span);
}
Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call&
src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& in_dtype =
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
+ const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
const auto* src_attrs = src_call->attrs.as<AttentionAttrs>();
// define dims
- const auto& in_q_shape = GetShape(call->args[0]);
- const auto& in_v_shape = GetShape(call->args[2]);
+ const auto& in_q_shape = ExprUtils::GetShape(call->args[0]);
+ const auto& in_v_shape = ExprUtils::GetShape(call->args[2]);
const auto& batch_size = in_q_shape[0];
const auto& seq_len = in_q_shape[1];
const auto& num_head = in_q_shape[2];
@@ -198,50 +219,53 @@ Expr RewriteAttention(BlockBuilder builder, const Var&
var, const Call& src_call
auto permute_attrs = make_object<PermuteDimsAttrs>();
Array<Integer> axes{Integer(0), Integer(2), Integer(1), Integer(3)};
permute_attrs->axes = axes;
- const auto& q_trans = MakeCall(builder, call->span, "q_trans",
permute_dims_op, {call->args[0]},
- Attrs(permute_attrs));
- const auto& k_trans = MakeCall(builder, call->span, "k_trans",
permute_dims_op, {call->args[1]},
- Attrs(permute_attrs));
- const auto& v_trans = MakeCall(builder, call->span, "v_trans",
permute_dims_op, {call->args[2]},
- Attrs(permute_attrs));
+ const auto& q_trans =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_trans"),
permute_dims_op,
+ {call->args[0]}, Attrs(permute_attrs));
+ const auto& k_trans =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_trans"),
permute_dims_op,
+ {call->args[1]}, Attrs(permute_attrs));
+ const auto& v_trans =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_trans"),
permute_dims_op,
+ {call->args[2]}, Attrs(permute_attrs));
Array<PrimExpr> q_shape({batch_size * num_head, seq_len, head_dim});
- const auto& q_reshape =
- MakeCall(builder, call->span, "q_reshape", reshape_op, {q_trans,
ShapeExpr(q_shape)});
+ const auto& q_reshape = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "q_reshape"),
+ reshape_op, {q_trans,
ShapeExpr(q_shape)});
Array<PrimExpr> k_shape({batch_size * num_head, seq_len_kv, head_dim});
- const auto& k_reshape =
- MakeCall(builder, call->span, "k_reshape", reshape_op, {k_trans,
ShapeExpr(k_shape)});
+ const auto& k_reshape = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "k_reshape"),
+ reshape_op, {k_trans,
ShapeExpr(k_shape)});
Array<PrimExpr> v_shape({batch_size * num_head, seq_len_kv, head_dim_v});
- const auto& v_reshape =
- MakeCall(builder, call->span, "v_reshape", reshape_op, {v_trans,
ShapeExpr(v_shape)});
+ const auto& v_reshape = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "v_reshape"),
+ reshape_op, {v_trans,
ShapeExpr(v_shape)});
auto reduce_permute_attrs = make_object<PermuteDimsAttrs>();
Array<Integer> v_axes{Integer(0), Integer(2), Integer(1)};
reduce_permute_attrs->axes = v_axes;
// transpose for batch_matmul
- const auto& k_reshape_trans = MakeCall(builder, call->span,
"k_reshape_trans", permute_dims_op,
- {k_reshape},
Attrs(reduce_permute_attrs));
+ const auto& k_reshape_trans =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"k_reshape_trans"),
+ permute_dims_op, {k_reshape},
Attrs(reduce_permute_attrs));
// calculate product
auto matmul_attrs = make_object<MatmulAttrs>();
matmul_attrs->out_dtype = in_dtype;
- const auto& qk_prod = MakeCall(builder, call->span, "qk_prod", matmul_op,
- {q_reshape, k_reshape_trans},
Attrs(matmul_attrs));
+ const auto& qk_prod =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "qk_prod"),
matmul_op,
+ {q_reshape, k_reshape_trans},
Attrs(matmul_attrs));
Expr p_scale;
if (src_attrs->scale.defined()) {
- const auto& scale =
MakeConstant(static_cast<double>(src_attrs->scale.value()->value), in_dtype,
- SpanUtils::GetAttr(call->span,
msc_attr::kName) + "_scale");
- Array<PrimExpr> exp_shape(3, Integer(1));
- const auto& exp_scale =
- MakeCall(builder, call->span, "exp_scale", reshape_op, {scale,
ShapeExpr(exp_shape)});
- p_scale = MakeCall(builder, call->span, "p_scale", multiply_op, {qk_prod,
exp_scale});
+ double value = static_cast<double>(src_attrs->scale.value()->value);
+ const auto& scale = RewriteUtils::MakeConstant(builder,
ExprUtils::GetSpanName(call, "scale"),
+ value, in_dtype, 3);
+ p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"p_scale"), multiply_op,
+ {qk_prod, scale});
} else {
- const auto& scale =
- MakeConstant(static_cast<double>(Downcast<Integer>(head_dim)->value),
in_dtype,
- SpanUtils::GetAttr(call->span, msc_attr::kName) +
"_scale");
- Array<PrimExpr> exp_shape(3, Integer(1));
- const auto& exp_scale =
- MakeCall(builder, call->span, "exp_scale", reshape_op, {scale,
ShapeExpr(exp_shape)});
- const auto& sqrt_scale = MakeCall(builder, call->span, "sqrt_scale",
sqrt_op, {exp_scale});
- p_scale = MakeCall(builder, call->span, "p_scale", divide_op, {qk_prod,
sqrt_scale});
+ double value = static_cast<double>(Downcast<Integer>(head_dim)->value);
+ const auto& scale = RewriteUtils::MakeConstant(builder,
ExprUtils::GetSpanName(call, "scale"),
+ value, in_dtype, 3);
+ const auto& sqrt_scale = RewriteUtils::MakeCall(
+ builder, ExprUtils::GetSpanName(call, "sqrt_scale"), sqrt_op, {scale});
+ p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"p_scale"), divide_op,
+ {qk_prod, sqrt_scale});
}
// bias
@@ -249,12 +273,12 @@ Expr RewriteAttention(BlockBuilder builder, const Var&
var, const Call& src_call
if (call->args.size() == 4) {
Array<PrimExpr> exp_shape{batch_size, num_head, seq_len, seq_len_kv};
Array<PrimExpr> reduce_shape{batch_size * num_head, seq_len, seq_len_kv};
- const auto& prod_exp =
- MakeCall(builder, call->span, "prod_exp", reshape_op, {prod,
ShapeExpr(exp_shape)});
- const auto& prod_add =
- MakeCall(builder, call->span, "prod_add", add_op, {prod_exp,
call->args[3]});
- prod = MakeCall(builder, call->span, "prod_reduce", reshape_op,
- {prod_add, ShapeExpr(reduce_shape)});
+ const auto& prod_exp = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "prod_exp"),
+ reshape_op, {prod,
ShapeExpr(exp_shape)});
+ const auto& prod_add = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "prod_add"),
+ add_op, {prod_exp,
call->args[3]});
+ prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"prod_reduce"), reshape_op,
+ {prod_add, ShapeExpr(reduce_shape)});
}
// causal_mask
@@ -262,7 +286,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var,
const Call& src_call
if (!src_attrs->causal_mask.defined()) {
auto softmax_attrs = make_object<SoftmaxAttrs>();
softmax_attrs->axis = 2;
- s_value = MakeCall(builder, call->span, "act", softmax_op, {prod},
Attrs(softmax_attrs));
+ s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"act"), softmax_op,
+ {prod}, Attrs(softmax_attrs));
} else {
const auto& causal_mask = src_attrs->causal_mask.value();
PrimValue tril_k;
@@ -273,41 +298,47 @@ Expr RewriteAttention(BlockBuilder builder, const Var&
var, const Call& src_call
} else {
LOG_FATAL << "Unexpected causal_mask " << causal_mask;
}
- const auto& p_masked = MakeCall(builder, call->span, "p_masked", tril_op,
{prod, tril_k});
+ const auto& p_masked = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "p_masked"),
+ tril_op, {prod, tril_k});
auto reduce_attrs = make_object<StatisticalAttrs>();
Array<Integer> axis{Integer(2)};
reduce_attrs->axis = axis;
reduce_attrs->keepdims = true;
- const auto& p_max = MakeCall(builder, call->span, "p_max", max_op, {prod},
Attrs(reduce_attrs));
- const auto& p_diff = MakeCall(builder, call->span, "p_diff", subtract_op,
{p_masked, p_max});
- const auto& p_exp = MakeCall(builder, call->span, "p_exp", exp_op,
{p_diff});
- const auto& p_masked_exp =
- MakeCall(builder, call->span, "p_masked_exp", tril_op, {p_exp,
tril_k});
+ const auto& p_max = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "p_max"),
+ max_op, {prod},
Attrs(reduce_attrs));
+ const auto& p_diff = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "p_diff"),
+ subtract_op, {p_masked,
p_max});
+ const auto& p_exp =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_exp"),
exp_op, {p_diff});
+ const auto& p_masked_exp = RewriteUtils::MakeCall(
+ builder, ExprUtils::GetSpanName(call, "p_masked_exp"), tril_op,
{p_exp, tril_k});
const auto& p_masked_sum =
- MakeCall(builder, call->span, "p_masked_sum", sum_op, {p_masked_exp},
Attrs(reduce_attrs));
- s_value = MakeCall(builder, call->span, "act", divide_op, {p_masked_exp,
p_masked_sum});
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"p_masked_sum"), sum_op,
+ {p_masked_exp}, Attrs(reduce_attrs));
+ s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"act"), divide_op,
+ {p_masked_exp, p_masked_sum});
}
// final calculation
- const auto& o_prod =
- MakeCall(builder, call->span, "o_prod", matmul_op, {s_value, v_reshape},
Attrs(matmul_attrs));
+ const auto& o_prod = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "o_prod"),
+ matmul_op, {s_value, v_reshape},
Attrs(matmul_attrs));
Array<PrimExpr> o_shape{batch_size, num_head, seq_len, head_dim_v};
return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(),
call->sinfo_args, call->span);
}
Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call&
src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& input_shape = GetShape(call->args[0]);
- const auto& in_dtype =
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
+ const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+ const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
const auto* src_attrs = src_call->attrs.as<BatchNormAttrs>();
// define expand shape
Array<PrimExpr> exp_shape(input_shape.size(), Integer(1));
exp_shape.Set(src_attrs->axis, input_shape[src_attrs->axis]);
// create eps constant
- const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype,
- SpanUtils::GetAttr(call->span,
msc_attr::kName) + "_eps");
+ const auto& eps = RewriteUtils::MakeConstant(builder,
ExprUtils::GetSpanName(call, "eps"),
+ src_attrs->epsilon, in_dtype);
// create ops
static const Op& add_op = Op::Get("relax.add");
@@ -318,36 +349,43 @@ Expr RewriteBatchNorm(BlockBuilder builder, const Var&
var, const Call& src_call
static const Op& subtract_op = Op::Get("relax.subtract");
// scale factor: gamma/sqrt(var + epsilon)
- const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op,
{call->args[4], eps});
- const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add});
- const auto& scale_factor =
- MakeCall(builder, call->span, "scale_factor", divide_op, {call->args[1],
sqrt});
+ const auto& eps_add = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "eps_add"),
+ add_op, {call->args[4], eps});
+ const auto& sqrt =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"),
sqrt_op, {eps_add});
+ const auto& scale_factor = RewriteUtils::MakeCall(
+ builder, ExprUtils::GetSpanName(call, "scale_factor"), divide_op,
{call->args[1], sqrt});
Expr res = call->args[0];
// scale
if (src_attrs->scale) {
- const auto& exp_scale = MakeCall(builder, call->span, "exp_scale",
reshape_op,
- {scale_factor, ShapeExpr(exp_shape)});
- res = MakeCall(builder, call->span, "scale", multiply_op, {res,
exp_scale});
+ const auto& exp_scale =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"exp_scale"), reshape_op,
+ {scale_factor, ShapeExpr(exp_shape)});
+ res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"scale"), multiply_op,
+ {res, exp_scale});
}
// offset
if (src_attrs->center) {
// offset factor: beta-mean*scale_factor
- const auto& average =
- MakeCall(builder, call->span, "average", multiply_op, {call->args[3],
scale_factor});
+ const auto& average = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "average"),
+ multiply_op, {call->args[3],
scale_factor});
const auto& offset_factor =
- MakeCall(builder, call->span, "offset_factor", subtract_op,
{call->args[2], average});
- const auto& exp_offset = MakeCall(builder, call->span, "exp_offset",
reshape_op,
- {offset_factor, ShapeExpr(exp_shape)});
- res = MakeCall(builder, call->span, "offset", add_op, {res, exp_offset});
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"offset_factor"), subtract_op,
+ {call->args[2], average});
+ const auto& exp_offset =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"exp_offset"), reshape_op,
+ {offset_factor, ShapeExpr(exp_shape)});
+ res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"offset"), add_op,
+ {res, exp_offset});
}
return Tuple(Array<Expr>{res}, call->span);
}
Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call&
src_call,
- const Map<Expr, Call>& new_calls, const
Array<Integer>& version) {
+ const Map<Expr, Call>& new_calls, const String&
config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& input_shape = GetShape(call->args[0]);
- const auto& output_shape = GetShape(var);
+ const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+ const auto& output_shape = ExprUtils::GetShape(var);
Expr concat_input = call->args[0];
static const Op& concat_op = Op::Get("relax.concat");
for (size_t i = 0; i < input_shape.size(); i++) {
@@ -357,30 +395,33 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var&
var, const Call& src_ca
Array<Expr> concat_inputs(out_dim / in_dim, concat_input);
auto concat_attrs = make_object<ConcatAttrs>();
concat_attrs->axis = Integer(i);
- concat_input = MakeCall(builder, call->span, "concat_" +
std::to_string(i), concat_op,
- {Tuple(concat_inputs)}, Attrs(concat_attrs));
+ concat_input = RewriteUtils::MakeCall(
+ builder, ExprUtils::GetSpanName(call, "concat_" +
std::to_string(i)), concat_op,
+ {Tuple(concat_inputs)}, Attrs(concat_attrs));
}
}
return concat_input;
}
Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
const auto* src_attrs = src_call->attrs.as<Conv1DAttrs>();
- const auto& input_shape = GetShape(call->args[0]);
- const auto& weight_shape = GetShape(call->args[1]);
- const auto& output_shape = GetShape(var);
+ const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+ const auto& weight_shape = ExprUtils::GetShape(call->args[1]);
+ const auto& output_shape = ExprUtils::GetShape(var);
if (src_attrs->data_layout == "NCW") {
Array<Expr> new_args;
// expand inputs
Array<PrimExpr> exp_input_shape{input_shape[0], input_shape[1],
Integer(1), input_shape[2]};
Array<PrimExpr> exp_weight_shape{weight_shape[0], weight_shape[1],
Integer(1), weight_shape[2]};
static const Op& reshape_op = Op::Get("relax.reshape");
- new_args.push_back(MakeCall(builder, call->span, "exp_input", reshape_op,
- {call->args[0], ShapeExpr(exp_input_shape)}));
- new_args.push_back(MakeCall(builder, call->span, "exp_weight", reshape_op,
- {call->args[1], ShapeExpr(exp_weight_shape)}));
+ new_args.push_back(RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "exp_input"),
+ reshape_op,
+ {call->args[0],
ShapeExpr(exp_input_shape)}));
+ new_args.push_back(RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "exp_weight"),
+ reshape_op,
+ {call->args[1],
ShapeExpr(exp_weight_shape)}));
// change to conv2d
static const Op& conv2d_op = Op::Get("relax.nn.conv2d");
auto conv_attrs = make_object<Conv2DAttrs>();
@@ -393,8 +434,8 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var,
const Call& src_call,
conv_attrs->kernel_layout = "OIHW";
conv_attrs->out_layout = "NCHW";
conv_attrs->out_dtype = src_attrs->out_dtype;
- const auto& conv2d =
- MakeCall(builder, call->span, "exp", conv2d_op, new_args,
Attrs(conv_attrs));
+ const auto& conv2d = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "exp"),
+ conv2d_op, new_args,
Attrs(conv_attrs));
// reduce output
return Call(reshape_op, {conv2d, ShapeExpr(output_shape)}, Attrs(),
call->sinfo_args,
call->span);
@@ -404,11 +445,80 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var,
const Call& src_call,
return call;
}
+Expr RewriteGelu(BlockBuilder builder, const Var& var, const Call& src_call,
+ const Map<Expr, Call>& new_calls, const String& config) {
+ // 0.5 * x * (1 + erf(sqrt(0.5) * x))
+ const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
+ size_t in_dim = ExprUtils::GetShape(call->args[0]).size();
+ const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
+ // create ops
+ static const Op& add_op = Op::Get("relax.add");
+ static const Op& multiply_op = Op::Get("relax.multiply");
+ static const Op& erf_op = Op::Get("relax.erf");
+
+ const auto& factor = RewriteUtils::MakeConstant(builder,
ExprUtils::GetSpanName(call, "factor"),
+ std::sqrt(0.5), in_dtype,
in_dim);
+ const auto& mul = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "mul"),
+ multiply_op, {factor,
call->args[0]});
+ const auto& erf =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "erf"),
erf_op, {mul});
+ const auto& one =
+ RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"),
1, in_dtype, in_dim);
+ const auto& add =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"),
add_op, {one, erf});
+ const auto& mul2 = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "mul2"),
+ multiply_op, {call->args[0], add});
+ const auto& half = RewriteUtils::MakeConstant(builder,
ExprUtils::GetSpanName(call, "one"), 0.5,
+ in_dtype, in_dim);
+ return Call(multiply_op, {half, mul2}, Attrs(), call->sinfo_args,
call->span);
+}
+
+Expr RewriteGeluTanh(BlockBuilder builder, const Var& var, const Call&
src_call,
+ const Map<Expr, Call>& new_calls, const String& config) {
+ // 0.5 * x * (1 + tanh(sqrt(2/pi) * (0.044715F * pow(x, 3) + x)))
+ const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
+ size_t in_dim = ExprUtils::GetShape(call->args[0]).size();
+ const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
+
+ // create ops
+ static const Op& add_op = Op::Get("relax.add");
+ static const Op& multiply_op = Op::Get("relax.multiply");
+ static const Op& pow_op = Op::Get("relax.power");
+ static const Op& tanh_op = Op::Get("relax.tanh");
+
+ const auto& pow_factor = RewriteUtils::MakeConstant(
+ builder, ExprUtils::GetSpanName(call, "pow_factor"), 3, in_dtype,
in_dim);
+ const auto& mul_factor = RewriteUtils::MakeConstant(
+ builder, ExprUtils::GetSpanName(call, "mul_factor"), 0.044715, in_dtype,
in_dim);
+ const auto& pi_factor = RewriteUtils::MakeConstant(
+ builder, ExprUtils::GetSpanName(call, "pi_factor"), std::sqrt(2 / M_PI),
in_dtype, in_dim);
+
+ const auto& pow = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "pow"), pow_op,
+ {call->args[0], pow_factor});
+ const auto& mul = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "mul"),
+ multiply_op, {mul_factor, pow});
+ const auto& add = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "add"), add_op,
+ {mul, call->args[0]});
+ const auto& mul2 = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "mul2"),
+ multiply_op, {pi_factor, add});
+ const auto& tanh =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "tanh"),
tanh_op, {mul2});
+ const auto& one =
+ RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"),
1, in_dtype, in_dim);
+ const auto& add2 =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"),
add_op, {one, tanh});
+ const auto& mul3 = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "mul3"),
+ multiply_op, {call->args[0],
add2});
+ const auto& half = RewriteUtils::MakeConstant(builder,
ExprUtils::GetSpanName(call, "one"), 0.5,
+ in_dtype, in_dim);
+ return Call(multiply_op, {half, mul3}, Attrs(), call->sinfo_args,
call->span);
+}
+
Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call&
src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& input_shape = GetShape(call->args[0]);
- const auto& in_dtype =
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
+ const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+ const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
const auto* src_attrs = src_call->attrs.as<GroupNormAttrs>();
Array<PrimExpr> group_shape = input_shape;
Array<PrimExpr> exp_shape(input_shape.size(), Integer(1));
@@ -420,8 +530,8 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var,
const Call& src_call
exp_shape.Set(axis, Integer(src_attrs->num_groups));
// create eps constant
- const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype,
- SpanUtils::GetAttr(call->span,
msc_attr::kName) + "_eps");
+ const auto& eps = RewriteUtils::MakeConstant(builder,
ExprUtils::GetSpanName(call, "eps"),
+ src_attrs->epsilon, in_dtype);
// create ops
static const Op& add_op = Op::Get("relax.add");
@@ -434,53 +544,63 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var&
var, const Call& src_call
static const Op& subtract_op = Op::Get("relax.subtract");
// reshape input
- const auto& reshape_in = MakeCall(builder, call->span, "reshape_in",
reshape_op,
- {call->args[0], ShapeExpr(group_shape)});
+ const auto& reshape_in =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"reshape_in"), reshape_op,
+ {call->args[0], ShapeExpr(group_shape)});
// mean(input)
auto mean_attrs = make_object<StatisticalAttrs>();
mean_attrs->axis = src_attrs->axes;
mean_attrs->keepdims = true;
- const auto& mean =
- MakeCall(builder, call->span, "mean", mean_op, {reshape_in},
Attrs(mean_attrs));
+ const auto& mean = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "mean"), mean_op,
+ {reshape_in}, Attrs(mean_attrs));
// variance: mean((input-mean)*(input-mean))
- const auto& diff = MakeCall(builder, call->span, "diff", subtract_op,
{reshape_in, mean});
- const auto& square = MakeCall(builder, call->span, "square", square_op,
{diff});
- const auto& variance =
- MakeCall(builder, call->span, "variance", mean_op, {square},
Attrs(mean_attrs));
+ const auto& diff = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "diff"),
+ subtract_op, {reshape_in, mean});
+ const auto& square =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"),
square_op, {diff});
+ const auto& variance = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "variance"),
+ mean_op, {square},
Attrs(mean_attrs));
// sqrt(var + epsilon)
Array<PrimExpr> exp_eps_shape(input_shape.size(), Integer(1));
- const auto& exp_eps =
- MakeCall(builder, call->span, "exp_eps", reshape_op, {eps,
ShapeExpr(exp_eps_shape)});
- const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op,
{variance, exp_eps});
- const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add});
+ const auto& exp_eps = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "exp_eps"),
+ reshape_op, {eps,
ShapeExpr(exp_eps_shape)});
+ const auto& eps_add = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "eps_add"),
+ add_op, {variance, exp_eps});
+ const auto& sqrt =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"),
sqrt_op, {eps_add});
// diff/sqrt
- Expr res = MakeCall(builder, call->span, "divide", divide_op, {diff, sqrt});
+ Expr res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"divide"), divide_op,
+ {diff, sqrt});
// scale
if (src_attrs->scale) {
- const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma",
reshape_op,
- {call->args[1], ShapeExpr(exp_shape)});
- res = MakeCall(builder, call->span, "scale", multiply_op, {res,
exp_gamma});
+ const auto& exp_gamma =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"exp_gamma"), reshape_op,
+ {call->args[1], ShapeExpr(exp_shape)});
+ res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"scale"), multiply_op,
+ {res, exp_gamma});
}
// offset
if (src_attrs->center) {
- const auto& exp_beta = MakeCall(builder, call->span, "exp_beta",
reshape_op,
- {call->args[2], ShapeExpr(exp_shape)});
- res = MakeCall(builder, call->span, "offset", add_op, {res, exp_beta});
+ const auto& exp_beta =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"exp_beta"), reshape_op,
+ {call->args[2], ShapeExpr(exp_shape)});
+ res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"offset"), add_op,
+ {res, exp_beta});
}
// reshape output
return Call(reshape_op, {res, ShapeExpr(input_shape)}, Attrs(),
call->sinfo_args, call->span);
}
Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call&
src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& input_shape = GetShape(call->args[0]);
- const auto& in_dtype =
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
+ const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+ const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
const auto* src_attrs = src_call->attrs.as<LayerNormAttrs>();
Array<PrimExpr> exp_shape(input_shape.size(), Integer(1));
for (const auto& a : src_attrs->axes) {
@@ -488,8 +608,8 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var,
const Call& src_call
exp_shape.Set(index, input_shape[index]);
}
// create eps constant
- const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype,
- SpanUtils::GetAttr(call->span,
msc_attr::kName) + "_eps");
+ const auto& eps = RewriteUtils::MakeConstant(builder,
ExprUtils::GetSpanName(call, "eps"),
+ src_attrs->epsilon, in_dtype);
// create ops
static const Op& add_op = Op::Get("relax.add");
@@ -505,30 +625,36 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var&
var, const Call& src_call
auto mean_attrs = make_object<StatisticalAttrs>();
mean_attrs->axis = src_attrs->axes;
mean_attrs->keepdims = true;
- const auto& mean =
- MakeCall(builder, call->span, "mean", mean_op, {call->args[0]},
Attrs(mean_attrs));
+ const auto& mean = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "mean"), mean_op,
+ {call->args[0]},
Attrs(mean_attrs));
// variance: mean((input-mean)*(input-mean))
- const auto& diff = MakeCall(builder, call->span, "diff", subtract_op,
{call->args[0], mean});
- const auto& square = MakeCall(builder, call->span, "square", square_op,
{diff});
- const auto& variance =
- MakeCall(builder, call->span, "variance", mean_op, {square},
Attrs(mean_attrs));
+ const auto& diff = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "diff"),
+ subtract_op, {call->args[0],
mean});
+ const auto& square =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"),
square_op, {diff});
+ const auto& variance = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "variance"),
+ mean_op, {square},
Attrs(mean_attrs));
// sqrt(var + epsilon)
Array<PrimExpr> exp_eps_shape(input_shape.size(), Integer(1));
- const auto& exp_eps =
- MakeCall(builder, call->span, "exp_eps", reshape_op, {eps,
ShapeExpr(exp_eps_shape)});
- const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op,
{variance, exp_eps});
- const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add});
+ const auto& exp_eps = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "exp_eps"),
+ reshape_op, {eps,
ShapeExpr(exp_eps_shape)});
+ const auto& eps_add = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "eps_add"),
+ add_op, {variance, exp_eps});
+ const auto& sqrt =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"),
sqrt_op, {eps_add});
// diff/sqrt
Call res = Call(divide_op, {diff, sqrt}, Attrs(), call->sinfo_args,
call->span);
// scale
if (src_attrs->scale) {
- const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma",
reshape_op,
- {call->args[1], ShapeExpr(exp_shape)});
- const auto& res_var = EmitCall(builder, res, call->span, "pre_scale");
+ const auto& exp_gamma =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"exp_gamma"), reshape_op,
+ {call->args[1], ShapeExpr(exp_shape)});
+ const auto& res_var =
+ RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call,
"pre_scale"), res);
if (src_attrs->center) {
res = Call(multiply_op, {res_var, exp_gamma});
} else {
@@ -537,87 +663,126 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var&
var, const Call& src_call
}
// offset
if (src_attrs->center) {
- const auto& exp_beta = MakeCall(builder, call->span, "exp_beta",
reshape_op,
- {call->args[2], ShapeExpr(exp_shape)});
- const auto& res_var = EmitCall(builder, res, call->span, "pre_offset");
+ const auto& exp_beta =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"exp_beta"), reshape_op,
+ {call->args[2], ShapeExpr(exp_shape)});
+ const auto& res_var =
+ RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call,
"pre_offset"), res);
res = Call(add_op, {res_var, exp_beta}, Attrs(), call->sinfo_args,
call->span);
}
return res;
}
Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
+ const auto& trt_config = ParseConfig(config);
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& shape_a = GetShape(call->args[0]);
- const auto& shape_b = GetShape(call->args[1]);
+ const auto& shape_a = ExprUtils::GetShape(call->args[0]);
+ const auto& shape_b = ExprUtils::GetShape(call->args[1]);
static const Op& reshape_op = Op::Get("relax.reshape");
+ if (call->args[1]->IsInstance<ConstantNode>() && shape_b.size() == 2 &&
+ trt_config.linear_to_conv) {
+ const auto& out_shape = ExprUtils::GetShape(var);
+ PrimExpr accumulate = ArrayUtils::Accumulate(shape_a, shape_a.size() - 1);
+ Array<PrimExpr> exp_shape{accumulate, shape_a[shape_a.size() - 1],
Integer(1), Integer(1)};
+ const auto& exp_in = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "exp_in"),
+ reshape_op, {call->args[0],
ShapeExpr(exp_shape)});
+ // transpose and expand weight to OIHW
+ static const Op& permute_dims_op = Op::Get("relax.permute_dims");
+ auto permute_attrs = make_object<PermuteDimsAttrs>();
+ Array<Integer> axes{Integer(1), Integer(0)};
+ permute_attrs->axes = axes;
+ const auto& trans_weight =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"trans_weight"),
+ permute_dims_op, {call->args[1]},
Attrs(permute_attrs));
+ Array<PrimExpr> weight_shape{shape_b[1], shape_b[0], Integer(1),
Integer(1)};
+ const auto& exp_weight =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"exp_weight"), reshape_op,
+ {trans_weight, ShapeExpr(weight_shape)});
+ // to conv2d
+ static const Op& conv2d_op = Op::Get("relax.nn.conv2d");
+ auto conv_attrs = make_object<Conv2DAttrs>();
+ conv_attrs->strides = Array<IntImm>{Integer(1), Integer(1)};
+ conv_attrs->padding = Array<IntImm>{Integer(0), Integer(0), Integer(0),
Integer(0)};
+ conv_attrs->dilation = Array<IntImm>{Integer(1), Integer(1)};
+ conv_attrs->groups = 1;
+ conv_attrs->data_layout = "NCHW";
+ conv_attrs->kernel_layout = "OIHW";
+ conv_attrs->out_layout = "NCHW";
+ conv_attrs->out_dtype = ExprUtils::GetDataType(var);
+ const auto& conv2d = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "conv2d"),
+ conv2d_op, {exp_in,
exp_weight}, Attrs(conv_attrs));
+ return Call(reshape_op, {conv2d, ShapeExpr(out_shape)}, Attrs(),
call->sinfo_args, call->span);
+ }
if (shape_a.size() > shape_b.size()) {
Array<PrimExpr> exp_shape(shape_a.size(), Integer(1));
- for (size_t i = shape_b.size(); i < shape_a.size(); i++) {
- exp_shape.Set(i, shape_b[i - shape_b.size()]);
+ size_t diff = shape_a.size() - shape_b.size();
+ for (size_t i = diff; i < shape_a.size(); i++) {
+ exp_shape.Set(i, shape_b[i - diff]);
}
- const auto& expand_b = MakeCall(builder, call->span, "expand_b",
reshape_op,
- {call->args[1], ShapeExpr(exp_shape)});
+ const auto& expand_b =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"expand_b"), reshape_op,
+ {call->args[1], ShapeExpr(exp_shape)});
return Call(call->op, {call->args[0], expand_b}, call->attrs,
call->sinfo_args, call->span);
}
if (shape_a.size() < shape_b.size()) {
Array<PrimExpr> exp_shape(shape_b.size(), Integer(1));
- for (size_t i = shape_a.size(); i < shape_b.size(); i++) {
- exp_shape.Set(i, shape_a[i - shape_a.size()]);
+ size_t diff = shape_b.size() - shape_a.size();
+ for (size_t i = diff; i < shape_b.size(); i++) {
+ exp_shape.Set(i, shape_a[i - diff]);
}
- const auto& expand_a = MakeCall(builder, call->span, "expand_a",
reshape_op,
- {call->args[0], ShapeExpr(exp_shape)});
+ const auto& expand_a =
+ RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call,
"expand_a"), reshape_op,
+ {call->args[0], ShapeExpr(exp_shape)});
return Call(call->op, {expand_a, call->args[1]}, call->attrs,
call->sinfo_args, call->span);
}
return call;
}
Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& input_shape = GetShape(call->args[0]);
- const auto& in_dtype =
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
- Array<PrimExpr> exp_shape(input_shape.size(), Integer(1));
+ const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+ const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
// create 1 constant
- const auto& one =
- MakeConstant(1, in_dtype, SpanUtils::GetAttr(call->span,
msc_attr::kName) + "_one");
+ const auto& one = RewriteUtils::MakeConstant(builder,
ExprUtils::GetSpanName(call, "eps"), 1,
+ in_dtype, input_shape.size());
// create ops
- static const Op& reshape_op = Op::Get("relax.reshape");
static const Op& divide_op = Op::Get("relax.divide");
static const Op& sqrt_op = Op::Get("relax.sqrt");
// expand and divide
- const auto& exp_one =
- MakeCall(builder, call->span, "exp_one", reshape_op, {one,
ShapeExpr(exp_shape)});
- const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op,
{call->args[0]});
- return Call(divide_op, {exp_one, sqrt}, Attrs(), call->sinfo_args,
call->span);
+ const auto& sqrt = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "sqrt"), sqrt_op,
+ {call->args[0]});
+ return Call(divide_op, {one, sqrt}, Attrs(), call->sinfo_args, call->span);
}
Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
// create ops
static const Op& multiply_op = Op::Get("relax.multiply");
static const Op& sigmoid_op = Op::Get("relax.sigmoid");
// silu=input*sigmoid(input)
- const auto& sigmoid = MakeCall(builder, call->span, "sigmoid", sigmoid_op,
{call->args[0]});
+ const auto& sigmoid = RewriteUtils::MakeCall(builder,
ExprUtils::GetSpanName(call, "sigmoid"),
+ sigmoid_op, {call->args[0]});
return Call(multiply_op, {call->args[0], sigmoid}, Attrs(),
call->sinfo_args, call->span);
}
Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call&
src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& output_shape = GetShape(var);
+ const auto& output_shape = ExprUtils::GetShape(var);
static const Op& reshape_op = Op::Get("relax.reshape");
return Call(reshape_op, {call->args[0], ShapeExpr(output_shape)}, Attrs(),
call->sinfo_args,
call->span);
}
Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call,
- const Map<Expr, Call>& new_calls, const Array<Integer>&
version) {
+ const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] :
src_call;
- const auto& input_shape = GetShape(call->args[0]);
+ const auto& input_shape = ExprUtils::GetShape(call->args[0]);
const auto* src_attrs = src_call->attrs.as<SplitAttrs>();
size_t axis = CommonUtils::GetIndex(src_attrs->axis, input_shape.size());
std::vector<int64_t> split_begins, split_ends;
@@ -646,9 +811,16 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var,
const Call& src_call,
// create strided_slices
Array<Expr> outputs;
for (size_t i = 0; i < split_begins.size(); i++) {
- auto slice = strided_slice(call->args[0],
Tuple(Array<Expr>{PrimValue(Integer(axis))}),
-
Tuple(Array<Expr>{PrimValue(Integer(split_begins[i]))}),
-
Tuple(Array<Expr>{PrimValue(Integer(split_ends[i]))}));
+ static const Op& strided_slice_op = Op::Get("relax.strided_slice");
+ const auto& axes = Tuple(Array<Expr>{PrimValue(IntImm(DataType::Int(64),
axis))});
+ const auto& begin = Tuple(Array<Expr>{PrimValue(IntImm(DataType::Int(64),
split_begins[i]))});
+ const auto& end = Tuple(Array<Expr>{PrimValue(IntImm(DataType::Int(64),
split_ends[i]))});
+ const auto& strides =
Tuple(Array<Expr>{PrimValue(IntImm(DataType::Int(64), 1))});
+ auto attrs = make_object<StridedSliceAttrs>();
+ attrs->assume_inbound = true;
+ const auto& slice = RewriteUtils::MakeCall(
+ builder, ExprUtils::GetSpanName(call, "slice_" + std::to_string(i)),
strided_slice_op,
+ {call->args[0], axes, begin, end, strides}, Attrs(attrs));
outputs.push_back(slice);
}
return Tuple(outputs, call->span);
@@ -664,6 +836,9 @@ TVM_REGISTER_OP("relax.nn.batch_norm")
TVM_REGISTER_OP("relax.nn.conv1d").set_attr<FRewriteTensorRT>("FRewriteTensorRT",
RewriteConv1d);
TVM_REGISTER_OP("relax.nn.group_norm")
.set_attr<FRewriteTensorRT>("FRewriteTensorRT", RewriteGroupNorm);
+TVM_REGISTER_OP("relax.nn.gelu").set_attr<FRewriteTensorRT>("FRewriteTensorRT",
RewriteGelu);
+TVM_REGISTER_OP("relax.nn.gelu_tanh")
+ .set_attr<FRewriteTensorRT>("FRewriteTensorRT", RewriteGeluTanh);
TVM_REGISTER_OP("relax.nn.layer_norm")
.set_attr<FRewriteTensorRT>("FRewriteTensorRT", RewriteLayerNorm);
TVM_REGISTER_OP("relax.nn.silu").set_attr<FRewriteTensorRT>("FRewriteTensorRT",
RewriteSilu);
@@ -695,9 +870,9 @@
TVM_REGISTER_OP("relax.split").set_attr<FRewriteTensorRT>("FRewriteTensorRT", Re
class TensorRTTransformer : public ExprMutator {
public:
- explicit TensorRTTransformer(IRModule ctx_module, const Array<Integer>&
version)
+ explicit TensorRTTransformer(IRModule ctx_module, const String& config)
: ExprMutator(ctx_module) {
- version_ = version;
+ config_ = config;
}
void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node)
final {
@@ -707,7 +882,7 @@ class TensorRTTransformer : public ExprMutator {
if (rewrite_map.count(op)) {
const auto& call = GetRef<Call>(call_node);
FRewriteTensorRT f = rewrite_map[op];
- const auto& new_call = f(builder_, binding->var, call, new_calls_,
version_);
+ const auto& new_call = f(builder_, binding->var, call, new_calls_,
config_);
if (new_call != call) {
ReEmitBinding(binding, builder_->Normalize(new_call));
new_calls_.Set(binding->var, call);
@@ -721,20 +896,19 @@ class TensorRTTransformer : public ExprMutator {
private:
Map<Expr, Call> new_calls_;
- Array<Integer> version_;
+ String config_;
};
-Function TransformTensorRT(const Function& func, const IRModule& module,
- const Array<Integer>& version) {
- return Downcast<Function>(TensorRTTransformer(module,
version).VisitExpr(func));
+Function TransformTensorRT(const Function& func, const IRModule& module, const
String& config) {
+ return Downcast<Function>(TensorRTTransformer(module,
config).VisitExpr(func));
}
namespace transform {
-Pass TransformTensorRT(const Array<Integer>& version) {
+Pass TransformTensorRT(const String& config) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
[=](Function f, IRModule m, PassContext pc) {
- return relax::TransformTensorRT(f, m, version);
+ return relax::TransformTensorRT(f, m, config);
};
return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {});
}
diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py
b/tests/python/contrib/test_msc/test_translate_tensorrt.py
index 74c25ceacf..7c8c283099 100644
--- a/tests/python/contrib/test_msc/test_translate_tensorrt.py
+++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py
@@ -87,7 +87,7 @@ def check_names(mod):
NameChecker().check(func)
-def verify_model(torch_model, input_info, allow_incomplete=False):
+def verify_model(torch_model, input_info, **trans_config):
"""Build model and verify results"""
graph_model = fx.symbolic_trace(torch_model)
@@ -100,9 +100,7 @@ def verify_model(torch_model, input_info,
allow_incomplete=False):
golden = [golden]
golden = [g.detach().cpu().numpy() for g in golden]
# partition module for tensorrt
- mod, graphs, weights = translate.partition_for_tensorrt(
- mod, trans_config={"allow_incomplete": allow_incomplete}
- )
+ mod, graphs, weights = translate.partition_for_tensorrt(mod,
trans_config=trans_config)
check_names(mod)
output_folder = msc_utils.msc_dir()
# tranalte to tensorrt
@@ -191,6 +189,8 @@ def test_linear():
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Dense1(), input_info)
verify_model(Dense2(), input_info)
+ verify_model(Dense1(), input_info, linear_to_conv=True)
+ verify_model(Dense2(), input_info, linear_to_conv=True)
verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")])
@@ -368,10 +368,10 @@ def test_embedding():
self.embedding = torch.nn.Embedding(10, 3)
def forward(self, data):
- return self.embedding(data)
+ return self.embedding(data.to(torch.int64))
- verify_model(Embedding(), [([4], "int64")], allow_incomplete=True)
- verify_model(Embedding(), [([4, 5], "int64")], allow_incomplete=True)
+ verify_model(Embedding(), [([4], "int32")])
+ verify_model(Embedding(), [([4, 5], "int32")])
@requires_tensorrt
@@ -801,14 +801,14 @@ def test_argmax():
class Argmax1(Module):
def forward(self, data):
- return torch.argmax(data, dim=-1)
+ return torch.argmax(data, dim=-1).to(torch.int32)
class Argmax2(Module):
def forward(self, data):
- return torch.argmax(data, dim=-1, keepdim=True)
+ return torch.argmax(data, dim=-1, keepdim=True).to(torch.int32)
- verify_model(Argmax1(), [([256, 256], "float32")], allow_incomplete=True)
- verify_model(Argmax2(), [([256, 256], "float32")], allow_incomplete=True)
+ verify_model(Argmax1(), [([256, 256], "float32")])
+ verify_model(Argmax2(), [([256, 256], "float32")])
@requires_tensorrt
@@ -817,14 +817,14 @@ def test_argmin():
class Argmin1(Module):
def forward(self, data):
- return torch.argmin(data, dim=-1)
+ return torch.argmin(data, dim=-1).to(torch.int32)
class Argmin2(Module):
def forward(self, data):
- return torch.argmin(data, dim=-1, keepdim=True)
+ return torch.argmin(data, dim=-1, keepdim=True).to(torch.int32)
- verify_model(Argmin1(), [([256, 256], "float32")], allow_incomplete=True)
- verify_model(Argmin2(), [([256, 256], "float32")], allow_incomplete=True)
+ verify_model(Argmin1(), [([256, 256], "float32")])
+ verify_model(Argmin2(), [([256, 256], "float32")])
@requires_tensorrt
@@ -876,5 +876,22 @@ def test_max():
verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")])
+@requires_tensorrt
+def test_gelu():
+ """test tensorrt translator for gelu"""
+
+ class Gelu1(Module):
+ def forward(self, data):
+ return torch.nn.functional.gelu(data)
+
+ class Gelu2(Module):
+ def forward(self, data):
+ return torch.nn.functional.gelu(data, approximate="tanh")
+
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(Gelu1(), input_info)
+ verify_model(Gelu2(), input_info)
+
+
if __name__ == "__main__":
tvm.testing.main()