This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new da75d85 [Relay][Dynamic] OneHot operation (#6209)
da75d85 is described below
commit da75d85cdce6fa189f3662793e0a68e0f84309f1
Author: Matthew Brookhart <[email protected]>
AuthorDate: Thu Aug 6 08:46:58 2020 -0700
[Relay][Dynamic] OneHot operation (#6209)
* Dynamic OneHot Op
* refactor dynamic_to_static
* add onehot to dynamic_to_static pass
---
include/tvm/topi/transform.h | 19 ++--
python/tvm/relay/op/dyn/_transform.py | 35 +++++--
python/tvm/relay/op/transform.py | 15 ++-
src/relay/op/dyn/tensor/transform.cc | 70 ++++++++++++++
src/relay/op/make_op.h | 2 +
src/relay/transforms/dynamic_to_static.cc | 113 +++++++++++++++-------
tests/python/relay/dyn/test_dynamic_op_level10.py | 64 ++++++++++--
tests/python/relay/test_pass_dynamic_to_static.py | 28 ++++++
8 files changed, 285 insertions(+), 61 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index cd19436..19b2ef4 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1421,22 +1421,25 @@ inline Tensor ndarray_size(const Tensor& src, const
DataType& dtype,
* \param depth depth of the one-hot dimension.
* \param axis axis to fill.
* \param dtype data type of the output tensor.
+ * \param oshape shape of the output tensor.
* \param name output tensor name.
* \param tag output tensor tag.
* \return one-hot tensor.
*/
inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const
PrimExpr off_value,
int depth, int axis, const DataType& dtype,
+ Array<PrimExpr> oshape = Array<PrimExpr>(),
const std::string name = "T_one_hot", const std::string
tag = kInjective) {
- Array<PrimExpr> oshape;
- int ndim = indices->shape.size() + 1;
- int indices_index = 0;
int true_axis = (axis == -1) ? indices->shape.size() : axis;
- for (int i = 0; i < ndim; i++) {
- if (i == true_axis) {
- oshape.push_back(Integer(depth));
- } else {
- oshape.push_back(indices->shape[indices_index++]);
+ if (oshape.size() == 0) {
+ int ndim = indices->shape.size() + 1;
+ int indices_index = 0;
+ for (int i = 0; i < ndim; i++) {
+ if (i == true_axis) {
+ oshape.push_back(Integer(depth));
+ } else {
+ oshape.push_back(indices->shape[indices_index++]);
+ }
}
}
diff --git a/python/tvm/relay/op/dyn/_transform.py
b/python/tvm/relay/op/dyn/_transform.py
index e2704bc..3a80f5a 100644
--- a/python/tvm/relay/op/dyn/_transform.py
+++ b/python/tvm/relay/op/dyn/_transform.py
@@ -25,11 +25,13 @@ from .. import op as _reg
_reg.register_broadcast_schedule("dyn.broadcast_to")
_reg.register_injective_schedule("dyn.reshape")
_reg.register_broadcast_schedule("dyn.tile")
+_reg.register_injective_schedule("dyn.one_hot")
+
@script
def _reshape_shape_func_input_data(data, newshape, ndim):
- out = output_tensor((ndim,), "int64")
- data_shape = allocate((len(data.shape),), "int64")
+ out = output_tensor((ndim, ), "int64")
+ data_shape = allocate((len(data.shape), ), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
@@ -59,7 +61,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
- out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
+ out[dst_idx] = data_shape[src_idx] * data_shape[src_idx + 1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
@@ -82,6 +84,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
out[infer_idx] = old_size // new_size
return out
+
@_reg.register_shape_func("dyn.reshape", True)
def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
@@ -89,7 +92,7 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
@script
def _tile_shape_func(data, reps, ndim, tndim, rndim):
- out = output_tensor((tndim,), "int64")
+ out = output_tensor((tndim, ), "int64")
if ndim == rndim:
for i in const_range(tndim):
@@ -120,5 +123,25 @@ def tile_shape_func(attrs, inputs, _):
ndim = len(inputs[0].shape)
rndim = inputs[1].shape[0].value
tndim = ndim if ndim > rndim else rndim
- return [_tile_shape_func(inputs[0], reps, convert(ndim),
- convert(tndim), convert(rndim))]
+ return [_tile_shape_func(inputs[0], reps, convert(ndim), convert(tndim),
convert(rndim))]
+
+
+@script
+def _onehot_shape_func(dshape, k, axis):
+ ndim = len(dshape) + 1
+ out = output_tensor((ndim, ), "int64")
+ for i in const_range(axis):
+ out[i] = int64(dshape[i])
+ out[axis] = int64(k[0])
+ for j in const_range(axis + 1, ndim):
+ out[j] = int64(dshape[j - 1])
+ return out
+
+
+@_reg.register_shape_func("dyn.one_hot", True)
+def one_hot_shape_func(attrs, inputs, _):
+ """
+ Shape function for dyn.one_hot op.
+ """
+ axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis
+ return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))]
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 6f23af2..5e5b867 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -148,6 +148,7 @@ def squeeze(data, axis=None):
"""
return _make.squeeze(data, axis)
+
def reshape(data, newshape):
"""Reshape the input array.
@@ -228,6 +229,7 @@ def reshape(data, newshape):
newshape = tempshape
return _make.reshape(data, list(newshape))
+
def argwhere(condition):
"""Find the indices of elements of a tensor that are
non-zero.
@@ -251,6 +253,7 @@ def argwhere(condition):
"""
return _make.argwhere(condition)
+
def scatter(data, indices, updates, axis):
"""Update data at positions defined by indices with values in updates
@@ -275,6 +278,7 @@ def scatter(data, indices, updates, axis):
"""
return _make.scatter(data, indices, updates, axis)
+
def scatter_add(data, indices, updates, axis):
"""Update data by adding values in updates at positions defined by indices
@@ -299,6 +303,7 @@ def scatter_add(data, indices, updates, axis):
"""
return _make.scatter_add(data, indices, updates, axis)
+
def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like`
operation reshapes
@@ -442,6 +447,7 @@ def arange(start, stop=None, step=None, dtype="float32"):
return _make.arange(start, stop, step, dtype)
+
def meshgrid(data, indexing="ij"):
"""Create coordinate matrices from coordinate vectors.
@@ -482,6 +488,7 @@ def meshgrid(data, indexing="ij"):
ret_size = len(data)
return TupleWrapper(_make.meshgrid(Tuple(data), indexing), ret_size)
+
def repeat(data, repeats, axis):
"""Repeats elements of an array.
By default, repeat flattens the input array into 1-D and then repeats the
elements.
@@ -668,6 +675,7 @@ def where(condition, x, y):
"""
return _make.where(condition, x, y)
+
def broadcast_to(data, shape):
"""Return a scalar value array with the same type, broadcast to
the provided shape.
@@ -693,6 +701,7 @@ def broadcast_to(data, shape):
shape = list(shape)
return _make.broadcast_to(data, shape)
+
def broadcast_to_like(data, broadcast_type):
"""Return a scalar value array with the same shape and type as the input
array.
@@ -1053,6 +1062,7 @@ def sequence_mask(data, valid_length, mask_value=0,
axis=0):
"""
return _make.sequence_mask(data, valid_length, mask_value, axis)
+
def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take
value on_value,
@@ -1070,7 +1080,7 @@ def one_hot(indices, on_value, off_value, depth, axis,
dtype):
off_value : relay.Expr
Value to fill at all other positions besides indices.
- depth : int
+ depth : int or relay.Expr
Depth of the one-hot dimension.
axis : int
@@ -1095,6 +1105,8 @@ def one_hot(indices, on_value, off_value, depth, axis,
dtype):
[0, 1, 0],
[0, 0, 1]]
"""
+ if isinstance(depth, Expr):
+ return _dyn_make.one_hot(indices, on_value, off_value, depth, axis,
dtype)
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
@@ -1120,6 +1132,7 @@ def unravel_index(indices, shape):
return _make.unravel_index(indices, shape)
+
def sparse_to_dense(sparse_indices, output_shape, sparse_values,
default_value=0):
"""Converts a sparse representation into a dense tensor.
diff --git a/src/relay/op/dyn/tensor/transform.cc
b/src/relay/op/dyn/tensor/transform.cc
index 2bb87ac..d2d6d69 100644
--- a/src/relay/op/dyn/tensor/transform.cc
+++ b/src/relay/op/dyn/tensor/transform.cc
@@ -304,6 +304,76 @@ RELAY_REGISTER_OP("dyn.ones")
.set_support_level(3)
.add_type_rel("DynamicInitOp", InitOpRel);
+bool OneHotRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // `types` contains: [indices, on_value, off_value, result]
+ CHECK_EQ(types.size(), 5);
+ const auto* indices = types[0].as<TensorTypeNode>();
+ CHECK(indices);
+
+ const auto param = attrs.as<OneHotAttrs>();
+
+ Array<IndexExpr> oshape;
+ int ndim = indices->shape.size() + 1;
+ int indices_index = 0;
+ int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
+ for (int i = 0; i < ndim; i++) {
+ if (i == true_axis) {
+ oshape.push_back(Any());
+ } else {
+ oshape.push_back(indices->shape[indices_index++]);
+ }
+ }
+
+ reporter->Assign(types[4], TensorType(oshape, param->dtype));
+ return true;
+}
+
+Array<te::Tensor> OneHotCompute(const Attrs& attrs, const Array<te::Tensor>&
inputs,
+ const Type& out_type) {
+ const auto* param = attrs.as<OneHotAttrs>();
+ CHECK(param != nullptr);
+ const auto* out_ttype = out_type.as<TensorTypeNode>();
+ return Array<te::Tensor>{topi::one_hot(inputs[0], inputs[1](), inputs[2](),
-1, param->axis,
+ param->dtype, out_ttype->shape)};
+}
+
+Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, Expr depth, int
axis, DataType dtype) {
+ auto attrs = make_object<OneHotAttrs>();
+ attrs->axis = axis;
+ attrs->dtype = dtype;
+ static const Op& op = Op::Get("dyn.one_hot");
+ return Call(op, {indices, on_value, off_value, depth}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.dyn._make.one_hot").set_body_typed(MakeOneHot);
+
+RELAY_REGISTER_OP("dyn.one_hot")
+ .describe(R"code(Returns a one-hot tensor where the locations repsented by
indices take value 1,
+ other locations take value 0. Final dimension is <indices dimensions> x
depth.
+
+ **indices** Locations to set to 1.
+
+ **on_value** Value to fill at indices.
+
+ **off_value** Value to fill at all other positions besides indices.
+
+ **depth** Depth of the one-hot dimension.
+
+ **axis** Axis to fill.
+
+ **dtype**)code" TVM_ADD_FILELINE)
+ .set_attrs_type<OneHotAttrs>()
+ .set_num_inputs(4)
+ .add_argument("indices", "Tensor", "Locations to set to on_value.")
+ .add_argument("on_value", "Expr", "Value to fill at indices.")
+ .add_argument("off_value", "Expr", "Value to fill at all other positions
besides indices.")
+ .add_argument("depth", "Expr", "Value to fill at all other positions
besides indices.")
+ .set_support_level(10)
+ .add_type_rel("DynOneHot", OneHotRel)
+ .set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
+ .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
+
} // namespace dyn
} // namespace relay
} // namespace tvm
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index 3b5e9a1..d2c170d 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -78,6 +78,8 @@ Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis,
bool keepdims, bool
Expr MakeZeros(Array<Integer> shape, DataType dtype);
+Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int
axis, DataType dtype);
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_MAKE_OP_H_
diff --git a/src/relay/transforms/dynamic_to_static.cc
b/src/relay/transforms/dynamic_to_static.cc
index d4de15c..8501ee5 100644
--- a/src/relay/transforms/dynamic_to_static.cc
+++ b/src/relay/transforms/dynamic_to_static.cc
@@ -33,44 +33,82 @@ namespace relay {
class DynamicToStaticMutator : public MixedModeMutator {
public:
- DynamicToStaticMutator() {}
+ DynamicToStaticMutator() {
+ op_map_ = {
+ {Op::Get("dyn.reshape"),
+ [](const CallNode* call_node) {
+ if (const ConstantNode* shape =
call_node->args[1].as<ConstantNode>()) {
+ CHECK_EQ(shape->data->ndim, 1);
+ return MakeReshape(call_node->args[0], ToVector(shape->data));
+ }
+ return Expr(nullptr);
+ }},
+ {Op::Get("dyn.tile"),
+ [](const CallNode* call_node) {
+ if (const ConstantNode* reps =
call_node->args[1].as<ConstantNode>()) {
+ CHECK_EQ(reps->data->ndim, 1);
+ return MakeTile(call_node->args[0], ToVector(reps->data));
+ }
+ return Expr(nullptr);
+ }},
+ {Op::Get("dyn.topk"),
+ [](const CallNode* call_node) {
+ if (const ConstantNode* k = call_node->args[1].as<ConstantNode>()) {
+ const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
+ CHECK(param);
+ return MakeTopK(call_node->args[0],
static_cast<int>(ToScalar(k->data, 0)),
+ param->axis, param->ret_type, param->is_ascend,
param->dtype);
+ }
+ return Expr(nullptr);
+ }},
+ {Op::Get("dyn.broadcast_to"),
+ [](const CallNode* call_node) {
+ if (const ConstantNode* shape =
call_node->args[1].as<ConstantNode>()) {
+ CHECK_EQ(shape->data->ndim, 1);
+ return MakeBroadCastTo(call_node->args[0], ToVector(shape->data));
+ }
+ return Expr(nullptr);
+ }},
+ {Op::Get("dyn.zeros"),
+ [](const CallNode* call_node) {
+ if (const ConstantNode* shape =
call_node->args[0].as<ConstantNode>()) {
+ const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
+ CHECK(param);
+ return MakeZeros(ToVector(shape->data), param->dtype);
+ }
+ return Expr(nullptr);
+ }},
+ {Op::Get("dyn.ones"),
+ [](const CallNode* call_node) {
+ if (const ConstantNode* shape =
call_node->args[0].as<ConstantNode>()) {
+ const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
+ CHECK(param);
+ return MakeOnes(ToVector(shape->data), param->dtype);
+ }
+ return Expr(nullptr);
+ }},
+ {Op::Get("dyn.one_hot"),
+ [](const CallNode* call_node) {
+ if (const ConstantNode* depth =
call_node->args[3].as<ConstantNode>()) {
+ const OneHotAttrs* param = call_node->attrs.as<OneHotAttrs>();
+ CHECK(param);
+ return MakeOneHot(call_node->args[0], call_node->args[1],
call_node->args[2],
+ static_cast<int>(ToScalar(depth->data, 0)),
param->axis,
+ param->dtype);
+ }
+ return Expr(nullptr);
+ }},
+ };
+ }
private:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
- const CallNode* call_node = post.as<CallNode>();
- if (call_node->op == Op::Get("dyn.reshape")) {
- if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
- CHECK_EQ(shape->data->ndim, 1);
- return MakeReshape(call_node->args[0], ToVector(shape->data));
- }
- } else if (call_node->op == Op::Get("dyn.tile")) {
- if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
- CHECK_EQ(reps->data->ndim, 1);
- return MakeTile(call_node->args[0], ToVector(reps->data));
- }
- } else if (call_node->op == Op::Get("dyn.topk")) {
- if (const ConstantNode* k = call_node->args[1].as<ConstantNode>()) {
- const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
- CHECK(param);
- return MakeTopK(call_node->args[0], static_cast<int>(ToScalar(k->data,
0)), param->axis,
- param->ret_type, param->is_ascend, param->dtype);
- }
- } else if (call_node->op == Op::Get("dyn.broadcast_to")) {
- if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
- CHECK_EQ(shape->data->ndim, 1);
- return MakeBroadCastTo(call_node->args[0], ToVector(shape->data));
- }
- } else if (call_node->op == Op::Get("dyn.zeros")) {
- if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
- const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
- CHECK(param);
- return MakeZeros(ToVector(shape->data), param->dtype);
- }
- } else if (call_node->op == Op::Get("dyn.ones")) {
- if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
- const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
- CHECK(param);
- return MakeOnes(ToVector(shape->data), param->dtype);
+ if (const CallNode* call_node = post.as<CallNode>()) {
+ if (op_map_.count(call_node->op)) {
+ auto out = op_map_[call_node->op](call_node);
+ if (out.defined()) {
+ return out;
+ }
}
}
return post;
@@ -83,6 +121,8 @@ class DynamicToStaticMutator : public MixedModeMutator {
}
return post;
}
+ std::unordered_map<Expr, std::function<Expr(const CallNode*)>,
ObjectPtrHash, ObjectPtrEqual>
+ op_map_;
};
Expr DynamicToStatic(Function f, IRModule m) {
@@ -90,6 +130,7 @@ Expr DynamicToStatic(Function f, IRModule m) {
Expr expr = f;
auto fold_const = transform::FoldConstant();
auto infer_type = transform::InferType();
+ DynamicToStaticMutator mutator;
Map<BaseFunc, GlobalVar> vars;
for (auto kv : m->functions) {
vars.Set(kv.second, kv.first);
@@ -101,7 +142,7 @@ Expr DynamicToStatic(Function f, IRModule m) {
// TODO(mbrookhart): Is it possible to run these passes JUST on the
current function?
m = infer_type(m);
m = fold_const(m);
- expr = DynamicToStaticMutator().Mutate(m->functions[gv]);
+ expr = mutator.Mutate(m->functions[gv]);
m->Update(gv, Downcast<BaseFunc>(expr));
i += 1;
} while (pre != expr && i < 1000);
diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py
b/tests/python/relay/dyn/test_dynamic_op_level10.py
index d9b23a7..95a030f 100644
--- a/tests/python/relay/dyn/test_dynamic_op_level10.py
+++ b/tests/python/relay/dyn/test_dynamic_op_level10.py
@@ -19,36 +19,80 @@ Support level10 operator test cases.
"""
-
import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import ctx_list, run_infer_type
+import tvm.topi.testing
import random
+
def test_dyn_broadcast_to():
dtype = 'uint8'
rank = 3
shape_type = 'int64'
- dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type))
- x_shape = (1,)
+ dyn_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type))
+ x_shape = (1, )
x = relay.Var("x", relay.ty.TensorType(x_shape, dtype))
z = relay.broadcast_to(x, dyn_shape)
zz = run_infer_type(z)
-
- assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype)
+
+ assert zz.checked_type == relay.ty.TensorType((relay.Any(), ) * rank,
dtype)
func = relay.Function([x, dyn_shape], z)
-
+
x = np.random.uniform(size=x_shape).astype(dtype)
- dyn_shape = (1,)*rank
+ dyn_shape = (1, ) * rank
ref_res = np.broadcast_to(x, dyn_shape)
for target, ctx in ctx_list():
- if (target != 'cuda'): #skip cuda because we don't have dynamic
support for GPU
+ if (target != 'cuda'): #skip cuda because we don't have dynamic
support for GPU
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx,
target=target)
- op_res =
intrp.evaluate(func)(x,np.array(dyn_shape).astype(shape_type))
+ op_res = intrp.evaluate(func)(x,
np.array(dyn_shape).astype(shape_type))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res,
rtol=1e-5)
-test_dyn_broadcast_to()
+
+def test_dyn_one_hot():
+ def _get_oshape(indices_shape, depth, axis):
+ oshape = []
+ true_axis = len(indices_shape) if axis == -1 else axis
+ ndim = len(indices_shape) + 1
+ indices_index = 0
+ for i in range(0, ndim):
+ if i == true_axis:
+ oshape.append(depth)
+ else:
+ oshape.append(indices_shape[indices_index])
+ indices_index += 1
+
+ return oshape
+
+ def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
+ indices = relay.var("indices", relay.TensorType(indices_shape,
"int32"))
+ depth_var = relay.var("depth", relay.TensorType((), "int32"))
+ on_value_const = relay.const(on_value)
+ off_value_const = relay.const(off_value)
+ out = relay.one_hot(indices, on_value_const, off_value_const,
depth_var, axis, dtype)
+ func = relay.Function([indices, depth_var], out)
+ indices_np = np.random.randint(0, depth,
size=indices_shape).astype("int32")
+ out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value,
depth, axis, dtype)
+ for target, ctx in ctx_list():
+ if (target != 'cuda'): #skip cuda because we don't have dynamic
support for GPU
+ for kind in ["vm", "debug"]:
+ mod = tvm.ir.IRModule.from_expr(func)
+ intrp = relay.create_executor(kind, mod=mod, ctx=ctx,
target=target)
+ out_relay = intrp.evaluate()(indices_np,
np.array(depth).astype("int32"))
+ tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)
+
+ _verify((3, ), 3, 1, 0, -1, "int32")
+ _verify((3, ), 3, 1.0, 0.0, -1, "float32")
+ _verify((2, 2), 5, 2, -2, 0, "int32")
+ _verify((2, 2), 5, 0.5, -0.5, 1, "float32")
+ _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
+ _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
+
+
+if __name__ == "__main__":
+ test_dyn_broadcast_to()
+ test_dyn_one_hot()
diff --git a/tests/python/relay/test_pass_dynamic_to_static.py
b/tests/python/relay/test_pass_dynamic_to_static.py
index 8ca7882..a50c9df 100644
--- a/tests/python/relay/test_pass_dynamic_to_static.py
+++ b/tests/python/relay/test_pass_dynamic_to_static.py
@@ -22,6 +22,8 @@ from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.testing import run_infer_type, create_workload, ctx_list
+import tvm.topi.testing
+
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, tvm.transform.Pass)
@@ -222,6 +224,32 @@ def test_dynamic_to_static_zeros_ones():
verify_ones_zeros((1, 2, 3), 'int64')
verify_ones_zeros((9, 8, 3, 4), 'float32')
+def test_dynamic_to_static_one_hot():
+ def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
+ indices = relay.var("indices", relay.TensorType(indices_shape,
"int32"))
+ depth_var = relay.const(depth)
+ on_value_const = relay.const(on_value)
+ off_value_const = relay.const(off_value)
+ out = relay.one_hot(indices, on_value_const, off_value_const,
depth_var, axis, dtype)
+ func = relay.Function([indices], out)
+
+ func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
transform.InferType())
+
+ zz = func2.body
+ assert isinstance(zz, relay.Call)
+ assert zz.op == relay.op.get("one_hot")
+
+ indices_np = np.random.randint(0, depth,
size=indices_shape).astype("int32")
+ out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value,
depth, axis, dtype)
+ verify_func(func2, [indices_np], out_np)
+
+ _verify((3, ), 3, 1, 0, -1, "int32")
+ _verify((3, ), 3, 1.0, 0.0, -1, "float32")
+ _verify((2, 2), 5, 2, -2, 0, "int32")
+ _verify((2, 2), 5, 0.5, -0.5, 1, "float32")
+ _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
+ _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
+
if __name__=="__main__":
test_dynamic_to_static_reshape()
test_dynamic_to_static_double_reshape()