jcf94 commented on a change in pull request #8234:
URL: https://github.com/apache/tvm/pull/8234#discussion_r660225637



##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -1471,6 +1471,47 @@ def bias_add(data, bias, axis=1):
     return _make.bias_add(data, bias, axis)
 
 
+def matmul(data, weight, units=None, out_dtype="", data_transposed=False, 
weight_transposed=False):
+    """Matmul operator.
+    Applies a linear transformation. The X & W can be transposed.
+
+    .. math::
+
+        `Y = X * W`
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator,
+        of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ..., 
units_in, d_n)`.

Review comment:
       No, the input of matmul is supposed to be a multiple-dim tensor(not 
limited to 2). This is copied from the original `nn.dense`.
   
   Other frameworks like Pytorch also has such definition.

##########
File path: python/tvm/relay/op/nn/_nn.py
##########
@@ -1160,21 +1186,46 @@ def batch_flatten_shape_func(attrs, inputs, _):
 
 
 @script
-def _dense_shape_func(data_shape, weight_shape):
+def _matmul_shape_func(data_shape, weight_shape, data_transposed, 
weight_transposed):
     out = output_tensor((data_shape.shape[0],), "int64")
     for i in const_range(out.shape[0] - 1):
         out[i] = data_shape[i]
-    out[out.shape[0] - 1] = weight_shape[0]
+    if data_transposed:
+        out[out.shape[0] - 2] = out[out.shape[0] - 1]
+    out[out.shape[0] - 1] = weight_shape[0] if weight_transposed else 
weight_shape[1]

Review comment:
       Since the dimension of data tensor can be more than 2, this is the 
simplest implementation to do so.

##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -1471,6 +1471,47 @@ def bias_add(data, bias, axis=1):
     return _make.bias_add(data, bias, axis)
 
 
+def matmul(data, weight, units=None, out_dtype="", data_transposed=False, 
weight_transposed=False):
+    """Matmul operator.
+    Applies a linear transformation. The X & W can be transposed.
+
+    .. math::
+
+        `Y = X * W`
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator,
+        of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ..., 
units_in, d_n)`.
+
+    weight : tvm.relay.Expr
+        The weight expressions, 2-D matrix,
+        of shape `(units_in, units)` or `(units, units_in)`.
+
+    units : int, optional
+        Number of hidden units of the matmul transformation.

Review comment:
       I think the doc has explained enough: "The hidden units." This is copied 
from the original `nn.dense`.

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1204,7 +1208,7 @@ def from_tensorflow(self, graph, layout="NHWC", 
shape=None, outputs=None):
         return func, self._params
 
 
-def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
+def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, 
use_dense_op=True):

Review comment:
       The problem is that we're not able to remove all the `nn.dense` at this 
moment and there's not enough AutoTVM template for `nn.matmul`.
   
   So the use of `nn.matmul` can only be seen as a experimental feature. We 
should not change the default behavior in case this may affect those who are 
using `nn.dense`.
   

##########
File path: include/tvm/relay/attrs/nn.h
##########
@@ -961,6 +961,32 @@ struct AvgPool3DAttrs : public 
tvm::AttrsNode<AvgPool3DAttrs> {
   }
 };
 
+/*! \brief Attributes for matmul operator */
+struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
+  IndexExpr units;
+  DataType out_dtype;
+  bool data_transposed;
+  bool weight_transposed;
+  tvm::String auto_scheduler_rewritten_layout;  // The layout after 
auto-scheduler's layout rewrite

Review comment:
       You mean `MatmulAttrs`? We're not able to remove all the `nn.dense` at 
this moment. So `nn.dense` and `nn.matmul` should still be two different ops 
now. They need different `Attrs`.

##########
File path: python/tvm/topi/nn/dense.py
##########
@@ -51,37 +65,120 @@ def dense(data, weight, bias=None, out_dtype=None, 
auto_scheduler_rewritten_layo
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = data.dtype
-    batch, in_dim = data.shape
+    if data_transposed:
+        in_dim, batch = data.shape
+    else:
+        batch, in_dim = data.shape
 
     if auto_scheduler_rewritten_layout:
         # Infer shape for the rewritten layout
         out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
-            auto_scheduler_rewritten_layout, ["j", "k"]
+            auto_scheduler_rewritten_layout, ["j", "k"] if weight_transposed 
else ["k", "j"]
         )
         auto_scheduler.remove_index_check(weight)
-    else:
+    elif weight_transposed:
         out_dim, red_dim = weight.shape
+    else:
+        red_dim, out_dim = weight.shape
     assert in_dim == red_dim
 
     k = te.reduce_axis((0, in_dim), name="k")
-    matmul = te.compute(
+    if data_transposed:
+        if weight_transposed:
+            compute_lambda = lambda i, j: te.sum(
+                data[k, i].astype(out_dtype) * weight[j, k].astype(out_dtype), 
axis=k
+            )
+            compute_name = "T_matmul_TT"
+        else:
+            compute_lambda = lambda i, j: te.sum(
+                data[k, i].astype(out_dtype) * weight[k, j].astype(out_dtype), 
axis=k
+            )
+            compute_name = "T_matmul_TN"
+        compute_tag = "matmul"
+    else:
+        if weight_transposed:
+            compute_lambda = lambda i, j: te.sum(
+                data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), 
axis=k
+            )
+            compute_name = "T_dense"

Review comment:
       I think its fine for it is just a op name. 😄  But the tag `dense` has 
been used in some schedule check, so I think we'd better keep that.
   
   There're some options I can come up with:
   - A: Use `T_dense` as name and `dense` as tag for NT format, use `T_matmul` 
as name and `matmul` as tag for all other 3 format.
   - B: Use `T_matmul_NN`, `T_matmul_NT`, `T_matmul_TN`, `T_matmul_TT` as name 
for each format, use `dense` as tag for NT format and `matmul` as tag for 
others.
   
   What do you think about?

##########
File path: python/tvm/relay/op/nn/_nn.py
##########
@@ -1160,21 +1186,46 @@ def batch_flatten_shape_func(attrs, inputs, _):
 
 
 @script
-def _dense_shape_func(data_shape, weight_shape):
+def _matmul_shape_func(data_shape, weight_shape, data_transposed, 
weight_transposed):

Review comment:
       Updated all `data_transposed` & `weight_transposed` to `transpose_a` & 
`transpose_b`. And also renamed all the `data` & `weight` in matmul to 
`tensor_a` & `tensor_b`. Tensor names in dense remain unchanged.

##########
File path: python/tvm/relay/op/strategy/x86.py
##########
@@ -370,6 +370,55 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):
     return strategy
 
 
+@matmul_strategy.register("cpu")
+def matmul_strategy_cpu(attrs, inputs, out_type, target):
+    """matmul x86 strategy"""
+    strategy = _op.OpStrategy()
+    if is_auto_scheduler_enabled():
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.nn.matmul, 
need_auto_scheduler_layout=True),
+            naive_schedule,
+            name="matmul.generic",
+            plevel=11,
+        )
+    else:
+        logger.warning("Matmul other than NT format is not optimized for x86.")
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.nn.matmul),
+            naive_schedule,
+            name="matmul.generic",
+        )
+
+    same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
+    dtype = inputs[0].dtype
+    u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and 
out_type.dtype == "int32"
+    if "cblas" in target.libs:

Review comment:
       I agree, but seems there is not an api for `SpecializedCondition` to 
process the False path?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to