This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 384f9b69ea [Unity] Add `axis` field to scatter_from_worker0 (#16092)
384f9b69ea is described below

commit 384f9b69eae44b8331572f2c43be6a7b87913d10
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Nov 9 07:39:29 2023 -0800

    [Unity] Add `axis` field to scatter_from_worker0 (#16092)
    
    This PR adds an `axis` field to scatter_from_worker0, which means the 
tensor axis along which it is scattered. legalize_ops will automatically 
generate reshape and transpose to preserve the constraint of ccl that 
collective communication ops must be performed on contiguous memory. For 
example, if the tensor shape of x is [10, 20], and we have 
`scatter_from_worker0(x, num_workers=2, axis=1)`, then after legalization it 
will expand to
    
    ```
    x = reshape(x, [10, 2, 10]) # shape: [10, 2, 10]
    x = permute_dims(x, [1, 0, 2]) # shape: [2, 10, 10]
    x = call_dps_packed("scatter_from_worker0", x) # shape: [10, 10]
    ```
    
    When axis=0, the behavior is the same as before.
    
    Also, this PR renames ScatterFromWorker0Attrs to ScatterAttrs to enable 
reuse by other ops like worker-id-aware slicing (scatter_from_worker0 = 
broadcast_from_worker0 +  worker-id-aware slicing).
---
 include/tvm/relax/attrs/ccl.h                      | 12 ++++--
 python/tvm/relax/op/ccl/ccl.py                     |  7 +++-
 python/tvm/relax/transform/legalize_ops/ccl.py     | 45 ++++++++++++++--------
 src/relax/op/ccl/ccl.cc                            | 17 ++++----
 src/relax/op/ccl/ccl.h                             |  2 +-
 .../relax/test_transform_legalize_ops_ccl.py       | 32 +++++++++++++--
 6 files changed, 82 insertions(+), 33 deletions(-)

diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h
index b4b3880384..42cec88de6 100644
--- a/include/tvm/relax/attrs/ccl.h
+++ b/include/tvm/relax/attrs/ccl.h
@@ -40,17 +40,21 @@ struct AllReduceAttrs : public 
tvm::AttrsNode<AllReduceAttrs> {
   }
 };  // struct AllReduceAttrs
 
-/*! \brief Attributes used in scatter_from_worker0 operators */
-struct ScatterFromWorker0Attrs : public 
tvm::AttrsNode<ScatterFromWorker0Attrs> {
+/*! \brief Attributes used in scatter operators */
+struct ScatterCollectiveAttrs : public tvm::AttrsNode<ScatterCollectiveAttrs> {
   int num_workers;
+  int axis;
 
-  TVM_DECLARE_ATTRS(ScatterFromWorker0Attrs, 
"relax.attrs.ScatterFromWorker0Attrs") {
+  TVM_DECLARE_ATTRS(ScatterCollectiveAttrs, 
"relax.attrs.ScatterCollectiveAttrs") {
     TVM_ATTR_FIELD(num_workers)
         .describe(
             "The number of workers, also the number of parts the given buffer 
should be chunked "
             "into.");
+    TVM_ATTR_FIELD(axis).describe(
+        "The axis of the tensor to be scattered. The tensor will be chunked 
along "
+        "this axis.");
   }
-};  // struct ScatterFromWorker0Attrs
+};  // struct ScatterCollectiveAttrs
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 4829bac761..21c7946120 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -84,7 +84,7 @@ def broadcast_from_worker0(x: Expr) -> Expr:
     return _ffi_api.broadcast_from_worker0(x)
 
 
-def scatter_from_worker0(x: Expr, num_workers: int) -> Expr:
+def scatter_from_worker0(x: Expr, num_workers: int, axis: int = 0) -> Expr:
     """Perform a scatter operation from worker-0, chunking the given buffer 
into equal parts.
 
     Parameters
@@ -95,9 +95,12 @@ def scatter_from_worker0(x: Expr, num_workers: int) -> Expr:
     num_worker : int
       The number of workers, i.e. the number of parts the given buffer should 
be chunked into.
 
+    axis : int
+      The dimension of the tensor to be scattered. Default is 0.
+
     Returns
     -------
     result : relax.Expr
       Chunked Tensor received by different workers.
     """
-    return _ffi_api.scatter_from_worker0(x, num_workers)
+    return _ffi_api.scatter_from_worker0(x, num_workers, axis)
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py 
b/python/tvm/relax/transform/legalize_ops/ccl.py
index 9b13d1be7c..ae0be3c228 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name
 """Default legalization function for ccl operators."""
-from tvm import tir, arith
+from tvm import tir, arith, topi
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr, ShapeExpr
 from ...op import call_dps_packed
@@ -80,28 +80,43 @@ def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) 
-> Expr:
     )
 
 
