This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new cc79591 [Relay][Op]Support symbolic TopK, Ones, Zeros and Full (#5459)
cc79591 is described below
commit cc79591f5e0f05955fd3180b1975e4344b532345
Author: Yao Wang <[email protected]>
AuthorDate: Mon May 25 18:09:44 2020 -0700
[Relay][Op]Support symbolic TopK, Ones, Zeros and Full (#5459)
* Support symbolic TopK, Ones, Zeros and Full
* Fix pylint
* Add docstring for topk shape func
* Fix grad
* Fix lazy_gradient_init
* Fix parser
* Fix print ir text
* Fix lint
* Improve pattern_util
* Fix topk
* Fix build
* Use Optional for attribute
* Fix clang-format
* Minot fix
* Fix pylint
* Fix build warning
* Fix parser
* Move ToScalar
* Fix lint
* Fix lint
* Make topk shape func as data independent when k is constant.
* Fix lint
* Minor fix
---
include/tvm/relay/attrs/algorithm.h | 5 +-
include/tvm/relay/attrs/transform.h | 2 +-
include/tvm/runtime/ndarray.h | 6 +-
python/tvm/relay/_parser.py | 2 +
python/tvm/relay/op/_algorithm.py | 68 ++++++++++++
python/tvm/relay/op/_tensor.py | 41 ++++----
python/tvm/relay/op/_tensor_grad.py | 8 +-
python/tvm/relay/op/_transform.py | 2 +
python/tvm/relay/op/algorithm.py | 9 +-
python/tvm/relay/op/strategy/generic.py | 4 +-
python/tvm/relay/op/tensor.py | 10 +-
python/tvm/relay/op/transform.py | 8 +-
src/relay/analysis/util.cc | 8 ++
src/relay/op/algorithm/topk.cc | 32 ++++--
src/relay/op/image/resize.cc | 4 +-
src/relay/op/tensor/transform.cc | 161 +++++++++++++++--------------
src/relay/op/tensor/transform.h | 43 ++++----
src/relay/qnn/util.cc | 4 +-
src/relay/transforms/lazy_gradient_init.cc | 8 +-
src/relay/transforms/pattern_util.h | 103 ++++++++++++++++--
tests/python/relay/test_any.py | 83 ++++++++++++---
topi/python/topi/sort.py | 10 +-
22 files changed, 435 insertions(+), 186 deletions(-)
diff --git a/include/tvm/relay/attrs/algorithm.h
b/include/tvm/relay/attrs/algorithm.h
index a7d4708..83b4dda 100644
--- a/include/tvm/relay/attrs/algorithm.h
+++ b/include/tvm/relay/attrs/algorithm.h
@@ -26,6 +26,7 @@
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
+#include <tvm/relay/expr.h>
#include <string>
@@ -52,14 +53,14 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
};
struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
- int k;
+ Optional<Integer> k;
int axis;
bool is_ascend;
std::string ret_type;
DataType dtype;
TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
- TVM_ATTR_FIELD(k).set_default(1).describe("Number of top elements to
select");
+ TVM_ATTR_FIELD(k).describe("Number of top elements to select");
TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort
the input tensor.");
TVM_ATTR_FIELD(ret_type).set_default("both").describe(
"The return type [both, values, indices]."
diff --git a/include/tvm/relay/attrs/transform.h
b/include/tvm/relay/attrs/transform.h
index 7fb7f3a..ccf8e54 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -111,7 +111,7 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
/*! \brief Attributes that specify a tensor */
struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
- Array<IndexExpr> shape;
+ Optional<Array<Integer>> shape;
DataType dtype;
TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") {
diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h
index 0171d8a..e69d802 100644
--- a/include/tvm/runtime/ndarray.h
+++ b/include/tvm/runtime/ndarray.h
@@ -462,7 +462,11 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file
format";
- CHECK(strm->Read(ret->data, data_byte_size)) << "Invalid DLTensor file
format";
+ auto read_ret = strm->Read(ret->data, data_byte_size);
+ // Only check non-empty data
+ if (ndim > 0 && shape[0] != 0) {
+ CHECK(read_ret) << "Invalid DLTensor file format";
+ }
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py
index 1d97b55..49f2d4d 100644
--- a/python/tvm/relay/_parser.py
+++ b/python/tvm/relay/_parser.py
@@ -116,6 +116,8 @@ class FuncOp(OpWrapper):
attrs = {}
if self.operator is op.reshape:
x = self.operator(*args)
+ elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
+ x = self.operator(*args, dtype=attrs["dtype"])
else:
x = self.operator(*args, **{k: self.convert(v) for k, v in
attrs.items()})
if isinstance(x, expr.TupleWrapper):
diff --git a/python/tvm/relay/op/_algorithm.py
b/python/tvm/relay/op/_algorithm.py
index e1e6fd3..5a20480 100644
--- a/python/tvm/relay/op/_algorithm.py
+++ b/python/tvm/relay/op/_algorithm.py
@@ -18,7 +18,11 @@
# pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import
+from tvm.te.hybrid import script
+from tvm.runtime import convert
+
from . import strategy
+from . import op as _reg
from .op import OpPattern, register_pattern
from .op import register_strategy
@@ -29,3 +33,67 @@ register_pattern("argsort", OpPattern.OPAQUE)
# topk
register_strategy("topk", strategy.topk_strategy)
register_pattern("topk", OpPattern.OPAQUE)
+
+@script
+def _topk_shape_func_input_data(data, k, axis):
+ ndim = len(data.shape)
+ val_out = output_tensor((ndim,), "int64")
+ indices_out = output_tensor((ndim,), "int64")
+
+ for i in const_range(ndim):
+ if i != axis:
+ val_out[i] = int64(data.shape[i])
+ indices_out[i] = int64(data.shape[i])
+ else:
+ if k[0] < 1:
+ val_out[i] = int64(data.shape[i])
+ indices_out[i] = int64(data.shape[i])
+ else:
+ val_out[i] = int64(k[0])
+ indices_out[i] = int64(k[0])
+ return val_out, indices_out
+
+@script
+def _topk_shape_func_input_shape(data_shape, k, axis):
+ ndim = data_shape.shape[0]
+ val_out = output_tensor((ndim,), "int64")
+ indices_out = output_tensor((ndim,), "int64")
+
+ for i in const_range(ndim):
+ if i != axis:
+ val_out[i] = int64(data_shape[i])
+ indices_out[i] = int64(data_shape[i])
+ else:
+ if k < 1:
+ val_out[i] = int64(data_shape[i])
+ indices_out[i] = int64(data_shape[i])
+ else:
+ val_out[i] = int64(k)
+ indices_out[i] = int64(k)
+ return val_out, indices_out
+
+@_reg.register_shape_func("topk", True)
+def topk_shape_func(attrs, inputs, _):
+ """
+ Shape func for topk.
+ """
+ axis = attrs.axis
+ if attrs.k is not None:
+ if axis < 0:
+ axis += inputs[0].shape[0]
+ val_out, indices_out = \
+ _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
+ else:
+ if axis < 0:
+ axis += len(inputs[0].shape)
+ val_out, indices_out = \
+ _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
+ ret_type = attrs.ret_type
+ if ret_type == "both":
+ ret = [val_out, indices_out]
+ elif ret_type == "values":
+ ret = [val_out]
+ else:
+ ret = [indices_out]
+
+ return ret
diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py
index e029e0c..cd9e4ed 100644
--- a/python/tvm/relay/op/_tensor.py
+++ b/python/tvm/relay/op/_tensor.py
@@ -17,10 +17,9 @@
#pylint: disable=invalid-name, unused-argument, len-as-condition
"""Backend compiler related feature registration"""
-from tvm.runtime import convert
from tvm.te.hybrid import script
import topi
-from topi.util import get_const_tuple
+
from .op import register_compute, register_shape_func
from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_pattern, OpPattern
@@ -93,7 +92,7 @@ register_broadcast_schedule("fast_erf")
# zeros
@register_compute("zeros")
def zeros_compute(attrs, inputs, output_type):
- assert not inputs
+ assert len(inputs) == 1
return [topi.full(output_type.shape, output_type.dtype, 0.0)]
register_broadcast_schedule("zeros")
@@ -110,7 +109,7 @@ register_broadcast_schedule("zeros_like")
# ones
@register_compute("ones")
def ones_compute(attrs, inputs, output_type):
- assert not inputs
+ assert len(inputs) == 1
return [topi.full(output_type.shape, output_type.dtype, 1.0)]
register_broadcast_schedule("ones")
@@ -132,20 +131,10 @@ def clip_compute(attrs, inputs, output_type):
register_injective_schedule("clip")
-@script
-def _cast_shape_function(x):
- out_ndim = len(x)
- out = output_tensor((out_ndim,), "int64")
- for i in const_range(out_ndim):
- out[i] = x[i]
- return out
-
-def cast_shape_func(attrs, inputs, out_ndims):
- return [_cast_shape_function(*inputs)]
-
+# full
@script
def _full_shape_func(shape):
- out_ndim = len(shape)
+ out_ndim = shape.shape[0]
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = int64(shape[i])
@@ -153,10 +142,15 @@ def _full_shape_func(shape):
def full_shape_func(attrs, inputs, out_ndims):
"""
- Shape func for zeros, zeros_like, ones, ones_like.
+ Shape func for full.
+ """
+ return [_full_shape_func(inputs[1])]
+
+def no_data_full_shape_func(attrs, inputs, out_ndims):
+ """
+ Shape func for zeros and ones.
"""
- shape = get_const_tuple(attrs.shape)
- return [_full_shape_func(convert(shape))]
+ return [_full_shape_func(inputs[0])]
@script
def _broadcast_shape_func(x, y, ndim):
@@ -198,13 +192,14 @@ def elemwise_shape_func(attrs, inputs, _):
"""
return [topi.math.identity(inputs[0])]
-register_shape_func("cast", False, cast_shape_func)
-register_shape_func("zeros", False, full_shape_func)
+register_shape_func("cast", False, elemwise_shape_func)
+register_shape_func("zeros", True, no_data_full_shape_func)
register_shape_func("zeros_like", False, elemwise_shape_func)
-register_shape_func("ones", False, full_shape_func)
+register_shape_func("ones", True, no_data_full_shape_func)
register_shape_func("ones_like", False, elemwise_shape_func)
-register_shape_func("full", False, full_shape_func)
+register_shape_func("full", True, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)
+register_shape_func("broadcast_to", True, full_shape_func)
register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
diff --git a/python/tvm/relay/op/_tensor_grad.py
b/python/tvm/relay/op/_tensor_grad.py
index 8be3358..8ba1020 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -232,14 +232,14 @@ def divide_grad(orig, grad):
@register_gradient("zeros")
def zeros_grad(orig, grad):
- """Returns []"""
- return []
+ """Returns [shape]"""
+ return [orig.args[0]]
@register_gradient("ones")
def ones_grad(orig, grad):
- """Returns []"""
- return []
+ """Returns [shape]"""
+ return [orig.args[0]]
@register_gradient("zeros_like")
diff --git a/python/tvm/relay/op/_transform.py
b/python/tvm/relay/op/_transform.py
index 43d8d62..e1c2bd7 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -120,6 +120,8 @@ def _concatenate_shape_func(inputs, axis):
@_reg.register_shape_func("concatenate", False)
def concatenate_shape_func(attrs, inputs, _):
axis = get_const_int(attrs.axis)
+ if axis < 0:
+ axis += inputs[0].shape[0]
return [_concatenate_shape_func(inputs, convert(axis))]
@script
diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py
index 17fab80..d31e89a 100644
--- a/python/tvm/relay/op/algorithm.py
+++ b/python/tvm/relay/op/algorithm.py
@@ -17,7 +17,7 @@
"""Classic algorithm operation"""
from __future__ import absolute_import as _abs
from . import _make
-from ..expr import TupleWrapper
+from ..expr import TupleWrapper, const
def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies
@@ -48,7 +48,8 @@ def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
return _make.argsort(data, axis, is_ascend, dtype)
-def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
+def topk(data, k=1, axis=-1, ret_type="both",
+ is_ascend=False, dtype="int32"):
"""Get the top k elements in an input tensor along the given axis.
ret_type specifies the return type, can be one of ("both", "values",
"indices").
@@ -58,7 +59,7 @@ def topk(data, k=1, axis=-1, ret_type="both",
is_ascend=False, dtype="int32"):
data : relay.Expr
The input data tensor.
- k : int, optional
+ k : int or relay.Expr, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
@@ -81,6 +82,8 @@ def topk(data, k=1, axis=-1, ret_type="both",
is_ascend=False, dtype="int32"):
out : relay.Expr or List[relay.Expr]
The computed result.
"""
+ if isinstance(k, int):
+ k = const(k, "int64")
out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if ret_type == "both":
return TupleWrapper(out, 2)
diff --git a/python/tvm/relay/op/strategy/generic.py
b/python/tvm/relay/op/strategy/generic.py
index 6db5b14..99439af 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -598,7 +598,9 @@ def argsort_strategy(attrs, inputs, out_type, target):
def wrap_compute_topk(topi_compute):
"""Wrap topk compute"""
def _compute_topk(attrs, inputs, out_type):
- k = get_const_int(attrs.k)
+ k = inputs[1]
+ if attrs.k is not None:
+ k = attrs.k
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index d5ae5cd..c60dbee 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -20,7 +20,7 @@ from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from . import _make
-from ..expr import Tuple
+from ..expr import Tuple, const
# We create a wrapper function for each operator in the
@@ -928,7 +928,7 @@ def zeros(shape, dtype):
Parameters
----------
- shape : tuple of int
+ shape : tuple of int or relay.Expr
The shape of the target.
dtype : data type
@@ -939,6 +939,8 @@ def zeros(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
+ if isinstance(shape, (list, tuple)):
+ shape = const(list(shape), "int32")
return _make.zeros(shape, dtype)
@@ -963,7 +965,7 @@ def ones(shape, dtype):
Parameters
----------
- shape : tuple of int
+ shape : tuple of int or relay.Expr
The shape of the target.
dtype : data type
@@ -974,6 +976,8 @@ def ones(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
+ if isinstance(shape, (list, tuple)):
+ shape = const(list(shape), "int32")
return _make.ones(shape, dtype)
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 2d9e4ba..1da58ae 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -299,7 +299,7 @@ def full(fill_value, shape=(), dtype=""):
fill_value : relay.Expr
The value to fill. Must be a scalar.
- shape : tuple of int
+ shape : tuple of int or relay.Expr
The shape of the target.
dtype : data type, optional (defaults to data type of the fill value)
@@ -310,6 +310,8 @@ def full(fill_value, shape=(), dtype=""):
result : relay.Expr
The resulting tensor.
"""
+ if isinstance(shape, (list, tuple)):
+ shape = const(list(shape), "int32")
return _make.full(fill_value, shape, dtype)
@@ -527,7 +529,7 @@ def broadcast_to(data, shape):
data : relay.Expr
The input tensor.
- shape : shape
+ shape : tuple of int or relay.Expr
Provide the shape to broadcast to.
Returns
@@ -535,6 +537,8 @@ def broadcast_to(data, shape):
result : relay.Expr
The resulting tensor.
"""
+ if isinstance(shape, (list, tuple)):
+ shape = const(list(shape), "int32")
return _make.broadcast_to(data, shape)
def broadcast_to_like(data, broadcast_type):
diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc
index a05bb8f..2853165 100644
--- a/src/relay/analysis/util.cc
+++ b/src/relay/analysis/util.cc
@@ -25,6 +25,7 @@
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/algorithm.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
@@ -450,6 +451,13 @@ bool IsDataDependant(const CallNode* call) {
return false;
}
}
+ } else if (op->name == "topk") {
+ if (const auto* attrs = call->attrs.as<TopKAttrs>()) {
+ if (attrs->k) {
+ // If k attribute exists, it isn't data dependant.
+ return false;
+ }
+ }
}
return tshape_data_dependant[op];
diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc
index 5ff5904..3db8eee 100644
--- a/src/relay/op/algorithm/topk.cc
+++ b/src/relay/op/algorithm/topk.cc
@@ -23,9 +23,11 @@
*/
#include <tvm/relay/attrs/algorithm.h>
#include <tvm/relay/op.h>
+#include <tvm/tir/op.h>
namespace tvm {
namespace relay {
+using tir::make_const;
TVM_REGISTER_NODE_TYPE(TopKAttrs);
@@ -33,7 +35,7 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const
Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
- CHECK_EQ(types.size(), 2);
+ CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data);
int ndim = data->shape.size();
@@ -44,35 +46,44 @@ bool TopKRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
CHECK(axis >= 0 && axis < ndim);
Array<IndexExpr> out_shape;
for (int i = 0; i < ndim; ++i) {
- if (i != axis || param->k < 1) {
+ if (i != axis) {
out_shape.push_back(data->shape[i]);
+ } else if (param->k) {
+ const Integer& ck = param->k.value();
+ if (ck->value < 1) {
+ out_shape.push_back(data->shape[i]);
+ } else {
+ out_shape.push_back(ck);
+ }
} else {
- out_shape.push_back(param->k);
+ out_shape.push_back(Any::make());
}
}
auto values_ty = TensorType(out_shape, data->dtype);
auto indices_ty = TensorType(out_shape, param->dtype);
if (param->ret_type == "both") {
- reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
+ reporter->Assign(types[2], TupleType({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
- reporter->Assign(types[1], values_ty);
+ reporter->Assign(types[2], values_ty);
} else if (param->ret_type == "indices") {
- reporter->Assign(types[1], indices_ty);
+ reporter->Assign(types[2], indices_ty);
} else {
LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
}
return true;
}
-Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend,
DataType dtype) {
+Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend,
DataType dtype) {
auto attrs = make_object<TopKAttrs>();
- attrs->k = k;
+ if (const auto& ck = k.as<ConstantNode>()) {
+ attrs->k = tvm::Integer(reinterpret_cast<int*>(ck->data->data)[0]);
+ }
attrs->axis = axis;
attrs->ret_type = ret_type;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("topk");
- return Call(op, {data}, Attrs(attrs), {});
+ return Call(op, {data, k}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
@@ -80,9 +91,10 @@
TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
RELAY_REGISTER_OP("topk")
.describe(R"doc(Get the top k elements in an input tensor along the given
axis.
)doc" TVM_ADD_FILELINE)
- .set_num_inputs(1)
+ .set_num_inputs(2)
.set_attrs_type<TopKAttrs>()
.add_argument("data", "Tensor", "Input data.")
+ .add_argument("k", "Tensor", "Number of top elements.")
.set_support_level(6)
.add_type_rel("TopK", TopKRel);
diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc
index 7bddb29..b6d2c71 100644
--- a/src/relay/op/image/resize.cc
+++ b/src/relay/op/image/resize.cc
@@ -194,12 +194,12 @@ bool CropAndResizeRel(const Array<Type>& types, int
num_inputs, const Attrs& att
const Layout in_layout(param->layout);
auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
auto oshape = layout_converter.ForwardShape(data->shape);
- oshape.Set(0, box_indices->shape[0]);
+ oshape.Set(0, boxes->shape[0]);
oshape.Set(2, crop_size[0]);
oshape.Set(3, crop_size[1]);
auto bshape = layout_converter.BackwardShape(oshape);
// assign output type
- reporter->Assign(types[3],
TensorType(layout_converter.BackwardShape(oshape), out_dtype));
+ reporter->Assign(types[3], TensorType(bshape, out_dtype));
return true;
}
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 6ccf585..7282ac7 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -447,44 +447,6 @@ RELAY_REGISTER_OP("transpose")
/* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
-double ToScalar(const runtime::NDArray& array, int i = 0) {
- if (array->dtype.code == kDLInt) {
- if (array->dtype.bits == 8) {
- return reinterpret_cast<int8_t*>(array->data)[i];
- } else if (array->dtype.bits == 16) {
- return reinterpret_cast<int16_t*>(array->data)[i];
- } else if (array->dtype.bits == 32) {
- return reinterpret_cast<int32_t*>(array->data)[i];
- } else if (array->dtype.bits == 64) {
- return reinterpret_cast<int64_t*>(array->data)[i];
- }
- } else if (array->dtype.code == kDLUInt) {
- if (array->dtype.bits == 8) {
- return reinterpret_cast<uint8_t*>(array->data)[i];
- } else if (array->dtype.bits == 16) {
- return reinterpret_cast<uint16_t*>(array->data)[i];
- } else if (array->dtype.bits == 32) {
- return reinterpret_cast<uint32_t*>(array->data)[i];
- } else if (array->dtype.bits == 64) {
- return reinterpret_cast<uint64_t*>(array->data)[i];
- }
- } else if (array->dtype.code == kDLFloat) {
-#if (__ARM_FP16_FORMAT_IEEE == 1)
- if (array->dtype.bits == 16) {
- return reinterpret_cast<__fp16*>(array->data)[i];
- }
-#endif
- if (array->dtype.bits == 32) {
- return reinterpret_cast<float*>(array->data)[i];
- } else if (array->dtype.bits == 64) {
- return reinterpret_cast<double*>(array->data)[i];
- }
- }
- LOG(FATAL) << "Unknown data type: " <<
tvm::runtime::DLDataType2String(array->dtype);
- // make compiler happy
- return -std::numeric_limits<double>::infinity();
-}
-
bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
const auto* param = attrs.as<ReshapeAttrs>();
@@ -663,11 +625,7 @@ Expr MakeReshape(Expr data, Expr newshape) {
auto attrs = make_object<ReshapeAttrs>();
if (const ConstantNode* c = newshape.as<ConstantNode>()) {
CHECK_EQ(c->data->ndim, 1);
- Array<Integer> newshape;
- for (int i = 0; i < c->data->shape[0]; i++) {
- newshape.push_back(Integer(static_cast<int>(ToScalar(c->data, i))));
- }
- attrs->newshape = newshape;
+ attrs->newshape = ToVector(c->data);
}
attrs->reverse = false;
static const Op& op = Op::Get("reshape");
@@ -929,9 +887,10 @@ TVM_REGISTER_NODE_TYPE(InitOpAttrs);
bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
- CHECK_EQ(types.size(), 2);
+ CHECK_EQ(types.size(), 3);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
const auto* fill_value = types[0].as<TensorTypeNode>();
+ const auto* fill_shape = types[1].as<TensorTypeNode>();
if (fill_value == nullptr) {
return false;
}
@@ -944,7 +903,21 @@ bool FullRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
CHECK_EQ(fill_value->shape.size(), 0)
<< "Fill value should be a scalar but has dimension " <<
fill_value->shape.size() << ".";
- reporter->Assign(types[1], TensorType(param->shape, out_dtype));
+ const IntImmNode* shape_shape = fill_shape->shape[0].as<IntImmNode>();
+ CHECK(shape_shape) << "Parameter shape must have static shape";
+
+ std::vector<IndexExpr> oshape;
+ if (param->shape) {
+ const Array<Integer>& cshape_array = param->shape.value();
+ for (size_t i = 0; i < cshape_array.size(); ++i) {
+ oshape.push_back(cshape_array[i]);
+ }
+ } else {
+ for (int i = 0; i < shape_shape->value; ++i) {
+ oshape.push_back(Any::make());
+ }
+ }
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
@@ -954,12 +927,14 @@ Array<te::Tensor> FullCompute(const Attrs& attrs, const
Array<te::Tensor>& input
return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
}
-Expr MakeFull(Expr fill_value, Array<IndexExpr> shape, DataType dtype) {
+Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
- attrs->shape = std::move(shape);
+ if (const auto* cshape = shape.as<ConstantNode>()) {
+ attrs->shape = ToVector(cshape->data);
+ }
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("full");
- return Call(op, {fill_value}, Attrs(attrs), {});
+ return Call(op, {fill_value, shape}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull);
@@ -969,8 +944,9 @@ RELAY_REGISTER_OP("full")
)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
- .set_num_inputs(1)
+ .set_num_inputs(2)
.add_argument("fill_value", "double", "The value to fill.")
+ .add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("Full", FullRel)
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
@@ -978,19 +954,37 @@ RELAY_REGISTER_OP("full")
bool InitOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
- CHECK_EQ(types.size(), 1);
+ CHECK_EQ(types.size(), 2);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
+ const auto* fill_shape = types[0].as<TensorTypeNode>();
+ DataType out_dtype = param->dtype;
+
+ const IntImmNode* shape_shape = fill_shape->shape[0].as<IntImmNode>();
+ CHECK(shape_shape) << "Parameter shape must have static shape";
- reporter->Assign(types[0], TensorType(param->shape, param->dtype));
+ std::vector<IndexExpr> oshape;
+ if (param->shape) {
+ const Array<Integer>& cshape_array = param->shape.value();
+ for (size_t i = 0; i < cshape_array.size(); ++i) {
+ oshape.push_back(cshape_array[i]);
+ }
+ } else {
+ for (int i = 0; i < shape_shape->value; ++i) {
+ oshape.push_back(Any::make());
+ }
+ }
+ reporter->Assign(types[1], TensorType(oshape, out_dtype));
return true;
}
-Expr MakeZeros(Array<IndexExpr> shape, DataType dtype) {
+Expr MakeZeros(Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
- attrs->shape = std::move(shape);
+ if (const auto* cshape = shape.as<ConstantNode>()) {
+ attrs->shape = ToVector(cshape->data);
+ }
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("zeros");
- return Call(op, {}, Attrs(attrs), {});
+ return Call(op, {shape}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.zeros").set_body_typed(MakeZeros);
@@ -1000,16 +994,19 @@ RELAY_REGISTER_OP("zeros")
)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
- .set_num_inputs(0)
+ .set_num_inputs(1)
+ .add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);
-Expr MakeOnes(Array<IndexExpr> shape, DataType dtype) {
+Expr MakeOnes(Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
- attrs->shape = std::move(shape);
+ if (const auto* cshape = shape.as<ConstantNode>()) {
+ attrs->shape = ToVector(cshape->data);
+ }
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("ones");
- return Call(op, {}, Attrs(attrs), {});
+ return Call(op, {shape}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.ones").set_body_typed(MakeOnes);
@@ -1019,7 +1016,8 @@ RELAY_REGISTER_OP("ones")
)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
- .set_num_inputs(0)
+ .set_num_inputs(1)
+ .add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);
@@ -1579,30 +1577,42 @@ RELAY_REGISTER_OP("collapse_sum_like")
// BroadCastTo: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
const TypeReporter& reporter) {
- CHECK_EQ(types.size(), 2);
- auto ioattrs = attrs.as<InitOpAttrs>();
- CHECK(ioattrs);
- auto intt = types[0].as<TensorTypeNode>();
- if (intt == nullptr) {
- return false;
+ CHECK_EQ(types.size(), 3);
+ const InitOpAttrs* param = attrs.as<InitOpAttrs>();
+ const auto* target_shape = types[1].as<TensorTypeNode>();
+ DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;
+
+ const IntImmNode* shape_shape = target_shape->shape[0].as<IntImmNode>();
+ CHECK(shape_shape) << "Parameter shape must have static shape";
+
+ std::vector<IndexExpr> oshape;
+ if (param->shape) {
+ const Array<Integer>& cshape_array = param->shape.value();
+ for (size_t i = 0; i < cshape_array.size(); ++i) {
+ oshape.push_back(cshape_array[i]);
+ }
+ } else {
+ for (int i = 0; i < shape_shape->value; ++i) {
+ oshape.push_back(Any::make());
+ }
}
- auto type = TensorType(ioattrs->shape, intt->dtype);
- reporter->Assign(types[1], type);
- return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
+ return BroadcastRel({types[0], types[2], types[2]}, 2, Attrs(), reporter);
}
-Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
+Expr MakeBroadCastTo(Expr data, Expr shape) {
static const Op& op = Op::Get("broadcast_to");
auto attrs = make_object<InitOpAttrs>();
- attrs->shape = std::move(shape);
- return Call(op, {data}, Attrs(attrs), {});
+ if (const auto* cshape = shape.as<ConstantNode>()) {
+ attrs->shape = ToVector(cshape->data);
+ }
+ return Call(op, {data, shape}, Attrs(attrs), {});
}
Array<te::Tensor> BroadCastToCompute(const Attrs& attrs, const
Array<te::Tensor>& inputs,
const Type& out_type) {
- auto ioattrs = attrs.as<InitOpAttrs>();
- CHECK(ioattrs != nullptr);
- return {topi::broadcast_to(inputs[0], ioattrs->shape)};
+ const auto* out_ttype = out_type.as<TensorTypeNode>();
+ return {topi::broadcast_to(inputs[0], out_ttype->shape)};
}
TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastTo);
@@ -1610,8 +1620,9 @@
TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastT
RELAY_REGISTER_OP("broadcast_to")
.describe(R"code(Broadcast the first input to match the shape argument.
)code" TVM_ADD_FILELINE)
- .set_num_inputs(1)
+ .set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("shape", "Tensor", "Target shape.")
.set_support_level(4)
.add_type_rel("BroadCastTo", BroadCastToRel)
.set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index bc35ed6..1f30b68 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -90,34 +90,33 @@ bool ConcatenateRel(const Array<Type>& types, int
num_inputs, const Attrs& attrs
if (e_dtype != dtype) {
throw Error("relay.concatenate requires all tensors have the same
dtype");
}
- for (size_t j = 0; j < first->shape.size(); ++j) {
- if (j == static_cast<size_t>(axis)) continue;
- if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
- throw Error(
- "relay.concatenate requires all tensors have the same shape "
- "on non-concatenating axes");
- }
}
// Calculate shape
std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
- IndexExpr& concat_dim = oshape[axis];
- bool has_any = false;
- if (concat_dim.as<Any>()) {
- has_any = true;
- } else {
- for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
- const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
- if (e->shape[axis].as<Any>()) {
- has_any = true;
- break;
+ int data_length = static_cast<int>(tensor_tuple->fields.size());
+ for (int i = 0; i < ndim; ++i) {
+ std::vector<IndexExpr> non_any;
+ for (int j = 0; j < data_length; ++j) {
+ const auto& e = Downcast<TensorType>(tensor_tuple->fields[j]);
+ if (!e->shape[i].as<Any>()) {
+ non_any.push_back(e->shape[i]);
+ // accumulate axis dimension
+ if (j > 0 && i == axis && !oshape[i].as<Any>()) {
+ oshape[i] += e->shape[i];
+ }
+ }
+ }
+ int non_any_size = static_cast<int>(non_any.size());
+ if (non_any_size != data_length) oshape[i] = Any::make();
+ if (i != axis) {
+ for (int k = 1; k < non_any_size; k++) {
+ if (reporter->AssertEQ(non_any[0], non_any[k])) continue;
+ throw Error(
+ "relay.concatenate requires all tensors have the same shape "
+ "on non-concatenating axes");
}
- concat_dim += e->shape[axis];
}
- }
-
- if (has_any) {
- concat_dim = Any::make();
}
auto rtype = TensorType(oshape, dtype);
diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc
index 7171ded..4daa5c9 100644
--- a/src/relay/qnn/util.cc
+++ b/src/relay/qnn/util.cc
@@ -202,8 +202,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor,
std::vector<double> multipliers,
round_scalar = exp_pos_rounding_value_expr;
} else if (rounding == "TONEAREST") {
// To satisfy where op shape requirements, the rounding values are
broadcasted.
- auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr,
input_shape);
- auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr,
input_shape);
+ auto pos_rounder = BroadCastTo(exp_pos_rounding_value_expr, input_shape);
+ auto neg_rounder = BroadCastTo(exp_neg_rounding_value_expr, input_shape);
auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder,
neg_rounder);
diff --git a/src/relay/transforms/lazy_gradient_init.cc
b/src/relay/transforms/lazy_gradient_init.cc
index 3cd29d6..f062466 100644
--- a/src/relay/transforms/lazy_gradient_init.cc
+++ b/src/relay/transforms/lazy_gradient_init.cc
@@ -203,9 +203,9 @@ class LazyGradientInitializer : public ExprMutator, public
TypeMutator {
}
if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) {
- // fn() -> T, function returns result of the operation
- Expr func =
- Function({}, {ExprMutator::VisitExpr_(call_node)},
{call_node->checked_type()}, {});
+ // ones and zeros need TensorType input
+ Expr result = CallPrimitiveOp(call_node);
+ Expr func = Function({}, result, {call_node->checked_type()},
Array<TypeVar>());
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones") ? "One" :
"Zero";
return Call(module_->GetConstructor("GradCell", constructor_name),
{func}, Attrs(),
@@ -288,7 +288,7 @@ class LazyGradientInitializer : public ExprMutator, public
TypeMutator {
args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(),
{expr->checked_type()}));
}
// result of operation
- return Call(call_node->op, args);
+ return Call(call_node->op, args, call_node->attrs);
}
};
diff --git a/src/relay/transforms/pattern_util.h
b/src/relay/transforms/pattern_util.h
index 8f37e7c..06b1e82 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -37,6 +37,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/data_layout.h>
+#include <limits>
#include <string>
#include <utility>
#include <vector>
@@ -311,6 +312,25 @@ static inline Constant MakeConstantTensor(DataType dtype,
std::vector<int64_t> s
}
/*!
+ * \brief Check whether a shape is static and create corresponding Constant.
+ *
+ * \param shape The Array of the shape values.
+ * \return A Constant.
+ */
+static inline Constant CheckConstantShape(const Array<IndexExpr>& shape) {
+ auto shape_array =
+ runtime::NDArray::Empty({int64_t(shape.size())}, DataType::Int(64),
{kDLCPU, 0});
+ auto* shape_data = static_cast<int64_t*>(shape_array->data);
+ for (size_t i = 0; i < shape.size(); ++i) {
+ const auto& dim_val = shape[i].as<IntImmNode>();
+ CHECK(dim_val) << "Do not support symbolic shape for "
+ "Array format. Pass shape as Expr instead.";
+ shape_data[i] = dim_val->value;
+ }
+ return Constant(shape_array);
+}
+
+/*!
* \brief Check if two expressions are equal scalars.
* \param a The expression to be checked.
* \param b The expression to be checked
@@ -325,6 +345,67 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
return tvm::StructuralEqual()(a, b);
}
+/*!
+ * \brief Convert an element of a NDArray with type int or float to scalar.
+ * \param array Input NDArray
+ * \param i element index
+ * \return Converted scalar value.
+ */
+static inline double ToScalar(const runtime::NDArray& array, size_t i = 0) {
+ if (array->dtype.code == kDLInt) {
+ if (array->dtype.bits == 8) {
+ return reinterpret_cast<int8_t*>(array->data)[i];
+ } else if (array->dtype.bits == 16) {
+ return reinterpret_cast<int16_t*>(array->data)[i];
+ } else if (array->dtype.bits == 32) {
+ return reinterpret_cast<int32_t*>(array->data)[i];
+ } else if (array->dtype.bits == 64) {
+ return reinterpret_cast<int64_t*>(array->data)[i];
+ }
+ } else if (array->dtype.code == kDLUInt) {
+ if (array->dtype.bits == 8) {
+ return reinterpret_cast<uint8_t*>(array->data)[i];
+ } else if (array->dtype.bits == 16) {
+ return reinterpret_cast<uint16_t*>(array->data)[i];
+ } else if (array->dtype.bits == 32) {
+ return reinterpret_cast<uint32_t*>(array->data)[i];
+ } else if (array->dtype.bits == 64) {
+ return reinterpret_cast<uint64_t*>(array->data)[i];
+ }
+ } else if (array->dtype.code == kDLFloat) {
+#if (__ARM_FP16_FORMAT_IEEE == 1)
+ if (array->dtype.bits == 16) {
+ return reinterpret_cast<__fp16*>(array->data)[i];
+ }
+#endif
+ if (array->dtype.bits == 32) {
+ return reinterpret_cast<float*>(array->data)[i];
+ } else if (array->dtype.bits == 64) {
+ return reinterpret_cast<double*>(array->data)[i];
+ }
+ }
+ LOG(FATAL) << "Unknown data type: " <<
tvm::runtime::DLDataType2String(array->dtype);
+ // make compiler happy
+ return -std::numeric_limits<double>::infinity();
+}
+
+/*!
+ * \brief Convert a NDArray with type int or float to Array<Integer>.
+ * \param array Input NDArray
+ * \return Converted Array.
+ */
+static inline Array<Integer> ToVector(const runtime::NDArray& array) {
+ size_t ndim = array.Shape().size();
+ CHECK_EQ(ndim, 1) << "This function should only used for shape tensor.";
+ size_t len = array.Shape().front();
+ Array<Integer> out;
+ for (size_t i = 0; i < len; ++i) {
+ double elem_val = ToScalar(array, i);
+ out.push_back(Integer(static_cast<int>(elem_val)));
+ }
+ return out;
+}
+
inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); }
inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); }
@@ -432,12 +513,10 @@ inline Expr ZerosLike(Expr e) {
return Call(op, {e});
}
+Expr MakeZeros(Expr shape, DataType dtype);
+
inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) {
- auto attrs = make_object<InitOpAttrs>();
- attrs->shape = std::move(shape);
- attrs->dtype = std::move(dtype);
- static const Op& op = Op::Get("zeros");
- return Call(op, {}, Attrs(attrs), {});
+ return MakeZeros(CheckConstantShape(shape), dtype);
}
inline Expr OnesLike(Expr e) {
@@ -503,12 +582,10 @@ static inline Expr GreaterEqual(const Expr& lhs, const
Expr& rhs) {
return Call(op, {lhs, rhs}, Attrs(), {});
}
+Expr MakeFull(Expr fill_value, Expr shape, DataType dtype);
+
static inline Expr Full(Expr fill_value, Array<IndexExpr> shape, DataType
dtype) {
- auto attrs = make_object<InitOpAttrs>();
- attrs->shape = std::move(shape);
- attrs->dtype = std::move(dtype);
- static const Op& op = Op::Get("full");
- return Call(op, {fill_value}, Attrs(attrs), {});
+ return MakeFull(fill_value, CheckConstantShape(shape), dtype);
}
static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
@@ -586,7 +663,11 @@ static inline Expr Tile(Expr data, Array<Integer> reps) {
return Call(op, {data}, Attrs(attrs), {});
}
-Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape);
+Expr MakeBroadCastTo(Expr data, Expr shape);
+
+static inline Expr BroadCastTo(Expr data, Array<IndexExpr> shape) {
+ return MakeBroadCastTo(data, CheckConstantShape(shape));
+}
Expr MakeConcatenate(Expr data, int axis);
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 5e5542d..504c20a 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -96,31 +96,48 @@ def test_any_broadcast_fail():
check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add)
-def verify_any_full(x_shape, x_np_shape, relay_op, np_op, dtype='float32'):
+def verify_any_full_like(x_shape, x_np_shape, relay_op, np_op,
dtype='float32'):
x = relay.var('x', shape=x_shape, dtype=dtype)
mod = tvm.IRModule()
- mod['main'] = relay.Function([x], relay.zeros_like(x))
+ mod['main'] = relay.Function([x], relay_op(x))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
- res_np = np.zeros_like(x_np)
+ res_np = np_op(x_np)
+ for kind in ['debug', 'vm']:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm')
+ result = ex.evaluate()(x_np).asnumpy()
+ tvm.testing.assert_allclose(result, res_np)
+
+def test_any_full_like():
+ # zeros_like, ones_like
+ verify_any_full_like(any_dims(3), (2, 3, 5), relay.zeros_like,
np.zeros_like, "float32")
+ verify_any_full_like(any_dims(3), (225, 115, 15), relay.zeros_like,
np.zeros_like, "float32")
+ verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like,
np.zeros_like, "int32")
+ verify_any_full_like(any_dims(3), (2, 3, 5), relay.ones_like,
np.ones_like, "float32")
+ verify_any_full_like(any_dims(3), (225, 115, 15), relay.ones_like,
np.ones_like, "float32")
+ verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like,
np.ones_like, "int32")
+
+def verify_any_full(x_np_shape, relay_op, np_op, dtype='float32', value=None):
+ x = relay.var('x', shape=(len(x_np_shape),), dtype="int32")
+ mod = tvm.IRModule()
+ out = relay_op(x, dtype) if value is None else
relay_op(relay.expr.const(value), x, dtype)
+ mod['main'] = relay.Function([x], out)
+ res_np = np_op(x_np_shape) if value is None else np_op(x_np_shape, value)
+ x_np = np.array(x_np_shape).astype("int32")
for kind in ['debug', 'vm']:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm')
result = ex.evaluate()(x_np).asnumpy()
tvm.testing.assert_allclose(result, res_np)
def test_any_full():
- # zeros, zeros_like, ones, ones_like
- verify_any_full(any_dims(3), (2, 3, 5), relay.zeros, np.zeros, "float32")
- verify_any_full(any_dims(3), (225, 115, 15), relay.zeros, np.zeros,
"float32")
- verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros, np.zeros,
"int32")
- verify_any_full(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like,
"float32")
- verify_any_full(any_dims(3), (225, 115, 15), relay.zeros_like,
np.zeros_like, "float32")
- verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like,
np.zeros_like, "int32")
- verify_any_full(any_dims(3), (2, 3, 5), relay.ones, np.ones, "float32")
- verify_any_full(any_dims(3), (225, 115, 15), relay.ones, np.ones,
"float32")
- verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones, np.ones,
"int32")
- verify_any_full(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like,
"float32")
- verify_any_full(any_dims(3), (225, 115, 15), relay.ones_like,
np.ones_like, "float32")
- verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like,
np.ones_like, "int32")
+ # zeros, ones, full
+ verify_any_full((2, 3, 5), relay.zeros, np.zeros, "float32")
+ verify_any_full((225, 115, 15), relay.zeros, np.zeros, "float32")
+ verify_any_full((10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32")
+ verify_any_full((2, 3, 5), relay.ones, np.ones, "float32")
+ verify_any_full((225, 115, 15), relay.ones, np.ones, "float32")
+ verify_any_full((10, 11, 12, 13, 14), relay.ones, np.ones, "int32")
+ verify_any_full((10, 11, 12, 13, 14), relay.full, np.full, "float32", 2.0)
+ verify_any_full((1, 2, 3, 4), relay.full, np.full, "int32", -2)
def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
@@ -566,6 +583,37 @@ def test_any_softmax():
verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3))
verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1))
+def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
+ mod = tvm.IRModule()
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ np_data = np.random.uniform(size=np_dshape).astype(dtype)
+ if const_k:
+ k = relay.const(kval)
+ args = [data]
+ in_vals = [np_data]
+ else:
+ k = relay.var('k', shape=(), dtype="int32")
+ args = [data, k]
+ in_vals = [np_data, kval]
+ out = relay.topk(data, k, ret_type="indices")
+ mod["main"] = relay.Function(args, out)
+
+ sorted = np.argsort(-np_data)
+ if len(np_dshape) == 2:
+ ref_out = sorted[:, 0:kval]
+ else:
+ ref_out = sorted[0:kval]
+
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(*in_vals)
+ tvm.testing.assert_allclose(result.asnumpy(), ref_out)
+
+def test_any_topk():
+ verify_any_topk(any_dims(1), 5, (10,), "float32")
+ verify_any_topk(any_dims(2), 2, (6, 3), "int32")
+ verify_any_topk(any_dims(2), 3, (6, 3), "float32", True)
+
def test_fused_ops():
x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
y0 = x + relay.const(1.0, 'float32')
@@ -723,6 +771,7 @@ def test_mixed_input_type():
if __name__ == "__main__":
test_any_full()
+ test_any_full_like()
test_any_broadcast()
test_any_elemwise()
test_any_broadcast_fail()
@@ -745,10 +794,10 @@ if __name__ == "__main__":
test_any_dense()
test_any_pad()
test_any_softmax()
+ test_any_topk()
test_fused_ops()
test_arange_with_dynamic_shape()
test_recursive_concat()
test_recursive_concat_with_wrong_annotation()
test_tuple_get_item()
test_mixed_input_type()
-
diff --git a/topi/python/topi/sort.py b/topi/python/topi/sort.py
index 744da62..e492d68 100644
--- a/topi/python/topi/sort.py
+++ b/topi/python/topi/sort.py
@@ -107,7 +107,7 @@ def topk(data, k=1, axis=-1, ret_type="both",
is_ascend=False, dtype="int64"):
data : tvm.te.Tensor
The input tensor.
- k : int, optional
+ k : int or tvm.te.Tensor, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
@@ -133,7 +133,10 @@ def topk(data, k=1, axis=-1, ret_type="both",
is_ascend=False, dtype="int64"):
assert ret_type in ["both", "values", "indices"]
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
data_alignment=8)
out_shape = list(get_const_tuple(data.shape))
- if k >= 1:
+ kvar = tvm.te.size_var("k")
+ if not isinstance(k, int):
+ out_shape[axis] = kvar
+ elif k >= 1:
out_shape[axis] = k
out_bufs = []
if ret_type in ["both", "values"]:
@@ -142,10 +145,11 @@ def topk(data, k=1, axis=-1, ret_type="both",
is_ascend=False, dtype="int64"):
out_bufs.append(tvm.tir.decl_buffer(out_shape, dtype, "indices_buf",
data_alignment=8))
out_shapes = [out_shape] * len(out_bufs)
+ kv = kvar if not isinstance(k, int) else k
out = te.extern(out_shapes,
[data],
lambda ins, outs: tvm.tir.call_packed(
- "tvm.contrib.sort.topk", ins[0], *outs, k, axis,
ret_type, is_ascend),
+ "tvm.contrib.sort.topk", ins[0], *outs, kv, axis,
ret_type, is_ascend),
in_buffers=[data_buf],
out_buffers=out_bufs,
name="topk_cpu",