Lunderberg commented on code in PR #16098:
URL: https://github.com/apache/tvm/pull/16098#discussion_r1391651539


##########
src/relax/op/distributed/distributed.cc:
##########
@@ -85,5 +86,84 @@ TVM_REGISTER_OP("relax.dist.redistribute")
     .set_attr<FInferStructInfo>("dist.FInferStructInfo", 
InferDistStructInfoRedistribute)
     .set_attr<Bool>("FPurity", Bool(true));
 
+StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  DataType output_dtype = input_sinfo->dtype;
+
+  const auto* attrs = call->attrs.as<ScatterCollectiveAttrs>();
+  int num_workers = attrs->num_workers;
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  auto input_shape = input_sinfo->GetShape();
+  CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should 
have defined shape.";
+
+  if (analyzer->CanProve(floormod(input_shape.value()[0], 
PrimExpr(num_workers))) != 0) {

Review Comment:
   `input_shape.value()[0]` should be `input_shape.value()[attrs->axis]`.  As 
currently written, we check the divisibility of the outermost axis instead of 
the axis being split.



##########
src/relax/op/distributed/distributed.cc:
##########
@@ -85,5 +86,84 @@ TVM_REGISTER_OP("relax.dist.redistribute")
     .set_attr<FInferStructInfo>("dist.FInferStructInfo", 
InferDistStructInfoRedistribute)
     .set_attr<Bool>("FPurity", Bool(true));
 
+StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  DataType output_dtype = input_sinfo->dtype;
+
+  const auto* attrs = call->attrs.as<ScatterCollectiveAttrs>();
+  int num_workers = attrs->num_workers;
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  auto input_shape = input_sinfo->GetShape();
+  CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should 
have defined shape.";
+
+  if (analyzer->CanProve(floormod(input_shape.value()[0], 
PrimExpr(num_workers))) != 0) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "scatter_from_worker0 expects the size of axis 0 of 
input tensor to be "
+                        "divisible by the "
+                        "num_workers. However, the axis 0 of input tensor is "
+                     << input_shape.value() << " while num_workers is " << 
num_workers);
+  }
+
+  Array<PrimExpr> output_shape = input_shape.value();
+  output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers));
+  if (input_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(output_shape), output_dtype, 
input_sinfo->vdevice.value());
+  }
+  return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
+}
+
+StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) {
+  using namespace distributed;
+  Array<DTensorStructInfo> input_dtensor_sinfos = 
GetInputDTensorStructInfo(call, ctx);
+  ICHECK(input_dtensor_sinfos.size() == 1);
+  DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0];
+  TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo;
+  const auto* attrs = call->attrs.as<ScatterCollectiveAttrs>();
+  int num_workers = attrs->num_workers;
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  auto input_shape = tensor_sinfo->GetShape();
+  CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should 
have defined shape.";
+
+  if (analyzer->CanProve(floormod(input_shape.value()[0], 
PrimExpr(num_workers))) != 0) {

Review Comment:
   Same here, incorrect axis being split.



##########
tests/python/relax/distributed/test_distributed_transform_legalize_redistribute.py:
##########
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#  type: ignore
+from tvm.script.parser import ir as I
+from tvm.script.parser import relax as R
+import tvm
+from tvm import relax
+import tvm.testing
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        I.module_attrs({"device_num": 2})
+        I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
+
+        @R.function
+        def foo(
+            x1: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
+            x2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
+        ):
+            R.func_attr({"num_input": 1})
+            # scatter
+            lv0 = R.dist.redistribute(x1, "mesh[0]", "S[1]")
+            # do nothing
+            lv1 = R.dist.redistribute(x2, "mesh[0]", "S[0]")
+            return (lv0, lv1)
+
+    @I.ir_module
+    class Expected:
+        I.module_attrs({"device_num": 2})
+        I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
+
+        @R.function
+        def foo(
+            x1: R.DTensor((128, 128), "float32", "mesh[0]", "R"),
+            x2: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
+        ) -> R.Tuple(
+            R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"),
+            R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
+        ):
+            R.func_attr({"num_input": 1})
+            lv0: R.DTensor(
+                (128, 64), "float32", "mesh[0]", "S[1]"
+            ) = R.dist.redistribute_replica_to_shard(x1, num_workers=2, axis=1)
+            lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]") = x2
+            return (lv0, lv1)
+
+    after = relax.distributed.transform.LegalizeRedistribute()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+

Review Comment:
   Can we add a unit test that would have caught the error in the divisibility 
check?  If we run shape inference on `R.dist.redistribute_replica_to_shard(arg, 
num_workers=2, axis=1)`, where `arg: R.Tensor([1, 8])`, then it would pass the 
correct divisibility check on `arg.shape[1]`, but would incorrectly trigger the 
exception with the divisibility check applied to `arg.shape[0]`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to