This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 384f9b69ea [Unity] Add `axis` field to scatter_from_worker0 (#16092)
384f9b69ea is described below
commit 384f9b69eae44b8331572f2c43be6a7b87913d10
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Nov 9 07:39:29 2023 -0800
[Unity] Add `axis` field to scatter_from_worker0 (#16092)
This PR adds an `axis` field to scatter_from_worker0, which means the
tensor axis along which it is scattered. legalize_ops will automatically
generate reshape and transpose to preserve the constraint of ccl that
collective communication ops must be performed on contiguous memory. For
example, if the tensor shape of x is [10, 20], and we have
`scatter_from_worker0(x, num_workers=2, axis=1)`, then after legalization it
will expand to
```
x = reshape(x, [10, 2, 10]) # shape: [10, 2, 10]
x = permute_dims(x, [1, 0, 2]) # shape: [2, 10, 10]
x = call_dps_packed("scatter_from_worker0", x) # shape: [10, 10]
```
When axis=0, the behavior is the same as before.
Also, this PR renames ScatterFromWorker0Attrs to ScatterAttrs to enable
reuse by other ops like worker-id-aware slicing (scatter_from_worker0 =
broadcast_from_worker0 + worker-id-aware slicing).
---
include/tvm/relax/attrs/ccl.h | 12 ++++--
python/tvm/relax/op/ccl/ccl.py | 7 +++-
python/tvm/relax/transform/legalize_ops/ccl.py | 45 ++++++++++++++--------
src/relax/op/ccl/ccl.cc | 17 ++++----
src/relax/op/ccl/ccl.h | 2 +-
.../relax/test_transform_legalize_ops_ccl.py | 32 +++++++++++++--
6 files changed, 82 insertions(+), 33 deletions(-)
diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h
index b4b3880384..42cec88de6 100644
--- a/include/tvm/relax/attrs/ccl.h
+++ b/include/tvm/relax/attrs/ccl.h
@@ -40,17 +40,21 @@ struct AllReduceAttrs : public
tvm::AttrsNode<AllReduceAttrs> {
}
}; // struct AllReduceAttrs
-/*! \brief Attributes used in scatter_from_worker0 operators */
-struct ScatterFromWorker0Attrs : public
tvm::AttrsNode<ScatterFromWorker0Attrs> {
+/*! \brief Attributes used in scatter operators */
+struct ScatterCollectiveAttrs : public tvm::AttrsNode<ScatterCollectiveAttrs> {
int num_workers;
+ int axis;
- TVM_DECLARE_ATTRS(ScatterFromWorker0Attrs,
"relax.attrs.ScatterFromWorker0Attrs") {
+ TVM_DECLARE_ATTRS(ScatterCollectiveAttrs,
"relax.attrs.ScatterCollectiveAttrs") {
TVM_ATTR_FIELD(num_workers)
.describe(
"The number of workers, also the number of parts the given buffer
should be chunked "
"into.");
+ TVM_ATTR_FIELD(axis).describe(
+ "The axis of the tensor to be scattered. The tensor will be chunked
along "
+ "this axis.");
}
-}; // struct ScatterFromWorker0Attrs
+}; // struct ScatterCollectiveAttrs
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 4829bac761..21c7946120 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -84,7 +84,7 @@ def broadcast_from_worker0(x: Expr) -> Expr:
return _ffi_api.broadcast_from_worker0(x)
-def scatter_from_worker0(x: Expr, num_workers: int) -> Expr:
+def scatter_from_worker0(x: Expr, num_workers: int, axis: int = 0) -> Expr:
"""Perform a scatter operation from worker-0, chunking the given buffer
into equal parts.
Parameters
@@ -95,9 +95,12 @@ def scatter_from_worker0(x: Expr, num_workers: int) -> Expr:
num_worker : int
The number of workers, i.e. the number of parts the given buffer should
be chunked into.
+ axis : int
+ The dimension of the tensor to be scattered. Default is 0.
+
Returns
-------
result : relax.Expr
Chunked Tensor received by different workers.
"""
- return _ffi_api.scatter_from_worker0(x, num_workers)
+ return _ffi_api.scatter_from_worker0(x, num_workers, axis)
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py
b/python/tvm/relax/transform/legalize_ops/ccl.py
index 9b13d1be7c..ae0be3c228 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Default legalization function for ccl operators."""
-from tvm import tir, arith
+from tvm import tir, arith, topi
from ...block_builder import BlockBuilder
from ...expr import Call, Expr, ShapeExpr
from ...op import call_dps_packed
@@ -80,28 +80,43 @@ def _broadcast_from_worker0(_bb: BlockBuilder, call: Call)
-> Expr:
)
-@register_legalize("relax.ccl.scatter_from_worker0")
-def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
- output_shape = []
+# Since collective communication ops are performed on contiguous memory,
+# we need to reshape and transpose the input tensor to make sharding dimension
in the highest order
+def _transpose_for_ccl(_bb: BlockBuilder, expr: Expr, axis: int, num_workers:
int):
assert isinstance(
- call.args[0].struct_info, TensorStructInfo
- ), "The input struct info of scatter_from_worker0 should be
TensorStructInfo."
- assert isinstance(call.args[0].struct_info.shape.struct_info,
ShapeStructInfo)
- arg_shape = call.args[0].struct_info.shape.struct_info
+ expr.struct_info, TensorStructInfo
+ ), "The input struct info should be TensorStructInfo."
+ assert isinstance(expr.struct_info.shape.struct_info, ShapeStructInfo)
+ arg_shape = expr.struct_info.shape.struct_info
+ new_shape = []
for i, shape_value in enumerate(arg_shape.values):
- if i == 0:
- modulo = arith.Analyzer().simplify(shape_value %
call.attrs.num_workers)
+ if i == axis:
+ modulo = arith.Analyzer().simplify(shape_value % num_workers)
assert modulo == 0, (
- "scatter_from_worker0 expects the size of axis 0 of input
tensor "
+ f"scatter_from_worker0 expects the size of axis {axis} of
input tensor "
"to be divisible by num_workers. However, the axis 0 of input
tensor "
- f"is {shape_value} while num_workers is
{call.attrs.num_workers}"
+ f"is {shape_value} while num_workers is {num_workers}"
)
- output_shape.append(tir.div(shape_value, call.attrs.num_workers))
+ new_shape.append(num_workers)
+ new_shape.append(tir.div(shape_value, num_workers))
else:
- output_shape.append(shape_value)
+ new_shape.append(shape_value)
+ reshape_var = _bb.emit_te(topi.reshape, expr, new_shape)
+ if axis == 0:
+ return reshape_var
+ permute_order = [axis] + list(range(axis)) + list(range(axis + 1,
len(new_shape)))
+ transpose_var = _bb.emit_te(topi.transpose, reshape_var, permute_order)
+ return transpose_var
+
+
+@register_legalize("relax.ccl.scatter_from_worker0")
+def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+ transpose_var = _transpose_for_ccl(_bb, call.args[0], call.attrs.axis,
call.attrs.num_workers)
+ output_shape = transpose_var.struct_info.shape.struct_info.values
+ output_shape = output_shape[1:]
return call_dps_packed(
"runtime.disco.scatter_from_worker0",
- call.args[0],
+ transpose_var,
out_sinfo=TensorStructInfo(
shape=output_shape,
dtype=call.args[0].struct_info.dtype,
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index 4372dd0aa6..22ab22e940 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -107,11 +107,12 @@ TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0")
.set_attr<Bool>("FPurity", Bool(true));
/* relax.ccl.scatter_from_worker0 */
-TVM_REGISTER_NODE_TYPE(ScatterFromWorker0Attrs);
+TVM_REGISTER_NODE_TYPE(ScatterCollectiveAttrs);
-Expr scatter_from_worker0(Expr data, int num_workers) {
- ObjectPtr<ScatterFromWorker0Attrs> attrs =
make_object<ScatterFromWorker0Attrs>();
+Expr scatter_from_worker0(Expr data, int num_workers, int axis) {
+ ObjectPtr<ScatterCollectiveAttrs> attrs =
make_object<ScatterCollectiveAttrs>();
attrs->num_workers = std::move(num_workers);
+ attrs->axis = std::move(axis);
static const Op& op = Op::Get("relax.ccl.scatter_from_worker0");
return Call(op, {std::move(data)}, Attrs{attrs}, {});
@@ -119,11 +120,11 @@ Expr scatter_from_worker0(Expr data, int num_workers) {
TVM_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0);
-StructInfo InferStructInfoScatterFromWorker0(const Call& call, const
BlockBuilder& ctx) {
+StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
DataType output_dtype = input_sinfo->dtype;
- const auto* attrs = call->attrs.as<ScatterFromWorker0Attrs>();
+ const auto* attrs = call->attrs.as<ScatterCollectiveAttrs>();
int num_workers = attrs->num_workers;
arith::Analyzer* analyzer = ctx->GetAnalyzer();
@@ -139,7 +140,7 @@ StructInfo InferStructInfoScatterFromWorker0(const Call&
call, const BlockBuilde
}
Array<PrimExpr> output_shape = input_shape.value();
- output_shape.Set(0, div(output_shape[0], num_workers));
+ output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers));
if (input_sinfo->vdevice.defined()) {
return TensorStructInfo(ShapeExpr(output_shape), output_dtype,
input_sinfo->vdevice.value());
}
@@ -150,8 +151,8 @@ TVM_REGISTER_OP("relax.ccl.scatter_from_worker0")
.set_num_inputs(1)
.add_argument("x", "Tensor",
"The buffer to be divided into equal parts and sent to each
worker accordingly.")
- .set_attrs_type<ScatterFromWorker0Attrs>()
- .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoScatterFromWorker0)
+ .set_attrs_type<ScatterCollectiveAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatter)
.set_attr<Bool>("FPurity", Bool(true));
} // namespace relax
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index 7742997d5a..3e7f0220c9 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -42,7 +42,7 @@ Expr allgather(Expr data, Expr num_workers);
Expr broadcast_from_worker0(Expr data);
/*! \brief Perform a scatter operation from worker-0, chunking the given
buffer into equal parts. */
-Expr scatter_from_worker0(Expr data, int num_workers);
+Expr scatter_from_worker0(Expr data, int num_workers, int axis);
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py
b/tests/python/relax/test_transform_legalize_ops_ccl.py
index bb2e8f3394..63563ee3c9 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -20,6 +20,7 @@ import tvm.testing
from tvm.relax.transform import LegalizeOps
from tvm.script import ir as I
from tvm.script import relax as R
+from tvm.script import tir as T
def test_allreduce():
@@ -101,14 +102,39 @@ def test_scatter_from_worker0():
class ScatterFromWorker0:
@R.function
def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((5, 10),
"float32"):
- gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x,
2)
+ gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x,
num_workers=2, axis=1)
return gv0
@I.ir_module
class Expected:
+ @T.prim_func(private=True)
+ def reshape(A: T.Buffer((T.int64(10), T.int64(10)), "float32"),
T_reshape: T.Buffer((T.int64(10), T.int64(2), T.int64(5)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for ax0, ax1, ax2 in T.grid(T.int64(10), T.int64(2), T.int64(5)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(A[((v_ax1 * T.int64(5) + v_ax2) // T.int64(10) +
v_ax0) % T.int64(10), (v_ax1 * T.int64(5) + v_ax2) % T.int64(10)])
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+ T_reshape[v_ax0, v_ax1, v_ax2] = A[((v_ax1 * T.int64(5) +
v_ax2) // T.int64(10) + v_ax0) % T.int64(10), (v_ax1 * T.int64(5) + v_ax2) %
T.int64(10)]
+
+ @T.prim_func(private=True)
+ def transpose(A: T.Buffer((T.int64(10), T.int64(2), T.int64(5)),
"float32"), T_transpose: T.Buffer((T.int64(2), T.int64(10), T.int64(5)),
"float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(10), T.int64(5)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(A[v_ax1, v_ax0, v_ax2])
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2])
+ T_transpose[v_ax0, v_ax1, v_ax2] = A[v_ax1, v_ax0, v_ax2]
+
@R.function
- def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((5, 10),
dtype="float32"):
- gv0: R.Tensor((5, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.scatter_from_worker0", x,
out_sinfo=R.Tensor((5, 10), dtype="float32"))
+ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 5),
dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((10, 2, 5),
dtype="float32"))
+ gv1 = R.call_tir(cls.transpose, (gv,), out_sinfo=R.Tensor((2, 10,
5), dtype="float32"))
+ gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0",
(gv1,), out_sinfo=R.Tensor((10, 5), dtype="float32"))
return gv0
# fmt: on