This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8274d142a3 [Relax] Implement operators to inspec DLTensor::strides and
offset (#16721)
8274d142a3 is described below
commit 8274d142a3c229eb664d041c5a8034c3638f8c0f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Mar 26 08:55:10 2024 -0500
[Relax] Implement operators to inspec DLTensor::strides and offset (#16721)
* [TIR] LowerTVMBuiltin may use device_type from PrimFunc annotation
If an allocation occurs within a host function, it may not have a
device/host split.
* lint fix
* [Relax] Implement operators to inspec DLTensor::strides and offset
A follow-up PR to https://github.com/apache/tvm/pull/16563. This PR
implements similar operators to inspect the runtime values of
`DLTensor::strides` and `DLTensor::byte_offset`. In addition, while the
element offset is not explicitly present in the `DLTensor` struct, a
Relax operator is implemented to infer it from the `byte_offset` and
`data_type` fields, for use when interacting with the TIR
`BufferNode::elem_offset` field.
---
python/tvm/relax/expr.py | 97 ++++++++
.../tvm/relax/transform/legalize_ops/__init__.py | 1 +
.../tvm/relax/transform/legalize_ops/inspect_op.py | 128 +++++++++++
src/relax/op/tensor/inspect.cc | 180 ++++++++++++---
src/relax/op/tensor/inspect.h | 39 ++++
src/tir/transforms/lower_tvm_builtin.cc | 36 ++-
tests/python/relax/test_op_inspect.py | 252 +++++++++++++++++++++
tests/python/relax/test_op_unpack.py | 127 -----------
.../test_tir_transform_lower_tvm_builtin.py | 37 ++-
9 files changed, 727 insertions(+), 170 deletions(-)
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 12f08f4dbf..4dca710e77 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -280,6 +280,33 @@ class ExprWithOp(Expr, Scriptable):
self._check_for_tensor_struct_info()
return _DLTensorShapeProxy(self)
+ @property
+ def strides(self) -> "_DLTensorStrideProxy":
+ """Returns a proxy object for accessing DLTensor::strides"""
+ self._check_for_tensor_struct_info()
+ return _DLTensorStrideProxy(self)
+
+ @property
+ def byte_offset(self) -> "Expr":
+ """Returns a proxy object for accessing DLTensor::byte_offset"""
+ self._check_for_tensor_struct_info()
+ op = tvm.ir.Op.get("relax.inspect.tensor_byte_offset")
+ return tvm.relax.Call(op, [self])
+
+ @property
+ def elem_offset(self) -> "Expr":
+ """Returns a proxy object for accessing a DLTensor's elem_offset
+
+ This parameter is not stored in the DLTensor, but is instead
+ derived from the DLTensor's byte offset and datatype. This is
+ exposed in Relax for ease of use, and for translation into the
+ `tir::BufferNode::elem_offset` field when interacting with TIR
+ buffers.
+ """
+ self._check_for_tensor_struct_info()
+ op = tvm.ir.Op.get("relax.inspect.tensor_elem_offset")
+ return tvm.relax.Call(op, [self])
+
class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric):
"""A proxy object for unpacking DLDatatype from DLTensor
@@ -431,6 +458,76 @@ class _DLTensorShapeProxy(tvm.runtime.ObjectGeneric):
return tvm.relax.Call(op, [self.tensor, axis])
+class _DLTensorStrideProxy(tvm.runtime.ObjectGeneric):
+ """A proxy object for unpacking the strides from DLTensor
+
+ Exposes accessors for the `DLTensor::strides` field. Accessing
+ these fields will produce `relax.Call` expressions, representing
+ the field's runtime value. If the datatype of the tensor is known
+ at compile-time, the `relax.Call` will be normalized into a
+ `relax.PrimValue`, with no runtime cost.
+
+ Parameters
+ ----------
+ tensor: relax.Expr
+
+ The relax tensor (or a variable referring to a relax tensor),
+ whose runtime strides is being inspected.
+ """
+
+ def __init__(self, tensor):
+ self.tensor = tensor
+
+ def asobject(self):
+ """Provide expected in error message
+
+ This method is called when `_DLTensorStrideProxy` is used in a
+ context that requires a `relax.Expr`. This usage is not
+ supported, and raising an error here can provide suggested
+ fixes that are not present in the default error message from
+ `tvm.runtime.convert_to_object`.
+ """
+ raise TypeError(
+ f"{self.tensor}.strides cannot be converted to a relax expression,
"
+ f"and should be used as a proxy object to access the runtime
strides of the DLTensor. "
+ f"The DLTensor::ndim field can be accessed as len({self.tensor}), "
+ f"and the DLTensor::strides array can be accessed as
{self.tensor}.strides[i]"
+ )
+
+ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
+ """Returns the extent of a tensor axis
+
+ Parameters
+ ----------
+ axis: Union[int, PrimExpr, Expr]
+
+ The tensor axis whose extent should be returned. For ease
+ of use, any python integers or TIR expressions are
+ converted to `relax.Expr`.
+
+ Returns
+ -------
+ extent: Expr
+
+ The extent of the tensor's axis.
+ """
+
+ if not isinstance(axis, tvm.relax.Expr):
+ axis = tvm.relax.PrimValue(axis)
+
+ if axis.struct_info_ is not None and not isinstance(
+ axis.struct_info_, tvm.relax.PrimStructInfo
+ ):
+ raise TypeError(
+ f"The index used to access {self.tensor}.strides "
+ f'must have struct info R.Prim("int64"), '
+ f"but index {axis} had struct info {axis.struct_info_}."
+ )
+
+ op = tvm.ir.Op.get("relax.inspect.tensor_stride_i")
+ return tvm.relax.Call(op, [self.tensor, axis])
+
+
@tvm._ffi.register_object("relax.expr.Call")
class Call(ExprWithOp):
"""Function call node in Relax.
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py
b/python/tvm/relax/transform/legalize_ops/__init__.py
index e3b3213a38..b4aba0291f 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/transform/legalize_ops/__init__.py
@@ -23,6 +23,7 @@ from . import distributed
from . import grad
from . import image
from . import index
+from . import inspect_op
from . import linear_algebra
from . import manipulate
from . import nn
diff --git a/python/tvm/relax/transform/legalize_ops/inspect_op.py
b/python/tvm/relax/transform/legalize_ops/inspect_op.py
new file mode 100644
index 0000000000..5f1b36667a
--- /dev/null
+++ b/python/tvm/relax/transform/legalize_ops/inspect_op.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Legalization functions for DLTensor inspection."""
+
+import enum
+
+from tvm.script import tir as T
+
+from ...block_builder import BlockBuilder
+from ...expr import Call, Expr
+from .common import register_legalize
+
+
+class TVMStructFieldKind(enum.IntEnum):
+ """Equivalent to tvm::tir::builtin::TVMStructFieldKind
+
+ This does not use `enum.auto()` to define the values, because
+ `enum.auto()` starts from 1, and this must match the C++
+ definition which starts from 0.
+ """
+
+ kArrAddr = 0
+ kArrData = 1
+ kArrShape = 2
+ kArrStrides = 3
+ kArrNDim = 4
+ kArrTypeCode = 5
+ kArrTypeBits = 6
+ kArrTypeLanes = 7
+ kArrByteOffset = 8
+ kArrDeviceId = 9
+ kArrDeviceType = 10
+ kArrKindBound_ = 11
+ kTVMValueContent = 12
+ kTVMValueKindBound_ = 13
+
+
+@register_legalize("relax.inspect.tensor_stride_i")
+def _tensor_stride_i(bb: BlockBuilder, call: Call) -> Expr:
+ @T.prim_func(private=True)
+ def _get_tensor_stride_i(dlpack_handle: T.handle, axis: T.int64) ->
T.int64:
+ T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled":
T.bool(True)})
+ assert T.int64(0) <= axis, "Specified axis may not be negative"
+ ndim: T.int32 = T.tvm_struct_get(
+ dlpack_handle, 0, int(TVMStructFieldKind.kArrNDim), "int32"
+ )
+ assert axis < T.Cast(
+ "int64", ndim
+ ), "Specified axis may not be larger than the tensor's dimensionality"
+ stride_ptr: T.handle("int64") = T.tvm_struct_get(
+ dlpack_handle, 0, int(TVMStructFieldKind.kArrStrides), "handle"
+ )
+
+ if T.isnullptr(stride_ptr):
+ shape_ptr: T.handle("int64") = T.tvm_struct_get(
+ dlpack_handle, 0, int(TVMStructFieldKind.kArrShape), "handle"
+ )
+ shape = T.decl_buffer(ndim, "int64", data=shape_ptr)
+
+ product = T.decl_buffer([], "int64")
+ product[()] = 1
+
+ # TODO(Lunderberg): Add a TIR lowering pass to allow
+ # ranges to start somewhere other than zero. This loop
+ # could then iterate on `range(axis+1, ndim)`.
+ for dim_offset in range(ndim - (axis + 1)):
+ dim = dim_offset + (axis + 1)
+ product[()] = product[()] * shape[dim]
+
+ return product[()]
+ else:
+ strides = T.decl_buffer(ndim, "int64", data=stride_ptr)
+ stride: T.int64 = strides[axis]
+ return stride
+
+ gvar = bb.add_func(_get_tensor_stride_i, "_get_tensor_stride_i")
+ return Call(gvar, call.args)
+
+
+@register_legalize("relax.inspect.tensor_byte_offset")
+def _tensor_byte_offset(bb: BlockBuilder, call: Call) -> Expr:
+ @T.prim_func(private=True)
+ def _get_tensor_byte_offset(dlpack_handle: T.handle) -> T.int64:
+ T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled":
T.bool(True)})
+ byte_offset: T.uint64 = T.tvm_struct_get(
+ dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
+ )
+ return byte_offset
+
+ gvar = bb.add_func(_get_tensor_byte_offset, "_get_tensor_byte_offset")
+ return Call(gvar, call.args)
+
+
+@register_legalize("relax.inspect.tensor_elem_offset")
+def _tensor_elem_offset(bb: BlockBuilder, call: Call) -> Expr:
+ @T.prim_func(private=True)
+ def _get_tensor_elem_offset(dlpack_handle: T.handle) -> T.int64:
+ T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled":
T.bool(True)})
+ byte_offset: T.uint64 = T.tvm_struct_get(
+ dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
+ )
+ scalar_bits: T.uint8 = T.tvm_struct_get(
+ dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeBits), "uint8"
+ )
+ lanes: T.uint16 = T.tvm_struct_get(
+ dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeLanes), "uint16"
+ )
+ bytes_per_element = T.ceildiv(scalar_bits.astype("uint64") *
lanes.astype("uint64"), 8)
+ elem_offset = byte_offset // bytes_per_element
+ return elem_offset
+
+ gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset")
+ return Call(gvar, call.args)
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
index a40b2af5ef..186fc9fa86 100644
--- a/src/relax/op/tensor/inspect.cc
+++ b/src/relax/op/tensor/inspect.cc
@@ -29,6 +29,8 @@
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
+#include <tuple>
+
namespace tvm {
namespace relax {
namespace inspect {
@@ -50,6 +52,42 @@ TensorStructInfo GetTensorArgInfo(const Call& call) {
return tensor_sinfo.value();
}
+std::tuple<TensorStructInfo, PrimStructInfo> GetTensorArgInfoWithIndex(const
Call& call) {
+ CHECK_EQ(call->args.size(), 2) << "TypeError: "
+ << "Operator " << call->op << " expects two
arguments, "
+ << "but received " << call->args.size()
+ << " arguments: " << call->args;
+ const auto& arg = call->args[0];
+ const auto& axis = call->args[1];
+
+ auto tensor_sinfo = arg->struct_info_.as<TensorStructInfoNode>();
+ CHECK(tensor_sinfo) << "TypeError: "
+ << "Operator " << call->op << " expects arguments
(tensor, axis), "
+ << "but the first argument " << arg << " in expression "
<< call
+ << " has struct info " << arg->struct_info_;
+
+ auto axis_sinfo = axis->struct_info_.as<PrimStructInfoNode>();
+ CHECK(axis_sinfo) << "TypeError: "
+ << "Operator " << call->op << " expects arguments (tensor,
axis), "
+ << "but the second argument " << arg << " in expression "
<< call
+ << " has struct info " << axis->struct_info_;
+
+ auto int_imm_axis = axis_sinfo->value.as<IntImmNode>();
+
+ if (int_imm_axis) {
+ CHECK_GE(int_imm_axis->value, 0);
+ }
+ if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) {
+ CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim)
+ << "ValueError: "
+ << "Expression " << call << " attempts to access " << arg << ".shape["
+ << int_imm_axis->value << "]"
+ << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << "
elements";
+ }
+
+ return {GetRef<TensorStructInfo>(tensor_sinfo),
GetRef<PrimStructInfo>(axis_sinfo)};
+}
+
DataType GetTensorDataType(const Call& call) { return
GetTensorArgInfo(call)->dtype; }
tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field,
DataType field_dtype) {
@@ -244,39 +282,11 @@ Expr tensor_shape_i(Expr expr) {
StructInfo InferStructInfoTensorShape(const Call& call, const BlockBuilder&) {
auto dlpack_type = DataType::Int(64);
- CHECK_EQ(call->args.size(), 2) << "TypeError: "
- << "Operator " << call->op << " expects two
arguments, "
- << "but received " << call->args.size()
- << " arguments: " << call->args;
- const auto& arg = call->args[0];
- const auto& axis = call->args[1];
-
- auto tensor_sinfo = arg->struct_info_.as<TensorStructInfoNode>();
- CHECK(tensor_sinfo) << "TypeError: "
- << "Operator " << call->op << " expects arguments
(tensor, axis), "
- << "but the first argument " << arg << " in expression "
<< call
- << " has struct info " << arg->struct_info_;
-
- auto axis_sinfo = axis->struct_info_.as<PrimStructInfoNode>();
- CHECK(axis_sinfo) << "TypeError: "
- << "Operator " << call->op << " expects arguments (tensor,
axis), "
- << "but the second argument " << arg << " in expression "
<< call
- << " has struct info " << axis->struct_info_;
+ auto [tensor_sinfo, axis_sinfo] = GetTensorArgInfoWithIndex(call);
+ auto tensor_shape = tensor_sinfo->GetShape();
auto int_imm_axis = axis_sinfo->value.as<IntImmNode>();
- if (int_imm_axis) {
- CHECK_GE(int_imm_axis->value, 0);
- }
- if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) {
- CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim)
- << "ValueError: "
- << "Expression " << call << " attempts to access " << arg << ".shape["
- << int_imm_axis->value << "]"
- << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << "
elements";
- }
-
- auto tensor_shape = tensor_sinfo->GetShape();
if (int_imm_axis && tensor_shape.defined()) {
return PrimStructInfo(tensor_shape.value()[int_imm_axis->value]);
} else {
@@ -346,6 +356,116 @@ TVM_REGISTER_OP("relax.inspect.tensor_shape_i")
.set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
.set_attr<Bool>("FPurity", Bool(true));
+//// relax.tensor_stride_i
+
+Expr tensor_stride_i(Expr expr) {
+ static const Op& op = Op::Get("relax.inspect.tensor_stride_i");
+ return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorStride(const Call& call, const BlockBuilder&) {
+ auto dlpack_type = DataType::Int(64);
+
+ auto [tensor_sinfo, axis_sinfo] = GetTensorArgInfoWithIndex(call);
+
+ auto opt_tensor_shape = tensor_sinfo->GetShape();
+ auto int_imm_axis = axis_sinfo->value.as<IntImmNode>();
+
+ if (int_imm_axis && opt_tensor_shape.defined()) {
+ // As of 2024-03-14, Relax does not have an explicit
+ // representation for striding in `TensorStructInfo`. The
+ // `FLegalize` function for most operators is implemented in terms
+ // of `topi`, and is then converted from TE to `tir::PrimFunc`
+ // using `tvm::tir::CreatePrimFunc`. The `te::Tensor` is
+ // converted to a `tir::Buffer` in `RewriteStageToBlock`, and uses
+ // the default empty list for the strides. The empty strides
+ // represent a compact data array.
+ //
+ // Therefore, while Relax does not explicitly represent the
+ // striding of a tensor, it implicitly requires compact striding
+ // for any legalizable Tensor.
+ auto tensor_shape = opt_tensor_shape.value();
+ PrimExpr stride = IntImm(DataType::Int(64), 1);
+ for (size_t axis = int_imm_axis->value + 1; axis < tensor_shape.size();
axis++) {
+ stride = stride * tensor_shape[axis];
+ }
+ return PrimStructInfo(stride);
+ } else {
+ return PrimStructInfo(dlpack_type);
+ }
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_stride_i")
+ .set_num_inputs(2)
+ .add_argument("tensor", "Tensor", "The tensor to be inspected")
+ .add_argument("axis", "Prim(int64)", "The axis whose extent should be
returned")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoTensorStride)
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_byte_offset
+
+Expr tensor_byte_offset(Expr expr) {
+ static const Op& op = Op::Get("relax.inspect.tensor_byte_offset");
+ return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorByteOffset(const Call& call, const
BlockBuilder&) {
+ auto dlpack_type = DataType::UInt(64);
+
+ auto tensor_sinfo = GetTensorArgInfo(call);
+
+ auto opt_tensor_shape = tensor_sinfo->GetShape();
+ if (opt_tensor_shape.defined()) {
+ // Relax implicitly requires that the byte offset is zero for any
+ // legalizable tensor. See InferStructInfoTensorStride for full
+ // explanation.
+ return PrimStructInfo(IntImm(dlpack_type, 0));
+ } else {
+ return PrimStructInfo(dlpack_type);
+ }
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_byte_offset")
+ .set_num_inputs(1)
+ .add_argument("tensor", "Tensor", "The tensor to be inspected")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoTensorByteOffset)
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_elem_offset
+
+Expr tensor_elem_offset(Expr expr) {
+ static const Op& op = Op::Get("relax.inspect.tensor_elem_offset");
+ return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorElemOffset(const Call& call, const
BlockBuilder&) {
+ auto dlpack_type = DataType::UInt(64);
+
+ auto tensor_sinfo = GetTensorArgInfo(call);
+
+ auto opt_tensor_shape = tensor_sinfo->GetShape();
+ if (opt_tensor_shape.defined()) {
+ // Relax implicitly requires that the element offset is zero for
+ // any legalizable tensor. See InferStructInfoTensorStride for
+ // full explanation.
+ return PrimStructInfo(IntImm(dlpack_type, 0));
+ } else {
+ return PrimStructInfo(dlpack_type);
+ }
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_elem_offset")
+ .set_num_inputs(1)
+ .add_argument("tensor", "Tensor", "The tensor to be inspected")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoTensorElemOffset)
+ .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+ .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace inspect
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/inspect.h b/src/relax/op/tensor/inspect.h
index 0225b00fb3..2aa20a1381 100644
--- a/src/relax/op/tensor/inspect.h
+++ b/src/relax/op/tensor/inspect.h
@@ -85,6 +85,45 @@ Expr tensor_ndim(Expr expr);
*/
Expr tensor_shape_i(Expr expr, Expr axis);
+/* \brief Return the DLTensor::strides[i] field
+ *
+ * The `int64_t* DLTensor::strides` is allowed to be NULL, which
+ * represents a compact packing of the data. In this case, the
+ * returned stride is computed from the `DLTensor::shape`.
+ *
+ * \param expr The relax expression to be inspected. Must have
+ * `TensorStructInfo`.
+ *
+ * \param axis The axis to inspect. Must be within the range `0 <=
+ * axis < tensor_ndim(expr)`, or else the results are undefined.
+ *
+ * \returns The int64_t extent of the specified tensor axis, with
+ * `PrimStructInfo(DataType::Int(64))`.
+ */
+Expr tensor_stride_i(Expr expr, Expr axis);
+
+/* \brief Return the DLTensor::byte_offset field
+ *
+ * \param expr The relax expression to be inspected. Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint64_t byte offset, with
`PrimStructInfo(DataType::UInt(64))`.
+ */
+Expr tensor_byte_offset(Expr expr);
+
+/* \brief Return the element offset of a DLTensor
+ *
+ * While the DLTensor does not directly contain the element offset, it
+ * can be inferred from the `DLTensor::byte_offset` and
+ * `DLTensor::data_type` fields.
+ *
+ * \param expr The relax expression to be inspected. Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint64_t element offset, with
`PrimStructInfo(DataType::UInt(64))`.
+ */
+Expr tensor_elem_offset(Expr expr);
+
} // namespace inspect
} // namespace relax
} // namespace tvm
diff --git a/src/tir/transforms/lower_tvm_builtin.cc
b/src/tir/transforms/lower_tvm_builtin.cc
index 6da2f873b7..1a3888a7cd 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -38,6 +38,19 @@ namespace tir {
// These information are needed during codegen.
class BuiltinLower : public StmtExprMutator {
public:
+ static PrimFunc Build(PrimFunc func) {
+ Optional<PrimExpr> device_type = NullOpt;
+ if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
+ device_type = Integer(target.value()->kind->default_device_type);
+ }
+
+ BuiltinLower mutator(device_type);
+ func.CopyOnWrite()->body = mutator.VisitBodyAndRealizeAlloca(func->body);
+ return func;
+ }
+
+ explicit BuiltinLower(Optional<PrimExpr> device_type = NullOpt) :
device_type_(device_type) {}
+
// NOTE: Right now, we make the following scoping requirement
// for memory allocated by the following primitives
// - tvm_stack_make_array
@@ -284,13 +297,17 @@ class BuiltinLower : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_id) {
- ICHECK(!device_id_);
+ auto cache = device_id_;
device_id_ = op->value;
- return this->VisitStmt(op->body);
+ Stmt out = this->VisitStmt(op->body);
+ device_id_ = cache;
+ return out;
} else if (op->attr_key == attr::device_type) {
- ICHECK(!device_type_);
+ auto cache = device_type_;
device_type_ = op->value;
- return this->VisitStmt(op->body);
+ Stmt out = this->VisitStmt(op->body);
+ device_type_ = cache;
+ return out;
} else {
return StmtExprMutator::VisitStmt_(op);
}
@@ -656,13 +673,12 @@ class BuiltinLower : public StmtExprMutator {
namespace transform {
Pass LowerTVMBuiltin() {
- auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- if (IsHostFunc(f).value_or(false)) {
- auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- f.CopyOnWrite()->body = BuiltinLower().Build(f->body);
- VLOG(2) << "LowerTVMBuiltin: " << f;
+ auto pass_func = [](PrimFunc func, IRModule m, PassContext ctx) {
+ if (IsHostFunc(func).value_or(false)) {
+ func = BuiltinLower::Build(func);
+ VLOG(2) << "LowerTVMBuiltin: " << func;
}
- return f;
+ return func;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
}
diff --git a/tests/python/relax/test_op_inspect.py
b/tests/python/relax/test_op_inspect.py
new file mode 100644
index 0000000000..18d7a88f05
--- /dev/null
+++ b/tests/python/relax/test_op_inspect.py
@@ -0,0 +1,252 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ctypes
+
+import numpy as np
+import pytest
+
+import tvm.testing
+
+from tvm import relax
+from tvm.ir import Op
+from tvm.script import ir as I, relax as R
+
+# Parameterization for reading dtype of DLTensor. Chosen to have
+# multiple distinct type codes, number of lanes, and widths.
+dtype = tvm.testing.parameter(
+ "int32",
+ "int64",
+ "float32",
+ "float32x4",
+ "bfloat",
+ "e4m3_float8",
+)
+shape = tvm.testing.parameter(
+ [],
+ [16],
+ [128, 256],
+ [1] * 64,
+)
+
+elem_offset = tvm.testing.parameter(0, 64, 128)
+
+
+def test_tensor_dtype_code(dtype):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.dtype.type_code
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty([16], dtype)
+ res = vm["main"](arg)
+
+ expected_type_code = tvm.runtime.DataType(dtype).type_code
+ assert res == expected_type_code
+
+
+def test_tensor_dtype_bits(dtype):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.dtype.bits
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty([16], dtype)
+ res = vm["main"](arg)
+
+ expected_type_bits = tvm.runtime.DataType(dtype).bits
+ assert res == expected_type_bits
+
+
+def test_tensor_dtype_lanes(dtype):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.dtype.lanes
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty([16], dtype)
+ res = vm["main"](arg)
+
+ expected_type_lanes = tvm.runtime.DataType(dtype).lanes
+ assert res == expected_type_lanes
+
+
+def test_tensor_ndim(shape):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.ndim
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty(shape, "int32")
+ res = vm["main"](arg)
+
+ assert res == len(shape)
+
+
+def test_tensor_shape(shape):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor, axis: R.Prim("int64")):
+ return A.shape[axis]
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty(shape, "int32")
+
+ res = [vm["main"](arg, i) for i, _ in enumerate(shape)]
+
+ tvm.ir.assert_structural_equal(res, shape)
+
+
+def _get_compact_striding(shape):
+ strides = []
+ product = 1
+ for dim in reversed(shape):
+ strides.append(product)
+ product *= dim
+ return list(reversed(strides))
+
+
+def test_strides_of_compact_tensor(shape):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor, axis: R.Prim("int64")):
+ return A.strides[axis]
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ arg = tvm.nd.empty(shape, "int32")
+
+ res = [vm["main"](arg, i) for i, _ in enumerate(shape)]
+ expected = _get_compact_striding(shape)
+
+ tvm.ir.assert_structural_equal(res, expected)
+
+
+def test_strides_of_non_compact_tensor():
+ backing_shape = [64, 64]
+ view_shape = [16, 16]
+ expected_strides = [backing_shape[0], 1]
+
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor, axis: R.Prim("int64")):
+ return A.strides[axis]
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ backing_ndarray = tvm.nd.empty(backing_shape, "int32")
+
+ # Manually overwrite the DLTensor fields to make a view into the
+ # tensor.
+ view = backing_ndarray.handle[0]
+ np_shape = np.array([16, 16], "int64")
+ view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
+ np_strides = np.array([64, 1], "int64")
+ view.strides = np_strides.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
+ backing_ndarray.handle[0] = view
+
+ res = [vm["main"](backing_ndarray, i) for i, _ in enumerate(view_shape)]
+
+ tvm.ir.assert_structural_equal(res, expected_strides)
+
+
+def test_byte_offset(elem_offset):
+ backing_shape = [64, 64]
+ view_shape = [16, 16]
+ byte_offset = elem_offset * 4
+
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.byte_offset
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ backing_ndarray = tvm.nd.empty(backing_shape, "int32")
+
+ # Manually overwrite the DLTensor fields to make a view into the
+ # tensor.
+ view = backing_ndarray.handle[0]
+ np_shape = np.array(view_shape, "int64")
+ view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
+ view.byte_offset = byte_offset
+ backing_ndarray.handle[0] = view
+
+ res = vm["main"](backing_ndarray)
+
+ assert res == byte_offset
+
+
+def test_elem_offset(elem_offset, dtype):
+ tvm_dtype = tvm.runtime.DataType(dtype)
+
+ backing_shape = [64, 64]
+ view_shape = [16, 16]
+ element_bytes = (tvm_dtype.bits * tvm_dtype.lanes) // 8
+ byte_offset = elem_offset * element_bytes
+
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(A: R.Tensor):
+ return A.elem_offset
+
+ built = relax.build(mod)
+ vm = relax.VirtualMachine(built, tvm.cpu())
+
+ backing_ndarray = tvm.nd.empty(backing_shape, dtype)
+
+ # Manually overwrite the DLTensor fields to make a view into the
+ # tensor.
+ view = backing_ndarray.handle[0]
+ np_shape = np.array(view_shape, "int64")
+ view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
+ view.byte_offset = byte_offset
+ backing_ndarray.handle[0] = view
+
+ res = vm["main"](backing_ndarray)
+
+ assert res == elem_offset
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_op_unpack.py
b/tests/python/relax/test_op_unpack.py
deleted file mode 100644
index 03e4e0fc85..0000000000
--- a/tests/python/relax/test_op_unpack.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import tvm.testing
-
-from tvm import relax
-from tvm.ir import Op
-from tvm.script import ir as I, relax as R
-
-# Parameterization for reading dtype of DLTensor. Chosen to have
-# multiple distinct type codes, number of lanes, and widths.
-dtype = tvm.testing.parameter(
- "int32",
- "int64",
- "float32",
- "float32x4",
- "bfloat",
- "e4m3_float8",
-)
-shape = tvm.testing.parameter(
- [],
- [16],
- [128, 256],
- [1] * 64,
-)
-
-
-def test_tensor_dtype_code(dtype):
- @I.ir_module
- class mod:
- @R.function
- def main(A: R.Tensor):
- return A.dtype.type_code
-
- built = relax.build(mod)
- vm = relax.VirtualMachine(built, tvm.cpu())
-
- arg = tvm.nd.empty([16], dtype)
- res = vm["main"](arg)
-
- expected_type_code = tvm.runtime.DataType(dtype).type_code
- assert res == expected_type_code
-
-
-def test_tensor_dtype_bits(dtype):
- @I.ir_module
- class mod:
- @R.function
- def main(A: R.Tensor):
- return A.dtype.bits
-
- built = relax.build(mod)
- vm = relax.VirtualMachine(built, tvm.cpu())
-
- arg = tvm.nd.empty([16], dtype)
- res = vm["main"](arg)
-
- expected_type_bits = tvm.runtime.DataType(dtype).bits
- assert res == expected_type_bits
-
-
-def test_tensor_dtype_lanes(dtype):
- @I.ir_module
- class mod:
- @R.function
- def main(A: R.Tensor):
- return A.dtype.lanes
-
- built = relax.build(mod)
- vm = relax.VirtualMachine(built, tvm.cpu())
-
- arg = tvm.nd.empty([16], dtype)
- res = vm["main"](arg)
-
- expected_type_lanes = tvm.runtime.DataType(dtype).lanes
- assert res == expected_type_lanes
-
-
-def test_tensor_ndim(shape):
- @I.ir_module
- class mod:
- @R.function
- def main(A: R.Tensor):
- return A.ndim
-
- built = relax.build(mod)
- vm = relax.VirtualMachine(built, tvm.cpu())
-
- arg = tvm.nd.empty(shape, "int32")
- res = vm["main"](arg)
-
- assert res == len(shape)
-
-
-def test_tensor_shape(shape):
- @I.ir_module
- class mod:
- @R.function
- def main(A: R.Tensor, axis: R.Prim("int64")):
- return A.shape[axis]
-
- built = relax.build(mod)
- vm = relax.VirtualMachine(built, tvm.cpu())
-
- arg = tvm.nd.empty(shape, "int32")
-
- res = [vm["main"](arg, i) for i, _ in enumerate(shape)]
-
- tvm.ir.assert_structural_equal(res, shape)
-
-
-if __name__ == "__main__":
- tvm.testing.main()
diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
index de1020ef20..754ce03240 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
@@ -260,11 +260,13 @@ class
TestLowerCPUAllocation(tvm.testing.CompareBeforeAfter):
class TestLowerAllocateRequiresDeviceID(tvm.testing.CompareBeforeAfter):
+ """If device id is missing, error."""
+
transform = tvm.tir.transform.LowerTVMBuiltin()
def before():
T.func_attr({"target": T.target("llvm")})
- T.attr("dummy", "device_id", 0)
+ T.attr("dummy", "device_type", 2) # kDLCuda
ptr = T.allocate([16], "float32")
buf = T.decl_buffer(16, "float32", data=ptr)
buf[0] = 0.0
@@ -273,16 +275,45 @@ class
TestLowerAllocateRequiresDeviceID(tvm.testing.CompareBeforeAfter):
class TestLowerAllocateRequiresDeviceType(tvm.testing.CompareBeforeAfter):
+ """If device type is missing, error.
+
+ The device type can be inferred either from the `"device_type"`
+ statement attribute, or from the `"target"` function attribute.
+ Here, we provide neither. The `"tir.is_host_func"` attribute is
+ provided as otherwise the function would be skipped altogether by
+ LowerTVMBuiltin.
+ """
+
transform = tvm.tir.transform.LowerTVMBuiltin()
def before():
- T.func_attr({"target": T.target("llvm")})
+ T.func_attr({"tir.is_host_func": True})
T.attr("dummy", "device_id", 0)
+ ptr = T.allocate([1024 * 1024], "float32")
+ buf = T.decl_buffer(1024 * 1024, "float32", data=ptr)
+ buf[0] = 0.0
+
+ expected = tvm.TVMError
+
+
+class TestLowerCPUAllocWithFunctionAttr(tvm.testing.CompareBeforeAfter):
+ """CPU allocations can be handled at codegen time
+
+ Like `TestLowerCPUAllocation`, but the device type is taken from
+ the function attribute. The `AttrStmt` can override the device
+ type for allocations within its scope, but it defaults to the
+ function's target.
+ """
+
+ transform = tvm.tir.transform.LowerTVMBuiltin()
+
+ def before():
+ T.func_attr({"target": T.target("llvm")})
ptr = T.allocate([16], "float32")
buf = T.decl_buffer(16, "float32", data=ptr)
buf[0] = 0.0
- expected = tvm.TVMError
+ expected = before
if __name__ == "__main__":