adavare commented on a change in pull request #27:
URL: https://github.com/apache/tvm-vta/pull/27#discussion_r646843304



##########
File path: hardware/chisel/src/main/scala/core/Compute.scala
##########
@@ -118,44 +123,102 @@ class Compute(debug: Boolean = false)(implicit p: 
Parameters) extends Module {
   loadUop.io.baddr := io.uop_baddr
   io.vme_rd(0) <> loadUop.io.vme_rd
   loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, 
tensorAlu.io.uop.idx)
+  assert( !tensorGemm.io.uop.idx.valid || !tensorAlu.io.uop.idx.valid)
 
   // acc
   tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
   tensorAcc.io.inst := inst_q.io.deq.bits
   tensorAcc.io.baddr := io.acc_baddr
-  tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, 
tensorAlu.io.acc.rd.idx)
-  tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, 
tensorAlu.io.acc.wr)
+  require(tensorAcc.io.tensor.lenSplit ==
+    tensorAcc.io.tensor.tensorLength, "-F- Expecting a whole batch in acc 
group")
+
+  // split factor of isGemm for many groups
+  val splitFactorL0 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
+  val splitFactorL1 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth)
+    - log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
+  require(splitFactorL0 * splitFactorL1 == tensorAcc.io.tensor.splitWidth)
+  val accRdSelectL0 = for (idx <- 0 until splitFactorL1) yield {
+    // can save 1 stage on small design
+    if (splitFactorL1 > 1) RegNext(dec.io.isGemm, init = false.B) else 
dec.io.isGemm
+  }
+
+  for (idx <- 0 until tensorAcc.io.tensor.splitWidth) {
+    tensorAcc.io.tensor.rd(idx).idx <> Mux(
+      RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
+      tensorGemm.io.acc.rd(idx).idx,
+      tensorAlu.io.acc.rd(idx).idx)
+    tensorAcc.io.tensor.wr(idx) <> Mux(
+      RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
+      tensorGemm.io.acc.wr(idx),
+      tensorAlu.io.acc.wr(idx))
+  }
   io.vme_rd(1) <> tensorAcc.io.vme_rd
-  io.acc_wr_event := tensorAcc.io.tensor.wr.valid
+  io.acc_wr_event := tensorAcc.io.tensor.wr(topAccGrpIdx).valid
 
   // gemm
-  tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
-  tensorGemm.io.inst := inst_q.io.deq.bits
+  tensorGemm.io.start := RegNext(state === sIdle & start & dec.io.isGemm, init 
= false.B)
+  tensorGemm.io.dec := inst_q.io.deq.bits.asTypeOf(new GemmDecode)
   tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
   tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
   tensorGemm.io.inp <> io.inp
   tensorGemm.io.wgt <> io.wgt
-  tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & 
dec.io.isGemm
-  tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
-  tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
-  tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
+  for (idx <- 0 until tensorGemm.io.acc.splitWidth) {
+    tensorGemm.io.acc.rd(idx).data.valid :=
+      tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isGemm, init = 
false.B)
+    tensorGemm.io.acc.rd(idx).data.bits <> 
tensorAcc.io.tensor.rd(idx).data.bits
+  }
+  for (idx <- 0 until tensorGemm.io.out.splitWidth) {
+    tensorGemm.io.out.rd(idx).data.valid :=
+      io.out.rd(idx).data.valid & RegNext(dec.io.isGemm, init = false.B)
+    tensorGemm.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
+  }
 
   // alu
-  tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
-  tensorAlu.io.inst := inst_q.io.deq.bits
+  tensorAlu.io.start := RegNext(state === sIdle & start & dec.io.isAlu, init = 
false.B)
+  tensorAlu.io.dec := inst_q.io.deq.bits.asTypeOf(new AluDecode)
   tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
   tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
-  tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & 
dec.io.isAlu
-  tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
-  tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
-  tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
+  for (idx <- 0 until tensorAlu.io.acc.splitWidth) {
+    tensorAlu.io.acc.rd(idx).data.valid :=
+      tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isAlu, init = 
false.B)
+    tensorAlu.io.acc.rd(idx).data.bits <> tensorAcc.io.tensor.rd(idx).data.bits
+  }
+  for (idx <- 0 until tensorAlu.io.out.splitWidth) {
+    tensorAlu.io.out.rd(idx).data.valid :=
+      io.out.rd(idx).data.valid & RegNext(dec.io.isAlu, init = false.B)
+    tensorAlu.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
+  }
 
   // out
-  io.out.rd.idx <> Mux(dec.io.isGemm,
-    tensorGemm.io.out.rd.idx,
-    tensorAlu.io.out.rd.idx)
-  io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
+  for (idx <- 0 until tensorGemm.io.out.splitWidth) {
+    io.out.rd(idx).idx <> Mux(dec.io.isGemm,
+      tensorGemm.io.out.rd(idx).idx,
+      tensorAlu.io.out.rd(idx).idx)
+    assert( !tensorGemm.io.out.rd(idx).idx.valid || 
!tensorAlu.io.out.rd(idx).idx.valid)
+    assert( !tensorGemm.io.out.rd(idx).data.valid || 
!tensorAlu.io.out.rd(idx).data.valid)
 
+    assert( !tensorGemm.io.out.wr(idx).valid || 
!tensorAlu.io.out.wr(idx).valid)
+  }
+  require (tensorGemm.io.out.splitWidth == 1)
+  require (tensorAlu.io.out.splitWidth == 1)
+  io.out.wr(0).valid := Mux(
+    RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).valid, 
tensorAlu.io.out.wr(0).valid)
+  io.out.wr(0).bits.idx := Mux(
+    RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).bits.idx, 
tensorAlu.io.out.wr(0).bits.idx)
+  //put mux/Reg into every gemm group to build pipe (for Mux select) tree over 
distance

Review comment:
       All "//\S" occurrences replaced throughout PR




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to