tmoreau89 commented on a change in pull request #32:
URL: https://github.com/apache/tvm-vta/pull/32#discussion_r701577282



##########
File path: hardware/chisel/src/main/scala/core/TensorLoadWideVME.scala
##########
@@ -0,0 +1,767 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import scala.math.pow
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+
+/** TensorLoad.
+ *
+ * Load Cachelines from main memory (DRAM) into SRAM
+ * Mux Cachelines to tensor size memory blocks in
+ * scratchpads (SRAM). Also, there is support for zero padding, while
+ * doing the load. Zero-padding works on the y and x axis, and it is
+ * managed by ZeroPadding.
+ * Read tensors from SRAM.
+
+ * banks number (BN) = CachLineSize (CS) / Tensor bit size (TS)
+ * the number of banks is pow of 2
+ * Scratchpad: Seq(BN) {Mem(TensorsNb/BN, TS)}
+ * Cacheline: Vec(BN,CS/BN)
+
+ * Load:
+ *          Scratchpad
+ *       bank1      bank2
+ *         |          |
+ *        ---        ---
+ * wmask-/   \     -/   \
+ *       -----      -----
+ *        | |        | |
+ *  c     | |        | |
+ *  a  -----|--------  |
+ *  c       |          |
+ *  h       |          |
+ *  e       |          |
+ *  l       |          |
+ *  i ------------------
+ *  n
+ *  e
+
+
+
+
+ */
+class TensorLoadWideVME(tensorType: String = "none", debug: Boolean = false)(
+    implicit p: Parameters)
+    extends Module {
+  val tp = new TensorParams(tensorType)
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val done = Output(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+    val baddr = Input(UInt(mp.addrBits.W))
+    val vme_rd = new VMEReadMaster
+    val tensor = new TensorClient(tensorType)
+  })
+  // the delay cycles of write pipe. Needed to deliver singal over physical 
distance
+  val writePipeLatency = tp.writePipeLatency
+
+  val sIdle :: sBusy :: Nil =
+    Enum(2)
+  val state = RegInit(sIdle)
+
+  val isBusy = state === sBusy
+  val localDone = Wire(Bool())
+  when(io.start) {
+    state := sBusy
+  }.elsewhen(localDone) {
+    state := sIdle
+  }
+
+  val dec = io.inst.asTypeOf(new MemDecode)
+
+  val readVMEDataLatency = tp.readVMEDataLatency
+  val vmeDataBitsPipe = ShiftRegister(io.vme_rd.data.bits, readVMEDataLatency, 
en = true.B)
+  val vmeDataValidPipe = ShiftRegister(io.vme_rd.data.valid, 
readVMEDataLatency, resetData = false.B, en = true.B)
+  val vmeDataReadyPipe = ShiftRegister(io.vme_rd.data.ready, 
readVMEDataLatency, resetData = true.B, en = true.B)
+  val vmeDataFirePipe = vmeDataValidPipe & vmeDataReadyPipe
+
+  //--------------------------------------
+  //--- Generate data load VME command ---
+  //--------------------------------------
+  val vmeCmd = Module (new GenVMECmdWideTL(tensorType, debug))
+  vmeCmd.io.start := io.start
+  vmeCmd.io.isBusy := isBusy
+  vmeCmd.io.inst := io.inst
+  vmeCmd.io.baddr := io.baddr
+  vmeCmd.io.vmeCmd <> io.vme_rd.cmd
+  val readLen = vmeCmd.io.readLen
+  val commandsDone = vmeCmd.io.done
+
+  require (mp.dataBits >= tp.tensorSizeBits,
+    "-F- Chacheline width must be larger than tensor bit size")
+  require(pow(2, log2Ceil(mp.dataBits)) == mp.dataBits,
+    "-F- Chacheline width must be pow of 2")
+  require(pow(2, log2Ceil(tp.tensorSizeBits)) == tp.tensorSizeBits,
+    "-F- Tensor size bits must be pow of 2")
+
+  // me mux puts tensors in a single memory line of Cacheline (CL) bits
+  val tensorsInClNb = tp.clSizeRatio
+  val tensorsInClNbWidth = log2Ceil(tensorsInClNb)
+
+  //--------------------------------------
+  //--- count how many CLs not receved ---
+  //--------------------------------------
+
+  // the address size of scratchpad memory
+  val clCntIdxWdth = log2Ceil(tp.memDepth/tensorsInClNb) + 1
+  // Nb of CLs requestd, not received.
+  val clInFlight = Reg(UInt(clCntIdxWdth.W))
+  when(io.start) {
+    clInFlight := 0.U
+  }.elsewhen(isBusy && io.vme_rd.cmd.fire() && !vmeDataFirePipe) {
+    clInFlight := clInFlight + readLen
+  }.elsewhen(isBusy && io.vme_rd.cmd.fire() && vmeDataFirePipe) {
+    clInFlight := clInFlight + readLen - 1.U
+  }.elsewhen(isBusy && !io.vme_rd.cmd.fire() && vmeDataFirePipe) {
+    assert(clInFlight > 0.U)
+    clInFlight := clInFlight - 1.U
+  }.otherwise {
+    clInFlight := clInFlight
+  }
+
+  //---------------------
+  //--- Read VME data ---
+  //---------------------
+
+  val readData = Module(new ReadVMEDataWide(tensorType, debug))
+  readData.io.start := io.start
+  readData.io.vmeData.valid := vmeDataValidPipe
+  readData.io.vmeData.bits := vmeDataBitsPipe
+  assert(!readData.io.vmeData.valid || readData.io.vmeData.ready,
+    "-F- Expecting const ready. Fix ReadVMEData to receive data piped after 
ready")
+  io.vme_rd.data.ready := readData.io.vmeData.ready
+  // write mask defined number of elems strating with offset in SRAM line
+  val rdDataWrIdx  = readData.io.destIdx // SP index vector
+  val rdDataWrData = readData.io.destData // SP data vector
+  val rdDataWrEn   = readData.io.destMask // write enable vector
+
+  //-------------------------
+  //--- Fill zero padding ---
+  //-------------------------
+
+  val fillPadding = Module(new ZeroPadding(tensorType, debug))
+  fillPadding.io.canWriteMem := !vmeDataFirePipe
+  fillPadding.io.inst := io.inst
+  fillPadding.io.start := io.start
+
+  val isZeroPadWrite = fillPadding.io.tensorIdx.valid // Store zero filled 
tensor, zpDestIdx is valid
+  val zpDestIdx = fillPadding.io.tensorIdx.bits >>  tensorsInClNbWidth // SP 
idx
+  val zpDestMask =
+    if (tensorsInClNb == 1) 1.U
+    else  UIntToOH(fillPadding.io.tensorIdx.bits (tensorsInClNbWidth - 1, 0)) 
// tensor in a memory line
+  val paddingDone = fillPadding.io.done
+
+  //--------------------
+  //--- Write memory ---
+  //--------------------
+
+  // depth is reduced by dataBlock/tensorSize ratio
+  // width is dataBlock bits split into tensor bits
+  // each tensor is split into group bits
+  // group bits can be read/written independently
+
+
+  val splitDataFactor = tp.splitWidth * tp.splitLength
+  val splitMemFactor = tp.splitMemsFactor
+  val groupSizeBits = tp.tensorSizeBits/splitDataFactor
+  val memSizeBits = groupSizeBits/splitMemFactor
+  val tensorFile = Seq.fill(tensorsInClNb * splitDataFactor*splitMemFactor) {
+    SyncReadMem(tp.memDepth/tensorsInClNb, UInt(memSizeBits.W))
+  }
+
+  // direct write
+  val directWrIdx = for (grpIdx <- 0 until splitDataFactor) yield {
+    io.tensor.wr(grpIdx).bits.idx >> tensorsInClNbWidth // SP idx
+  }
+  val directWrMask = for (grpIdx <- 0 until splitDataFactor) yield {
+    Mux(
+      io.tensor.wr(grpIdx).valid,
+      if(tensorsInClNb == 1) 1.U
+      else UIntToOH(io.tensor.wr(grpIdx).bits.idx(tensorsInClNbWidth - 1, 
0)),// tensor in a memory line
+      0.U)
+  }
+
+  // THIS directWrData writes continous scratchpad data space
+  // It is WRONG for ACC is batch is > 1
+  // maps group data bits to continous sequence of mem blocks
+  // but wr(x).bits.data is a window in a tensor
+  val directWrData = VecInit(for (grpIdx <- 0 until splitDataFactor) yield {
+    io.tensor.wr(grpIdx).bits.data
+  }).asTypeOf(UInt(tp.tensorSizeBits.W))
+
+
+  val wmask = Wire(Vec(tensorsInClNb*splitDataFactor*splitMemFactor, Bool()))
+  for (i <- 0 until tensorsInClNb) {
+    for (grpIdx <- 0 until splitDataFactor) {
+      for (memIdx <- 0 until splitMemFactor) { // duplicate control
+        wmask(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + 
memIdx) :=
+          Mux(
+            ShiftRegister(state === sIdle, writePipeLatency, resetData = 
true.B, en = true.B),
+            directWrMask(grpIdx)(i),
+            Mux(
+              ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = 
false.B, en = true.B),
+              ShiftRegister(zpDestMask(i), writePipeLatency),
+              Mux(
+                ShiftRegister(vmeDataFirePipe, writePipeLatency, resetData = 
false.B, en = true.B),
+                ShiftRegister(rdDataWrEn(i), writePipeLatency),
+                false.B)))
+      }
+    }
+  }
+
+  val wdata = Wire(Vec(tensorsInClNb*splitDataFactor, UInt(groupSizeBits.W)))
+  for (i <- 0 until tensorsInClNb){
+    for (grpIdx <- 0 until splitDataFactor) {
+      val zpDestData = 0.U
+      wdata(i*splitDataFactor + grpIdx) := Mux(
+        ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, 
en = true.B),
+        io.tensor.wr(grpIdx).bits.data.asTypeOf(UInt(groupSizeBits.W)),
+        Mux(
+          ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = false.B, 
en = true.B),
+          ShiftRegister(zpDestData /* group size zero */, writePipeLatency),
+          ShiftRegister(
+            (rdDataWrData(i).asTypeOf(Vec(splitDataFactor, 
UInt(groupSizeBits.W))))(grpIdx), writePipeLatency)))
+    }
+  }
+
+  val widx = Wire(Vec(tensorsInClNb*splitDataFactor*splitMemFactor, 
UInt(tp.memAddrBits.W)))
+  for (i <- 0 until tensorsInClNb) {
+    for (grpIdx <- 0 until splitDataFactor) {
+      for (memIdx <- 0 until splitMemFactor) { // duplicate control
+        widx(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + 
memIdx) :=
+          Mux(
+            ShiftRegister(state === sIdle, writePipeLatency, resetData = 
true.B, en = true.B),
+            directWrIdx(grpIdx),
+            Mux(
+              ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = 
false.B, en = true.B),
+              ShiftRegister(zpDestIdx, writePipeLatency),
+              ShiftRegister(rdDataWrIdx(i), writePipeLatency)))
+      }
+    }
+  }
+
+  for (i <- 0 until tensorsInClNb) {
+    for (grpIdx <- 0 until splitDataFactor) {
+      for (memIdx <- 0 until splitMemFactor) { // duplicate control
+        when(wmask(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor 
+ memIdx)) {
+          tensorFile(i*splitDataFactor*splitMemFactor + grpIdx * 
splitMemFactor + memIdx).write(
+            widx(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + 
memIdx),
+            wdata(i*splitDataFactor + grpIdx).asTypeOf(
+              Vec(splitMemFactor, UInt(memSizeBits.W)))(memIdx))
+        }
+      }
+    }
+  }
+  if (debug) {
+    when(isZeroPadWrite) {
+      printf(s"[TensorLoad] $tensorType isZeroPadWrite data zpDestIdx:%d\n",
+        zpDestIdx)
+    }
+  }
+
+  // read-from-sram
+  for (grpIdx <- 0 until splitDataFactor) {
+    val rIdx = io.tensor.rd(grpIdx).idx.bits >> tensorsInClNbWidth // SP idx
+    val rMask =
+      Mux(
+        io.tensor.rd(grpIdx).idx.valid,
+        if(tensorsInClNb == 1) 1.U
+        else UIntToOH(io.tensor.rd(grpIdx).idx.bits(tensorsInClNbWidth - 1, 
0)),// tensor in a memory line
+        0.U)
+
+    val rdataVec =   for (i <- 0 until tensorsInClNb) yield {
+      VecInit(for (memIdx <- 0 until splitMemFactor) yield {
+        tensorFile(
+          i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + 
memIdx).read(
+            ShiftRegister(rIdx, tp.readTensorLatency),
+            ShiftRegister(VecInit(rMask.asBools)(i), tp.readTensorLatency, 
resetData = false.B, en = true.B))
+      }).asUInt
+    }
+
+    val rdata = Wire(UInt(tp.tensorSizeBits.W))
+    rdata := Mux1H(ShiftRegister(rMask, tp.readTensorLatency + 1), rdataVec)
+    io.tensor.rd(grpIdx).data.bits := 
rdata.asTypeOf(io.tensor.rd(grpIdx).data.bits.cloneType)
+
+    val rvalid = ShiftRegister(
+      io.tensor.rd(grpIdx).idx.valid, tp.readTensorLatency + 1, resetData = 
false.B, en = true.B)
+    io.tensor.rd(grpIdx).data.valid := rvalid
+  }
+
+  // done
+  val loadDone = clInFlight === 0.U && commandsDone && state === sBusy
+  localDone := loadDone && paddingDone
+  io.done := ShiftRegister(localDone, writePipeLatency, resetData = false.B, 
en = true.B)
+}
+
+//---------------------
+//--- Read VME data ---
+//---------------------
+//----------------------------------------------------------------------------
+// Read VME data. Generate Memory index and data
+// transaction TAG is a data block offset in scratchpad
+// Different transactions are identified by atag change
+// SAME DESTINATION SUBSEQUENT REQUESTS IN ONE INSTRUCTION LEADS TO UNDEFINED 
BEHAVIOR
+//----------------------------------------------------------------------------
+class ReadVMEDataWide(tensorType: String = "none", debug: Boolean = false)(
+    implicit p: Parameters)
+    extends Module {
+  val tp = new TensorParams(tensorType)
+  val mp = p(ShellKey).memParams
+  val wmaskWidth = mp.dataBits/tp.tensorSizeBits
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val vmeData = Flipped(Decoupled(new VMEData))
+
+    val destIdx  = Output(Vec(tp.clSizeRatio, UInt(tp.memAddrBits.W)))
+    val destData = Output(Vec(tp.clSizeRatio, UInt(tp.tensorSizeBits.W)))
+    val destMask = Output(Vec(tp.clSizeRatio, Bool()))
+  })
+
+  io.vmeData.ready := true.B // always ready to read VME data
+
+  require(pow(2, log2Ceil(tp.tensorLength)) == tp.tensorLength,
+    "-F- Tensor length must be 2^. Using shift and bits to divide.")
+  val blkIdxWdth = log2Ceil(tp.memDepth) // the size of scratchpad in cache 
lines
+
+  //decode data destination
+  val vmeTagDecode = io.vmeData.bits.tag
+  val vmeTagDecodeLast = Reg(vmeTagDecode.cloneType) // store tag to identify 
a new burst
+  val clBytes = mp.dataBits / 8 // cacheline bytes
+  val elemBytes = tp.tensorLength * tp.tensorWidth * tp.tensorElemBits / 8 // 
bytes in tensor
+  val rdDataMaskDecodeWidth = if (wmaskWidth == 1) 1 else 
(log2Ceil(wmaskWidth) + 1)
+  val rdDataElemIdx = vmeTagDecode(vmeTagDecode.getWidth - 1, 2 * 
rdDataMaskDecodeWidth)
+  val rdFstOffsetNb = if (rdDataMaskDecodeWidth == 0) {
+    0.U
+  } else {
+    val readOffset  = vmeTagDecode(2 * rdDataMaskDecodeWidth - 1, 
rdDataMaskDecodeWidth)
+    readOffset
+  }
+  val rdLstNb = if (rdDataMaskDecodeWidth == 0) {
+    1.U
+  } else {
+    val readNb  = vmeTagDecode(rdDataMaskDecodeWidth - 1, 0)
+    assert(!io.vmeData.valid || readNb > 0.U,"-F- Expecting some elements to 
read")
+    readNb
+  }
+  val wrMask1st = if (rdDataMaskDecodeWidth == 0) {
+    1.U
+  } else {
+    Reverse(VecInit(for(idx <- 0 until wmaskWidth) yield {
+      idx.U < tp.clSizeRatio.U - rdFstOffsetNb
+    }).asUInt)
+  }
+  val wrMaskLast = if (rdDataMaskDecodeWidth == 0) {
+    1.U
+  } else {
+    VecInit(for(idx <- 0 until wmaskWidth) yield {
+      idx.U < rdLstNb
+    }).asUInt
+  }
+  val rdDataElemDestIdx = Wire(UInt(tp.memAddrBits.W)) // this is an idx  of a 
tensor
+  val rdDataElemDestIdxNext = Reg(UInt(tp.memAddrBits.W))
+  val rdDataClDestIdx = rdDataElemDestIdx >> log2Ceil(tp.clSizeRatio)
+  val rdDataDestElemOffset = rdDataElemDestIdx % tp.clSizeRatio.U
+
+  val vmeTagDecodeLastValid = Wire(Bool())
+  val vmeTagDecodeLastValidNext = RegNext(
+    next = vmeTagDecodeLastValid,
+    init = false.B)
+  when(io.start) {
+    vmeTagDecodeLastValid :=false.B // reset tag valid
+  }.elsewhen(io.vmeData.fire()) {
+    vmeTagDecodeLastValid := true.B // set tag valid on a new read
+  }.otherwise {
+    vmeTagDecodeLastValid := vmeTagDecodeLastValidNext // keep value
+  }
+
+  val isFirstPulse = Wire(Bool())
+  val isLastPulse = io.vmeData.bits.last
+  val wmaskSel =
+    Mux(
+      isFirstPulse && isLastPulse,
+      wrMask1st & wrMaskLast,
+      Mux(
+        isFirstPulse,
+        wrMask1st,
+        Mux(
+          isLastPulse,
+          wrMaskLast,
+          ((1 << wmaskWidth) - 1).U)))
+  val wmask = Mux(io.vmeData.fire(), wmaskSel, 0.U)
+  rdDataElemDestIdx := DontCare
+  isFirstPulse := false.B
+  when(io.vmeData.fire()) {
+    when (
+      !vmeTagDecodeLastValidNext ||
+      (vmeTagDecodeLastValidNext &&
+        vmeTagDecode.asUInt =/= vmeTagDecodeLast.asUInt)) {
+
+      vmeTagDecodeLast := vmeTagDecode // a new burst
+      isFirstPulse := true.B
+      rdDataElemDestIdx := rdDataElemIdx
+      // dont incrememt first partial read pulse
+      rdDataElemDestIdxNext := rdDataElemIdx + PopCount(wmask)
+    }.otherwise {
+      rdDataElemDestIdxNext := rdDataElemDestIdxNext + PopCount(wmask)
+      rdDataElemDestIdx := rdDataElemDestIdxNext
+    }
+  }
+
+
+  val srcData  = io.vmeData.bits.data.asTypeOf(Vec(tp.clSizeRatio, 
UInt(tp.tensorSizeBits.W)))
+  val srcOffset = Wire(Vec(tp.clSizeRatio, UInt((log2Ceil(tp.clSizeRatio) + 
1).W)))
+  val srcIdx = Wire(Vec(tp.clSizeRatio, UInt(log2Ceil(tp.clSizeRatio).W)))
+
+  // D(j+d) = S(j+s)  replace i=j+d --> D(i) = S(i-d+s)
+  for (i <- 0 until tp.clSizeRatio) {
+    srcOffset(i) := i.U + Mux(isFirstPulse, rdFstOffsetNb, 0.U)
+    srcIdx(i) := srcOffset(i) -% rdDataDestElemOffset
+    val srcIdxOH = UIntToOH(srcIdx(i))
+    io.destData(i) := Mux1H(srcIdxOH,srcData)
+    io.destMask(i) := Mux1H(srcIdxOH, wmask)
+
+    //if dest offset overflow, incr that dest idx
+    val incrIdx = if (tp.clSizeRatio == 1 ) {
+      0.U
+    } else {
+      Mux(srcOffset(i) >= rdDataDestElemOffset, 0.U, 1.U)
+    }
+    io.destIdx(i) := rdDataClDestIdx + incrIdx
+
+
+  }
+
+
+}
+
+// transaction TAG is a data block offset in scratchpad
+// Different transactions are identified by atag change
+// SAME DESTINATION SUBSEQUENT REQUESTS IN ONE INSTRUCTION LEADS TO UNDEFINED 
BEHAVIOR
+class GenVMECmdWide(tensorType: String = "none", debug: Boolean = false)(
+    implicit p: Parameters)
+    extends Module {
+  val tp = new TensorParams(tensorType)
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val isBusy = Input(Bool())
+    val updateState = Input(Bool())
+    val canSendCmd = Input(Bool())
+    val baddr = Input(UInt(mp.addrBits.W))
+    val vmeCmd = Decoupled(new VMECmd)
+    val readLen = Output(UInt((mp.lenBits + 1).W))
+    val done = Output(Bool())
+    val fstPulseDataStart = Output(UInt((log2Ceil(tp.clSizeRatio) + 1).W))
+    val lstPulseDataEnd = Output(UInt((log2Ceil(tp.clSizeRatio) + 1).W))
+    val spElemIdx = Output(UInt(tp.memAddrBits.W))
+
+    val ysize = Input(UInt(M_SIZE_BITS.W))
+    val xsize = Input(UInt(M_SIZE_BITS.W))
+    val xstride = Input(UInt(M_STRIDE_BITS.W))
+    val dram_offset = Input(UInt(M_DRAM_OFFSET_BITS.W))
+    val sram_offset = Input(UInt(M_SRAM_OFFSET_BITS.W))
+    val xpad_0 = Input(UInt(M_PAD_BITS.W))
+    val xpad_1 = Input(UInt(M_PAD_BITS.W))
+    val ypad_0 = Input(UInt(M_PAD_BITS.W))
+  })
+
+  val clBytes = mp.dataBits / 8 // cacheline bytes
+  val elemBytes = tp.tensorLength * tp.tensorWidth * tp.tensorElemBits / 8 // 
bytes in tensor
+  val stride = Wire(Bool()) // flags change to the next row to read
+
+  //----------------------------------------
+  //--- Count lines of DRAM memory lines ---
+  //----------------------------------------
+
+  // set which source row of data to read. io.ysize defines the number of rows
+  val dramLineIdx = Reg(UInt(io.ysize.getWidth.W)) // current row of stride 
read
+  when (io.start) {
+    dramLineIdx := 0.U // 1st row
+  }.elsewhen (stride) {
+    dramLineIdx := dramLineIdx + 1.U // increment row
+  }.otherwise {
+    dramLineIdx := dramLineIdx // stay in the row
+  }
+
+  // calculate address of DRAM memory line begin (initial/stride)
+  val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
+  val dramInitialAddr = (io.dram_offset << 
log2Ceil(elemBytes)).asTypeOf(UInt(mp.addrBits.W))
+  val xferElemInitAddr = io.baddr | dramInitialAddr // SHOULD have + here?
+  //aling address to CL size
+  // lower bits - elem offset in a cachline
+  val dramClAddrAlignNotMask = ((BigInt(1) << log2Ceil(clBytes)) - 
1).U.asTypeOf(xferElemInitAddr)
+  // upper bits - cacheline alinement
+  val dramClAddrAlignMask = ~dramClAddrAlignNotMask
+  val xferClInitAddr = xferElemInitAddr & dramClAddrAlignMask
+  val rdLineElemBeginAddr = Reg(UInt(mp.addrBits.W)) // DRAM address of xsize 
tensors memory line
+  val rdLineClBeginAddr = rdLineElemBeginAddr & dramClAddrAlignMask
+  // begin of the next DRAM memory line
+  val nextLineBeginElemAddr = rdLineElemBeginAddr + (io.xstride << 
log2Ceil(elemBytes))
+  val nextLineBeginClAddr = nextLineBeginElemAddr & dramClAddrAlignMask
+  when (io.start) {
+    rdLineElemBeginAddr := xferElemInitAddr
+  }.elsewhen (stride) {
+    rdLineElemBeginAddr := nextLineBeginElemAddr
+  }.otherwise {
+    rdLineElemBeginAddr := rdLineElemBeginAddr
+  }
+
+  //-----------------------------------------------------
+  //--- Calculate current DRAM address of transaction ---
+  //-----------------------------------------------------
+
+  val rdLen = Wire(UInt((mp.lenBits + 1).W)) // read cmd transaction length. 
It is <= maxTransfer
+  val rdLineAddr = Reg(UInt(mp.addrBits.W)) // current DRAM address of command
+  when (io.start) {
+    rdLineAddr := xferClInitAddr
+  }.elsewhen (io.updateState) {
+    when(stride) {
+      rdLineAddr := nextLineBeginClAddr
+    }.otherwise {
+      rdLineAddr := rdLineAddr + (rdLen << log2Ceil(clBytes))
+    }
+  }.otherwise {
+    rdLineAddr := rdLineAddr
+  }
+
+  //total load length in cachelines
+  val rdLineBytes = io.xsize << log2Ceil(elemBytes)
+
+  //First transaction in a line length (1st or stride)
+  val maxTransfer = (1 << mp.lenBits).U // max number of pulses in transfer
+  val maxTrBytes = maxTransfer << log2Ceil(clBytes)
+  val rdLen1stMaxTransBytes = maxTrBytes - rdLineClBeginAddr % maxTrBytes
+  // get the number of cachelines till maxTrBytes aligned address
+  val rdLen1stMaxTransClNb = rdLen1stMaxTransBytes >> log2Ceil(clBytes)
+
+  //Transaction begin mask. Number of tensors to read from right
+  val rd1stPulseOffsetBytes = rdLineElemBeginAddr % clBytes.U
+  assert(rd1stPulseOffsetBytes >> log2Ceil(elemBytes) <= tp.clSizeRatio.U,
+    "-F- Expecting the number of tensors to skip in CL")
+  val rd1stPulseOffsetTensNb =  Wire(UInt((log2Ceil(tp.clSizeRatio) + 1).W))
+  rd1stPulseOffsetTensNb := rd1stPulseOffsetBytes >> log2Ceil(elemBytes)
+
+  val rdLineClNbTmp = (rdLineBytes + rd1stPulseOffsetBytes) >> 
log2Ceil(clBytes)
+  val rdLineClNb =
+    Mux((rdLineBytes + rd1stPulseOffsetBytes) % clBytes.U === 0.U, 
rdLineClNbTmp, rdLineClNbTmp + 1.U)
+
+  //Transaction end mask. Number of tensors to read from left
+  val rdLastPulseBytes =  (rdLineElemBeginAddr + rdLineBytes) % clBytes.U
+  assert(rdLastPulseBytes >> log2Ceil(elemBytes) <= (clBytes/elemBytes).U,
+    "-F- Expecting the number of active tensors in CL")
+  val rdLastPulseTensNb =  Wire(UInt((log2Ceil(clBytes/elemBytes) + 1).W))
+  val rdLastPulseTensNbTmp =  rdLastPulseBytes >> log2Ceil(elemBytes)
+  rdLastPulseTensNb :=  Mux(rdLastPulseTensNbTmp === 0.U, 
(clBytes/elemBytes).U, rdLastPulseTensNbTmp)
+
+
+
+  //--------------------------------------
+  //--- Generate data load VME command ---
+  //--------------------------------------
+
+  val rdCmdStartIdxValid = Wire(Bool()) // Command is valid
+  val startIssueCmdRead = Wire(Bool()) // First transaction in io.xsize 
transfer
+  val rdCmdStartIdx = Reg(UInt(log2Ceil(tp.memDepth).W)) // Scratchpad data 
block index for the first transaction
+  val commandsDone = RegInit(true.B) // Done generating VME commands
+  // counts the number of CLs read in a xsize line
+  val clReadIdx = Reg(UInt((io.xsize.getWidth + log2Ceil(elemBytes) - 
log2Ceil(clBytes)).W))
+  val newReadRow = clReadIdx === 0.U // flags the first read of io.xsize
+
+  // set how many blocks of data being loaded
+  commandsDone := commandsDone
+  when (io.start || stride) {
+    clReadIdx := 0.U
+    commandsDone := false.B
+  }.elsewhen (io.updateState) {
+    val nextClIdx = clReadIdx + rdLen
+    clReadIdx := nextClIdx // THIS IS WHEN A NEW VME CMD HAPPENS
+    when (nextClIdx === rdLineClNb && dramLineIdx === io.ysize - 1.U) {
+      commandsDone := true.B
+    }
+  }.otherwise {
+    clReadIdx := clReadIdx
+  }
+
+  //when the whole xsize row read commands are sent, go for the next src row
+  when((clReadIdx === rdLineClNb - rdLen) && (dramLineIdx =/= io.ysize - 1.U) 
&& io.updateState) {
+    stride := true.B
+  }.otherwise {
+    stride := false.B
+  }
+
+  // current transaction tensors to read nb in 1st and last pulses
+  val rdCmd1stPluseOffsetTensNb = Wire(rd1stPulseOffsetTensNb.cloneType)
+  val rdCmdLastPluseTensNb = Wire(rdLastPulseTensNb.cloneType)
+  when(newReadRow) {
+    // first read in line
+    rdCmd1stPluseOffsetTensNb := rd1stPulseOffsetTensNb
+  }.otherwise {
+    // any other read
+    rdCmd1stPluseOffsetTensNb := 0.U
+  }
+  when (clReadIdx === rdLineClNb - rdLen) {
+    // last read in line
+    rdCmdLastPluseTensNb := rdLastPulseTensNb
+  }.otherwise {
+    // any other read
+    rdCmdLastPluseTensNb := (clBytes/elemBytes).U
+  }
+
+  //when the whole xsize row read commands are sent, go for the next src row
+  when((clReadIdx === rdLineClNb - rdLen) && (dramLineIdx =/= io.ysize - 1.U) 
&& io.updateState) {
+    stride := true.B
+  }.otherwise {
+    stride := false.B
+  }
+
+  assert(!io.isBusy || rdLineClNb >= clReadIdx)// define how many cachelines 
to read at this cycle
+  val clRemained = rdLineClNb - clReadIdx
+  when (newReadRow) {
+    when(clRemained < rdLen1stMaxTransClNb) {
+      rdLen := clRemained
+    }.otherwise {
+      rdLen := rdLen1stMaxTransClNb
+    }
+  }.otherwise {
+    when(clRemained < maxTransfer) {
+      rdLen := clRemained
+    }.otherwise {
+      rdLen := maxTransfer
+    }
+  }
+  // block index of the read data row (xsize). Modified by zero padding
+  val totalWidth = io.xsize + io.xpad_0 + io.xpad_1 // width of scratchpad 
matrix in tensors
+  // instead of multiplying total width by ypad_0 do incremental addition.
+  //Should cost ypad_0 cycles to issue 1st read cmd
+  // counts src matrix with y padding rows of tensors
+  val currentRowIdx = Reg(UInt((io.ysize.getWidth + io.ypad_0.getWidth).W))
+  // start to issue read cmd
+  rdCmdStartIdxValid := currentRowIdx >= io.ypad_0 &&
+    currentRowIdx < (io.ysize + io.ypad_0) &&
+    io.isBusy &&
+    !commandsDone
+  when (io.start) {
+    currentRowIdx := 0.U
+    rdCmdStartIdx := io.sram_offset + io.xpad_0 // this index is in tensors
+  }.elsewhen (io.isBusy && (currentRowIdx < io.ypad_0 || stride)) {
+    rdCmdStartIdx := rdCmdStartIdx + totalWidth
+    currentRowIdx := currentRowIdx + 1.U
+  }
+  startIssueCmdRead := false.B
+  when(newReadRow && rdCmdStartIdxValid) {
+    startIssueCmdRead := true.B
+  }
+
+  //-------------------------------------
+  //--- execute VME data load command ---
+  //-------------------------------------
+
+  require(pow(2, log2Ceil(tp.tensorLength)) == tp.tensorLength,
+    "-F- Tensor length must be 2^. Using shift and bits to divide.")
+  val blkIdxWdth = log2Ceil(tp.memDepth) // the size of scratchpad
+
+  val rdCmdDestElemIdx = Wire(UInt(tp.memAddrBits.W)) // element(tensor) size 
block index in a scratchpad
+  val rdCmdDestElemIdxNext = Reg(rdCmdDestElemIdx.cloneType)
+  rdCmdDestElemIdxNext := rdCmdDestElemIdxNext
+  rdCmdDestElemIdx := rdCmdDestElemIdxNext
+
+  val rdCmdValid = Wire(Bool())
+  // the number of tensors read in transaction
+  val rdCmdTransactionTensNb = (rdLen << log2Ceil(clBytes/elemBytes)) - 
rdCmd1stPluseOffsetTensNb
+  //increment scratch pad destination index
+  when(rdCmdStartIdxValid) {
+    rdCmdValid := true.B
+    when(startIssueCmdRead) {
+      rdCmdDestElemIdx := rdCmdStartIdx
+      rdCmdDestElemIdxNext:= rdCmdStartIdx + rdCmdTransactionTensNb
+    }.elsewhen (io.updateState) {
+      // increment block position by transaction length
+      rdCmdDestElemIdxNext:= rdCmdDestElemIdxNext + rdCmdTransactionTensNb
+    }
+  }.otherwise {
+    rdCmdValid := false.B
+  }
+
+  // read-from-dram
+  require(io.vmeCmd.bits.tag.getWidth >= rdCmdDestElemIdx.getWidth +
+    rdCmdLastPluseTensNb.getWidth + rdCmd1stPluseOffsetTensNb.getWidth,
+    s"-F- Tensor ${tensorType} Not enough VME tag bits to store transaction" +
+    s" tag. need:${rdCmdDestElemIdx.getWidth + rdCmdLastPluseTensNb.getWidth + 
rdCmd1stPluseOffsetTensNb.getWidth}")
+  io.vmeCmd.valid := rdCmdValid && io.canSendCmd
+  io.vmeCmd.bits.addr := rdLineAddr
+  io.vmeCmd.bits.len := rdLen - 1.U
+  assert(!io.vmeCmd.valid || ((rdLen << log2Ceil(clBytes)) <= maxTrBytes - 
rdLineAddr % maxTrBytes),
+    s"-F- ${tensorType} DRAM page alignment failure. DRAM " +
+    s"address + len overlaps mp.lenBits*memBlockSize alignment %x %x",
+    rdLineAddr, rdLen)
+  io.vmeCmd.bits.tag := Cat(rdCmdDestElemIdx, Cat(rdCmd1stPluseOffsetTensNb, 
rdCmdLastPluseTensNb))
+  io.readLen := rdLen
+  io.spElemIdx := rdCmdDestElemIdx // scratchpad tensor idx
+  io.fstPulseDataStart := rdCmd1stPluseOffsetTensNb // first pulse data start
+  io.lstPulseDataEnd := rdCmdLastPluseTensNb // last pulse data end
+  io.done := commandsDone
+}
+class GenVMECmdWideTL(tensorType: String = "none", debug: Boolean = false)(
+    implicit p: Parameters)
+    extends Module {
+  val tp = new TensorParams(tensorType)
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val isBusy = Input(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+    val baddr = Input(UInt(mp.addrBits.W))
+    val vmeCmd = Decoupled(new VMECmd)
+    val readLen = Output(UInt((mp.lenBits + 1).W))
+    val done = Output(Bool())
+  })
+  //del val sizeFactor = tp.tensorLength * tp.numMemBlock
+

Review comment:
       remove comment above




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to