abhikran-quic commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1209241946


##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> 
SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + 
fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * 
fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * 
fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorDivNode>();
+        const arith::IntSet new_provided = 
arith::IntSet::SinglePoint(div_f->a);
+        return SolveBlockVarDomain(new_provided, required, dim_max, analyzer);
+      } else if ((floormod(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorModNode>();

Review Comment:
   Agree with you. I have added condition to check for positive operands.



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> 
SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + 
fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * 
fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * 
fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {

Review Comment:
   Thank you. I've added changes to pass `p_f2` as new_provided to floordiv.
   
   Regarding your comment on merging this with line 426, I'm seeing a 
compilation error at line 430 if I try to assign `p_f1.Eval()` to `var` . IMHO, 
this is a genuine error. Please let me know if you have ideas to get past the 
error. 



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> 
SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + 
fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * 
fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * 
fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorDivNode>();
+        const arith::IntSet new_provided = 
arith::IntSet::SinglePoint(div_f->a);
+        return SolveBlockVarDomain(new_provided, required, dim_max, analyzer);
+      } else if ((floormod(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorModNode>();
+        const arith::IntSet new_provided = 
arith::IntSet::SinglePoint(div_f->a);

Review Comment:
   Sure. I've removed the use of `div_f`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to