This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6f650db042 [Unity][DistIR] Legalize redistribute (#16098)
6f650db042 is described below
commit 6f650db0423877bda16798b0593a6ebab145c8cb
Author: Hongyi Jin <[email protected]>
AuthorDate: Tue Nov 14 10:09:39 2023 -0800
[Unity][DistIR] Legalize redistribute (#16098)
* legalize redistribute
* format
* add whitespace
* fix lint
* fix ci
* address comments
* fix lint
* fix
* fix test
---
include/tvm/relax/distributed/transform.h | 7 ++
python/tvm/relax/distributed/transform/__init__.py | 2 +-
.../tvm/relax/distributed/transform/transform.py | 13 +++
python/tvm/relax/op/distributed/__init__.py | 2 +-
python/tvm/relax/op/distributed/distributed.py | 26 +++++
.../tvm/relax/transform/legalize_ops/__init__.py | 1 +
.../relax/transform/legalize_ops/distributed.py | 43 +++++++
.../tvm/script/ir_builder/relax/distributed/ir.py | 3 +-
python/tvm/script/parser/relax/__init__.py | 27 ++---
python/tvm/script/parser/relax/dist.py | 8 +-
.../distributed/transform/legalize_redistribute.cc | 123 +++++++++++++++++++++
src/relax/op/distributed/distributed.cc | 82 ++++++++++++++
src/relax/op/distributed/distributed.h | 12 ++
..._distributed_transform_legalize_redistribute.py | 69 ++++++++++++
tests/python/relax/test_op_distributed.py | 58 ++++++++++
.../test_transform_legalize_ops_distributed.py | 64 +++++++++++
16 files changed, 521 insertions(+), 19 deletions(-)
diff --git a/include/tvm/relax/distributed/transform.h
b/include/tvm/relax/distributed/transform.h
index 335d9a62e9..5cf15ee65e 100644
--- a/include/tvm/relax/distributed/transform.h
+++ b/include/tvm/relax/distributed/transform.h
@@ -48,6 +48,13 @@ using DataflowBlock = tvm::relax::DataflowBlock;
*/
TVM_DLL Pass PropagateSharding();
+/*!
+ * \brief Legalize redistribute op to ccl op.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass LegalizeRedistribute();
+
} // namespace transform
} // namespace distributed
} // namespace relax
diff --git a/python/tvm/relax/distributed/transform/__init__.py
b/python/tvm/relax/distributed/transform/__init__.py
index 9d1e780030..517189392a 100644
--- a/python/tvm/relax/distributed/transform/__init__.py
+++ b/python/tvm/relax/distributed/transform/__init__.py
@@ -16,4 +16,4 @@
# under the License.
"""Relax distributed-related transformations. """
-from .transform import PropagateSharding
+from .transform import PropagateSharding, LegalizeRedistribute
diff --git a/python/tvm/relax/distributed/transform/transform.py
b/python/tvm/relax/distributed/transform/transform.py
index 7c8bc49e0d..b1f6bde044 100644
--- a/python/tvm/relax/distributed/transform/transform.py
+++ b/python/tvm/relax/distributed/transform/transform.py
@@ -30,3 +30,16 @@ def PropagateSharding() -> tvm.ir.transform.Pass:
The registered pass
"""
return _ffi_api.PropagateSharding() # type: ignore
+
+
+def LegalizeRedistribute() -> tvm.ir.transform.Pass:
+ """Legalize redistribute op to ccl op.
+ S->R: R.ccl.allgather
+ R->S: R.dist.redistribute_replica_to_shard
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass
+ """
+ return _ffi_api.LegalizeRedistribute() # type: ignore
diff --git a/python/tvm/relax/op/distributed/__init__.py
b/python/tvm/relax/op/distributed/__init__.py
index 35226b8790..cd38b90bce 100644
--- a/python/tvm/relax/op/distributed/__init__.py
+++ b/python/tvm/relax/op/distributed/__init__.py
@@ -16,4 +16,4 @@
# under the License.
"""Operators serving for distributed Relax."""
-from .distributed import annotate_sharding, redistribute
+from .distributed import annotate_sharding, redistribute,
redistribute_replica_to_shard
diff --git a/python/tvm/relax/op/distributed/distributed.py
b/python/tvm/relax/op/distributed/distributed.py
index bcc5ca2f49..7e0e6fa23d 100644
--- a/python/tvm/relax/op/distributed/distributed.py
+++ b/python/tvm/relax/op/distributed/distributed.py
@@ -59,3 +59,29 @@ def redistribute(input: Expr, device_mesh: DeviceMesh,
placement: Placement) ->
The tensor after redistribution.
"""
return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore
+
+
+def redistribute_replica_to_shard(input: Expr, num_workers: int, axis: int) ->
Expr:
+ """Slice tensor into several parts along one axis,
+ and each worker takes one part.
+ input.struct_info.shape[axis] % num_workers == 0 is required.
+ Each worker must have an identical copy of the input.
+ This is a specialized version of redistribute op.
+
+ Parameters
+ ----------
+ input : relax.Expr
+ The buffer to be sliced into equal parts.
+
+ num_worker : int
+ The number of workers, i.e. the number of parts the given buffer should
be sliced into.
+
+ axis : int
+ The axis of the tensor to be sliced.
+
+ Returns
+ -------
+ result : relax.Expr
+ Sliced Tensor kept by each device.
+ """
+ return _ffi_api.redistribute_replica_to_shard(input, num_workers, axis)
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py
b/python/tvm/relax/transform/legalize_ops/__init__.py
index e222e7d4ef..e3b3213a38 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/transform/legalize_ops/__init__.py
@@ -19,6 +19,7 @@ from . import binary
from . import ccl
from . import create
from . import datatype
+from . import distributed
from . import grad
from . import image
from . import index
diff --git a/python/tvm/relax/transform/legalize_ops/distributed.py
b/python/tvm/relax/transform/legalize_ops/distributed.py
new file mode 100644
index 0000000000..d540628e0e
--- /dev/null
+++ b/python/tvm/relax/transform/legalize_ops/distributed.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Default legalization function for distir-related operators."""
+from tvm import tir, relax
+from ...block_builder import BlockBuilder
+from ...expr import Call, Expr
+from ...op import call_pure_packed
+from ...struct_info import ShapeStructInfo
+from .common import register_legalize
+
+
+@register_legalize("relax.dist.redistribute_replica_to_shard")
+def _redistribute_replica_to_shard(_bb: BlockBuilder, call: Call) -> Expr:
+ num_workers = call.attrs.num_workers
+ axis = call.attrs.axis
+ worker_id_symbol = tir.Var("worker_id", "int64")
+ worker_id_var = _bb.emit(
+ call_pure_packed("runtime.disco.worker_id",
sinfo_args=[ShapeStructInfo(None)])
+ )
+ _bb.match_cast(worker_id_var, ShapeStructInfo([worker_id_symbol]))
+
+ split_axis_size = call.args[0].struct_info.shape[axis]
+ return relax.op.strided_slice(
+ call.args[0],
+ axes=[axis],
+ begin=[worker_id_symbol * split_axis_size // num_workers],
+ end=[(worker_id_symbol + 1) * split_axis_size // num_workers],
+ )
diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py
b/python/tvm/script/ir_builder/relax/distributed/ir.py
index 51ef84620c..28ff6ae732 100644
--- a/python/tvm/script/ir_builder/relax/distributed/ir.py
+++ b/python/tvm/script/ir_builder/relax/distributed/ir.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=redefined-builtin, wrong-import-order, no-member,
invalid-name
+# pylint: disable=redefined-builtin, wrong-import-order, no-member,
invalid-name, unused-import
"""IRBuilder for distributed Relax dialect"""
from typing import Union, List, Tuple, Optional
@@ -32,6 +32,7 @@ from tvm.runtime import ndarray as _nd
from tvm.relax.op.distributed import (
redistribute as _redistribute,
annotate_sharding as _annotate_sharding,
+ redistribute_replica_to_shard,
)
from tvm.relax.distributed import DeviceMesh, Placement
from . import _ffi_api
diff --git a/python/tvm/script/parser/relax/__init__.py
b/python/tvm/script/parser/relax/__init__.py
index 0b1ed168de..704189060b 100644
--- a/python/tvm/script/parser/relax/__init__.py
+++ b/python/tvm/script/parser/relax/__init__.py
@@ -34,18 +34,15 @@ if TYPE_CHECKING:
else:
from .entry import function, macro
-__all__ = (
- _relax.__all__
- + dist.__all__
- + [
- "Callable",
- "Object",
- "Prim",
- "Shape",
- "Tensor",
- "Tuple",
- "function",
- "macro",
- "match_cast",
- ]
-)
+__all__ = _relax.__all__ + [
+ "dist",
+ "Callable",
+ "Object",
+ "Prim",
+ "Shape",
+ "Tensor",
+ "Tuple",
+ "function",
+ "macro",
+ "match_cast",
+]
diff --git a/python/tvm/script/parser/relax/dist.py
b/python/tvm/script/parser/relax/dist.py
index f9c78f980f..c4e81a2ae2 100644
--- a/python/tvm/script/parser/relax/dist.py
+++ b/python/tvm/script/parser/relax/dist.py
@@ -25,7 +25,13 @@ from tvm.tir import PrimExpr
from tvm.relax.distributed import DeviceMesh, Placement, DTensorStructInfo,
device_mesh
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder.ir import IRModuleFrame
-from tvm.script.ir_builder.relax.distributed import call_tir, const,
annotate_sharding, redistribute
+from tvm.script.ir_builder.relax.distributed import (
+ call_tir,
+ const,
+ annotate_sharding,
+ redistribute,
+ redistribute_replica_to_shard,
+)
from .entry import StructInfoProxy, TensorProxy
diff --git a/src/relax/distributed/transform/legalize_redistribute.cc
b/src/relax/distributed/transform/legalize_redistribute.cc
new file mode 100644
index 0000000000..5a67f0351c
--- /dev/null
+++ b/src/relax/distributed/transform/legalize_redistribute.cc
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/distributed/transform/legalize_redistribute.cc
+ * \brief Pass for legalizing redistribute op to ccl op.
+ */
+
+#include <tvm/relax/attrs/ccl.h>
+#include <tvm/relax/attrs/distributed.h>
+#include <tvm/relax/distributed/axis_group_graph.h>
+#include <tvm/relax/distributed/transform.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../../tir/schedule/transform.h"
+#include "../../op/ccl/ccl.h"
+#include "../../op/distributed/distributed.h"
+
+namespace tvm {
+namespace relax {
+namespace distributed {
+
+class RedistributeLegalizer : public ExprMutator {
+ public:
+ static IRModule LegalizeRedistribute(IRModule mod) {
+ return RedistributeLegalizer(mod).Legalize();
+ }
+
+ private:
+ explicit RedistributeLegalizer(IRModule mod) : ExprMutator(mod) {}
+
+ IRModule Legalize() {
+ auto mod = builder_->GetContextIRModule();
+ for (const auto& [gv, base_func] : mod->functions) {
+ const auto* func_ = base_func.as<FunctionNode>();
+ if (func_ == nullptr) {
+ continue;
+ }
+ Expr new_func_body = VisitExpr(func_->body);
+ auto new_func = make_object<FunctionNode>(*func_);
+ new_func->body = new_func_body;
+ builder_->UpdateFunction(gv, Function(new_func));
+ }
+ return builder_->GetContextIRModule();
+ }
+ using ExprMutator::VisitExpr_;
+ Expr VisitExpr_(const CallNode* op) final {
+ Call call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+ static Op redistribute_op = Op::Get("relax.dist.redistribute");
+ if (call->op.same_as(redistribute_op)) {
+ const auto* attrs = call->attrs.as<DistributionAttrs>();
+ ICHECK(attrs);
+ const auto* input_sinfo =
call->args[0]->struct_info_.as<DTensorStructInfoNode>();
+ ICHECK(input_sinfo);
+ // As the first step, we only support redistribute in the same device
mesh,
+ // and the device mesh must be 1d
+ // todo: extend the ccl ops so that it can support 2d device mesh, and
different sharding
+ // dimension
+ ICHECK(StructuralEqual()(input_sinfo->device_mesh, attrs->device_mesh));
+ ICHECK(input_sinfo->device_mesh->shape.size() == 1);
+ // only support "S[x]"-> "R" and "R" -> "S[x]"
+ PlacementSpec input_spec = input_sinfo->placement->dim_specs[0];
+ PlacementSpec output_spec = attrs->placement->dim_specs[0];
+ if (input_spec->kind == PlacementSpecKind::kReplica &&
+ output_spec->kind == PlacementSpecKind::kReplica) {
+ // "R" -> "R"
+ return call->args[0];
+ } else if (input_spec->kind == PlacementSpecKind::kSharding &&
+ output_spec->kind == PlacementSpecKind::kSharding) {
+ // "S[x]" -> "S[y]"
+ if (input_spec->axis != output_spec->axis) {
+ LOG(FATAL) << "AlltoAll not implemented yet";
+ } else {
+ return call->args[0];
+ }
+ } else if (input_spec->kind == PlacementSpecKind::kSharding &&
+ output_spec->kind == PlacementSpecKind::kReplica) {
+ // "S[x]" -> "R"
+ LOG(FATAL) << "Allgather not implemented yet";
+ } else if (input_spec->kind == PlacementSpecKind::kReplica &&
+ output_spec->kind == PlacementSpecKind::kSharding) {
+ // "R" -> "S[x]"
+ return redistribute_replica_to_shard(call->args[0],
attrs->device_mesh->shape[0],
+ output_spec->axis);
+ } else {
+ LOG(FATAL) << "Unsupported redistribute op";
+ }
+ }
+ return call;
+ }
+};
+
+namespace transform {
+
+Pass LegalizeRedistribute() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule m, PassContext pc) { return
RedistributeLegalizer::LegalizeRedistribute(m); };
+ return CreateModulePass(pass_func, 1, "LegalizeRedistribute", {});
+}
+TVM_REGISTER_GLOBAL("relax.distributed.transform.LegalizeRedistribute")
+ .set_body_typed(LegalizeRedistribute);
+} // namespace transform
+
+} // namespace distributed
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/distributed/distributed.cc
b/src/relax/op/distributed/distributed.cc
index 6def4ef019..658b7cde34 100644
--- a/src/relax/op/distributed/distributed.cc
+++ b/src/relax/op/distributed/distributed.cc
@@ -24,6 +24,7 @@
#include "distributed.h"
+#include <tvm/relax/attrs/ccl.h>
#include <tvm/topi/einsum.h>
#include <algorithm>
@@ -85,5 +86,86 @@ TVM_REGISTER_OP("relax.dist.redistribute")
.set_attr<FInferStructInfo>("dist.FInferStructInfo",
InferDistStructInfoRedistribute)
.set_attr<Bool>("FPurity", Bool(true));
+StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) {
+ TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ DataType output_dtype = input_sinfo->dtype;
+
+ const auto* attrs = call->attrs.as<ScatterCollectiveAttrs>();
+ int num_workers = attrs->num_workers;
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ auto input_shape = input_sinfo->GetShape();
+ CHECK(input_shape.defined())
+ << "input tensor of redistribute_replica_to_shard should have defined
shape.";
+
+ if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis],
PrimExpr(num_workers))) != 0) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "redistribute_replica_to_shard expects the size of
axis " << attrs->axis
+ << " of input tensor to be "
+ "divisible by the "
+ "num_workers. However, the axis "
+ << attrs->axis << " of input tensor is " <<
input_shape.value()[attrs->axis]
+ << " while num_workers is " << num_workers);
+ }
+
+ Array<PrimExpr> output_shape = input_shape.value();
+ output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers));
+ if (input_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), output_dtype,
input_sinfo->vdevice.value());
+ }
+ return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
+}
+
+StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) {
+ using namespace distributed;
+ Array<DTensorStructInfo> input_dtensor_sinfos =
GetInputDTensorStructInfo(call, ctx);
+ ICHECK(input_dtensor_sinfos.size() == 1);
+ DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0];
+ TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo;
+ const auto* attrs = call->attrs.as<ScatterCollectiveAttrs>();
+ int num_workers = attrs->num_workers;
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ auto input_shape = tensor_sinfo->GetShape();
+ CHECK(input_shape.defined())
+ << "input tensor of redistribute_replica_to_shard should have defined
shape.";
+
+ if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis],
PrimExpr(num_workers))) != 0) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "redistribute_replica_to_shard expects the size of
axis " << attrs->axis
+ << " of input tensor to be "
+ "divisible by the "
+ "num_workers. However, the axis "
+ << attrs->axis << " of input tensor is " <<
input_shape.value()[attrs->axis]
+ << " while num_workers is " << num_workers);
+ }
+
+ DeviceMesh device_mesh = input_dtensor_sinfo->device_mesh;
+ // FIXME: this is a hack where there's only 1d mesh
+ ICHECK(device_mesh->shape.size() == 1);
+ ICHECK(input_dtensor_sinfo->placement->dim_specs[0]->kind ==
PlacementSpecKind::kReplica);
+ return DTensorStructInfo(tensor_sinfo, device_mesh,
+ Placement::FromText("S[" +
std::to_string(attrs->axis) + "]"));
+}
+
+Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis) {
+ ObjectPtr<ScatterCollectiveAttrs> attrs =
make_object<ScatterCollectiveAttrs>();
+ attrs->num_workers = std::move(num_workers);
+ attrs->axis = std::move(axis);
+ static const Op& op = Op::Get("relax.dist.redistribute_replica_to_shard");
+
+ return Call(op, {std::move(input)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.dist.redistribute_replica_to_shard")
+ .set_body_typed(redistribute_replica_to_shard);
+
+TVM_REGISTER_OP("relax.dist.redistribute_replica_to_shard")
+ .set_num_inputs(1)
+ .add_argument("input", "Tensor", "The buffer to be sliced.")
+ .set_attrs_type<ScatterCollectiveAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRtoS)
+ .set_attr<FInferStructInfo>("dist.FInferStructInfo",
InferDistStructInfoRtoS)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/distributed/distributed.h
b/src/relax/op/distributed/distributed.h
index 06f926edb1..67ccc0fc6b 100644
--- a/src/relax/op/distributed/distributed.h
+++ b/src/relax/op/distributed/distributed.h
@@ -27,6 +27,7 @@
#include <tvm/relax/attrs/distributed.h>
#include "../op_common.h"
+#include "utils.h"
namespace tvm {
namespace relax {
@@ -51,6 +52,17 @@ Expr annotate_sharding(Expr input, distributed::DeviceMesh
device_mesh,
Expr redistribute(Expr input, distributed::DeviceMesh device_mesh,
distributed::Placement placement);
+/*!
+ * \brief slice tensor into several parts along one axis,
+ and each worker takes one part.
+ Assumes input is already broadcasted.
+ This is a specialized version of redistribute op.
+ * \param input The input tensor.
+ * \param num_workers The number of workers.
+ * \param axis The tensor axis to slice.
+ * \return The result.
+ */
+Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis);
} // namespace relax
} // namespace tvm
diff --git
a/tests/python/relax/distributed/test_distributed_transform_legalize_redistribute.py
b/tests/python/relax/distributed/test_distributed_transform_legalize_redistribute.py
new file mode 100644
index 0000000000..92e9e5a4d4
--- /dev/null
+++
b/tests/python/relax/distributed/test_distributed_transform_legalize_redistribute.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# type: ignore
+from tvm.script.parser import ir as I
+from tvm.script.parser import relax as R
+import tvm
+from tvm import relax
+import tvm.testing
+
+
+def test_simple():
+ @I.ir_module
+ class Before:
+ I.module_attrs({"device_num": 2})
+ I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
+
+ @R.function
+ def foo(
+ x1: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
+ x2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
+ ):
+ R.func_attr({"num_input": 1})
+ # scatter
+ lv0 = R.dist.redistribute(x1, "mesh[0]", "S[1]")
+ # do nothing
+ lv1 = R.dist.redistribute(x2, "mesh[0]", "S[0]")
+ return (lv0, lv1)
+
+ @I.ir_module
+ class Expected:
+ I.module_attrs({"device_num": 2})
+ I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
+
+ @R.function
+ def foo(
+ x1: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
+ x2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
+ ) -> R.Tuple(
+ R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
+ R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
+ ):
+ R.func_attr({"num_input": 1})
+ lv0: R.DTensor(
+ (128, 128), "float32", "mesh[0]", "S[1]"
+ ) = R.dist.redistribute_replica_to_shard(x1, num_workers=2, axis=1)
+ lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]") = x2
+ return (lv0, lv1)
+
+ after = relax.distributed.transform.LegalizeRedistribute()(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_op_distributed.py
b/tests/python/relax/test_op_distributed.py
new file mode 100644
index 0000000000..dc5440c0c1
--- /dev/null
+++ b/tests/python/relax/test_op_distributed.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+from tvm._ffi.base import TVMError
+import tvm.testing
+from tvm import relax
+from tvm.script.parser import relax as R
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
+ ret = bb.normalize(call)
+ tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_redistribute_R_to_S():
+ bb = relax.BlockBuilder()
+ mesh = R.device_mesh((4,), list(range(4)))
+ x = relax.Var("x", R.DTensor((3, 4), "float32", device_mesh=mesh,
placement="R"))
+
+ _check_inference(
+ bb,
+ R.distributed.redistribute_replica_to_shard(x, num_workers=4, axis=1),
+ R.DTensor((3, 4), "float32", device_mesh=mesh, placement="S[1]"),
+ )
+
+ # wrong: indivisible
+ with pytest.raises(TVMError):
+ bb.normalize(R.distributed.redistribute_replica_to_shard(x,
num_workers=4, axis=0))
+
+ y = relax.Var("y", R.Tensor((3, 4), "float32"))
+ _check_inference(
+ bb,
+ R.distributed.redistribute_replica_to_shard(y, num_workers=4, axis=1),
+ R.Tensor((3, 1), "float32"),
+ )
+
+ # wrong: indivisible
+ with pytest.raises(TVMError):
+ bb.normalize(R.distributed.redistribute_replica_to_shard(y,
num_workers=4, axis=0))
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_distributed.py
b/tests/python/relax/test_transform_legalize_ops_distributed.py
new file mode 100644
index 0000000000..c1f462e36c
--- /dev/null
+++ b/tests/python/relax/test_transform_legalize_ops_distributed.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relax.transform import LegalizeOps
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_redistribute_replica_to_shard():
+ # fmt: off
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 5),
"float32"):
+ gv0 = R.dist.redistribute_replica_to_shard(x, num_workers=2,
axis=1)
+ return gv0
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def strided_slice(A: T.Buffer((T.int64(10), T.int64(10)), "float32"),
redistribute_replica_to_shard: T.Buffer((T.int64(10), T.int64(5)), "float32"),
worker_id: T.int64):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i0, i1 in T.grid(T.int64(10), T.int64(5)):
+ with T.block("redistribute_replica_to_shard"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(A[v_i0, worker_id * T.int64(5) + v_i1])
+ T.writes(redistribute_replica_to_shard[v_i0, v_i1])
+ redistribute_replica_to_shard[v_i0, v_i1] = A[v_i0,
worker_id * T.int64(5) + v_i1]
+
+ @R.function
+ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 5),
dtype="float32"):
+ worker_id = T.int64()
+ cls = Expected
+ gv: R.Shape(ndim=-1) =
R.call_pure_packed("runtime.disco.worker_id", sinfo_args=(R.Shape(ndim=-1),))
+ gv1: R.Shape([worker_id]) = R.match_cast(gv, R.Shape([worker_id]))
+ gv0 = R.call_tir(cls.strided_slice, (x,), out_sinfo=R.Tensor((10,
5), dtype="float32"), tir_vars=R.shape([worker_id]))
+ return gv0
+ # fmt: on
+
+ mod = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()