vegaluisjose commented on a change in pull request #27:
URL: https://github.com/apache/tvm-vta/pull/27#discussion_r646700419



##########
File path: hardware/chisel/src/main/scala/core/TensorAlu.scala
##########
@@ -97,38 +97,330 @@ class AluVector(implicit p: Parameters) extends Module {
   io.out.data.valid := valid.asUInt.andR
 }
 
-/** TensorAlu.
- *
- * This unit instantiate the ALU vector unit (AluVector) and go over the
- * micro-ops (uops) which are used to read the source operands (vectors)
- * from the acc-scratchpad and then they are written back the same
- * acc-scratchpad.
- */
-class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module 
{
+class TensorAluIndexGenerator(debug: Boolean = false)(implicit p: Parameters) 
extends Module {
+  val cnt_o_width = (new AluDecode).lp_0.getWidth
+  val cnt_i_width = (new AluDecode).lp_1.getWidth
+
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val last = Output(Bool())
+    val dec = Input(new AluDecode)
+    val valid = Output(Bool())
+    val src_valid = Output(Bool())
+    val dst_idx = Output(UInt(new 
TensorParams(tensorType="acc").memAddrBits.W))
+    val src_idx = Output(UInt(new 
TensorParams(tensorType="acc").memAddrBits.W))
+    val uop_idx = Output(UInt(log2Ceil(p(CoreKey).uopMemDepth).W))
+    val cnt_o = Output(UInt(cnt_o_width.W))
+    val cnt_i = Output(UInt(cnt_i_width.W))
+  })
+
+  io.last := false.B
+
+  val running = RegInit( false.B)
+  val stutter = RegInit( false.B)
+
+  val advance = io.dec.alu_use_imm || stutter
+
+  when( !running && io.start) {
+    running := true.B
+  } .elsewhen( running && !advance) {
+    stutter := true.B
+  } .elsewhen( running && advance) {
+    when ( io.last) {
+      running := false.B
+    }
+    stutter := false.B
+  }
+
+  val cnt_i = Reg( chiselTypeOf(io.dec.lp_1))
+  val dst_i = Reg( chiselTypeOf(io.dst_idx))
+  val src_i = Reg( chiselTypeOf(io.src_idx))
+
+  val cnt_o = Reg( chiselTypeOf(io.dec.lp_0))
+  val dst_o = Reg( chiselTypeOf(io.dst_idx))
+  val src_o = Reg( chiselTypeOf(io.src_idx))
+
+  val uop_idx = Reg( chiselTypeOf(io.dec.uop_end))
+
+  io.valid := running && advance
+  io.src_valid := running && !advance
+  io.dst_idx := dst_i
+  io.src_idx := src_i
+  io.uop_idx := uop_idx
+  io.cnt_o := cnt_o
+  io.cnt_i := cnt_i
+
+  when( !running) {
+    cnt_i := 0.U; dst_i := 0.U; src_i := 0.U;
+    cnt_o := 0.U; dst_o := 0.U; src_o := 0.U;
+    uop_idx := io.dec.uop_begin
+  } .elsewhen (advance) {
+    when (uop_idx =/= io.dec.uop_end - 1.U) {
+      uop_idx := uop_idx + 1.U
+    }.otherwise {
+      uop_idx := io.dec.uop_begin
+      when ( cnt_i =/= io.dec.lp_1 - 1.U) {
+        cnt_i := cnt_i + 1.U
+        dst_i := dst_i + io.dec.dst_1
+        src_i := src_i + io.dec.src_1
+      }.otherwise {
+        when ( cnt_o =/= io.dec.lp_0 - 1.U) {
+          val dst_tmp = dst_o + io.dec.dst_0
+          val src_tmp = src_o + io.dec.src_0
+          cnt_o := cnt_o + 1.U
+          dst_o := dst_tmp
+          src_o := src_tmp
+          cnt_i := 0.U
+          dst_i := dst_tmp
+          src_i := src_tmp
+        } .otherwise {
+          io.last := true.B
+        }
+      }
+    }
+  }
+}
+
+class TensorAluIfc(implicit p: Parameters) extends Module {
   val aluBits = p(CoreKey).accBits
   val io = IO(new Bundle {
     val start = Input(Bool())
     val done = Output(Bool())
-    val inst = Input(UInt(INST_BITS.W))
+    val dec = Input(new AluDecode)
     val uop = new UopMaster
     val acc = new TensorMaster(tensorType = "acc")
     val out = new TensorMaster(tensorType = "out")
   })
+}
+
+class TensorAluPipelined(debug: Boolean = false)(implicit p: Parameters) 
extends TensorAluIfc {
+  val stateBits = 2
+  val inflightBits = 4
+  val dataSplitFactor = p(CoreKey).blockOutFactor
+
+  val sIdle::sRun::sWait::Nil = Enum(3)
+  val state = RegInit(init=sIdle)
+  val inflight = RegInit(0.U(inflightBits.W))
+
+  val index_generator = Module(new TensorAluIndexGenerator)
+  val aluDataReadPipeDelay = 0 // available for pipelining
+
+  // State Machine for compute io.done correctly
+  io.done := false.B
+  when( state === sIdle && io.start) {
+    state := sRun
+  }.elsewhen( state === sRun && index_generator.io.last) {
+    state := sWait
+  }.elsewhen( state === sWait && inflight === 0.U) {
+    state := sIdle
+    io.done := true.B
+  }
+
+  index_generator.io.start := io.start
+  index_generator.io.dec := io.dec
+
+  // second term works around funny clearing in uop register file flopped 
output
+  io.uop.idx.valid := index_generator.io.valid || index_generator.io.src_valid
+  io.uop.idx.bits := index_generator.io.uop_idx
+
+  val valid_001 = ShiftRegister( index_generator.io.valid, 
aluDataReadPipeDelay + 1, resetData=false.B, en = true.B)
+  val valid_002 = RegNext( valid_001, init=false.B)
+  val valid_003 = RegNext( valid_002, init=false.B)
+  val valid_004 = RegNext( valid_003, init=false.B)

Review comment:
       remove extra space here as well




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to