adavare commented on a change in pull request #27:
URL: https://github.com/apache/tvm-vta/pull/27#discussion_r646843304
##########
File path: hardware/chisel/src/main/scala/core/Compute.scala
##########
@@ -118,44 +123,102 @@ class Compute(debug: Boolean = false)(implicit p:
Parameters) extends Module {
loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd
loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx,
tensorAlu.io.uop.idx)
+ assert( !tensorGemm.io.uop.idx.valid || !tensorAlu.io.uop.idx.valid)
// acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr
- tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx,
tensorAlu.io.acc.rd.idx)
- tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr,
tensorAlu.io.acc.wr)
+ require(tensorAcc.io.tensor.lenSplit ==
+ tensorAcc.io.tensor.tensorLength, "-F- Expecting a whole batch in acc
group")
+
+ // split factor of isGemm for many groups
+ val splitFactorL0 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
+ val splitFactorL1 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth)
+ - log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
+ require(splitFactorL0 * splitFactorL1 == tensorAcc.io.tensor.splitWidth)
+ val accRdSelectL0 = for (idx <- 0 until splitFactorL1) yield {
+ // can save 1 stage on small design
+ if (splitFactorL1 > 1) RegNext(dec.io.isGemm, init = false.B) else
dec.io.isGemm
+ }
+
+ for (idx <- 0 until tensorAcc.io.tensor.splitWidth) {
+ tensorAcc.io.tensor.rd(idx).idx <> Mux(
+ RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
+ tensorGemm.io.acc.rd(idx).idx,
+ tensorAlu.io.acc.rd(idx).idx)
+ tensorAcc.io.tensor.wr(idx) <> Mux(
+ RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
+ tensorGemm.io.acc.wr(idx),
+ tensorAlu.io.acc.wr(idx))
+ }
io.vme_rd(1) <> tensorAcc.io.vme_rd
- io.acc_wr_event := tensorAcc.io.tensor.wr.valid
+ io.acc_wr_event := tensorAcc.io.tensor.wr(topAccGrpIdx).valid
// gemm
- tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
- tensorGemm.io.inst := inst_q.io.deq.bits
+ tensorGemm.io.start := RegNext(state === sIdle & start & dec.io.isGemm, init
= false.B)
+ tensorGemm.io.dec := inst_q.io.deq.bits.asTypeOf(new GemmDecode)
tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorGemm.io.inp <> io.inp
tensorGemm.io.wgt <> io.wgt
- tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid &
dec.io.isGemm
- tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
- tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
- tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
+ for (idx <- 0 until tensorGemm.io.acc.splitWidth) {
+ tensorGemm.io.acc.rd(idx).data.valid :=
+ tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isGemm, init =
false.B)
+ tensorGemm.io.acc.rd(idx).data.bits <>
tensorAcc.io.tensor.rd(idx).data.bits
+ }
+ for (idx <- 0 until tensorGemm.io.out.splitWidth) {
+ tensorGemm.io.out.rd(idx).data.valid :=
+ io.out.rd(idx).data.valid & RegNext(dec.io.isGemm, init = false.B)
+ tensorGemm.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
+ }
// alu
- tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
- tensorAlu.io.inst := inst_q.io.deq.bits
+ tensorAlu.io.start := RegNext(state === sIdle & start & dec.io.isAlu, init =
false.B)
+ tensorAlu.io.dec := inst_q.io.deq.bits.asTypeOf(new AluDecode)
tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
- tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid &
dec.io.isAlu
- tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
- tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
- tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
+ for (idx <- 0 until tensorAlu.io.acc.splitWidth) {
+ tensorAlu.io.acc.rd(idx).data.valid :=
+ tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isAlu, init =
false.B)
+ tensorAlu.io.acc.rd(idx).data.bits <> tensorAcc.io.tensor.rd(idx).data.bits
+ }
+ for (idx <- 0 until tensorAlu.io.out.splitWidth) {
+ tensorAlu.io.out.rd(idx).data.valid :=
+ io.out.rd(idx).data.valid & RegNext(dec.io.isAlu, init = false.B)
+ tensorAlu.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
+ }
// out
- io.out.rd.idx <> Mux(dec.io.isGemm,
- tensorGemm.io.out.rd.idx,
- tensorAlu.io.out.rd.idx)
- io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
+ for (idx <- 0 until tensorGemm.io.out.splitWidth) {
+ io.out.rd(idx).idx <> Mux(dec.io.isGemm,
+ tensorGemm.io.out.rd(idx).idx,
+ tensorAlu.io.out.rd(idx).idx)
+ assert( !tensorGemm.io.out.rd(idx).idx.valid ||
!tensorAlu.io.out.rd(idx).idx.valid)
+ assert( !tensorGemm.io.out.rd(idx).data.valid ||
!tensorAlu.io.out.rd(idx).data.valid)
+ assert( !tensorGemm.io.out.wr(idx).valid ||
!tensorAlu.io.out.wr(idx).valid)
+ }
+ require (tensorGemm.io.out.splitWidth == 1)
+ require (tensorAlu.io.out.splitWidth == 1)
+ io.out.wr(0).valid := Mux(
+ RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).valid,
tensorAlu.io.out.wr(0).valid)
+ io.out.wr(0).bits.idx := Mux(
+ RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).bits.idx,
tensorAlu.io.out.wr(0).bits.idx)
+ //put mux/Reg into every gemm group to build pipe (for Mux select) tree over
distance
Review comment:
All "//\S" occurrences replaced throughout PR
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]