-@register_legalize("relax.ccl.scatter_from_worker0")
-def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
-    output_shape = []
+# Since collective communication ops are performed on contiguous memory,
+# we need to reshape and transpose the input tensor to make sharding dimension 
in the highest order
+def _transpose_for_ccl(_bb: BlockBuilder, expr: Expr, axis: int, num_workers: 
int):
     assert isinstance(
-        call.args[0].struct_info, TensorStructInfo
-    ), "The input struct info of scatter_from_worker0 should be 
TensorStructInfo."
-    assert isinstance(call.args[0].struct_info.shape.struct_info, 
ShapeStructInfo)
-    arg_shape = call.args[0].struct_info.shape.struct_info
+        expr.struct_info, TensorStructInfo
+    ), "The input struct info should be TensorStructInfo."
+    assert isinstance(expr.struct_info.shape.struct_info, ShapeStructInfo)
+    arg_shape = expr.struct_info.shape.struct_info
+    new_shape = []
     for i, shape_value in enumerate(arg_shape.values):
-        if i == 0:
-            modulo = arith.Analyzer().simplify(shape_value % 
call.attrs.num_workers)
+        if i == axis:
+            modulo = arith.Analyzer().simplify(shape_value % num_workers)
             assert modulo == 0, (
-                "scatter_from_worker0 expects the size of axis 0 of input 
tensor "
+                f"scatter_from_worker0 expects the size of axis {axis} of 
input tensor "
                 "to be divisible by num_workers. However, the axis 0 of input 
tensor "
-                f"is {shape_value} while num_workers is 
{call.attrs.num_workers}"
+                f"is {shape_value} while num_workers is {num_workers}"
             )
-            output_shape.append(tir.div(shape_value, call.attrs.num_workers))
+            new_shape.append(num_workers)
+            new_shape.append(tir.div(shape_value, num_workers))
         else:
