tkonolige commented on a change in pull request #6580:
URL: https://github.com/apache/incubator-tvm/pull/6580#discussion_r498364820
##########
File path: src/relay/op/nn/sparse.cc
##########
@@ -82,6 +82,35 @@ RELAY_REGISTER_OP("nn.sparse_dense")
- **weight**: `(units, input_dim)`
- **out**: `(x1, x2, ..., xn, units)`.
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<SparseDenseAttrs>()
+ .set_num_inputs(4)
+ .add_argument("data", "nD Tensor", "Input data.")
+ .add_argument("weight_data", "1D Tensor", "Weight data matrix.")
+ .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
+ .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
+ .set_support_level(1)
+ .add_type_rel("SparseDense", SparseDenseRel);
+
+Expr MakeSparseDensePadded(Expr data, Expr weight_data, Expr weight_indices,
Expr weight_indptr) {
+ auto attrs = make_object<SparseDenseAttrs>();
+ static const Op& op = Op::Get("nn.sparse_dense_padded");
+ return Call(op, {data, weight_data, weight_indices, weight_indptr},
Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense_padded")
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 4>(MakeSparseDensePadded, args, rv);
+ });
+
+RELAY_REGISTER_OP("nn.sparse_dense_padded")
+ .describe(
+ R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X
sparse. This variation uses a matrix with row lengths padded to a multiple of
32 for better GPU performance.
Review comment:
Good catch, fixed it for sparse_dense too.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]