comaniac commented on a change in pull request #8234:
URL: https://github.com/apache/tvm/pull/8234#discussion_r649373782
##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -1471,7 +1471,46 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)
-def dense(data, weight, units=None, out_dtype=""):
+def matmul(data, weight, units=None, out_dtype="", data_transposed=False,
weight_transposed=False):
+ """Dense operator.
+ Applies a linear transformation. The X & W can be transposed.
+
+ .. math::
+
+ `Y = X * W`
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input data to the operator,
+ of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ...,
units_in, d_n)`.
+
+ weight : tvm.relay.Expr
+ The weight expressions, 2-D matrix,
+ of shape `(units_in, units)` or `(units, units_in)`.
+
+ units : int, optional
+ Number of hidden units of the dense transformation.
+
+ out_dtype : str, optional
+ Specifies the output data type for mixed precision dense,
+ of shape `(d_1, d_2, ..., d_n, units)`.
+
+ data_transposed : bool, optional
+ Whether the data tensor is in transposed format.
+
+ weight_transposed : bool, optional
+ Whether the weight tensor is in transposed format.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+ return _make.matmul(data, weight, units, out_dtype, data_transposed,
weight_transposed)
+
+
+def dense(data, weight, units=None, out_dtype="", data_transposed=False,
weight_transposed=True):
Review comment:
It's weird to have transpose arguments in dense. This op should enforce
transposed weight. Ideally, the implementation of dense in Relay level should
be:
```
return _make.matmul(data, weight, units, out_dtype, False, True)
```
##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -1471,7 +1471,46 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)
-def dense(data, weight, units=None, out_dtype=""):
+def matmul(data, weight, units=None, out_dtype="", data_transposed=False,
weight_transposed=False):
+ """Dense operator.
Review comment:
```suggestion
"""Matmul operator.
```
##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -698,6 +698,26 @@ def conv1d_transpose_strategy_cuda(attrs, inputs,
out_type, target):
return strategy
+@matmul_strategy.register(["cuda", "gpu"])
+def matmul_strategy_cuda(attrs, inputs, out_type, target):
Review comment:
it's highly possible that this function generates no strategy. We need a
checker to throw an error as well as the workaround for this situation.
##########
File path: python/tvm/topi/nn/dense.py
##########
@@ -51,37 +65,120 @@ def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layo
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
- batch, in_dim = data.shape
+ if data_transposed:
+ in_dim, batch = data.shape
+ else:
+ batch, in_dim = data.shape
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
- auto_scheduler_rewritten_layout, ["j", "k"]
+ auto_scheduler_rewritten_layout, ["j", "k"] if weight_transposed
else ["k", "j"]
)
auto_scheduler.remove_index_check(weight)
- else:
+ elif weight_transposed:
out_dim, red_dim = weight.shape
+ else:
+ red_dim, out_dim = weight.shape
assert in_dim == red_dim
k = te.reduce_axis((0, in_dim), name="k")
- matmul = te.compute(
+ if data_transposed:
Review comment:
nit: This might be clearer:
```
if (data_transposed, weight_transposed) == (False, False):
# ...
elif (data_transposed, weight_transposed) == (False, True):
# ...
elif (data_transposed, weight_transposed) == (True, False):
# ...
elif (data_transposed, weight_transposed) == (True, True):
# ...
```
Also I think it's fine to just use `T_matmul_NT` instead of `T_dense`.
##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -1471,7 +1471,46 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)
-def dense(data, weight, units=None, out_dtype=""):
+def matmul(data, weight, units=None, out_dtype="", data_transposed=False,
weight_transposed=False):
+ """Dense operator.
+ Applies a linear transformation. The X & W can be transposed.
+
+ .. math::
+
+ `Y = X * W`
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input data to the operator,
+ of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ...,
units_in, d_n)`.
+
+ weight : tvm.relay.Expr
+ The weight expressions, 2-D matrix,
+ of shape `(units_in, units)` or `(units, units_in)`.
+
+ units : int, optional
+ Number of hidden units of the dense transformation.
+
+ out_dtype : str, optional
Review comment:
Default?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]