tmoreau89 commented on a change in pull request #32:
URL: https://github.com/apache/tvm-vta/pull/32#discussion_r701576523



##########
File path: hardware/chisel/src/main/scala/core/TensorLoadNarrowVME.scala
##########
@@ -0,0 +1,740 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import scala.math.pow
+import scala.math.sqrt
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+
+/** TensorLoad.
+ *
+ * Load 1D and 2D tensors from main memory (DRAM) to input/weight
+ * scratchpads (SRAM). Also, there is support for zero padding, while
+ * doing the load.
+ */
+class TensorLoadNarrowVME(tensorType: String = "none", debug: Boolean = false)(
+    implicit p: Parameters)
+    extends Module {
+  val tp = new TensorParams(tensorType)
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val done = Output(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+    val baddr = Input(UInt(mp.addrBits.W))
+    val vme_rd = new VMEReadMaster
+    val tensor = new TensorClient(tensorType)
+  })
+  val writePipeLatency = tp.writePipeLatency
+
+  val sIdle :: sBusy :: Nil =
+    Enum(2)
+  val state = RegInit(sIdle)
+
+  val isBusy = state === sBusy
+
+  val localDone = Wire(Bool())
+  when(io.start) {
+    state := sBusy
+  }.elsewhen(localDone) {
+    state := sIdle
+  }
+
+  val dec = io.inst.asTypeOf(new MemDecode)
+
+  val vmeDataBitsPipe = RegNext(io.vme_rd.data.bits)
+  val vmeDataValidPipe = RegNext(io.vme_rd.data.valid, init = false.B)
+  val vmeDataReadyPipe = RegNext(io.vme_rd.data.ready, init = false.B)
+  val vmeDataFirePipe = vmeDataValidPipe & vmeDataReadyPipe
+
+  //--------------------------------------
+  //--- Generate data load VME command ---
+  //--------------------------------------
+  val vmeCmd = Module (new GenVMECmd(tensorType, debug))
+  vmeCmd.io.start := io.start
+  vmeCmd.io.isBusy := isBusy
+  vmeCmd.io.inst := io.inst
+  vmeCmd.io.baddr := io.baddr
+  vmeCmd.io.vmeCmd <> io.vme_rd.cmd
+  val readLen = vmeCmd.io.readLen
+  val commandsDone = vmeCmd.io.done
+
+  // count how many blocks not receved
+  val blkIdxWdth = log2Ceil(tp.tsSizeRatio * tp.memDepth) // the size of 
scratchpad in blocks
+  // Nb of data blocks requestd, not received. TODO: smaller width parameter
+  val blocksInFlight = Reg(UInt(blkIdxWdth.W))
+  when(io.start) {
+    blocksInFlight := 0.U
+  }.elsewhen(isBusy && io.vme_rd.cmd.fire() && !vmeDataFirePipe) {
+    blocksInFlight := blocksInFlight + readLen
+  }.elsewhen(isBusy && io.vme_rd.cmd.fire() && vmeDataFirePipe) {
+    blocksInFlight := blocksInFlight + readLen - 1.U
+  }.elsewhen(isBusy && !io.vme_rd.cmd.fire() && vmeDataFirePipe) {
+    assert(blocksInFlight > 0.U)
+    blocksInFlight := blocksInFlight - 1.U
+  }.otherwise {
+    blocksInFlight := blocksInFlight
+  }
+
+  //---------------------
+  //--- Read VME data ---
+  //---------------------
+
+  val readData = Module(new ReadVMEData(tensorType, debug))
+  readData.io.start := io.start
+  readData.io.vmeData.valid := vmeDataValidPipe
+  readData.io.vmeData.bits := vmeDataBitsPipe
+  assert(!readData.io.vmeData.valid || readData.io.vmeData.ready,
+    "-F- Expecting const ready. Fix ReadVMEData to receive data 1 cyce after 
ready")
+  io.vme_rd.data.ready := readData.io.vmeData.ready
+  val rdDataDestCol = readData.io.col // this is an index of a col in tensor
+  val rdDataDestIdx = readData.io.idx // this is an index of a tensor
+
+  //-------------------------
+  //--- Fill zero padding ---
+  //-------------------------
+
+  val fillPadding = Module(new ZeroPadding(tensorType, debug))
+  fillPadding.io.canWriteMem := !vmeDataFirePipe
+  fillPadding.io.inst := RegNext(io.inst) // stage it to move from instr queue
+  fillPadding.io.start := RegNext(io.start, init = false.B)// stage it to move 
from instr que
+
+  val isZeroPadWrite = fillPadding.io.tensorIdx.valid // Store zero filled 
tensor, zpDestIdx is valid
+  val zpDestIdx = fillPadding.io.tensorIdx.bits // Tensor index
+  val paddingDone = fillPadding.io.done
+
+  //--------------------
+  //--- Write memory ---
+  //--------------------
+
+  val memSizeRatio = tp.tsSizeRatio
+  val splitDataFactor = tp.splitWidth * tp.splitLength
+  val splitMemBlockFactor = if (splitDataFactor > memSizeRatio) {
+    require((splitDataFactor/memSizeRatio) * memSizeRatio == splitDataFactor,
+      "-F- Cannot split tensor data memBlockBits further.")
+    splitDataFactor/memSizeRatio
+  }else {
+    1
+  }
+  val groupMemBlockFactor = if (splitDataFactor > memSizeRatio) {
+    1
+  }else {
+    require((memSizeRatio/splitDataFactor) * splitDataFactor == memSizeRatio,
+      "-F- Cannot group tensor data memBlockBits into groups.")
+    memSizeRatio/splitDataFactor
+  }
+  // one macro has a VME memory read bit width or read/write group bit width
+  //different groups can read/write scratchpad separately
+  val tensorFile = Seq.fill(memSizeRatio * splitMemBlockFactor
+  ) {
+    SyncReadMem(tp.memDepth, UInt((tp.memBlockBits/splitMemBlockFactor).W))
+  }
+
+
+  require(splitDataFactor * groupMemBlockFactor == memSizeRatio * 
splitMemBlockFactor,
+    "-F- Wrong split of data")
+  //-------------------------------
+  //--- Write address vector ------
+  //-------------------------------
+  // split data to build pipe tree
+  val splitFactorL0 = pow(2,log2Ceil(memSizeRatio) / 2).toInt
+  val splitFactorL1 = pow(2,log2Ceil(memSizeRatio) - log2Ceil(memSizeRatio) / 
2).toInt
+  require(splitFactorL0 * splitFactorL1 == memSizeRatio)
+  // tensor load instruction writes a VME data block or a whole tensor
+  val waddrTensInstrTmp = Mux(isZeroPadWrite, zpDestIdx, rdDataDestIdx)
+  val waddrTensInstrPipe = VecInit((for (j <- 0 until splitFactorL1) yield {
+    ShiftRegister(waddrTensInstrTmp, if (writePipeLatency > 0) 1 else 0)
+  }).flatMap(elem => for (k <- 0 until splitFactorL0) yield {
+    elem
+  }).flatMap(elem => for (k <- 0 until splitMemBlockFactor) yield {
+    ShiftRegister(elem, if (writePipeLatency < 2) 0 else writePipeLatency - 1)
+  }))
+  require(waddrTensInstrPipe.size == memSizeRatio * splitMemBlockFactor)
+
+  val waddrDirect = (VecInit((for (grIdx <- 0 until splitDataFactor) yield {
+    io.tensor.wr(grIdx).bits.idx
+  }).flatMap(elem => for (k <- 0 until groupMemBlockFactor) yield 
{elem}))).asTypeOf(
+    Vec(memSizeRatio * splitMemBlockFactor, io.tensor.wr(0).bits.idx.cloneType)
+  )
+
+
+  val waddr = Wire(Vec(memSizeRatio * splitMemBlockFactor, 
waddrTensInstrTmp.cloneType))
+  for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+    waddr(j) := Mux(
+      ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en 
= true.B),
+      waddrDirect(j),
+      waddrTensInstrPipe(j))
+  }
+
+  //-------------------------------
+  //--- Write enable vector -------
+  //-------------------------------
+  val dataOffset = rdDataDestCol
+  // get en sygnal and duplicate
+  val wenTensInstr = VecInit((for (j <- 0 until memSizeRatio) yield {
+    Mux(isZeroPadWrite, true.B, dataOffset === j.U && vmeDataFirePipe)
+  }).flatMap(elem => for (k <- 0 until splitMemBlockFactor) yield {elem}))
+
+  val wenDirect = VecInit((for (grIdx <- 0 until splitDataFactor) yield {
+    io.tensor.wr(grIdx).valid
+  }).flatMap(elem => for (k <- 0 until groupMemBlockFactor) yield {elem}))
+
+  val wen = Wire(Vec(memSizeRatio * splitMemBlockFactor, Bool()))
+  for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+    wen(j) := Mux(
+      ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en 
= true.B),
+      wenDirect(j),
+      ShiftRegister(wenTensInstr(j), writePipeLatency))
+  }
+
+  require(tp.memBlockBits % tp.tensorElemBits == 0)
+
+
+  //-------------------------------
+  //--- Write data vector ---------
+  //-------------------------------
+  val wdataTensInstrDataPipe = VecInit((for (j <- 0 until splitFactorL0) yield 
{
+    ShiftRegister(vmeDataBitsPipe.data, if (writePipeLatency > 0) 1 else 0)
+  }).flatMap(elem => for (k <- 0 until splitFactorL1) yield {
+    elem
+  }).flatMap(elem => for (k <- 0 until splitMemBlockFactor) yield {
+    require(elem.getWidth == tp.memBlockBits)
+    ShiftRegister(
+      elem.asTypeOf(Vec(splitMemBlockFactor, 
UInt((tp.memBlockBits/splitMemBlockFactor).W)))(k),
+      if (writePipeLatency < 2) 0 else writePipeLatency - 1)
+  }))
+  require(wdataTensInstrDataPipe.size == memSizeRatio * splitMemBlockFactor)
+  val wdataTensInstr = Wire(Vec(memSizeRatio * splitMemBlockFactor, 
UInt((tp.memBlockBits/splitMemBlockFactor).W)))
+  for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+    // pipe 1 stage paddingControl per group
+    val padValue = 0.U
+
+    wdataTensInstr(j) := Mux(
+      ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = false.B, en 
= true.B),
+      ShiftRegister(padValue /* a single group total data bits */, 
writePipeLatency),
+      wdataTensInstrDataPipe(j))
+  }
+
+  // THIS wdataDirect writes continous scratchpad data space
+  // It is WRONG for ACC batch > 1
+  // maps group data bits to continous sequence of mem blocks
+  // but wr(x).bits.data is a window in a tensor
+  val wdataDirect = VecInit((for (grIdx <- 0 until splitDataFactor) yield {
+    io.tensor.wr(grIdx).bits.data
+  }).flatMap(elem => for (k <- 0 until groupMemBlockFactor) yield {
+    elem.asTypeOf(Vec(groupMemBlockFactor, 
UInt((tp.memBlockBits/splitMemBlockFactor).W)))(k)
+  }))
+  val wdata = Wire(Vec(memSizeRatio * splitMemBlockFactor, 
UInt((tp.memBlockBits/splitMemBlockFactor).W)))
+  for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+    wdata(j) := Mux(
+      ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en 
= true.B),
+      wdataDirect(j),
+      wdataTensInstr(j))
+  }
+
+  for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+    when(wen(j)) {
+      tensorFile(j).write(waddr(j), wdata(j))
+    }
+  }
+  if (debug) {
+    when(isZeroPadWrite) {
+      printf(s"[TensorLoad] $tensorType isZeroPadWrite data zpDestIdx:%d\n",
+        zpDestIdx)
+    }
+    when (vmeDataFirePipe) {
+      printf(s"[TensorLoad] $tensorType data rdDataDestCol:%d 
rdDataDestIdx:%d\n",
+        rdDataDestCol,
+        rdDataDestIdx)
+    }
+  }
+
+  // read-from-sram
+  for (grIdx <- 0 until splitDataFactor) {
+    val rvalid = ShiftRegister(
+      io.tensor.rd(grIdx).idx.valid, tp.readTensorLatency + 1, resetData = 
false.B, en = true.B)
+    io.tensor.rd(grIdx).data.valid := rvalid
+  }
+
+  val memsInGroup = memSizeRatio * splitMemBlockFactor / splitDataFactor
+  for (grIdx <- 0 until splitDataFactor) {
+    io.tensor.rd(grIdx).data.bits :=
+      VecInit(for (memBlkIdx <- 0 until memsInGroup) yield {
+        tensorFile(grIdx * memsInGroup + memBlkIdx).read(
+          ShiftRegister(io.tensor.rd(grIdx).idx.bits, tp.readTensorLatency),
+          ShiftRegister(io.tensor.rd(grIdx).idx.valid, tp.readTensorLatency, 
resetData = false.B, en = true.B))
+      }).asTypeOf(io.tensor.rd(grIdx).data.bits)
+  }
+
+  // done
+  val loadDone = blocksInFlight === 0.U && commandsDone && state === sBusy
+  localDone := loadDone && paddingDone
+  io.done := ShiftRegister(localDone, writePipeLatency, resetData = false.B, 
en = true.B)
+
+}
+
+//-------------------------
+//--- Fill zero padding ---
+//-------------------------
+
+//----------------------------------------------------------------------------
+// Fill tensors with zeros is padding is defined
+// stride must be used (xstride and ysize) if xpad_0 or xpad_1
+// are not zero and matrix xas more than one row of tensors

Review comment:
       typo: xas




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to