This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8274d142a3 [Relax] Implement operators to inspec DLTensor::strides and 
offset  (#16721)
8274d142a3 is described below

commit 8274d142a3c229eb664d041c5a8034c3638f8c0f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Mar 26 08:55:10 2024 -0500

    [Relax] Implement operators to inspec DLTensor::strides and offset  (#16721)
    
    * [TIR] LowerTVMBuiltin may use device_type from PrimFunc annotation
    
    If an allocation occurs within a host function, it may not have a
    device/host split.
    
    * lint fix
    
    * [Relax] Implement operators to inspec DLTensor::strides and offset
    
    A follow-up PR to https://github.com/apache/tvm/pull/16563.  This PR
    implements similar operators to inspect the runtime values of
    `DLTensor::strides` and `DLTensor::byte_offset`.  In addition, while the
    element offset is not explicitly present in the `DLTensor` struct, a
    Relax operator is implemented to infer it from the `byte_offset` and
    `data_type` fields, for use when interacting with the TIR
    `BufferNode::elem_offset` field.
---
 python/tvm/relax/expr.py                           |  97 ++++++++
 .../tvm/relax/transform/legalize_ops/__init__.py   |   1 +
 .../tvm/relax/transform/legalize_ops/inspect_op.py | 128 +++++++++++
 src/relax/op/tensor/inspect.cc                     | 180 ++++++++++++---
 src/relax/op/tensor/inspect.h                      |  39 ++++
 src/tir/transforms/lower_tvm_builtin.cc            |  36 ++-
 tests/python/relax/test_op_inspect.py              | 252 +++++++++++++++++++++
 tests/python/relax/test_op_unpack.py               | 127 -----------
 .../test_tir_transform_lower_tvm_builtin.py        |  37 ++-
 9 files changed, 727 insertions(+), 170 deletions(-)

diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 12f08f4dbf..4dca710e77 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -280,6 +280,33 @@ class ExprWithOp(Expr, Scriptable):
         self._check_for_tensor_struct_info()
         return _DLTensorShapeProxy(self)
 
+    @property
+    def strides(self) -> "_DLTensorStrideProxy":
+        """Returns a proxy object for accessing DLTensor::strides"""
+        self._check_for_tensor_struct_info()
+        return _DLTensorStrideProxy(self)
+
+    @property
+    def byte_offset(self) -> "Expr":
+        """Returns a proxy object for accessing DLTensor::byte_offset"""
+        self._check_for_tensor_struct_info()
+        op = tvm.ir.Op.get("relax.inspect.tensor_byte_offset")
+        return tvm.relax.Call(op, [self])
+
+    @property
+    def elem_offset(self) -> "Expr":
+        """Returns a proxy object for accessing a DLTensor's elem_offset
+
+        This parameter is not stored in the DLTensor, but is instead
+        derived from the DLTensor's byte offset and datatype.  This is
+        exposed in Relax for ease of use, and for translation into the
+        `tir::BufferNode::elem_offset` field when interacting with TIR
+        buffers.
+        """
+        self._check_for_tensor_struct_info()
+        op = tvm.ir.Op.get("relax.inspect.tensor_elem_offset")
+        return tvm.relax.Call(op, [self])
+
 
 class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric):
     """A proxy object for unpacking DLDatatype from DLTensor
@@ -431,6 +458,76 @@ class _DLTensorShapeProxy(tvm.runtime.ObjectGeneric):
         return tvm.relax.Call(op, [self.tensor, axis])
 
 
+class _DLTensorStrideProxy(tvm.runtime.ObjectGeneric):
+    """A proxy object for unpacking the strides from DLTensor
+
+    Exposes accessors for the `DLTensor::strides` field.  Accessing
+    these fields will produce `relax.Call` expressions, representing
+    the field's runtime value.  If the datatype of the tensor is known
+    at compile-time, the `relax.Call` will be normalized into a
+    `relax.PrimValue`, with no runtime cost.
+
+    Parameters
+    ----------
+    tensor: relax.Expr
+
+        The relax tensor (or a variable referring to a relax tensor),
+        whose runtime strides is being inspected.
+    """
+
+    def __init__(self, tensor):
+        self.tensor = tensor
+
+    def asobject(self):
+        """Provide expected in error message
+
+        This method is called when `_DLTensorStrideProxy` is used in a
+        context that requires a `relax.Expr`.  This usage is not
+        supported, and raising an error here can provide suggested
+        fixes that are not present in the default error message from
+        `tvm.runtime.convert_to_object`.
+        """
+        raise TypeError(
+            f"{self.tensor}.strides cannot be converted to a relax expression, 
"
+            f"and should be used as a proxy object to access the runtime 
strides of the DLTensor. "
+            f"The DLTensor::ndim field can be accessed as len({self.tensor}), "
+            f"and the DLTensor::strides array can be accessed as 
{self.tensor}.strides[i]"
+        )
+
+    def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
+        """Returns the extent of a tensor axis
+
+        Parameters
+        ----------
+        axis: Union[int, PrimExpr, Expr]
+
+            The tensor axis whose extent should be returned.  For ease
+            of use, any python integers or TIR expressions are
+            converted to `relax.Expr`.
+
+        Returns
+        -------
+        extent: Expr
+
+            The extent of the tensor's axis.
+        """
+
+        if not isinstance(axis, tvm.relax.Expr):
+            axis = tvm.relax.PrimValue(axis)
+
+        if axis.struct_info_ is not None and not isinstance(
+            axis.struct_info_, tvm.relax.PrimStructInfo
+        ):
+            raise TypeError(
+                f"The index used to access {self.tensor}.strides "
+                f'must have struct info R.Prim("int64"), '
+                f"but index {axis} had struct info {axis.struct_info_}."
+            )
+
+        op = tvm.ir.Op.get("relax.inspect.tensor_stride_i")
+        return tvm.relax.Call(op, [self.tensor, axis])
+
+
 @tvm._ffi.register_object("relax.expr.Call")
 class Call(ExprWithOp):
     """Function call node in Relax.
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py 
b/python/tvm/relax/transform/legalize_ops/__init__.py
index e3b3213a38..b4aba0291f 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/transform/legalize_ops/__init__.py
@@ -23,6 +23,7 @@ from . import distributed
 from . import grad
 from . import image
 from . import index
+from . import inspect_op
 from . import linear_algebra
 from . import manipulate
 from . import nn
diff --git a/python/tvm/relax/transform/legalize_ops/inspect_op.py 
b/python/tvm/relax/transform/legalize_ops/inspect_op.py
new file mode 100644
index 0000000000..5f1b36667a
--- /dev/null
+++ b/python/tvm/relax/transform/legalize_ops/inspect_op.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Legalization functions for DLTensor inspection."""
+
+import enum
+
+from tvm.script import tir as T
+
+from ...block_builder import BlockBuilder
+from ...expr import Call, Expr
+from .common import register_legalize
+
+
+class TVMStructFieldKind(enum.IntEnum):
+    """Equivalent to tvm::tir::builtin::TVMStructFieldKind
+
+    This does not use `enum.auto()` to define the values, because
+    `enum.auto()` starts from 1, and this must match the C++
+    definition which starts from 0.
+    """
+
+    kArrAddr = 0
+    kArrData = 1
+    kArrShape = 2
+    kArrStrides = 3
+    kArrNDim = 4
+    kArrTypeCode = 5
+    kArrTypeBits = 6
+    kArrTypeLanes = 7
+    kArrByteOffset = 8
+    kArrDeviceId = 9
+    kArrDeviceType = 10
+    kArrKindBound_ = 11
+    kTVMValueContent = 12
+    kTVMValueKindBound_ = 13
+
+
+@register_legalize("relax.inspect.tensor_stride_i")
+def _tensor_stride_i(bb: BlockBuilder, call: Call) -> Expr:
+    @T.prim_func(private=True)
+    def _get_tensor_stride_i(dlpack_handle: T.handle, axis: T.int64) -> 
T.int64:
+        T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": 
T.bool(True)})
+        assert T.int64(0) <= axis, "Specified axis may not be negative"
+        ndim: T.int32 = T.tvm_struct_get(
+            dlpack_handle, 0, int(TVMStructFieldKind.kArrNDim), "int32"
+        )
+        assert axis < T.Cast(
+            "int64", ndim
+        ), "Specified axis may not be larger than the tensor's dimensionality"
+        stride_ptr: T.handle("int64") = T.tvm_struct_get(
+            dlpack_handle, 0, int(TVMStructFieldKind.kArrStrides), "handle"
+        )
+
+        if T.isnullptr(stride_ptr):
+            shape_ptr: T.handle("int64") = T.tvm_struct_get(
+                dlpack_handle, 0, int(TVMStructFieldKind.kArrShape), "handle"
+            )
+            shape = T.decl_buffer(ndim, "int64", data=shape_ptr)
+
+            product = T.decl_buffer([], "int64")
+            product[()] = 1
+
+            # TODO(Lunderberg): Add a TIR lowering pass to allow
+            # ranges to start somewhere other than zero.  This loop
+            # could then iterate on `range(axis+1, ndim)`.
+            for dim_offset in range(ndim - (axis + 1)):
+                dim = dim_offset + (axis + 1)
+                product[()] = product[()] * shape[dim]
+
+            return product[()]
+        else:
+            strides = T.decl_buffer(ndim, "int64", data=stride_ptr)
+            stride: T.int64 = strides[axis]
+            return stride
+
+    gvar = bb.add_func(_get_tensor_stride_i, "_get_tensor_stride_i")
+    return Call(gvar, call.args)
+
+
+@register_legalize("relax.inspect.tensor_byte_offset")
+def _tensor_byte_offset(bb: BlockBuilder, call: Call) -> Expr:
+    @T.prim_func(private=True)
+    def _get_tensor_byte_offset(dlpack_handle: T.handle) -> T.int64:
+        T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": 
T.bool(True)})
+        byte_offset: T.uint64 = T.tvm_struct_get(
+            dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
+        )
+        return byte_offset
+
+    gvar = bb.add_func(_get_tensor_byte_offset, "_get_tensor_byte_offset")
+    return Call(gvar, call.args)
+
+
+@register_legalize("relax.inspect.tensor_elem_offset")
+def _tensor_elem_offset(bb: BlockBuilder, call: Call) -> Expr:
+    @T.prim_func(private=True)
+    def _get_tensor_elem_offset(dlpack_handle: T.handle) -> T.int64:
+        T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": 
T.bool(True)})
+        byte_offset: T.uint64 = T.tvm_struct_get(
+            dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
+        )
+        scalar_bits: T.uint8 = T.tvm_struct_get(
+            dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeBits), "uint8"
+        )
+        lanes: T.uint16 = T.tvm_struct_get(
+            dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeLanes), "uint16"
+        )
+        bytes_per_element = T.ceildiv(scalar_bits.astype("uint64") * 
lanes.astype("uint64"), 8)
+        elem_offset = byte_offset // bytes_per_element
+        return elem_offset
+
+    gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset")
+    return Call(gvar, call.args)
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
index a40b2af5ef..186fc9fa86 100644
--- a/src/relax/op/tensor/inspect.cc
+++ b/src/relax/op/tensor/inspect.cc
@@ -29,6 +29,8 @@
 #include <tvm/tir/function.h>
 #include <tvm/tir/op.h>
 
+#include <tuple>
+
 namespace tvm {
 namespace relax {
 namespace inspect {
@@ -50,6 +52,42 @@ TensorStructInfo GetTensorArgInfo(const Call& call) {
   return tensor_sinfo.value();
 }
 
+std::tuple<TensorStructInfo, PrimStructInfo> GetTensorArgInfoWithIndex(const 
Call& call) {
+  CHECK_EQ(call->args.size(), 2) << "TypeError: "
+                                 << "Operator " << call->op << " expects two 
arguments, "
+                                 << "but received " << call->args.size()
+                                 << " arguments: " << call->args;
+  const auto& arg = call->args[0];
+  const auto& axis = call->args[1];
+
+  auto tensor_sinfo = arg->struct_info_.as<TensorStructInfoNode>();
+  CHECK(tensor_sinfo) << "TypeError: "
+                      << "Operator " << call->op << " expects arguments 
(tensor, axis), "
+                      << "but the first argument " << arg << " in expression " 
<< call
+                      << " has struct info " << arg->struct_info_;
+
+  auto axis_sinfo = axis->struct_info_.as<PrimStructInfoNode>();
+  CHECK(axis_sinfo) << "TypeError: "
+                    << "Operator " << call->op << " expects arguments (tensor, 
axis), "
+                    << "but the second argument " << arg << " in expression " 
<< call
+                    << " has struct info " << axis->struct_info_;
+
+  auto int_imm_axis = axis_sinfo->value.as<IntImmNode>();
+
+  if (int_imm_axis) {
+    CHECK_GE(int_imm_axis->value, 0);
+  }
+  if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) {
+    CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim)
+        << "ValueError: "
+        << "Expression " << call << " attempts to access " << arg << ".shape["
+        << int_imm_axis->value << "]"
+        << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " 
elements";
+  }
+
+  return {GetRef<TensorStructInfo>(tensor_sinfo), 
GetRef<PrimStructInfo>(axis_sinfo)};
+}
+
 DataType GetTensorDataType(const Call& call) { return 
GetTensorArgInfo(call)->dtype; }
 
 tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, 
DataType field_dtype) {
@@ -244,39 +282,11 @@ Expr tensor_shape_i(Expr expr) {
 StructInfo InferStructInfoTensorShape(const Call& call, const BlockBuilder&) {
   auto dlpack_type = DataType::Int(64);
 
-  CHECK_EQ(call->args.size(), 2) << "TypeError: "
-                                 << "Operator " << call->op << " expects two 
arguments, "
-                                 << "but received " << call->args.size()
-                                 << " arguments: " << call->args;
-  const auto& arg = call->args[0];
-  const auto& axis = call->args[1];
-
-  auto tensor_sinfo = arg->struct_info_.as<TensorStructInfoNode>();
-  CHECK(tensor_sinfo) << "TypeError: "
-                      << "Operator " << call->op << " expects arguments 
(tensor, axis), "
-                      << "but the first argument " << arg << " in expression " 
<< call
-                      << " has struct info " << arg->struct_info_;
-
-  auto axis_sinfo = axis->struct_info_.as<PrimStructInfoNode>();
-  CHECK(axis_sinfo) << "TypeError: "
-                    << "Operator " << call->op << " expects arguments (tensor, 
axis), "
-                    << "but the second argument " << arg << " in expression " 
<< call
-                    << " has struct info " << axis->struct_info_;
+  auto [tensor_sinfo, axis_sinfo] = GetTensorArgInfoWithIndex(call);
 
+  auto tensor_shape = tensor_sinfo->GetShape();
   auto int_imm_axis = axis_sinfo->value.as<IntImmNode>();
 
-  if (int_imm_axis) {
-    CHECK_GE(int_imm_axis->value, 0);
-  }
-  if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) {
-    CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim)
-        << "ValueError: "
-        << "Expression " << call << " attempts to access " << arg << ".shape["
-        << int_imm_axis->value << "]"
-        << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " 
elements";
-  }
-
-  auto tensor_shape = tensor_sinfo->GetShape();
   if (int_imm_axis && tensor_shape.defined()) {
     return PrimStructInfo(tensor_shape.value()[int_imm_axis->value]);
   } else {
@@ -346,6 +356,116 @@ TVM_REGISTER_OP("relax.inspect.tensor_shape_i")
     .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
     .set_attr<Bool>("FPurity", Bool(true));
 
+//// relax.tensor_stride_i
+
+Expr tensor_stride_i(Expr expr) {
+  static const Op& op = Op::Get("relax.inspect.tensor_stride_i");
+  return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorStride(const Call& call, const BlockBuilder&) {
+  auto dlpack_type = DataType::Int(64);
+
+  auto [tensor_sinfo, axis_sinfo] = GetTensorArgInfoWithIndex(call);
+
+  auto opt_tensor_shape = tensor_sinfo->GetShape();
+  auto int_imm_axis = axis_sinfo->value.as<IntImmNode>();
+
+  if (int_imm_axis && opt_tensor_shape.defined()) {
+    // As of 2024-03-14, Relax does not have an explicit
+    // representation for striding in `TensorStructInfo`.  The
+    // `FLegalize` function for most operators is implemented in terms
+    // of `topi`, and is then converted from TE to `tir::PrimFunc`
+    // using `tvm::tir::CreatePrimFunc`.  The `te::Tensor` is
+    // converted to a `tir::Buffer` in `RewriteStageToBlock`, and uses
+    // the default empty list for the strides.  The empty strides
+    // represent a compact data array.
+    //
+    // Therefore, while Relax does not explicitly represent the
+    // striding of a tensor, it implicitly requires compact striding
+    // for any legalizable Tensor.
+    auto tensor_shape = opt_tensor_shape.value();
+    PrimExpr stride = IntImm(DataType::Int(64), 1);
+    for (size_t axis = int_imm_axis->value + 1; axis < tensor_shape.size(); 
axis++) {
+      stride = stride * tensor_shape[axis];
+    }
+    return PrimStructInfo(stride);
+  } else {
+    return PrimStructInfo(dlpack_type);
+  }
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_stride_i")
+    .set_num_inputs(2)
+    .add_argument("tensor", "Tensor", "The tensor to be inspected")
+    .add_argument("axis", "Prim(int64)", "The axis whose extent should be 
returned")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoTensorStride)
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_byte_offset
+
+Expr tensor_byte_offset(Expr expr) {
+  static const Op& op = Op::Get("relax.inspect.tensor_byte_offset");
+  return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorByteOffset(const Call& call, const 
BlockBuilder&) {
+  auto dlpack_type = DataType::UInt(64);
+
+  auto tensor_sinfo = GetTensorArgInfo(call);
+
+  auto opt_tensor_shape = tensor_sinfo->GetShape();
+  if (opt_tensor_shape.defined()) {
+    // Relax implicitly requires that the byte offset is zero for any
+    // legalizable tensor.  See InferStructInfoTensorStride for full
+    // explanation.
+    return PrimStructInfo(IntImm(dlpack_type, 0));
+  } else {
+    return PrimStructInfo(dlpack_type);
+  }
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_byte_offset")
+    .set_num_inputs(1)
+    .add_argument("tensor", "Tensor", "The tensor to be inspected")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoTensorByteOffset)
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_elem_offset
+
+Expr tensor_elem_offset(Expr expr) {
+  static const Op& op = Op::Get("relax.inspect.tensor_elem_offset");
+  return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorElemOffset(const Call& call, const 
BlockBuilder&) {
+  auto dlpack_type = DataType::UInt(64);
+
+  auto tensor_sinfo = GetTensorArgInfo(call);
+
+  auto opt_tensor_shape = tensor_sinfo->GetShape();
+  if (opt_tensor_shape.defined()) {
+    // Relax implicitly requires that the element offset is zero for
+    // any legalizable tensor.  See InferStructInfoTensorStride for
+    // full explanation.
+    return PrimStructInfo(IntImm(dlpack_type, 0));
+  } else {
+    return PrimStructInfo(dlpack_type);
+  }
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_elem_offset")
+    .set_num_inputs(1)
+    .add_argument("tensor", "Tensor", "The tensor to be inspected")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoTensorElemOffset)
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace inspect
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/inspect.h b/src/relax/op/tensor/inspect.h
index 0225b00fb3..2aa20a1381 100644
--- a/src/relax/op/tensor/inspect.h
+++ b/src/relax/op/tensor/inspect.h
@@ -85,6 +85,45 @@ Expr tensor_ndim(Expr expr);
  */
 Expr tensor_shape_i(Expr expr, Expr axis);
 
+/* \brief Return the DLTensor::strides[i] field
+ *
+ * The `int64_t* DLTensor::strides` is allowed to be NULL, which
+ * represents a compact packing of the data.  In this case, the
+ * returned stride is computed from the `DLTensor::shape`.
+ *
+ * \param expr The relax expression to be inspected.  Must have
+ * `TensorStructInfo`.
+ *
+ * \param axis The axis to inspect.  Must be within the range `0 <=
+ *     axis < tensor_ndim(expr)`, or else the results are undefined.
+ *
+ * \returns The int64_t extent of the specified tensor axis, with
+ * `PrimStructInfo(DataType::Int(64))`.
+ */
+Expr tensor_stride_i(Expr expr, Expr axis);
+
+/* \brief Return the DLTensor::byte_offset field
+ *
+ * \param expr The relax expression to be inspected.  Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint64_t byte offset, with 
`PrimStructInfo(DataType::UInt(64))`.
+ */
+Expr tensor_byte_offset(Expr expr);
+
+/* \brief Return the element offset of a DLTensor
+ *
+ * While the DLTensor does not directly contain the element offset, it
+ * can be inferred from the `DLTensor::byte_offset` and
+ * `DLTensor::data_type` fields.
+ *
+ * \param expr The relax expression to be inspected.  Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint64_t element offset, with 
`PrimStructInfo(DataType::UInt(64))`.
+ */
+Expr tensor_elem_offset(Expr expr);
+
 }  // namespace inspect
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/tir/transforms/lower_tvm_builtin.cc 
b/src/tir/transforms/lower_tvm_builtin.cc
index 6da2f873b7..1a3888a7cd 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -38,6 +38,19 @@ namespace tir {
 // These information are needed during codegen.
 class BuiltinLower : public StmtExprMutator {
  public:
+  static PrimFunc Build(PrimFunc func) {
+    Optional<PrimExpr> device_type = NullOpt;
+    if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
+      device_type = Integer(target.value()->kind->default_device_type);
+    }
+
+    BuiltinLower mutator(device_type);
+    func.CopyOnWrite()->body = mutator.VisitBodyAndRealizeAlloca(func->body);
+    return func;
+  }
+
+  explicit BuiltinLower(Optional<PrimExpr> device_type = NullOpt) : 
device_type_(device_type) {}
+
   // NOTE: Right now, we make the following scoping requirement
   // for memory allocated by the following primitives
   // - tvm_stack_make_array
@@ -284,13 +297,17 @@ class BuiltinLower : public StmtExprMutator {
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::device_id) {
-      ICHECK(!device_id_);
+      auto cache = device_id_;
       device_id_ = op->value;
-      return this->VisitStmt(op->body);
+      Stmt out = this->VisitStmt(op->body);
+      device_id_ = cache;
+      return out;
     } else if (op->attr_key == attr::device_type) {
-      ICHECK(!device_type_);
+      auto cache = device_type_;
       device_type_ = op->value;
-      return this->VisitStmt(op->body);
+      Stmt out = this->VisitStmt(op->body);
+      device_type_ = cache;
+      return out;
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
@@ -656,13 +673,12 @@ class BuiltinLower : public StmtExprMutator {
 namespace transform {
 
 Pass LowerTVMBuiltin() {
-  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    if (IsHostFunc(f).value_or(false)) {
-      auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-      f.CopyOnWrite()->body = BuiltinLower().Build(f->body);
-      VLOG(2) << "LowerTVMBuiltin: " << f;
+  auto pass_func = [](PrimFunc func, IRModule m, PassContext ctx) {
+    if (IsHostFunc(func).value_or(false)) {
+      func = BuiltinLower::Build(func);
+      VLOG(2) << "LowerTVMBuiltin: " << func;
     }
-    return f;
+    return func;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
 }
diff --git a/tests/python/relax/test_op_inspect.py 
b/tests/python/relax/test_op_inspect.py
new file mode 100644
index 0000000000..18d7a88f05
--- /dev/null
+++ b/tests/python/relax/test_op_inspect.py
@@ -0,0 +1,252 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ctypes
+
+import numpy as np
+import pytest
+
+import tvm.testing
+
+from tvm import relax
+from tvm.ir import Op
+from tvm.script import ir as I, relax as R
+
+# Parameterization for reading dtype of DLTensor.  Chosen to have
+# multiple distinct type codes, number of lanes, and widths.
+dtype = tvm.testing.parameter(
+    "int32",
+    "int64",
+    "float32",
+    "float32x4",
+    "bfloat",
+    "e4m3_float8",
+)
+shape = tvm.testing.parameter(
+    [],
+    [16],
+    [128, 256],
+    [1] * 64,
+)
+
+elem_offset = tvm.testing.parameter(0, 64, 128)
+
+
+def test_tensor_dtype_code(dtype):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.dtype.type_code
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty([16], dtype)
+    res = vm["main"](arg)
+
+    expected_type_code = tvm.runtime.DataType(dtype).type_code
+    assert res == expected_type_code
+
+
+def test_tensor_dtype_bits(dtype):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.dtype.bits
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty([16], dtype)
+    res = vm["main"](arg)
+
+    expected_type_bits = tvm.runtime.DataType(dtype).bits
+    assert res == expected_type_bits
+
+
+def test_tensor_dtype_lanes(dtype):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.dtype.lanes
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty([16], dtype)
+    res = vm["main"](arg)
+
+    expected_type_lanes = tvm.runtime.DataType(dtype).lanes
+    assert res == expected_type_lanes
+
+
+def test_tensor_ndim(shape):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.ndim
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty(shape, "int32")
+    res = vm["main"](arg)
+
+    assert res == len(shape)
+
+
+def test_tensor_shape(shape):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor, axis: R.Prim("int64")):
+            return A.shape[axis]
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty(shape, "int32")
+
+    res = [vm["main"](arg, i) for i, _ in enumerate(shape)]
+
+    tvm.ir.assert_structural_equal(res, shape)
+
+
+def _get_compact_striding(shape):
+    strides = []
+    product = 1
+    for dim in reversed(shape):
+        strides.append(product)
+        product *= dim
+    return list(reversed(strides))
+
+
+def test_strides_of_compact_tensor(shape):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor, axis: R.Prim("int64")):
+            return A.strides[axis]
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty(shape, "int32")
+
+    res = [vm["main"](arg, i) for i, _ in enumerate(shape)]
+    expected = _get_compact_striding(shape)
+
+    tvm.ir.assert_structural_equal(res, expected)
+
+
+def test_strides_of_non_compact_tensor():
+    backing_shape = [64, 64]
+    view_shape = [16, 16]
+    expected_strides = [backing_shape[0], 1]
+
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor, axis: R.Prim("int64")):
+            return A.strides[axis]
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    backing_ndarray = tvm.nd.empty(backing_shape, "int32")
+
+    # Manually overwrite the DLTensor fields to make a view into the
+    # tensor.
+    view = backing_ndarray.handle[0]
+    np_shape = np.array([16, 16], "int64")
+    view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
+    np_strides = np.array([64, 1], "int64")
+    view.strides = np_strides.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
+    backing_ndarray.handle[0] = view
+
+    res = [vm["main"](backing_ndarray, i) for i, _ in enumerate(view_shape)]
+
+    tvm.ir.assert_structural_equal(res, expected_strides)
+
+
+def test_byte_offset(elem_offset):
+    backing_shape = [64, 64]
+    view_shape = [16, 16]
+    byte_offset = elem_offset * 4
+
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.byte_offset
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    backing_ndarray = tvm.nd.empty(backing_shape, "int32")
+
+    # Manually overwrite the DLTensor fields to make a view into the
+    # tensor.
+    view = backing_ndarray.handle[0]
+    np_shape = np.array(view_shape, "int64")
+    view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
+    view.byte_offset = byte_offset
+    backing_ndarray.handle[0] = view
+
+    res = vm["main"](backing_ndarray)
+
+    assert res == byte_offset
+
+
+def test_elem_offset(elem_offset, dtype):
+    tvm_dtype = tvm.runtime.DataType(dtype)
+
+    backing_shape = [64, 64]
+    view_shape = [16, 16]
+    element_bytes = (tvm_dtype.bits * tvm_dtype.lanes) // 8
+    byte_offset = elem_offset * element_bytes
+
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.elem_offset
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    backing_ndarray = tvm.nd.empty(backing_shape, dtype)
+
+    # Manually overwrite the DLTensor fields to make a view into the
+    # tensor.
+    view = backing_ndarray.handle[0]
+    np_shape = np.array(view_shape, "int64")
+    view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long))
+    view.byte_offset = byte_offset
+    backing_ndarray.handle[0] = view
+
+    res = vm["main"](backing_ndarray)
+
+    assert res == elem_offset
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_op_unpack.py 
b/tests/python/relax/test_op_unpack.py
deleted file mode 100644
index 03e4e0fc85..0000000000
--- a/tests/python/relax/test_op_unpack.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import tvm.testing
-
-from tvm import relax
-from tvm.ir import Op
-from tvm.script import ir as I, relax as R
-
-# Parameterization for reading dtype of DLTensor.  Chosen to have
-# multiple distinct type codes, number of lanes, and widths.
-dtype = tvm.testing.parameter(
-    "int32",
-    "int64",
-    "float32",
-    "float32x4",
-    "bfloat",
-    "e4m3_float8",
-)
-shape = tvm.testing.parameter(
-    [],
-    [16],
-    [128, 256],
-    [1] * 64,
-)
-
-
-def test_tensor_dtype_code(dtype):
-    @I.ir_module
-    class mod:
-        @R.function
-        def main(A: R.Tensor):
-            return A.dtype.type_code
-
-    built = relax.build(mod)
-    vm = relax.VirtualMachine(built, tvm.cpu())
-
-    arg = tvm.nd.empty([16], dtype)
-    res = vm["main"](arg)
-
-    expected_type_code = tvm.runtime.DataType(dtype).type_code
-    assert res == expected_type_code
-
-
-def test_tensor_dtype_bits(dtype):
-    @I.ir_module
-    class mod:
-        @R.function
-        def main(A: R.Tensor):
-            return A.dtype.bits
-
-    built = relax.build(mod)
-    vm = relax.VirtualMachine(built, tvm.cpu())
-
-    arg = tvm.nd.empty([16], dtype)
-    res = vm["main"](arg)
-
-    expected_type_bits = tvm.runtime.DataType(dtype).bits
-    assert res == expected_type_bits
-
-
-def test_tensor_dtype_lanes(dtype):
-    @I.ir_module
-    class mod:
-        @R.function
-        def main(A: R.Tensor):
-            return A.dtype.lanes
-
-    built = relax.build(mod)
-    vm = relax.VirtualMachine(built, tvm.cpu())
-
-    arg = tvm.nd.empty([16], dtype)
-    res = vm["main"](arg)
-
-    expected_type_lanes = tvm.runtime.DataType(dtype).lanes
-    assert res == expected_type_lanes
-
-
-def test_tensor_ndim(shape):
-    @I.ir_module
-    class mod:
-        @R.function
-        def main(A: R.Tensor):
-            return A.ndim
-
-    built = relax.build(mod)
-    vm = relax.VirtualMachine(built, tvm.cpu())
-
-    arg = tvm.nd.empty(shape, "int32")
-    res = vm["main"](arg)
-
-    assert res == len(shape)
-
-
-def test_tensor_shape(shape):
-    @I.ir_module
-    class mod:
-        @R.function
-        def main(A: R.Tensor, axis: R.Prim("int64")):
-            return A.shape[axis]
-
-    built = relax.build(mod)
-    vm = relax.VirtualMachine(built, tvm.cpu())
-
-    arg = tvm.nd.empty(shape, "int32")
-
-    res = [vm["main"](arg, i) for i, _ in enumerate(shape)]
-
-    tvm.ir.assert_structural_equal(res, shape)
-
-
-if __name__ == "__main__":
-    tvm.testing.main()
diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py 
b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
index de1020ef20..754ce03240 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
@@ -260,11 +260,13 @@ class 
TestLowerCPUAllocation(tvm.testing.CompareBeforeAfter):
 
 
 class TestLowerAllocateRequiresDeviceID(tvm.testing.CompareBeforeAfter):
+    """If device id is missing, error."""
+
     transform = tvm.tir.transform.LowerTVMBuiltin()
 
     def before():
         T.func_attr({"target": T.target("llvm")})
-        T.attr("dummy", "device_id", 0)
+        T.attr("dummy", "device_type", 2)  # kDLCuda
         ptr = T.allocate([16], "float32")
         buf = T.decl_buffer(16, "float32", data=ptr)
         buf[0] = 0.0
@@ -273,16 +275,45 @@ class 
TestLowerAllocateRequiresDeviceID(tvm.testing.CompareBeforeAfter):
 
 
 class TestLowerAllocateRequiresDeviceType(tvm.testing.CompareBeforeAfter):
+    """If device type is missing, error.
+
+    The device type can be inferred either from the `"device_type"`
+    statement attribute, or from the `"target"` function attribute.
+    Here, we provide neither.  The `"tir.is_host_func"` attribute is
+    provided as otherwise the function would be skipped altogether by
+    LowerTVMBuiltin.
+    """
+
     transform = tvm.tir.transform.LowerTVMBuiltin()
 
     def before():
-        T.func_attr({"target": T.target("llvm")})
+        T.func_attr({"tir.is_host_func": True})
         T.attr("dummy", "device_id", 0)
+        ptr = T.allocate([1024 * 1024], "float32")
+        buf = T.decl_buffer(1024 * 1024, "float32", data=ptr)
+        buf[0] = 0.0
+
+    expected = tvm.TVMError
+
+
+class TestLowerCPUAllocWithFunctionAttr(tvm.testing.CompareBeforeAfter):
+    """CPU allocations can be handled at codegen time
+
+    Like `TestLowerCPUAllocation`, but the device type is taken from
+    the function attribute.  The `AttrStmt` can override the device
+    type for allocations within its scope, but it defaults to the
+    function's target.
+    """
+
+    transform = tvm.tir.transform.LowerTVMBuiltin()
+
+    def before():
+        T.func_attr({"target": T.target("llvm")})
         ptr = T.allocate([16], "float32")
         buf = T.decl_buffer(16, "float32", data=ptr)
         buf[0] = 0.0
 
-    expected = tvm.TVMError
+    expected = before
 
 
 if __name__ == "__main__":


Reply via email to