This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b21855758e [Relax] Implement operators to read runtime DLTensor*
information (#16563)
b21855758e is described below
commit b21855758e40057e1b4d7f10410ed7bfb36aa808
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Feb 20 14:59:13 2024 -0600
[Relax] Implement operators to read runtime DLTensor* information (#16563)
Relax is capable of expressing tensors whose element type is unknown.
However, these must typically be replaced with a known dtype prior to
compilation, as most operators require known data types prior to
legalization. This can be done by using a `relax::MatchCast` node,
such as accepting a parameter `arg: R.Tensor([16,16])`, then defining
the dtype using `R.match_cast(arg, R.Tensor([16,16],'float16'))`.
However, using a `R.match_cast` node requires knowing which data type
should be used in the new `R.Tensor`, and raises an error for an
incorrect data type. If an argument may be one of two distinct data
types, `R.match_cast` cannot be used to check which data type is in
use.
This commit adds Relax operators to read the runtime values of a
`DLTensor*` argument. These can be be used to normalize arguments
prior to a compute step. For example, pre-processing a model weight
that may be provided in either `float16` or `bfloat16` format.
---
python/tvm/relax/expr.py | 186 +++++++++++++++++++
src/relax/op/tensor/inspect.cc | 351 +++++++++++++++++++++++++++++++++++
src/relax/op/tensor/inspect.h | 92 +++++++++
src/relax/transform/legalize_ops.cc | 115 +++++++++---
tests/python/relax/test_op_unpack.py | 127 +++++++++++++
5 files changed, 846 insertions(+), 25 deletions(-)
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index c9780bea7e..12f08f4dbf 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -244,6 +244,192 @@ class ExprWithOp(Expr, Scriptable):
raise IndexError from err
raise
+ def _check_for_tensor_struct_info(self):
+ """Raise an error if this is something other than a Tensor
+
+ Used for early checks in `expr.dtype` and `expr.shape`
+ accessors. While invalid usage would cause errors to be
+ raised during shape inference, an earlier check makes it
+ easier to find the invalid usage.
+ """
+ if self.struct_info_ is None:
+ return
+
+ if not isinstance(self.struct_info_, tvm.relax.TensorStructInfo):
+ raise TypeError(
+ f"Runtime unpacking of DLDataType is only implemented for
tensors, "
+ f"but was applied to object {self} of type {type(self)}."
+ )
+
+ @property
+ def dtype(self) -> "_DLTensorDTypeProxy":
+ """Returns a proxy object for accessing DLTensor::dtype"""
+ self._check_for_tensor_struct_info()
+ return _DLTensorDTypeProxy(self)
+
+ @property
+ def ndim(self) -> "Expr":
+ """Returns the runtime value of DLTensor::ndim"""
+ self._check_for_tensor_struct_info()
+ op = tvm.ir.Op.get("relax.inspect.tensor_ndim")
+ return tvm.relax.Call(op, [self])
+
+ @property
+ def shape(self) -> "_DLTensorShapeProxy":
+ """Returns a proxy object for accessing DLTensor::shape"""
+ self._check_for_tensor_struct_info()
+ return _DLTensorShapeProxy(self)
+
+
+class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric):
+ """A proxy object for unpacking DLDatatype from DLTensor
+
+ Exposes accessors for `DLDataType` fields `type_code`, `lanes`,
+ and `bits` within a `DLTensor::dtype`. Accessing these fields
+ will produce `relax.Call` expressions, representing the field's
+ runtime value. If the datatype of the tensor is known at
+ compile-time, the `relax.Call` will be normalized into a
+ `relax.PrimValue`, with no runtime cost.
+
+ Parameters
+ ----------
+ tensor: relax.Expr
+
+ The relax tensor (or a variable referring to a relax tensor),
+ whose runtime shape is being inspected.
+
+ """
+
+ def __init__(self, tensor):
+ self.tensor = tensor
+
+ def asobject(self):
+ """Provide expected in error message
+
+ This method is called when `_DLTensorDTypeProxy` is used in a
+ context that requires a `relax.Expr`. This usage is not
+ supported, and raising an error here can provide suggested
+ fixes that are not present in the default error message from
+ `tvm.runtime.convert_to_object`.
+ """
+
+ fields = [f"{self.tensor}.dtype.{field}" for field in ["type_code",
"bits", "lanes"]]
+ raise TypeError(
+ f"{self.tensor}.dtype cannot be converted to a relax expression, "
+ f"and should be used as a proxy object to access "
+ f"fields {fields}"
+ )
+
+ @property
+ def type_code(self) -> Expr:
+ """Accessor for the DLDataType::bits field
+
+ Returns
+ -------
+ type_code: Expr
+
+ The type code of the DLTensor. See the `DLDeviceType`
+ enum in `dlpack.h` for more information.
+ """
+ op = tvm.ir.Op.get("relax.inspect.tensor_dtype_code")
+ return tvm.relax.Call(op, [self.tensor])
+
+ @property
+ def lanes(self) -> Expr:
+ """Accessor for the DLDataType::bits field
+
+ Returns
+ -------
+ lanes: Expr
+
+ The number of lanes in the DLDataType
+ """
+ op = tvm.ir.Op.get("relax.inspect.tensor_dtype_lanes")
+ return tvm.relax.Call(op, [self.tensor])
+
+ @property
+ def bits(self) -> Expr:
+ """Accessor for the DLDataType::bits field
+
+ Returns
+ -------
+ bits: Expr
+
+ The number of bits in the DLDataType
+ """
+ op = tvm.ir.Op.get("relax.inspect.tensor_dtype_bits")
+ return tvm.relax.Call(op, [self.tensor])
+
+
+class _DLTensorShapeProxy(tvm.runtime.ObjectGeneric):
+ """A proxy object for unpacking the shape from DLTensor
+
+ Exposes accessors for the `DLTensor::shape` field. Accessing
+ these fields will produce `relax.Call` expressions, representing
+ the field's runtime value. If the datatype of the tensor is known
+ at compile-time, the `relax.Call` will be normalized into a
+ `relax.PrimValue`, with no runtime cost.
+
+ Parameters
+ ----------
+ tensor: relax.Expr
+
+ The relax tensor (or a variable referring to a relax tensor),
+ whose runtime shape is being inspected.
+ """
+
+ def __init__(self, tensor):
+ self.tensor = tensor
+
+ def asobject(self):
+ """Provide expected in error message
+
+ This method is called when `_DLTensorShapeProxy` is used in a
+ context that requires a `relax.Expr`. This usage is not
+ supported, and raising an error here can provide suggested
+ fixes that are not present in the default error message from
+ `tvm.runtime.convert_to_object`.
+ """
+ raise TypeError(
+ f"{self.tensor}.shape cannot be converted to a relax expression, "
+ f"and should be used as a proxy object to access the runtime shape
of the DLTensor. "
+ f"The DLTensor::ndim field can be accessed as len({self.tensor}), "
+ f"and the DLTensor::shape array can be accessed as
{self.tensor}.shape[i]"
+ )
+
+ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
+ """Returns the extent of a tensor axis
+
+ Parameters
+ ----------
+ axis: Union[int, PrimExpr, Expr]
+
+ The tensor axis whose extent should be returned. For ease
+ of use, any python integers or TIR expressions are
+ converted to `relax.Expr`.
+
+ Returns
+ -------
+ extent: Expr
+
+ The extent of the tensor's axis.
+ """
+
+ if not isinstance(axis, tvm.relax.Expr):
+ axis = tvm.relax.PrimValue(axis)
+
+ if axis.struct_info_ is not None and not isinstance(
+ axis.struct_info_, tvm.relax.PrimStructInfo
+ ):
+ raise TypeError(
+ f"The index used to access {self.tensor}.shape "
+ f'must have struct info R.Prim("int64"), '
+ f"but index {axis} had struct info {axis.struct_info_}."
+ )
+
+ op = tvm.ir.Op.get("relax.inspect.tensor_shape_i")
+ return tvm.relax.Call(op, [self.tensor, axis])
+
@tvm._ffi.register_object("relax.expr.Call")
class Call(ExprWithOp):
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
new file mode 100644
index 0000000000..a40b2af5ef
--- /dev/null
+++ b/src/relax/op/tensor/inspect.cc
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file inspect.cc
+ * \brief Operators to access runtime DLTensor parameters
+ */
+
+#include "inspect.h"
+
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace relax {
+namespace inspect {
+
+TensorStructInfo GetTensorArgInfo(const Call& call) {
+ CHECK_EQ(call->args.size(), 1) << "TypeError: "
+ << "Operator " << call->op << " expects one
argument, "
+ << "but received " << call->args.size()
+ << " arguments: " << call->args;
+
+ const auto& arg = call->args[0];
+ auto sinfo = GetStructInfo(arg);
+
+ auto tensor_sinfo = sinfo.as<TensorStructInfo>();
+ CHECK(tensor_sinfo) << "TypeError: "
+ << "Operator " << call->op << " expects a tensor
argument, "
+ << "but argument " << arg << " has struct info " <<
sinfo;
+
+ return tensor_sinfo.value();
+}
+
+DataType GetTensorDataType(const Call& call) { return
GetTensorArgInfo(call)->dtype; }
+
+tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field,
DataType field_dtype) {
+ tir::Var dlpack_handle("dlpack_handle", DataType::Handle());
+
+ tir::Var value("value", field_dtype);
+
+ tir::LetStmt body(
+ value,
+ tir::Call(field_dtype, tir::builtin::tvm_struct_get(),
+ {dlpack_handle, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), field)}),
+ tir::Evaluate(tvm::ret(value)));
+
+ DictAttrs attrs({{"tir.is_scheduled", Bool(true)}, {"tir.is_host",
Bool(true)}});
+
+ tir::PrimFunc func(Array<tir::Var>{dlpack_handle}, body,
PrimType(field_dtype), {}, attrs);
+
+ FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)},
+ PrimStructInfo(field_dtype));
+ UpdateStructInfo(func, sinfo);
+
+ return func;
+}
+
+Expr NormalizeToKnownPrimValue(const BlockBuilder&, Call call) {
+ if (auto prim_sinfo = call->struct_info_.as<PrimStructInfoNode>()) {
+ if (prim_sinfo->value.defined()) {
+ return PrimValue(prim_sinfo->value.value());
+ }
+ }
+ return call;
+}
+
+//// relax.tensor_dtype_code
+
+Expr tensor_dtype_code(Expr expr) {
+ static const Op& op = Op::Get("relax.inspect.tensor_dtype_code");
+ return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorDtypeCode(const Call& call, const
BlockBuilder&) {
+ auto dlpack_type = DataType::UInt(8);
+
+ DataType dtype = GetTensorDataType(call);
+ if (dtype.is_void()) {
+ return PrimStructInfo(dlpack_type);
+ } else {
+ return PrimStructInfo(IntImm(dlpack_type, dtype.code()));
+ }
+}
+
+Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) {
+ auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+ Expr arg = call->args[0];
+ tir::PrimFunc getter =
+ GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrTypeCode,
field_dtype);
+
+ GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_code");
+ return Call(gvar_getter, {arg});
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_dtype_code")
+ .set_num_inputs(1)
+ .add_argument("tensor", "Tensor", "The tensor to be inspected")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoTensorDtypeCode)
+ .set_attr<FLegalize>("FLegalize", LegalizeTensorDtypeCode)
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_dtype_bits
+
+Expr tensor_dtype_bits(Expr expr) {
+ static const Op& op = Op::Get("relax.inspect.tensor_dtype_bits");
+ return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorDtypeBits(const Call& call, const
BlockBuilder&) {
+ auto dlpack_type = DataType::UInt(8);
+
+ DataType dtype = GetTensorDataType(call);
+ if (dtype.is_void()) {
+ return PrimStructInfo(dlpack_type);
+ } else {
+ return PrimStructInfo(IntImm(dlpack_type, dtype.bits()));
+ }
+}
+
+Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) {
+ auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+ Expr arg = call->args[0];
+ tir::PrimFunc getter =
+ GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrTypeBits,
field_dtype);
+
+ GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_bits");
+ return Call(gvar_getter, {arg});
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_dtype_bits")
+ .set_num_inputs(1)
+ .add_argument("tensor", "Tensor", "The tensor to be inspected")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoTensorDtypeBits)
+ .set_attr<FLegalize>("FLegalize", LegalizeTensorDtypeBits)
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_dtype_lanes
+
+Expr tensor_dtype_lanes(Expr expr) {
+ static const Op& op = Op::Get("relax.inspect.tensor_dtype_lanes");
+ return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorDtypeLanes(const Call& call, const
BlockBuilder&) {
+ auto dlpack_type = DataType::UInt(16);
+
+ DataType dtype = GetTensorDataType(call);
+ if (dtype.is_void()) {
+ return PrimStructInfo(dlpack_type);
+ } else {
+ return PrimStructInfo(IntImm(dlpack_type, dtype.lanes()));
+ }
+}
+
+Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) {
+ auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+ Expr arg = call->args[0];
+ tir::PrimFunc getter =
+ GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrTypeLanes,
field_dtype);
+
+ GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_lanes");
+ return Call(gvar_getter, {arg});
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_dtype_lanes")
+ .set_num_inputs(1)
+ .add_argument("tensor", "Tensor", "The tensor to be inspected")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoTensorDtypeLanes)
+ .set_attr<FLegalize>("FLegalize", LegalizeTensorDtypeLanes)
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_ndim
+
+Expr tensor_ndim(Expr expr) {
+ static const Op& op = Op::Get("relax.inspect.tensor_ndim");
+ return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorNDim(const Call& call, const BlockBuilder&) {
+ auto dlpack_type = DataType::Int(32);
+
+ auto sinfo = GetTensorArgInfo(call);
+ if (sinfo->IsUnknownNdim()) {
+ return PrimStructInfo(dlpack_type);
+ } else {
+ return PrimStructInfo(IntImm(dlpack_type, sinfo->ndim));
+ }
+}
+
+Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) {
+ auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+ Expr arg = call->args[0];
+ tir::PrimFunc getter =
GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrNDim, field_dtype);
+
+ GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_ndim");
+ return Call(gvar_getter, {arg});
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_ndim")
+ .set_num_inputs(1)
+ .add_argument("tensor", "Tensor", "The tensor to be inspected")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTensorNDim)
+ .set_attr<FLegalize>("FLegalize", LegalizeTensorNDim)
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_shape_i
+
+Expr tensor_shape_i(Expr expr) {
+ static const Op& op = Op::Get("relax.inspect.tensor_shape_i");
+ return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorShape(const Call& call, const BlockBuilder&) {
+ auto dlpack_type = DataType::Int(64);
+
+ CHECK_EQ(call->args.size(), 2) << "TypeError: "
+ << "Operator " << call->op << " expects two
arguments, "
+ << "but received " << call->args.size()
+ << " arguments: " << call->args;
+ const auto& arg = call->args[0];
+ const auto& axis = call->args[1];
+
+ auto tensor_sinfo = arg->struct_info_.as<TensorStructInfoNode>();
+ CHECK(tensor_sinfo) << "TypeError: "
+ << "Operator " << call->op << " expects arguments
(tensor, axis), "
+ << "but the first argument " << arg << " in expression "
<< call
+ << " has struct info " << arg->struct_info_;
+
+ auto axis_sinfo = axis->struct_info_.as<PrimStructInfoNode>();
+ CHECK(axis_sinfo) << "TypeError: "
+ << "Operator " << call->op << " expects arguments (tensor,
axis), "
+ << "but the second argument " << arg << " in expression "
<< call
+ << " has struct info " << axis->struct_info_;
+
+ auto int_imm_axis = axis_sinfo->value.as<IntImmNode>();
+
+ if (int_imm_axis) {
+ CHECK_GE(int_imm_axis->value, 0);
+ }
+ if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) {
+ CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim)
+ << "ValueError: "
+ << "Expression " << call << " attempts to access " << arg << ".shape["
+ << int_imm_axis->value << "]"
+ << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << "
elements";
+ }
+
+ auto tensor_shape = tensor_sinfo->GetShape();
+ if (int_imm_axis && tensor_shape.defined()) {
+ return PrimStructInfo(tensor_shape.value()[int_imm_axis->value]);
+ } else {
+ return PrimStructInfo(dlpack_type);
+ }
+}
+
+Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) {
+ auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+ tir::PrimFunc getter = [&]() -> tir::PrimFunc {
+ tir::Var dlpack_handle("dlpack_handle", DataType::Handle());
+ tir::Var axis("axis", DataType::Int(64));
+
+ tir::Var ndim("ndim", DataType::Int(32));
+
+ tir::Buffer shape_buffer = tir::decl_buffer({ndim}, field_dtype, "shape");
+
+ tir::Var extent("extent", field_dtype);
+
+ tir::Stmt body = tir::Evaluate(tvm::ret(extent));
+
+ body = tir::LetStmt(extent, tir::BufferLoad(shape_buffer, {axis}), body);
+ body = tir::DeclBuffer(shape_buffer, body);
+ body = tir::LetStmt(
+ shape_buffer->data,
+ tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(),
+ {dlpack_handle, IntImm(DataType::Int(32), 0),
+ IntImm(DataType::Int(32),
tir::builtin::TVMStructFieldKind::kArrShape)}),
+ body);
+
+ body = tir::AssertStmt(
+ axis < tvm::cast(axis->dtype, ndim),
+ tir::StringImm("Specified axis may not be larger than the tensor's
dimensionality"), body);
+
+ body = tir::LetStmt(
+ ndim,
+ tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(),
+ {dlpack_handle, IntImm(DataType::Int(32), 0),
+ IntImm(DataType::Int(32),
tir::builtin::TVMStructFieldKind::kArrNDim)}),
+ body);
+
+ body = tir::AssertStmt(0 <= axis, tir::StringImm("Specified axis may not
be negative"), body);
+
+ DictAttrs attrs({{"tir.is_scheduled", Bool(true)}, {"tir.is_host",
Bool(true)}});
+
+ tir::PrimFunc func({dlpack_handle, axis}, body, PrimType(field_dtype), {},
attrs);
+
+ FuncStructInfo sinfo(
+ {TensorStructInfo(DataType::Void(), kUnknownNDim),
PrimStructInfo(axis->dtype)},
+ PrimStructInfo(field_dtype));
+ UpdateStructInfo(func, sinfo);
+ return func;
+ }();
+
+ GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_shape_i");
+ return Call(gvar_getter, call->args);
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_shape_i")
+ .set_num_inputs(2)
+ .add_argument("tensor", "Tensor", "The tensor to be inspected")
+ .add_argument("axis", "Prim(int64)", "The axis whose extent should be
returned")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTensorShape)
+ .set_attr<FLegalize>("FLegalize", LegalizeTensorShape)
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+} // namespace inspect
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/tensor/inspect.h b/src/relax/op/tensor/inspect.h
new file mode 100644
index 0000000000..0225b00fb3
--- /dev/null
+++ b/src/relax/op/tensor/inspect.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. Sex The NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. Sex The License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file inspect.h
+ * \brief Operators to access runtime DLTensor parameters
+ */
+#ifndef TVM_RELAX_OP_TENSOR_INSPECT_H_
+#define TVM_RELAX_OP_TENSOR_INSPECT_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+namespace inspect {
+
+/* \brief Return the DLTensor::dtype::type_code field
+ *
+ * \param expr The relax expression to be inspected. Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint8_t value of the type_code, with
+ * `PrimStructInfo(DataType::UInt(8))`
+ */
+Expr tensor_dtype_code(Expr expr);
+
+/* \brief Return the DLTensor::dtype::bits field
+ *
+ * \param expr The relax expression to be inspected. Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint8_t value of the number of bits, with
+ * `PrimStructInfo(DataType::UInt(8))`. For vectorized types, returns
+ * the bit width of the underlying scalar type (e.g. 32 for
+ * "float32x4", not 128).
+ */
+Expr tensor_dtype_bits(Expr expr);
+
+/* \brief Return the DLTensor::dtype::lanes field
+ *
+ * \param expr The relax expression to be inspected. Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint16_t value of the number of lanes, with
+ * `PrimStructInfo(DataType::UInt(16))`
+ */
+Expr tensor_dtype_lanes(Expr expr);
+
+/* \brief Return the DLTensor::ndim field
+ *
+ * \param expr The relax expression to be inspected. Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The int32_t value of the dimensionality, with
+ * `PrimStructInfo(DataType::Int(32))`.
+ */
+Expr tensor_ndim(Expr expr);
+
+/* \brief Return the DLTensor::shape[i] field
+ *
+ * \param expr The relax expression to be inspected. Must have
+ * `TensorStructInfo`.
+ *
+ * \param axis The axis to inspect. Must be within the range `0 <=
+ * axis < tensor_ndim(expr)`, or else the results are undefined.
+ *
+ * \returns The int64_t extent of the specified tensor axis, with
+ * `PrimStructInfo(DataType::Int(64))`.
+ */
+Expr tensor_shape_i(Expr expr, Expr axis);
+
+} // namespace inspect
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_OP_TENSOR_INSPECT_H_
diff --git a/src/relax/transform/legalize_ops.cc
b/src/relax/transform/legalize_ops.cc
index c8fba59dab..343c18acd7 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -90,23 +90,32 @@ class LegalizeMutator : public ExprMutator {
bool WrapPureCondition(const Op& op, const Expr& legalized) {
static const auto& purity_map = Op::GetAttrMap<Bool>("FPurity");
- // unlikely for this condition not to be met
- if (const CallNode* call = legalized.as<CallNode>()) {
- // if the original op is not pure, don't wrap
- if (!(purity_map.count(op) && purity_map[op]->value)) {
+ const CallNode* call = legalized.as<CallNode>();
+
+ if (!call) {
+ // Unlikely for this condition to be met, but it is possible.
+ // For example, an operation could produce a Tuple output, and
+ // be legalized into separate calls for each item in the Tuple.
+ return false;
+ }
+
+ bool pure_original_op = purity_map.get(op, Bool(false))->value;
+ bool pure_legalized_op = [&]() -> bool {
+ if (auto legalized_op = call->op.as<Op>()) {
+ return purity_map.get(legalized_op.value(), Bool(false))->value;
+ } else if (auto func_sinfo =
call->op->struct_info_.as<FuncStructInfoNode>()) {
+ return func_sinfo->purity;
+ } else {
return false;
}
- if (const OpNode* call_op = call->op.as<OpNode>()) {
- auto res_op = GetRef<Op>(call_op);
- if (purity_map.count(res_op)) {
- // if the legalized op is already pure, we *don't* need a wrapper
- return !purity_map[res_op]->value;
- }
- }
- // simplest case: wrap if the original op was pure and the result is
somehow not
- return true;
- }
- return false;
+ }();
+
+ // If the original op was pure, but the legalized op was not,
+ // the legalized op may occur in a context that requires pure
+ // functions, such as a `relax::DataflowBlock`. In this case,
+ // we should wrap the legalized operation to indicate that it is
+ // still pure.
+ return pure_original_op && !pure_legalized_op;
}
Call WrapPureCall(const Call& ret) {
@@ -148,6 +157,7 @@ class LegalizeMutator : public ExprMutator {
Expr VisitExpr_(const CallNode* call) final {
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
+ static const auto& requires_arg_shapes_map =
Op::GetAttrMap<Bool>("RequiresArgumentShapes");
static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
@@ -157,17 +167,72 @@ class LegalizeMutator : public ExprMutator {
if (op_node == nullptr) {
return visited_call;
}
-
auto op = GetRef<Op>(op_node);
- std::string op_name(op->name);
- bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos);
- // Not all shape values are known
- // Data-dependent ops are exception since their output shape will be
identified at runtime.
- // Legalizer will insert their shape functions, which are manually
registered, and match cast
- // to define symbolic output shape at compile time.
- if (!std::all_of(visited_call->args.begin(), visited_call->args.end(),
- [](Expr arg) { return
KnowAllShapeValues(GetStructInfo(arg)); }) ||
- (!is_data_dependent_op &&
!KnowAllShapeValues(GetStructInfo(visited_call)))) {
+
+ bool can_legalize = [&]() -> bool {
+ bool requires_arg_shapes = requires_arg_shapes_map.get(op,
Bool(true))->value;
+ if (!requires_arg_shapes) {
+ // This operator does not require its arguments to have a
+ // known shape/dtype. For example, the "relax.tensor_ndim"
+ // operator can output the dimensionality of a tensor at
+ // runtime, and does not require the dimensionality to be
+ // known at compile-time.
+ return true;
+ }
+
+ bool arg_shapes_defined =
+ std::all_of(visited_call->args.begin(), visited_call->args.end(),
+ [](Expr arg) { return
KnowAllShapeValues(GetStructInfo(arg)); });
+ if (!arg_shapes_defined) {
+ // This operator cannot be legalized, because legalization
+ // requires the argument shapes to be known.
+ //
+ // TODO(Lunderberg):
+ //
+ // Improve this fallback case, as failure to legalize can
+ // produce unexpected errors during CodeGenVM. This could
+ // be done by having `R.Tensor(ndim=2)` be syntactic sugar
+ // for `R.Tensor(shape=[m, n])`, where `m` and `n` are new
+ // shape variables. This would allow legalization into
+ // dynamic TIR PrimFuncs.
+ //
+ // This fallback would only be applicable for cases where
+ // both the dtype and the dimensionality are known. While
+ // Relax can express a tensor with unknown dtype and
+ // dimensionality as `TensorStructInfo(DataType::Void(),
+ // kUnknownNDim)`, TIR cannot express unknown dtype or
+ // unknown dimensionality.
+ return false;
+ }
+
+ std::string op_name(op->name);
+ bool is_data_dependent_op = (op_name.find("dynamic") !=
std::string::npos);
+ bool ret_shape_defined = KnowAllShapeValues(GetStructInfo(visited_call));
+ if (!is_data_dependent_op && !ret_shape_defined) {
+ // This operator cannot be legalized, because legalization by
+ // default requires the output shape. The exception is
+ // data-dependent operators (e.g. `R.dynamic_strided_slice`),
+ // where the shape of the output depends on the runtime values
+ // stored in a tensor.
+ //
+ // For data-dependent ops, the output shape will be identified
+ // at runtime. The Legalizer will insert their shape
+ // functions, which are manually registered for each
+ // data-dependent op, and match cast to define symbolic output
+ // shapes. These symbolic output shapes at compile time can
+ // be by later operations to refer to the runtime shape.
+ //
+ // TODO(Lunderberg): Make a new operator attribute
+ // `.set_attr<Bool>("DataDependent")`, rather than relying on
+ // the name of the operator.
+ return false;
+ }
+
+ // All checks pass, this operator can be legalized.
+ return true;
+ }();
+
+ if (!can_legalize) {
return visited_call;
}
diff --git a/tests/python/relax/test_op_unpack.py
b/tests/python/relax/test_op_unpack.py
new file mode 100644
index 0000000000..03e4e0fc85
--- /dev/null
+++ b/tests/python/relax/test_op_unpack.py
@@ -0,0 +1,127 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.testing
+
+from tvm import relax
+from tvm.ir import Op
+from tvm.script import ir as I, relax as R
+
+# Parameterization for reading dtype of DLTensor. Chosen to have
+# multiple distinct type codes, number of lanes, and widths.
+dtype = tvm.testing.parameter(
+ "int32",
+ "int64",
+ "float32",
+ "float32x4",
+ "bfloat",
+ "e4m3_float8",
+)
+shape = tvm.testing.parameter(
+ [],
+ [16],
+ [128, 256],
+ [1] * 64,
+)
+
+
+def test_tensor_dtype_code(dtype):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.dtype.type_code
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty([16], dtype)
+ res = vm["main"](arg)
+
+ expected_type_code = tvm.runtime.DataType(dtype).type_code
+ assert res == expected_type_code
+
+
+def test_tensor_dtype_bits(dtype):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.dtype.bits
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty([16], dtype)
+ res = vm["main"](arg)
+
+ expected_type_bits = tvm.runtime.DataType(dtype).bits
+ assert res == expected_type_bits
+
+
+def test_tensor_dtype_lanes(dtype):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.dtype.lanes
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty([16], dtype)
+ res = vm["main"](arg)
+
+ expected_type_lanes = tvm.runtime.DataType(dtype).lanes
+ assert res == expected_type_lanes
+
+
+def test_tensor_ndim(shape):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.ndim
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty(shape, "int32")
+ res = vm["main"](arg)
+
+ assert res == len(shape)
+
+
+def test_tensor_shape(shape):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor, axis: R.Prim("int64")):
+ return A.shape[axis]
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty(shape, "int32")
+
+ res = [vm["main"](arg, i) for i, _ in enumerate(shape)]
+
+ tvm.ir.assert_structural_equal(res, shape)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()