This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6f650db042 [Unity][DistIR] Legalize redistribute (#16098)
6f650db042 is described below

commit 6f650db0423877bda16798b0593a6ebab145c8cb
Author: Hongyi Jin <[email protected]>
AuthorDate: Tue Nov 14 10:09:39 2023 -0800

    [Unity][DistIR] Legalize redistribute (#16098)
    
    * legalize redistribute
    
    * format
    
    * add whitespace
    
    * fix lint
    
    * fix ci
    
    * address comments
    
    * fix lint
    
    * fix
    
    * fix test
---
 include/tvm/relax/distributed/transform.h          |   7 ++
 python/tvm/relax/distributed/transform/__init__.py |   2 +-
 .../tvm/relax/distributed/transform/transform.py   |  13 +++
 python/tvm/relax/op/distributed/__init__.py        |   2 +-
 python/tvm/relax/op/distributed/distributed.py     |  26 +++++
 .../tvm/relax/transform/legalize_ops/__init__.py   |   1 +
 .../relax/transform/legalize_ops/distributed.py    |  43 +++++++
 .../tvm/script/ir_builder/relax/distributed/ir.py  |   3 +-
 python/tvm/script/parser/relax/__init__.py         |  27 ++---
 python/tvm/script/parser/relax/dist.py             |   8 +-
 .../distributed/transform/legalize_redistribute.cc | 123 +++++++++++++++++++++
 src/relax/op/distributed/distributed.cc            |  82 ++++++++++++++
 src/relax/op/distributed/distributed.h             |  12 ++
 ..._distributed_transform_legalize_redistribute.py |  69 ++++++++++++
 tests/python/relax/test_op_distributed.py          |  58 ++++++++++
 .../test_transform_legalize_ops_distributed.py     |  64 +++++++++++
 16 files changed, 521 insertions(+), 19 deletions(-)

diff --git a/include/tvm/relax/distributed/transform.h 
b/include/tvm/relax/distributed/transform.h
index 335d9a62e9..5cf15ee65e 100644
--- a/include/tvm/relax/distributed/transform.h
+++ b/include/tvm/relax/distributed/transform.h
@@ -48,6 +48,13 @@ using DataflowBlock = tvm::relax::DataflowBlock;
  */
 TVM_DLL Pass PropagateSharding();
 
+/*!
+ * \brief Legalize redistribute op to ccl op.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass LegalizeRedistribute();
+
 }  // namespace transform
 }  // namespace distributed
 }  // namespace relax
diff --git a/python/tvm/relax/distributed/transform/__init__.py 
b/python/tvm/relax/distributed/transform/__init__.py
index 9d1e780030..517189392a 100644
--- a/python/tvm/relax/distributed/transform/__init__.py
+++ b/python/tvm/relax/distributed/transform/__init__.py
@@ -16,4 +16,4 @@
 # under the License.
 """Relax distributed-related transformations. """
 
-from .transform import PropagateSharding
+from .transform import PropagateSharding, LegalizeRedistribute
diff --git a/python/tvm/relax/distributed/transform/transform.py 
b/python/tvm/relax/distributed/transform/transform.py
index 7c8bc49e0d..b1f6bde044 100644
--- a/python/tvm/relax/distributed/transform/transform.py
+++ b/python/tvm/relax/distributed/transform/transform.py
@@ -30,3 +30,16 @@ def PropagateSharding() -> tvm.ir.transform.Pass:
         The registered pass
     """
     return _ffi_api.PropagateSharding()  # type: ignore
+
+
+def LegalizeRedistribute() -> tvm.ir.transform.Pass:
+    """Legalize redistribute op to ccl op.
+    S->R: R.ccl.allgather
+    R->S: R.dist.redistribute_replica_to_shard
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass
+    """
+    return _ffi_api.LegalizeRedistribute()  # type: ignore
diff --git a/python/tvm/relax/op/distributed/__init__.py 
b/python/tvm/relax/op/distributed/__init__.py
index 35226b8790..cd38b90bce 100644
--- a/python/tvm/relax/op/distributed/__init__.py
+++ b/python/tvm/relax/op/distributed/__init__.py
@@ -16,4 +16,4 @@
 # under the License.
 """Operators serving for distributed Relax."""
 
-from .distributed import annotate_sharding, redistribute
+from .distributed import annotate_sharding, redistribute, 
redistribute_replica_to_shard
diff --git a/python/tvm/relax/op/distributed/distributed.py 
b/python/tvm/relax/op/distributed/distributed.py
index bcc5ca2f49..7e0e6fa23d 100644
--- a/python/tvm/relax/op/distributed/distributed.py
+++ b/python/tvm/relax/op/distributed/distributed.py
@@ -59,3 +59,29 @@ def redistribute(input: Expr, device_mesh: DeviceMesh, 
placement: Placement) ->
       The tensor after redistribution.
     """
     return _ffi_api.redistribute(input, device_mesh, placement)  # type: ignore
+
+
+def redistribute_replica_to_shard(input: Expr, num_workers: int, axis: int) -> 
Expr:
+    """Slice tensor into several parts along one axis,
+        and each worker takes one part.
+        input.struct_info.shape[axis] % num_workers == 0 is required.
+        Each worker must have an identical copy of the input.
+        This is a specialized version of redistribute op.
+
+    Parameters
+    ----------
+    input : relax.Expr
+      The buffer to be sliced into equal parts.
+
+    num_worker : int
+      The number of workers, i.e. the number of parts the given buffer should 
be sliced into.
+
+    axis : int
+      The axis of the tensor to be sliced.
+
+    Returns
+    -------
+    result : relax.Expr
+      Sliced Tensor kept by each device.
+    """
+    return _ffi_api.redistribute_replica_to_shard(input, num_workers, axis)
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py 
b/python/tvm/relax/transform/legalize_ops/__init__.py
index e222e7d4ef..e3b3213a38 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/transform/legalize_ops/__init__.py
@@ -19,6 +19,7 @@ from . import binary
 from . import ccl
 from . import create
 from . import datatype
+from . import distributed
 from . import grad
 from . import image
 from . import index
diff --git a/python/tvm/relax/transform/legalize_ops/distributed.py 
b/python/tvm/relax/transform/legalize_ops/distributed.py
new file mode 100644
index 0000000000..d540628e0e
--- /dev/null
+++ b/python/tvm/relax/transform/legalize_ops/distributed.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Default legalization function for distir-related operators."""
+from tvm import tir, relax
+from ...block_builder import BlockBuilder
+from ...expr import Call, Expr
+from ...op import call_pure_packed
+from ...struct_info import ShapeStructInfo
+from .common import register_legalize
+
+
+@register_legalize("relax.dist.redistribute_replica_to_shard")
+def _redistribute_replica_to_shard(_bb: BlockBuilder, call: Call) -> Expr:
+    num_workers = call.attrs.num_workers
+    axis = call.attrs.axis
+    worker_id_symbol = tir.Var("worker_id", "int64")
+    worker_id_var = _bb.emit(
+        call_pure_packed("runtime.disco.worker_id", 
sinfo_args=[ShapeStructInfo(None)])
+    )
+    _bb.match_cast(worker_id_var, ShapeStructInfo([worker_id_symbol]))
+
+    split_axis_size = call.args[0].struct_info.shape[axis]
+    return relax.op.strided_slice(
+        call.args[0],
+        axes=[axis],
+        begin=[worker_id_symbol * split_axis_size // num_workers],
+        end=[(worker_id_symbol + 1) * split_axis_size // num_workers],
+    )
diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py 
b/python/tvm/script/ir_builder/relax/distributed/ir.py
index 51ef84620c..28ff6ae732 100644
--- a/python/tvm/script/ir_builder/relax/distributed/ir.py
+++ b/python/tvm/script/ir_builder/relax/distributed/ir.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=redefined-builtin, wrong-import-order, no-member, 
invalid-name
+# pylint: disable=redefined-builtin, wrong-import-order, no-member, 
invalid-name, unused-import
 
 """IRBuilder for distributed Relax dialect"""
 from typing import Union, List, Tuple, Optional
@@ -32,6 +32,7 @@ from tvm.runtime import ndarray as _nd
 from tvm.relax.op.distributed import (
     redistribute as _redistribute,
     annotate_sharding as _annotate_sharding,
+    redistribute_replica_to_shard,
 )
 from tvm.relax.distributed import DeviceMesh, Placement
 from . import _ffi_api
diff --git a/python/tvm/script/parser/relax/__init__.py 
b/python/tvm/script/parser/relax/__init__.py
index 0b1ed168de..704189060b 100644
--- a/python/tvm/script/parser/relax/__init__.py
+++ b/python/tvm/script/parser/relax/__init__.py
@@ -34,18 +34,15 @@ if TYPE_CHECKING:
 else:
     from .entry import function, macro
 
-__all__ = (
-    _relax.__all__
-    + dist.__all__
-    + [
-        "Callable",
-        "Object",
-        "Prim",
-        "Shape",
-        "Tensor",
-        "Tuple",
-        "function",
-        "macro",
-        "match_cast",
-    ]
-)
+__all__ = _relax.__all__ + [
+    "dist",
+    "Callable",
+    "Object",
+    "Prim",
+    "Shape",
+    "Tensor",
+    "Tuple",
+    "function",
+    "macro",
+    "match_cast",
+]
diff --git a/python/tvm/script/parser/relax/dist.py 
b/python/tvm/script/parser/relax/dist.py
index f9c78f980f..c4e81a2ae2 100644
--- a/python/tvm/script/parser/relax/dist.py
+++ b/python/tvm/script/parser/relax/dist.py
@@ -25,7 +25,13 @@ from tvm.tir import PrimExpr
 from tvm.relax.distributed import DeviceMesh, Placement, DTensorStructInfo, 
device_mesh
 from tvm.script.ir_builder import IRBuilder
 from tvm.script.ir_builder.ir import IRModuleFrame
-from tvm.script.ir_builder.relax.distributed import call_tir, const, 
annotate_sharding, redistribute
+from tvm.script.ir_builder.relax.distributed import (
+    call_tir,
+    const,
+    annotate_sharding,
+    redistribute,
+    redistribute_replica_to_shard,
+)
 from .entry import StructInfoProxy, TensorProxy
 
 
diff --git a/src/relax/distributed/transform/legalize_redistribute.cc 
b/src/relax/distributed/transform/legalize_redistribute.cc
new file mode 100644
index 0000000000..5a67f0351c
--- /dev/null
+++ b/src/relax/distributed/transform/legalize_redistribute.cc
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/distributed/transform/legalize_redistribute.cc
+ * \brief Pass for legalizing redistribute op to ccl op.
+ */
+
+#include <tvm/relax/attrs/ccl.h>
+#include <tvm/relax/attrs/distributed.h>
+#include <tvm/relax/distributed/axis_group_graph.h>
+#include <tvm/relax/distributed/transform.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../../tir/schedule/transform.h"
+#include "../../op/ccl/ccl.h"
+#include "../../op/distributed/distributed.h"
+
+namespace tvm {
+namespace relax {
+namespace distributed {
+
+class RedistributeLegalizer : public ExprMutator {
+ public:
+  static IRModule LegalizeRedistribute(IRModule mod) {
+    return RedistributeLegalizer(mod).Legalize();
+  }
+
+ private:
+  explicit RedistributeLegalizer(IRModule mod) : ExprMutator(mod) {}
+
+  IRModule Legalize() {
+    auto mod = builder_->GetContextIRModule();
+    for (const auto& [gv, base_func] : mod->functions) {
+      const auto* func_ = base_func.as<FunctionNode>();
+      if (func_ == nullptr) {
+        continue;
+      }
+      Expr new_func_body = VisitExpr(func_->body);
+      auto new_func = make_object<FunctionNode>(*func_);
+      new_func->body = new_func_body;
+      builder_->UpdateFunction(gv, Function(new_func));
+    }
+    return builder_->GetContextIRModule();
+  }
+  using ExprMutator::VisitExpr_;
+  Expr VisitExpr_(const CallNode* op) final {
+    Call call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+    static Op redistribute_op = Op::Get("relax.dist.redistribute");
+    if (call->op.same_as(redistribute_op)) {
+      const auto* attrs = call->attrs.as<DistributionAttrs>();
+      ICHECK(attrs);
+      const auto* input_sinfo = 
call->args[0]->struct_info_.as<DTensorStructInfoNode>();
+      ICHECK(input_sinfo);
+      // As the first step, we only support redistribute in the same device 
mesh,
+      // and the device mesh must be 1d
+      // todo: extend the ccl ops so that it can support 2d device mesh, and 
different sharding
+      // dimension
+      ICHECK(StructuralEqual()(input_sinfo->device_mesh, attrs->device_mesh));
+      ICHECK(input_sinfo->device_mesh->shape.size() == 1);
+      // only support "S[x]"-> "R" and "R" -> "S[x]"
+      PlacementSpec input_spec = input_sinfo->placement->dim_specs[0];
+      PlacementSpec output_spec = attrs->placement->dim_specs[0];
+      if (input_spec->kind == PlacementSpecKind::kReplica &&
+          output_spec->kind == PlacementSpecKind::kReplica) {
+        // "R" -> "R"
+        return call->args[0];
+      } else if (input_spec->kind == PlacementSpecKind::kSharding &&
+                 output_spec->kind == PlacementSpecKind::kSharding) {
+        // "S[x]" -> "S[y]"
+        if (input_spec->axis != output_spec->axis) {
+          LOG(FATAL) << "AlltoAll not implemented yet";
+        } else {
+          return call->args[0];
+        }
+      } else if (input_spec->kind == PlacementSpecKind::kSharding &&
+                 output_spec->kind == PlacementSpecKind::kReplica) {
+        // "S[x]" -> "R"
+        LOG(FATAL) << "Allgather not implemented yet";
+      } else if (input_spec->kind == PlacementSpecKind::kReplica &&
+                 output_spec->kind == PlacementSpecKind::kSharding) {
+        // "R" -> "S[x]"
+        return redistribute_replica_to_shard(call->args[0], 
attrs->device_mesh->shape[0],
+                                             output_spec->axis);
+      } else {
+        LOG(FATAL) << "Unsupported redistribute op";
+      }
+    }
+    return call;
+  }
+};
+
+namespace transform {
+
+Pass LegalizeRedistribute() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule m, PassContext pc) { return 
RedistributeLegalizer::LegalizeRedistribute(m); };
+  return CreateModulePass(pass_func, 1, "LegalizeRedistribute", {});
+}
+TVM_REGISTER_GLOBAL("relax.distributed.transform.LegalizeRedistribute")
+    .set_body_typed(LegalizeRedistribute);
+}  // namespace transform
+
+}  // namespace distributed
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/distributed/distributed.cc 
b/src/relax/op/distributed/distributed.cc
index 6def4ef019..658b7cde34 100644
--- a/src/relax/op/distributed/distributed.cc
+++ b/src/relax/op/distributed/distributed.cc
@@ -24,6 +24,7 @@
 
 #include "distributed.h"
 
+#include <tvm/relax/attrs/ccl.h>
 #include <tvm/topi/einsum.h>
 
 #include <algorithm>
@@ -85,5 +86,86 @@ TVM_REGISTER_OP("relax.dist.redistribute")
     .set_attr<FInferStructInfo>("dist.FInferStructInfo", 
InferDistStructInfoRedistribute)
     .set_attr<Bool>("FPurity", Bool(true));
 
+StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  DataType output_dtype = input_sinfo->dtype;
+
+  const auto* attrs = call->attrs.as<ScatterCollectiveAttrs>();
+  int num_workers = attrs->num_workers;
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  auto input_shape = input_sinfo->GetShape();
+  CHECK(input_shape.defined())
+      << "input tensor of redistribute_replica_to_shard should have defined 
shape.";
+
+  if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis], 
PrimExpr(num_workers))) != 0) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "redistribute_replica_to_shard expects the size of 
axis " << attrs->axis
+                     << " of input tensor to be "
+                        "divisible by the "
+                        "num_workers. However, the axis "
+                     << attrs->axis << " of input tensor is " << 
input_shape.value()[attrs->axis]
+                     << " while num_workers is " << num_workers);
+  }
+
+  Array<PrimExpr> output_shape = input_shape.value();
+  output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers));
+  if (input_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(output_shape), output_dtype, 
input_sinfo->vdevice.value());
+  }
+  return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
+}
+
+StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) {
+  using namespace distributed;
+  Array<DTensorStructInfo> input_dtensor_sinfos = 
GetInputDTensorStructInfo(call, ctx);
+  ICHECK(input_dtensor_sinfos.size() == 1);
+  DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0];
+  TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo;
+  const auto* attrs = call->attrs.as<ScatterCollectiveAttrs>();
+  int num_workers = attrs->num_workers;
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  auto input_shape = tensor_sinfo->GetShape();
+  CHECK(input_shape.defined())
+      << "input tensor of redistribute_replica_to_shard should have defined 
shape.";
+
+  if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis], 
PrimExpr(num_workers))) != 0) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "redistribute_replica_to_shard expects the size of 
axis " << attrs->axis
+                     << " of input tensor to be "
+                        "divisible by the "
+                        "num_workers. However, the axis "
+                     << attrs->axis << " of input tensor is " << 
input_shape.value()[attrs->axis]
+                     << " while num_workers is " << num_workers);
+  }
+
+  DeviceMesh device_mesh = input_dtensor_sinfo->device_mesh;
+  // FIXME: this is a hack where there's only 1d mesh
+  ICHECK(device_mesh->shape.size() == 1);
+  ICHECK(input_dtensor_sinfo->placement->dim_specs[0]->kind == 
PlacementSpecKind::kReplica);
+  return DTensorStructInfo(tensor_sinfo, device_mesh,
+                           Placement::FromText("S[" + 
std::to_string(attrs->axis) + "]"));
+}
+
+Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis) {
+  ObjectPtr<ScatterCollectiveAttrs> attrs = 
make_object<ScatterCollectiveAttrs>();
+  attrs->num_workers = std::move(num_workers);
+  attrs->axis = std::move(axis);
+  static const Op& op = Op::Get("relax.dist.redistribute_replica_to_shard");
+
+  return Call(op, {std::move(input)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.dist.redistribute_replica_to_shard")
+    .set_body_typed(redistribute_replica_to_shard);
+
+TVM_REGISTER_OP("relax.dist.redistribute_replica_to_shard")
+    .set_num_inputs(1)
+    .add_argument("input", "Tensor", "The buffer to be sliced.")
+    .set_attrs_type<ScatterCollectiveAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRtoS)
+    .set_attr<FInferStructInfo>("dist.FInferStructInfo", 
InferDistStructInfoRtoS)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/distributed/distributed.h 
b/src/relax/op/distributed/distributed.h
index 06f926edb1..67ccc0fc6b 100644
--- a/src/relax/op/distributed/distributed.h
+++ b/src/relax/op/distributed/distributed.h
@@ -27,6 +27,7 @@
 #include <tvm/relax/attrs/distributed.h>
 
 #include "../op_common.h"
+#include "utils.h"
 
 namespace tvm {
 namespace relax {
@@ -51,6 +52,17 @@ Expr annotate_sharding(Expr input, distributed::DeviceMesh 
device_mesh,
 Expr redistribute(Expr input, distributed::DeviceMesh device_mesh,
                   distributed::Placement placement);
 
+/*!
+ * \brief slice tensor into several parts along one axis,
+          and each worker takes one part.
+          Assumes input is already broadcasted.
+          This is a specialized version of redistribute op.
+ * \param input The input tensor.
+ * \param num_workers The number of workers.
+ * \param axis The tensor axis to slice.
+ * \return The result.
+ */
+Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis);
 }  // namespace relax
 }  // namespace tvm
 
diff --git 
a/tests/python/relax/distributed/test_distributed_transform_legalize_redistribute.py
 
b/tests/python/relax/distributed/test_distributed_transform_legalize_redistribute.py
new file mode 100644
index 0000000000..92e9e5a4d4
--- /dev/null
+++ 
b/tests/python/relax/distributed/test_distributed_transform_legalize_redistribute.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#  type: ignore
+from tvm.script.parser import ir as I
+from tvm.script.parser import relax as R
+import tvm
+from tvm import relax
+import tvm.testing
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        I.module_attrs({"device_num": 2})
+        I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
+
+        @R.function
+        def foo(
+            x1: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
+            x2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
+        ):
+            R.func_attr({"num_input": 1})
+            # scatter
+            lv0 = R.dist.redistribute(x1, "mesh[0]", "S[1]")
+            # do nothing
+            lv1 = R.dist.redistribute(x2, "mesh[0]", "S[0]")
+            return (lv0, lv1)
+
+    @I.ir_module
+    class Expected:
+        I.module_attrs({"device_num": 2})
+        I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
+
+        @R.function
+        def foo(
+            x1: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
+            x2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
+        ) -> R.Tuple(
+            R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
+            R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
+        ):
+            R.func_attr({"num_input": 1})
+            lv0: R.DTensor(
+                (128, 128), "float32", "mesh[0]", "S[1]"
+            ) = R.dist.redistribute_replica_to_shard(x1, num_workers=2, axis=1)
+            lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]") = x2
+            return (lv0, lv1)
+
+    after = relax.distributed.transform.LegalizeRedistribute()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_op_distributed.py 
b/tests/python/relax/test_op_distributed.py
new file mode 100644
index 0000000000..dc5440c0c1
--- /dev/null
+++ b/tests/python/relax/test_op_distributed.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+from tvm._ffi.base import TVMError
+import tvm.testing
+from tvm import relax
+from tvm.script.parser import relax as R
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_redistribute_R_to_S():
+    bb = relax.BlockBuilder()
+    mesh = R.device_mesh((4,), list(range(4)))
+    x = relax.Var("x", R.DTensor((3, 4), "float32", device_mesh=mesh, 
placement="R"))
+
+    _check_inference(
+        bb,
+        R.distributed.redistribute_replica_to_shard(x, num_workers=4, axis=1),
+        R.DTensor((3, 4), "float32", device_mesh=mesh, placement="S[1]"),
+    )
+
+    # wrong: indivisible
+    with pytest.raises(TVMError):
+        bb.normalize(R.distributed.redistribute_replica_to_shard(x, 
num_workers=4, axis=0))
+
+    y = relax.Var("y", R.Tensor((3, 4), "float32"))
+    _check_inference(
+        bb,
+        R.distributed.redistribute_replica_to_shard(y, num_workers=4, axis=1),
+        R.Tensor((3, 1), "float32"),
+    )
+
+    # wrong: indivisible
+    with pytest.raises(TVMError):
+        bb.normalize(R.distributed.redistribute_replica_to_shard(y, 
num_workers=4, axis=0))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_distributed.py 
b/tests/python/relax/test_transform_legalize_ops_distributed.py
new file mode 100644
index 0000000000..c1f462e36c
--- /dev/null
+++ b/tests/python/relax/test_transform_legalize_ops_distributed.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relax.transform import LegalizeOps
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_redistribute_replica_to_shard():
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((10, 10), "float32"))  -> R.Tensor((10, 5), 
"float32"):
+            gv0 = R.dist.redistribute_replica_to_shard(x, num_workers=2, 
axis=1)
+            return gv0
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def strided_slice(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), 
redistribute_replica_to_shard: T.Buffer((T.int64(10), T.int64(5)), "float32"), 
worker_id: T.int64):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i0, i1 in T.grid(T.int64(10), T.int64(5)):
+                with T.block("redistribute_replica_to_shard"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[v_i0, worker_id * T.int64(5) + v_i1])
+                    T.writes(redistribute_replica_to_shard[v_i0, v_i1])
+                    redistribute_replica_to_shard[v_i0, v_i1] = A[v_i0, 
worker_id * T.int64(5) + v_i1]
+
+        @R.function
+        def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 5), 
dtype="float32"):
+            worker_id = T.int64()
+            cls = Expected
+            gv: R.Shape(ndim=-1) = 
R.call_pure_packed("runtime.disco.worker_id", sinfo_args=(R.Shape(ndim=-1),))
+            gv1: R.Shape([worker_id]) = R.match_cast(gv, R.Shape([worker_id]))
+            gv0 = R.call_tir(cls.strided_slice, (x,), out_sinfo=R.Tensor((10, 
5), dtype="float32"), tir_vars=R.shape([worker_id]))
+            return gv0
+    # fmt: on
+
+    mod = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to