-            output_shape.append(shape_value)
+            new_shape.append(shape_value)
+    reshape_var = _bb.emit_te(topi.reshape, expr, new_shape)
+    if axis == 0:
+        return reshape_var
+    permute_order = [axis] + list(range(axis)) + list(range(axis + 1, 
len(new_shape)))
+    transpose_var = _bb.emit_te(topi.transpose, reshape_var, permute_order)
+    return transpose_var
+
+
+@register_legalize("relax.ccl.scatter_from_worker0")
+def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+    transpose_var = _transpose_for_ccl(_bb, call.args[0], call.attrs.axis, 
call.attrs.num_workers)
+    output_shape = transpose_var.struct_info.shape.struct_info.values
+    output_shape = output_shape[1:]
     return call_dps_packed(
         "runtime.disco.scatter_from_worker0",
-        call.args[0],
+        transpose_var,
         out_sinfo=TensorStructInfo(
             shape=output_shape,
             dtype=call.args[0].struct_info.dtype,
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index 4372dd0aa6..22ab22e940 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -107,11 +107,12 @@ TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0")
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.ccl.scatter_from_worker0 */
-TVM_REGISTER_NODE_TYPE(ScatterFromWorker0Attrs);
+TVM_REGISTER_NODE_TYPE(ScatterCollectiveAttrs);
 
-Expr scatter_from_worker0(Expr data, int num_workers) {
-  ObjectPtr<ScatterFromWorker0Attrs> attrs = 
make_object<ScatterFromWorker0Attrs>();
+Expr scatter_from_worker0(Expr data, int num_workers, int axis) {
+  ObjectPtr<ScatterCollectiveAttrs> attrs = 
make_object<ScatterCollectiveAttrs>();
   attrs->num_workers = std::move(num_workers);
+  attrs->axis = std::move(axis);
   static const Op& op = Op::Get("relax.ccl.scatter_from_worker0");
 
   return Call(op, {std::move(data)}, Attrs{attrs}, {});
@@ -119,11 +120,11 @@ Expr scatter_from_worker0(Expr data, int num_workers) {
 
 
TVM_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0);
 
-StructInfo InferStructInfoScatterFromWorker0(const Call& call, const 
BlockBuilder& ctx) {
+StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) {
   TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
   DataType output_dtype = input_sinfo->dtype;
 
-  const auto* attrs = call->attrs.as<ScatterFromWorker0Attrs>();
+  const auto* attrs = call->attrs.as<ScatterCollectiveAttrs>();
   int num_workers = attrs->num_workers;
 
   arith::Analyzer* analyzer = ctx->GetAnalyzer();
@@ -139,7 +140,7 @@ StructInfo InferStructInfoScatterFromWorker0(const Call& 
call, const BlockBuilde
   }
 
   Array<PrimExpr> output_shape = input_shape.value();
-  output_shape.Set(0, div(output_shape[0], num_workers));
+  output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers));
   if (input_sinfo->vdevice.defined()) {
     return TensorStructInfo(ShapeExpr(output_shape), output_dtype, 
input_sinfo->vdevice.value());
   }
@@ -150,8 +151,8 @@ TVM_REGISTER_OP("relax.ccl.scatter_from_worker0")
     .set_num_inputs(1)
     .add_argument("x", "Tensor",
                   "The buffer to be divided into equal parts and sent to each 
worker accordingly.")
-    .set_attrs_type<ScatterFromWorker0Attrs>()
-    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoScatterFromWorker0)
+    .set_attrs_type<ScatterCollectiveAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatter)
     .set_attr<Bool>("FPurity", Bool(true));
 
 }  // namespace relax
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index 7742997d5a..3e7f0220c9 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -42,7 +42,7 @@ Expr allgather(Expr data, Expr num_workers);
 Expr broadcast_from_worker0(Expr data);
 
 /*! \brief Perform a scatter operation from worker-0, chunking the given 
buffer into equal parts. */
-Expr scatter_from_worker0(Expr data, int num_workers);
+Expr scatter_from_worker0(Expr data, int num_workers, int axis);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py 
b/tests/python/relax/test_transform_legalize_ops_ccl.py
index bb2e8f3394..63563ee3c9 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -20,6 +20,7 @@ import tvm.testing
 from tvm.relax.transform import LegalizeOps
 from tvm.script import ir as I
 from tvm.script import relax as R
+from tvm.script import tir as T
 
 
 def test_allreduce():
@@ -101,14 +102,39 @@ def test_scatter_from_worker0():
     class ScatterFromWorker0:
         @R.function
         def main(x: R.Tensor((10, 10), "float32"))  -> R.Tensor((5, 10), 
"float32"):
-            gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x, 
2)
+            gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x, 
num_workers=2, axis=1)
             return gv0
 
     @I.ir_module
     class Expected:
+        @T.prim_func(private=True)
+        def reshape(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), 
T_reshape: T.Buffer((T.int64(10), T.int64(2), T.int64(5)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for ax0, ax1, ax2 in T.grid(T.int64(10), T.int64(2), T.int64(5)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(A[((v_ax1 * T.int64(5) + v_ax2) // T.int64(10) + 
v_ax0) % T.int64(10), (v_ax1 * T.int64(5) + v_ax2) % T.int64(10)])
+                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+                    T_reshape[v_ax0, v_ax1, v_ax2] = A[((v_ax1 * T.int64(5) + 
v_ax2) // T.int64(10) + v_ax0) % T.int64(10), (v_ax1 * T.int64(5) + v_ax2) % 
T.int64(10)]
+
+        @T.prim_func(private=True)
+        def transpose(A: T.Buffer((T.int64(10), T.int64(2), T.int64(5)), 
"float32"), T_transpose: T.Buffer((T.int64(2), T.int64(10), T.int64(5)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(10), T.int64(5)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(A[v_ax1, v_ax0, v_ax2])
+                    T.writes(T_transpose[v_ax0, v_ax1, v_ax2])
+                    T_transpose[v_ax0, v_ax1, v_ax2] = A[v_ax1, v_ax0, v_ax2]
+
         @R.function
-        def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((5, 10), 
dtype="float32"):
-            gv0: R.Tensor((5, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.scatter_from_worker0", x, 
out_sinfo=R.Tensor((5, 10), dtype="float32"))
+        def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 5), 
dtype="float32"):
+            cls = Expected
+            gv = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((10, 2, 5), 
dtype="float32"))
+            gv1 = R.call_tir(cls.transpose, (gv,), out_sinfo=R.Tensor((2, 10, 
5), dtype="float32"))
+            gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0", 
(gv1,), out_sinfo=R.Tensor((10, 5), dtype="float32"))
             return gv0
     # fmt: on
 

Reply via email to