You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/06/08 01:26:49 UTC

[tvm-vta] branch main updated: Chisel Pipelined ALU (#27)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-vta.git


The following commit(s) were added to refs/heads/main by this push:
     new d5e8117  Chisel Pipelined ALU  (#27)
d5e8117 is described below

commit d5e8117ce1535c527c536e115b4e58d53817b82f
Author: Abhijit Davare <ab...@intel.com>
AuthorDate: Mon Jun 7 18:26:41 2021 -0700

    Chisel Pipelined ALU  (#27)
    
    * Remove parameter values from case class
    
    * Add new blockOutFactor parameter with default value = 1
    
    * Support split access
    
    * Modify to support split interface, minor refactoring
    
    * Use split read/write intefaces
    
    * Pipelined ALU with split interfaces
    
    * Modify instantiation and usage of pipelined ALU and split interfaces
    
    * Don't use internal Random by default
    
    * Change tester name
    
    * Add generic tester class
    
    * Derive from GenericTest, minor refactoring
    
    * Test ALU index generator and pipelined ALU
    
    * Add ASF header
    
    * Bugfix: delay slicing index by a cycle to match SyncReadMem read delay
    
    * Formatting, comment, and minor refactoring changes for clarity
    
    * Fix scalastyle issues for test files
---
 hardware/chisel/src/main/scala/core/Compute.scala  | 103 +++++-
 hardware/chisel/src/main/scala/core/Configs.scala  |   1 +
 hardware/chisel/src/main/scala/core/Core.scala     |  29 +-
 hardware/chisel/src/main/scala/core/LoadUop.scala  |   3 +-
 .../chisel/src/main/scala/core/TensorAlu.scala     | 401 ++++++++++++++++++---
 .../chisel/src/main/scala/core/TensorGemm.scala    | 106 +++---
 .../chisel/src/main/scala/core/TensorLoad.scala    |  62 +++-
 .../chisel/src/main/scala/core/TensorStore.scala   |  28 +-
 .../chisel/src/main/scala/core/TensorUtil.scala    | 186 ++++++++--
 .../chisel/src/test/scala/unittest/AluTest.scala   |  42 ++-
 .../scala/unittest/Generic.scala}                  |  48 ++-
 .../chisel/src/test/scala/unittest/Launcher.scala  |   2 +-
 .../src/test/scala/unittest/TensorAluTest.scala    | 252 +++++++++++++
 .../test/scala/unittest/utils/RandomArray.scala    |   7 +-
 14 files changed, 1042 insertions(+), 228 deletions(-)

diff --git a/hardware/chisel/src/main/scala/core/Compute.scala b/hardware/chisel/src/main/scala/core/Compute.scala
index a1e7fad..0055a25 100644
--- a/hardware/chisel/src/main/scala/core/Compute.scala
+++ b/hardware/chisel/src/main/scala/core/Compute.scala
@@ -19,6 +19,8 @@
 
 package vta.core
 
+import scala.math.pow
+
 import chisel3._
 import chisel3.util._
 import vta.util.config._
@@ -32,7 +34,7 @@ import vta.shell._
  * - Compute ALU instructions (tensorAlu module)
  * - Compute GEMM instructions (tensorGemm module)
  */
-class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class Compute(debug: Boolean = false)(implicit val p: Parameters) extends Module {
   val mp = p(ShellKey).memParams
   val io = IO(new Bundle {
     val i_post = Vec(2, Input(Bool()))
@@ -58,6 +60,9 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val tensorGemm = Module(new TensorGemm)
   val tensorAlu = Module(new TensorAlu)
 
+  // try to use the acc closest to top IO
+  val topAccGrpIdx = tensorGemm.io.acc.closestIOGrpIdx
+
   val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
 
   // decode
@@ -118,44 +123,102 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
   loadUop.io.baddr := io.uop_baddr
   io.vme_rd(0) <> loadUop.io.vme_rd
   loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
+  assert(!tensorGemm.io.uop.idx.valid || !tensorAlu.io.uop.idx.valid)
 
   // acc
   tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
   tensorAcc.io.inst := inst_q.io.deq.bits
   tensorAcc.io.baddr := io.acc_baddr
-  tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
-  tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
+  require(tensorAcc.io.tensor.lenSplit ==
+    tensorAcc.io.tensor.tensorLength, "-F- Expecting a whole batch in acc group")
+
+  // split factor of isGemm for many groups
+  val splitFactorL0 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
+  val splitFactorL1 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth)
+    - log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
+  require(splitFactorL0 * splitFactorL1 == tensorAcc.io.tensor.splitWidth)
+  val accRdSelectL0 = for (idx <- 0 until splitFactorL1) yield {
+    // can save 1 stage on small design
+    if (splitFactorL1 > 1) RegNext(dec.io.isGemm, init = false.B) else dec.io.isGemm
+  }
+
+  for (idx <- 0 until tensorAcc.io.tensor.splitWidth) {
+    tensorAcc.io.tensor.rd(idx).idx <> Mux(
+      RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
+      tensorGemm.io.acc.rd(idx).idx,
+      tensorAlu.io.acc.rd(idx).idx)
+    tensorAcc.io.tensor.wr(idx) <> Mux(
+      RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
+      tensorGemm.io.acc.wr(idx),
+      tensorAlu.io.acc.wr(idx))
+  }
   io.vme_rd(1) <> tensorAcc.io.vme_rd
-  io.acc_wr_event := tensorAcc.io.tensor.wr.valid
+  io.acc_wr_event := tensorAcc.io.tensor.wr(topAccGrpIdx).valid
 
   // gemm
-  tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
-  tensorGemm.io.inst := inst_q.io.deq.bits
+  tensorGemm.io.start := RegNext(state === sIdle & start & dec.io.isGemm, init = false.B)
+  tensorGemm.io.dec := inst_q.io.deq.bits.asTypeOf(new GemmDecode)
   tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
   tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
   tensorGemm.io.inp <> io.inp
   tensorGemm.io.wgt <> io.wgt
-  tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm
-  tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
-  tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
-  tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
+  for (idx <- 0 until tensorGemm.io.acc.splitWidth) {
+    tensorGemm.io.acc.rd(idx).data.valid :=
+      tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isGemm, init = false.B)
+    tensorGemm.io.acc.rd(idx).data.bits <> tensorAcc.io.tensor.rd(idx).data.bits
+  }
+  for (idx <- 0 until tensorGemm.io.out.splitWidth) {
+    tensorGemm.io.out.rd(idx).data.valid :=
+      io.out.rd(idx).data.valid & RegNext(dec.io.isGemm, init = false.B)
+    tensorGemm.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
+  }
 
   // alu
-  tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
-  tensorAlu.io.inst := inst_q.io.deq.bits
+  tensorAlu.io.start := RegNext(state === sIdle & start & dec.io.isAlu, init = false.B)
+  tensorAlu.io.dec := inst_q.io.deq.bits.asTypeOf(new AluDecode)
   tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
   tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
-  tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu
-  tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
-  tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
-  tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
+  for (idx <- 0 until tensorAlu.io.acc.splitWidth) {
+    tensorAlu.io.acc.rd(idx).data.valid :=
+      tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isAlu, init = false.B)
+    tensorAlu.io.acc.rd(idx).data.bits <> tensorAcc.io.tensor.rd(idx).data.bits
+  }
+  for (idx <- 0 until tensorAlu.io.out.splitWidth) {
+    tensorAlu.io.out.rd(idx).data.valid :=
+      io.out.rd(idx).data.valid & RegNext(dec.io.isAlu, init = false.B)
+    tensorAlu.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
+  }
 
   // out
-  io.out.rd.idx <> Mux(dec.io.isGemm,
-    tensorGemm.io.out.rd.idx,
-    tensorAlu.io.out.rd.idx)
-  io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
+  for (idx <- 0 until tensorGemm.io.out.splitWidth) {
+    io.out.rd(idx).idx <> Mux(dec.io.isGemm,
+      tensorGemm.io.out.rd(idx).idx,
+      tensorAlu.io.out.rd(idx).idx)
+    assert(!tensorGemm.io.out.rd(idx).idx.valid || !tensorAlu.io.out.rd(idx).idx.valid)
+    assert(!tensorGemm.io.out.rd(idx).data.valid || !tensorAlu.io.out.rd(idx).data.valid)
 
+    assert(!tensorGemm.io.out.wr(idx).valid || !tensorAlu.io.out.wr(idx).valid)
+  }
+  require (tensorGemm.io.out.splitWidth == 1)
+  require (tensorAlu.io.out.splitWidth == 1)
+  io.out.wr(0).valid := Mux(
+    RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).valid, tensorAlu.io.out.wr(0).valid)
+  io.out.wr(0).bits.idx := Mux(
+    RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).bits.idx, tensorAlu.io.out.wr(0).bits.idx)
+  // put mux/Reg into every gemm group to build pipe (for Mux select) tree over distance
+  val chunkWidth = io.out.wr(0).bits.data.getWidth / tensorGemm.io.acc.splitWidth
+  val outDataBits = Wire(Vec(tensorGemm.io.acc.splitWidth, UInt(chunkWidth.W)))
+  io.out.wr(0).bits.data := outDataBits.asTypeOf(io.out.wr(0).bits.data)
+  for (idx <- 0 until tensorGemm.io.acc.splitWidth) {
+    val lowBitIdx = idx * chunkWidth
+    val highBitIdx = lowBitIdx + chunkWidth - 1
+    val srcAluFlat = tensorAlu.io.out.wr(0).bits.data.asUInt
+    val srcGemFlat = tensorGemm.io.out.wr(0).bits.data.asUInt
+    outDataBits(idx) := Mux(
+      RegNext(dec.io.isGemm, init = false.B),
+      srcGemFlat(highBitIdx, lowBitIdx),
+      srcAluFlat(highBitIdx, lowBitIdx))
+  }
   // semaphore
   s(0).io.spost := io.i_post(0)
   s(1).io.spost := io.i_post(1)
diff --git a/hardware/chisel/src/main/scala/core/Configs.scala b/hardware/chisel/src/main/scala/core/Configs.scala
index 4ab7d85..4022dc3 100644
--- a/hardware/chisel/src/main/scala/core/Configs.scala
+++ b/hardware/chisel/src/main/scala/core/Configs.scala
@@ -32,6 +32,7 @@ class CoreConfig extends Config((site, here, up) => {
     CoreParams(
       batch = 1,
       blockOut = 16,
+      blockOutFactor = 1,
       blockIn = 16,
       inpBits = 8,
       wgtBits = 8,
diff --git a/hardware/chisel/src/main/scala/core/Core.scala b/hardware/chisel/src/main/scala/core/Core.scala
index e2ac51a..a01c805 100644
--- a/hardware/chisel/src/main/scala/core/Core.scala
+++ b/hardware/chisel/src/main/scala/core/Core.scala
@@ -25,20 +25,21 @@ import vta.shell._
 
 /** Core parameters */
 case class CoreParams(
-    batch: Int = 1,
-    blockOut: Int = 16,
-    blockIn: Int = 16,
-    inpBits: Int = 8,
-    wgtBits: Int = 8,
-    uopBits: Int = 32,
-    accBits: Int = 32,
-    outBits: Int = 8,
-    uopMemDepth: Int = 512,
-    inpMemDepth: Int = 512,
-    wgtMemDepth: Int = 512,
-    accMemDepth: Int = 512,
-    outMemDepth: Int = 512,
-    instQueueEntries: Int = 32
+    batch: Int,
+    blockOut: Int,
+    blockOutFactor: Int,
+    blockIn: Int,
+    inpBits: Int,
+    wgtBits: Int,
+    uopBits: Int,
+    accBits: Int,
+    outBits: Int,
+    uopMemDepth: Int,
+    inpMemDepth: Int,
+    wgtMemDepth: Int,
+    accMemDepth: Int,
+    outMemDepth: Int,
+    instQueueEntries: Int
 ) {
   require(uopBits % 8 == 0,
     s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
diff --git a/hardware/chisel/src/main/scala/core/LoadUop.scala b/hardware/chisel/src/main/scala/core/LoadUop.scala
index 87bd508..e9f5d40 100644
--- a/hardware/chisel/src/main/scala/core/LoadUop.scala
+++ b/hardware/chisel/src/main/scala/core/LoadUop.scala
@@ -205,7 +205,8 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
   // read-from-sram
   io.uop.data.valid := RegNext(io.uop.idx.valid)
 
-  val sIdx = io.uop.idx.bits % numUop.U
+  // delay LSB of idx by a cycle because of the one-cycle memory read latency
+  val sIdx = RegNext(io.uop.idx.bits % numUop.U)
   val rIdx = io.uop.idx.bits >> log2Ceil(numUop)
   val memRead = mem.read(rIdx, io.uop.idx.valid)
   val sWord = memRead.asUInt.asTypeOf(wdata)
diff --git a/hardware/chisel/src/main/scala/core/TensorAlu.scala b/hardware/chisel/src/main/scala/core/TensorAlu.scala
index 81abb8e..8d7aa31 100644
--- a/hardware/chisel/src/main/scala/core/TensorAlu.scala
+++ b/hardware/chisel/src/main/scala/core/TensorAlu.scala
@@ -97,38 +97,322 @@ class AluVector(implicit p: Parameters) extends Module {
   io.out.data.valid := valid.asUInt.andR
 }
 
-/** TensorAlu.
- *
- * This unit instantiate the ALU vector unit (AluVector) and go over the
- * micro-ops (uops) which are used to read the source operands (vectors)
- * from the acc-scratchpad and then they are written back the same
- * acc-scratchpad.
- */
-class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class TensorAluIndexGenerator(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val cnt_o_width = (new AluDecode).lp_0.getWidth
+  val cnt_i_width = (new AluDecode).lp_1.getWidth
+
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val last = Output(Bool())
+    val dec = Input(new AluDecode)
+    val valid = Output(Bool())
+    val src_valid = Output(Bool())
+    val dst_idx = Output(UInt(new TensorParams(tensorType="acc").memAddrBits.W))
+    val src_idx = Output(UInt(new TensorParams(tensorType="acc").memAddrBits.W))
+    val uop_idx = Output(UInt(log2Ceil(p(CoreKey).uopMemDepth).W))
+    val cnt_o = Output(UInt(cnt_o_width.W))
+    val cnt_i = Output(UInt(cnt_i_width.W))
+  })
+
+  io.last := false.B
+
+  val running = RegInit(false.B)
+  val stutter = RegInit(false.B)
+
+  val advance = io.dec.alu_use_imm || stutter
+
+  when(!running && io.start) {
+    running := true.B
+  } .elsewhen(running && !advance) {
+    stutter := true.B
+  } .elsewhen(running && advance) {
+    when (io.last) {
+      running := false.B
+    }
+    stutter := false.B
+  }
+
+  val cnt_i = Reg(chiselTypeOf(io.dec.lp_1))
+  val dst_i = Reg(chiselTypeOf(io.dst_idx))
+  val src_i = Reg(chiselTypeOf(io.src_idx))
+
+  val cnt_o = Reg(chiselTypeOf(io.dec.lp_0))
+  val dst_o = Reg(chiselTypeOf(io.dst_idx))
+  val src_o = Reg(chiselTypeOf(io.src_idx))
+
+  val uop_idx = Reg(chiselTypeOf(io.dec.uop_end))
+
+  io.valid := running && advance
+  io.src_valid := running && !advance
+  io.dst_idx := dst_i
+  io.src_idx := src_i
+  io.uop_idx := uop_idx
+  io.cnt_o := cnt_o
+  io.cnt_i := cnt_i
+
+  when(!running) {
+    cnt_i := 0.U; dst_i := 0.U; src_i := 0.U;
+    cnt_o := 0.U; dst_o := 0.U; src_o := 0.U;
+    uop_idx := io.dec.uop_begin
+  } .elsewhen (advance) {
+    when (uop_idx =/= io.dec.uop_end - 1.U) {
+      uop_idx := uop_idx + 1.U
+    }.otherwise {
+      uop_idx := io.dec.uop_begin
+      when (cnt_i =/= io.dec.lp_1 - 1.U) {
+        cnt_i := cnt_i + 1.U
+        dst_i := dst_i + io.dec.dst_1
+        src_i := src_i + io.dec.src_1
+      }.otherwise {
+        when (cnt_o =/= io.dec.lp_0 - 1.U) {
+          val dst_tmp = dst_o + io.dec.dst_0
+          val src_tmp = src_o + io.dec.src_0
+          cnt_o := cnt_o + 1.U
+          dst_o := dst_tmp
+          src_o := src_tmp
+          cnt_i := 0.U
+          dst_i := dst_tmp
+          src_i := src_tmp
+        } .otherwise {
+          io.last := true.B
+        }
+      }
+    }
+  }
+}
+
+class TensorAluIfc(implicit p: Parameters) extends Module {
   val aluBits = p(CoreKey).accBits
   val io = IO(new Bundle {
     val start = Input(Bool())
     val done = Output(Bool())
-    val inst = Input(UInt(INST_BITS.W))
+    val dec = Input(new AluDecode)
     val uop = new UopMaster
     val acc = new TensorMaster(tensorType = "acc")
     val out = new TensorMaster(tensorType = "out")
   })
+}
+
+class TensorAluPipelined(debug: Boolean = false)(implicit p: Parameters) extends TensorAluIfc {
+  val stateBits = 2
+  val inflightBits = 4
+  val dataSplitFactor = p(CoreKey).blockOutFactor
+
+  val sIdle::sRun::sWait::Nil = Enum(3)
+  val state = RegInit(init=sIdle)
+  val inflight = RegInit(0.U(inflightBits.W))
+
+  val index_generator = Module(new TensorAluIndexGenerator)
+  val aluDataReadPipeDelay = 0 // available for pipelining
+
+  // State Machine for compute io.done correctly
+  io.done := false.B
+  when(state === sIdle && io.start) {
+    state := sRun
+  }.elsewhen(state === sRun && index_generator.io.last) {
+    state := sWait
+  }.elsewhen(state === sWait && inflight === 0.U) {
+    state := sIdle
+    io.done := true.B
+  }
+
+  index_generator.io.start := io.start
+  index_generator.io.dec := io.dec
+
+  // second term works around funny clearing in uop register file flopped output
+  io.uop.idx.valid := index_generator.io.valid || index_generator.io.src_valid
+  io.uop.idx.bits := index_generator.io.uop_idx
+
+  val valid_r1 = ShiftRegister(index_generator.io.valid, aluDataReadPipeDelay + 1, resetData=false.B, en = true.B)
+  val valid_r2 = RegNext(valid_r1, init=false.B)
+  val valid_r3 = RegNext(valid_r2, init=false.B)
+  val valid_r4 = RegNext(valid_r3, init=false.B)
+
+  when(index_generator.io.valid && valid_r4) {
+  }.elsewhen(index_generator.io.valid) {
+    assert(inflight =/= ((1<<inflightBits)-1).U)
+    inflight := inflight + 1.U
+  }.elsewhen(valid_r4) {
+    assert(inflight =/= 0.U)
+    inflight := inflight - 1.U
+  }
+  when(state === sIdle) {
+    assert(inflight === 0.U)
+    inflight := 0.U
+  }
+
+  val src_valid_r1 = ShiftRegister(
+    index_generator.io.src_valid,
+    aluDataReadPipeDelay + 1,
+    resetData=false.B, en = true.B)
+  val src_valid_r2 = RegNext(src_valid_r1, init=false.B)
+  val src_valid_r3 = RegNext(src_valid_r2, init=false.B)
+  val src_valid_r4 = RegNext(src_valid_r3, init=false.B)
+
+  val dst_idx_r1 = ShiftRegister(index_generator.io.dst_idx, aluDataReadPipeDelay + 1)
+  val src_idx_r1 = ShiftRegister(index_generator.io.src_idx, aluDataReadPipeDelay + 1)
+
+  val uop_data_r1 = ShiftRegister(io.uop.data, aluDataReadPipeDelay)
+
+  val dst_offset = uop_data_r1.bits.u0
+
+  val w = dst_offset.getWidth
+  val u2 = uop_data_r1.bits.u2.asTypeOf(UInt(w.W))
+  val s = log2Ceil(p(CoreKey).inpMemDepth)
+  val u1 = uop_data_r1.bits.u1.asTypeOf(UInt(w.W))
+  val src_offset = (u2 << s) | u1
+
+  // split registers of stage 2 by data groups
+  val accRdIdxValid = valid_r1 || src_valid_r1
+  for (idx <- 0 until dataSplitFactor) {
+    io.acc.rd(idx).idx.valid := RegNext(accRdIdxValid)
+  }
+
+  val new_src_idx_r1 = src_idx_r1 + src_offset
+  val src_idx_r2 = RegNext(new_src_idx_r1)
+  val src_idx_r3 = RegNext(src_idx_r2)
+
+  val new_dst_idx_r1 = dst_idx_r1 + dst_offset
+  val dst_idx_r2 = RegNext(new_dst_idx_r1)
+  val dst_idx_r3 = RegNext(dst_idx_r2)
+  val dst_idx_r4 = RegNext(dst_idx_r3)
+
+  // split registers of stage 2 by data groups
+  val accRdIdxBits = Mux(src_valid_r1 || io.dec.alu_use_imm, new_src_idx_r1, new_dst_idx_r1)
+  for (idx <- 0 until dataSplitFactor) {
+    io.acc.rd(idx).idx.bits := RegNext(accRdIdxBits)
+    assert(io.acc.rd(idx).data.valid === (valid_r3 || src_valid_r3))
+  }
+
+  require(io.out.splitWidth == 1 && io.out.splitLength == 1, "-F- Out split write is not supported")
+  val numVecUnits = dataSplitFactor
+  val outData = Wire(io.out.wr(0).bits.data.cloneType)
+  val dataRemapB = Wire(Vec(numVecUnits, io.acc.rd(0).data.bits.cloneType))
+  val dataRemapA = Wire(Vec(numVecUnits, io.acc.rd(0).data.bits.cloneType))
+  // numVecUnits is a pow of 2
+  // split dec bits pipe further if there are many vecUnits
+  val decSplitNb0 =  if (numVecUnits < 8) 1 else 2
+  val decSplit0 = Wire(Vec(decSplitNb0, io.dec.cloneType))
+  for (idx <- 0 until decSplitNb0) {
+    decSplit0(idx) := ShiftRegister(io.dec, if(aluDataReadPipeDelay < 2) 0 else 1)
+  }
+
+  for (idx <- 0 until numVecUnits) {
+    val alu = Module(new AluVector)
+
+    for(aluLenIdx <- 0 until alu.io.acc_b.lenSplit) {
+      for(aluWdtIdx <- 0 until alu.io.acc_b.widthSplit) {
+        val (accGrpIdx, accLenIdx, accWdtIdx) =
+          alu.io.acc_b.reindexDataFromGroup(idx, aluLenIdx, aluWdtIdx)
+        dataRemapB(idx)(aluLenIdx)(aluWdtIdx) :=
+          io.acc.rd(accGrpIdx).data.bits(accLenIdx)(accWdtIdx)
+      }
+    }
+    val save_src = RegNext(dataRemapB(idx))
+    val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
+    tensorImm.data.valid := valid_r3
+    val tensorImmBits_piped = ShiftRegister(
+      decSplit0(idx/(numVecUnits/decSplitNb0)).alu_imm,
+      if(aluDataReadPipeDelay < 2) aluDataReadPipeDelay else aluDataReadPipeDelay -1)
+    tensorImm.data.bits.foreach { b =>
+      b.foreach { c =>
+        c := Mux(tensorImmBits_piped(C_ALU_IMM_BITS - 1),
+          Cat(-1.S((aluBits - C_ALU_IMM_BITS).W), tensorImmBits_piped), tensorImmBits_piped)
+      }
+    }
+
+    // alu
+    val tensorOpBits_piped = ShiftRegister(
+    decSplit0(idx/(numVecUnits/decSplitNb0)).alu_op,
+    if(aluDataReadPipeDelay < 2) aluDataReadPipeDelay else aluDataReadPipeDelay -1)
+    val isSHR = (tensorOpBits_piped === ALU_OP(3))
+    val neg_shift = isSHR & tensorImmBits_piped(C_ALU_IMM_BITS - 1)
+    val fixme_alu_op = Mux(
+      neg_shift,
+      ALU_OP(4), // use opcode = 4 for left shift
+      tensorOpBits_piped)
+    alu.io.opcode := fixme_alu_op
+
+    assert(!valid_r3 || io.acc.rd(idx).data.valid)
+
+    alu.io.acc_a.data.valid := RegNext(valid_r2) // valid_r3 split
+
+    for(aluLenIdx <- 0 until alu.io.acc_a.lenSplit) {
+      for(aluWdtIdx <- 0 until alu.io.acc_a.widthSplit) {
+        val (accGrpIdx, accLenIdx, accWdtIdx) =
+          alu.io.acc_a.reindexDataFromGroup(idx, aluLenIdx, aluWdtIdx)
+        dataRemapA(idx)(aluLenIdx)(aluWdtIdx) :=
+          io.acc.rd(accGrpIdx).data.bits(accLenIdx)(accWdtIdx)
+        alu.io.acc_a.data.bits := dataRemapA(idx)
+      }
+    }
+    val tensorUseImmBits_piped = ShiftRegister(
+    decSplit0(idx/(numVecUnits/decSplitNb0)).alu_use_imm,
+    if(aluDataReadPipeDelay < 2) aluDataReadPipeDelay else aluDataReadPipeDelay -1)
+    alu.io.acc_b.data.valid := Mux(tensorUseImmBits_piped,
+      tensorImm.data.valid,
+      valid_r3)
+    alu.io.acc_b.data.bits := Mux(tensorUseImmBits_piped,
+      tensorImm.data.bits,
+      save_src)
+
+    assert(alu.io.acc_y.data.valid === valid_r4)
+    io.acc.wr(idx).valid := valid_r4
+    io.acc.wr(idx).bits.idx := dst_idx_r4
+
+    for(aluLenIdx <- 0 until alu.io.acc_y.lenSplit) {
+      for(aluWdtIdx <- 0 until alu.io.acc_y.widthSplit) {
+        val (accGrpIdx, accLenIdx, accWdtIdx) =
+          alu.io.acc_y.reindexDataFromGroup(idx, aluLenIdx, aluWdtIdx)
+        io.acc.wr(accGrpIdx).bits.data(accLenIdx)(accWdtIdx) :=
+          alu.io.acc_y.data.bits(aluLenIdx)(aluWdtIdx)
+      }
+    }
+
+    assert(alu.io.out.data.valid === valid_r4)
+    for (idx1 <- 0 until io.out.tensorLength) {
+      for (idx2 <- 0 until io.out.tensorWidth/numVecUnits) {
+        outData(idx1)(idx*io.out.tensorWidth/numVecUnits + idx2) := alu.io.out.data.bits(idx1)(idx2)
+      }
+    }
+  }
+
+// comment for split write
+  io.out.wr(0).valid := valid_r4
+  io.out.wr(0).bits.idx := dst_idx_r4
+  io.out.wr(0).bits.data := outData
+  io.out.tieoffRead()
+
+  val bypass_dst = valid_r3 && valid_r4 && (dst_idx_r4 === dst_idx_r3)
+  val bypass_src = src_valid_r3 && valid_r4 && (dst_idx_r4 === src_idx_r3)
+
+  // Do we need a bypass
+  assert(!bypass_dst, s"Bypass required on dst_idx read $dst_idx_r3 RAW with write $dst_idx_r4\n")
+  assert(!bypass_src, s"Bypass required on src_idx read $src_idx_r3 RAW with write $dst_idx_r4\n")
+}
+
+/** TensorAluOrig.
+ * This unit instantiate the ALU vector unit (AluVector) and go over the
+ * micro-ops (uops) which are used to read the source operands (vectors)
+ * from the acc-scratchpad and then they are written back the same
+ * acc-scratchpad.
+ */
+class TensorAluOrig(debug: Boolean = false)(implicit p: Parameters) extends TensorAluIfc {
   val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil =
     Enum(6)
   val state = RegInit(sIdle)
   val alu = Module(new AluVector)
-  val dec = io.inst.asTypeOf(new AluDecode)
+  val dec = io.dec
   val uop_idx = Reg(chiselTypeOf(dec.uop_end))
   val uop_end = dec.uop_end
-  val uop_dst = Reg(chiselTypeOf(dec.uop_end))
-  val uop_src = Reg(chiselTypeOf(dec.uop_end))
+  val uop_dst = Reg(chiselTypeOf(io.uop.data.bits.u0)) // width can address entire acc
+  val uop_src = Reg(chiselTypeOf(io.uop.data.bits.u0)) // width can address entire acc
   val cnt_o = Reg(chiselTypeOf(dec.lp_0))
-  val dst_o = Reg(chiselTypeOf(dec.uop_end))
-  val src_o = Reg(chiselTypeOf(dec.uop_end))
+  val dst_o = Reg(chiselTypeOf(io.uop.data.bits.u0))
+  val src_o = Reg(chiselTypeOf(io.uop.data.bits.u0))
   val cnt_i = Reg(chiselTypeOf(dec.lp_1))
-  val dst_i = Reg(chiselTypeOf(dec.uop_end))
-  val src_i = Reg(chiselTypeOf(dec.uop_end))
+  val dst_i = Reg(chiselTypeOf(io.uop.data.bits.u0))
+  val src_i = Reg(chiselTypeOf(io.uop.data.bits.u0))
   val done =
     state === sExe &
       alu.io.out.data.valid &
@@ -208,57 +492,62 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
 
   when(state === sComputeIdx && io.uop.data.valid) {
     uop_dst := io.uop.data.bits.u0 + dst_i
-    uop_src := io.uop.data.bits.u1 + src_i
+    uop_src := ((io.uop.data.bits.u2.asTypeOf(UInt(width = uop_dst.getWidth.W)) << log2Ceil(p(CoreKey).inpMemDepth))
+      | io.uop.data.bits.u1.asTypeOf(UInt(width = uop_dst.getWidth.W))) + src_i
   }
 
   // uop
   io.uop.idx.valid := state === sReadUop
   io.uop.idx.bits := uop_idx
 
-  // acc (input)
-  io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
-  io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src)
-
-  // imm
-  val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
-  tensorImm.data.valid := state === sReadTensorB
-  tensorImm.data.bits.foreach { b =>
-    b.foreach { c =>
-      c := Mux(dec.alu_imm(C_ALU_IMM_BITS - 1),
-        Cat(-1.S((aluBits - C_ALU_IMM_BITS).W), dec.alu_imm), dec.alu_imm)
+  val dataSplitFactor = p(CoreKey).blockOutFactor
+
+  val accRdValid = state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
+  val accRdIdx = Mux(state === sReadTensorA, uop_dst, uop_src)
+  for (idx <- 0 until dataSplitFactor) {
+    // acc (input)
+    io.acc.rd(idx).idx.valid := accRdValid
+    io.acc.rd(idx).idx.bits := accRdIdx
+
+    // imm
+    val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
+    tensorImm.data.valid := state === sReadTensorB
+    tensorImm.data.bits.foreach { b =>
+      b.foreach { c =>
+        c := Mux(dec.alu_imm(C_ALU_IMM_BITS - 1),
+          Cat(-1.S((aluBits - C_ALU_IMM_BITS).W), dec.alu_imm), dec.alu_imm)
+      }
     }
-  }
 
-  // alu
-  val isSHR = dec.alu_op === ALU_OP(3)
-  val isSHL = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
-  // opcode - min:0, max:1, add:2, shr:3, shl:4
-  val fixme_alu_op = Cat(isSHL, Mux(isSHL, 0.U, dec.alu_op(1, 0)))
-  alu.io.opcode := fixme_alu_op
-  alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
-  alu.io.acc_a.data.bits <> io.acc.rd.data.bits
-  alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
-    tensorImm.data.valid,
-    io.acc.rd.data.valid & state === sExe)
-  alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
-    tensorImm.data.bits,
-    io.acc.rd.data.bits)
-
-  // acc (output)
-  io.acc.wr.valid := alu.io.acc_y.data.valid
-  io.acc.wr.bits.idx := uop_dst
-  io.acc.wr.bits.data <> alu.io.acc_y.data.bits
-
-  // out
-  io.out.wr.valid := alu.io.out.data.valid
-  io.out.wr.bits.idx := uop_dst
-  io.out.wr.bits.data <> alu.io.out.data.bits
+    // alu
+    val isSHR = (dec.alu_op === ALU_OP(3))
+    val isSHL = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
+    // opcode - min:0, max:1, add:2, shr:3, shl:4
+    val fixme_alu_op = Cat(isSHL, Mux(isSHL, 0.U, dec.alu_op(1, 0)))
+    alu.io.opcode := fixme_alu_op
+    alu.io.acc_a.data.valid := io.acc.rd(idx).data.valid & state === sReadTensorB
+    alu.io.acc_a.data.bits <> io.acc.rd(idx).data.bits
+    alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
+      tensorImm.data.valid,
+      io.acc.rd(idx).data.valid & state === sExe)
+    alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
+      tensorImm.data.bits,
+      io.acc.rd(idx).data.bits)
+
+    // acc (output)
+    io.acc.wr(idx).valid := alu.io.acc_y.data.valid
+    io.acc.wr(idx).bits.idx := uop_dst
+    io.acc.wr(idx).bits.data <> alu.io.acc_y.data.bits
+
+    // out
+    io.out.wr(idx).valid := alu.io.out.data.valid
+    io.out.wr(idx).bits.idx := uop_dst
+    io.out.wr(idx).bits.data <> alu.io.out.data.bits
+  }
   io.out.tieoffRead() // write-only
-
   io.done := done
 
   if (debug) {
-
     when(state === sReadUop) {
       printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
     }
@@ -308,3 +597,5 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
     }
   }
 }
+
+class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends TensorAluPipelined(debug)
diff --git a/hardware/chisel/src/main/scala/core/TensorGemm.scala b/hardware/chisel/src/main/scala/core/TensorGemm.scala
index f2d295f..e977552 100644
--- a/hardware/chisel/src/main/scala/core/TensorGemm.scala
+++ b/hardware/chisel/src/main/scala/core/TensorGemm.scala
@@ -173,31 +173,41 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
   io.out.data.valid := vld.asUInt.andR
 }
 
-/** TensorGemm.
- *
- * This unit instantiate the MatrixVectorMultiplication and go over the
- * micro-ops (uops) which are used to read inputs, weights and biases,
- * and writes results back to the acc and out scratchpads.
- *
- * Also, the TensorGemm uses the reset field in the Gemm instruction to
- * clear or zero-out the acc-scratchpad locations based on the micro-ops.
- */
-class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
+abstract class TensorGemmIfc(implicit p: Parameters) extends Module {
+  val stateBits = 3
+  val inflightBits = 4
   val io = IO(new Bundle {
     val start = Input(Bool())
     val done = Output(Bool())
-    val inst = Input(UInt(INST_BITS.W))
+    val dec = Input(new GemmDecode)
     val uop = new UopMaster
     val inp = new TensorMaster(tensorType = "inp")
     val wgt = new TensorMaster(tensorType = "wgt")
     val acc = new TensorMaster(tensorType = "acc")
     val out = new TensorMaster(tensorType = "out")
+    val state = Output(UInt(stateBits.W))
+    val inflight = Output(UInt(inflightBits.W))
   })
-  val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil =
-    Enum(6)
+}
+
+/** TensorGemm.
+ *
+ * This unit instantiate the MatrixVectorMultiplication and go over the
+ * micro-ops (uops) which are used to read inputs, weights and biases,
+ * and writes results back to the acc and out scratchpads.
+ *
+ * Also, the TensorGemm uses the reset field in the Gemm instruction to
+ * clear or zero-out the acc-scratchpad locations based on the micro-ops.
+ */
+class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends TensorGemmIfc {
+
+  require(p(CoreKey).blockOutFactor == 1,
+    "-F- Split GEMM not supported. Use TensorGemmPipelinedSplit or set blockOutFactor to 1")
+  val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
   val state = RegInit(sIdle)
+  io.state := state
   val mvc = Module(new MatrixVectorMultiplication)
-  val dec = io.inst.asTypeOf(new GemmDecode)
+  val dec = io.dec
   val uop_idx = Reg(chiselTypeOf(dec.uop_end))
   val uop_end = dec.uop_end
   val uop_acc = Reg(chiselTypeOf(dec.uop_end))
@@ -211,18 +221,18 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
   val acc_i = Reg(chiselTypeOf(dec.uop_end))
   val inp_i = Reg(chiselTypeOf(dec.uop_end))
   val wgt_i = Reg(chiselTypeOf(dec.uop_end))
-  val pBits = log2Ceil(p(CoreKey).blockOut) + 1
-  val inflight = Reg(UInt(pBits.W))
+
+  val inflight = Reg(UInt(inflightBits.W))
+  io.inflight := inflight
   // Latency is defined as two in the following, because there is one cycle in the MAC module,
   // and another cycle in the pipelined adders as the first layer of the accumulator
   val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = 2))
+  val cond_last = cnt_o === dec.lp_0 - 1.U &
+    cnt_i === dec.lp_1 - 1.U &
+    uop_idx === uop_end - 1.U
+
   val done = inflight === 0.U &
-    ((state === sExe &
-      cnt_o === dec.lp_0 - 1.U &
-      cnt_i === dec.lp_1 - 1.U &
-      uop_idx === uop_end - 1.U &
-      inflight === 0.U) |
-      state === sWait)
+    ((state === sExe) & cond_last | state === sWait)
 
   switch(state) {
     is(sIdle) {
@@ -240,10 +250,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
       state := sExe
     }
     is(sExe) {
-      when(
-        (cnt_o === dec.lp_0 - 1.U) &&
-          (cnt_i === dec.lp_1 - 1.U) &&
-          (uop_idx === uop_end - 1.U)) {
+      when(cond_last) {
         when(inflight =/= 0.U) {
           state := sWait
         }.otherwise {
@@ -264,10 +271,11 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
     inflight := 0.U
   }.elsewhen(!dec.reset) {
     when((state === sReadTensor) && mvc.io.acc_o.data.valid) { // issue & commit
-      inflight := inflight
     }.elsewhen(state === sReadTensor) { // issue a tensor
+      assert(inflight =/= ((1<<inflightBits)-1).U)
       inflight := inflight + 1.U
     }.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
+      assert(inflight =/= 0.U)
       inflight := inflight - 1.U
     }
   }
@@ -327,40 +335,42 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
   io.uop.idx.bits := uop_idx
 
   // inp
-  io.inp.rd.idx.valid := state === sReadTensor
-  io.inp.rd.idx.bits := uop_inp
+  io.inp.rd(0).idx.valid := state === sReadTensor
+  io.inp.rd(0).idx.bits := uop_inp
   io.inp.tieoffWrite() // read-only
 
   // wgt
-  io.wgt.rd.idx.valid := state === sReadTensor
-  io.wgt.rd.idx.bits := uop_wgt
+  io.wgt.rd(0).idx.valid := state === sReadTensor
+  io.wgt.rd(0).idx.bits := uop_wgt
   io.wgt.tieoffWrite() // read-only
 
   // acc_i
-  io.acc.rd.idx.valid := state === sReadTensor
-  io.acc.rd.idx.bits := uop_acc
+  io.acc.rd(0).idx.valid := state === sReadTensor
+  io.acc.rd(0).idx.bits := uop_acc
 
   // mvc
   mvc.io.reset := dec.reset & state === sExe
-  mvc.io.inp.data <> io.inp.rd.data
-  mvc.io.wgt.data <> io.wgt.rd.data
-  mvc.io.acc_i.data <> io.acc.rd.data
+  mvc.io.inp.data <> io.inp.rd(0).data
+  mvc.io.wgt.data <> io.wgt.rd(0).data
+  mvc.io.acc_i.data <> io.acc.rd(0).data
 
   // acc_o
-  io.acc.wr.valid := mvc.io.acc_o.data.valid &
+  io.acc.wr(0).valid := mvc.io.acc_o.data.valid &
     Mux(dec.reset, true.B, wrpipe.io.deq.valid)
-  io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
-  io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
+  io.acc.wr(0).bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
+  io.acc.wr(0).bits.data <> mvc.io.acc_o.data.bits
 
   // out
-  io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid
-  io.out.wr.bits.idx := wrpipe.io.deq.bits
-  io.out.wr.bits.data <> mvc.io.out.data.bits
+  io.out.wr(0).valid := mvc.io.out.data.valid & wrpipe.io.deq.valid
+  io.out.wr(0).bits.idx := wrpipe.io.deq.bits
+  io.out.wr(0).bits.data <> mvc.io.out.data.bits
   io.out.tieoffRead() // write-only
 
   io.done := done
 
   if (debug) {
+    printf("[TensorGemm] [state]:%d [inflight]:%d\n", state, inflight)
+
     when(state === sReadUop && ~dec.reset) {
       printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
     }
@@ -369,24 +379,24 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
       printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
     }
 
-    io.inp.rd.data.bits.zipWithIndex.foreach {
+    io.inp.rd(0).data.bits.zipWithIndex.foreach {
       case (r, i) =>
-        when(io.inp.rd.data.valid && ~dec.reset) {
+        when(io.inp.rd(0).data.valid && ~dec.reset) {
           printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
         }
     }
 
-    io.wgt.rd.data.bits.zipWithIndex.foreach {
+    io.wgt.rd(0).data.bits.zipWithIndex.foreach {
       case (r, i) =>
-        when(io.wgt.rd.data.valid && ~dec.reset) {
+        when(io.wgt.rd(0).data.valid && ~dec.reset) {
           printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
         }
     }
 
-    io.acc.rd.data.bits.foreach { tensor =>
+    io.acc.rd(0).data.bits.foreach { tensor =>
       tensor.zipWithIndex.foreach {
         case (elem, i) =>
-          when(io.acc.rd.data.valid && ~dec.reset) {
+          when(io.acc.rd(0).data.valid && ~dec.reset) {
             printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
           }
       }
diff --git a/hardware/chisel/src/main/scala/core/TensorLoad.scala b/hardware/chisel/src/main/scala/core/TensorLoad.scala
index 5ab690d..8b31253 100644
--- a/hardware/chisel/src/main/scala/core/TensorLoad.scala
+++ b/hardware/chisel/src/main/scala/core/TensorLoad.scala
@@ -23,7 +23,6 @@ import chisel3._
 import chisel3.util._
 import vta.util.config._
 import vta.shell._
-
 /** TensorLoad.
  *
  * Load 1D and 2D tensors from main memory (DRAM) to input/weight
@@ -45,6 +44,10 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
     val vme_rd = new VMEReadMaster
     val tensor = new TensorClient(tensorType)
   })
+
+  require(tp.numMemBlock > 0, s"-F- Unexpected data to tensor bit size ratio. ${tensorType} ${tp.numMemBlock}")
+  require(tp.splitWidth == 1 && tp.splitLength == 1, s"-F- Cannot do split direct access")
+
   val sizeFactor = tp.tensorLength * tp.numMemBlock
   val strideFactor = tp.tensorLength * tp.tensorWidth
 
@@ -82,12 +85,14 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
         when(dec.xpad_0 =/= 0.U) {
           state := sXPad0
         }.otherwise {
+          assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
           state := sReadCmd
         }
       }
     }
     is(sXPad0) {
       when(xPadCtrl0.io.done) {
+        assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
         state := sReadCmd
       }
     }
@@ -112,6 +117,7 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
           }.elsewhen(dec.xpad_0 =/= 0.U) {
             state := sXPad0
           }.otherwise {
+            assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
             state := sReadCmd
           }
         }.elsewhen(dataCtrl.io.split) {
@@ -131,6 +137,7 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
           when(dec.xpad_0 =/= 0.U) {
             state := sXPad0
           }.otherwise {
+            assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
             state := sReadCmd
           }
         }
@@ -191,13 +198,16 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
     state === sXPad1 |
     state === sYPad1
 
-  when(state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
+  when(state === sReadCmd && tag =/= (tp.numMemBlock - 1).U) { // split read inside row of mem blocks
+    tag := tag
+  }.elsewhen(state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
     tag := 0.U
   }.elsewhen(io.vme_rd.data.fire() || isZeroPad) {
     tag := tag + 1.U
   }
 
-  when(state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
+  when(state === sIdle || (dataCtrlDone && ~isZeroPad) ||
+    (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
     set := 0.U
   }.elsewhen((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
     set := set + 1.U
@@ -221,6 +231,7 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
   val tensorFile = Seq.fill(tp.tensorLength) {
     SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
   }
+
   val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
   val wdata = Seq.fill(tp.tensorLength) {
     Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
@@ -235,12 +246,12 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
       wmask(i)(j) := tag === j.U
       wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
     }
-    val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
+    val tdata = io.tensor.wr(0).bits.data(i).asUInt.asTypeOf(wdata(i))
     val muxWen =
       Mux(state === sIdle,
-        io.tensor.wr.valid,
+        io.tensor.wr(0).valid,
         (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
-    val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
+    val muxWaddr = Mux(state === sIdle, io.tensor.wr(0).bits.idx, waddr_cur)
     val muxWdata = Mux(state === sIdle, tdata, wdata(i))
     val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
     when(muxWen) {
@@ -249,14 +260,14 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
   }
 
   // read-from-sram
-  val rvalid = RegNext(io.tensor.rd.idx.valid)
-  io.tensor.rd.data.valid := rvalid
+  val rvalid = RegNext(io.tensor.rd(0).idx.valid)
+  io.tensor.rd(0).data.valid := rvalid
 
   val rdata =
-    tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
+    tensorFile.map(_.read(io.tensor.rd(0).idx.bits, io.tensor.rd(0).idx.valid))
   rdata.zipWithIndex.foreach {
     case (r, i) =>
-      io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
+      io.tensor.rd(0).data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd(0).data.bits(i))
   }
 
   // done
@@ -291,11 +302,42 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
           dataCtrl.io.addr,
           dataCtrl.io.len)
       }
+      when(state === sYPad0) {
+        printf("[TensorLoad] [wgt] sYPad0\n")
+      }
+      when(state === sYPad1) {
+        printf("[TensorLoad] [wgt] sYPad1\n")
+      }
+      when(state === sXPad0) {
+        printf("[TensorLoad] [wgt] sXPad0\n")
+      }
+      when(state === sXPad1) {
+        printf("[TensorLoad] [wgt] sXPad1\n")
+      }
     } else if (tensorType == "acc") {
       when(io.vme_rd.cmd.fire()) {
         printf("[TensorLoad] [acc] cmd addr:%x len:%x\n",
           dataCtrl.io.addr,
           dataCtrl.io.len)
+        printf("[TensorLoad] [acc info] dec.xsize: %d, dec.ysize: %d, dec.xstride: %d\n",
+          dec.xsize, dec.ysize, dec.xstride)
+        printf("[TensorLoad] [acc i2fo] dec.xpad_1: %d dec.xpad_0: %d dec.ypad_1: %d dec.ypad_0: %d\n",
+          dec.xpad_1, dec.xpad_0, dec.ypad_1, dec.ypad_0)
+
+        printf("tp.tensorLength: %d, tp.numMemBlock: %d, tp.tensorLength: %d, tp.tensorWidth: %d\n",
+          tp.tensorLength.U, tp.numMemBlock.U, tp.tensorLength.U, tp.tensorWidth.U)
+      }
+      when(state === sYPad0) {
+        printf("[TensorLoad] [acc] sYPad0\n")
+      }
+      when(state === sYPad1) {
+        printf("[TensorLoad] [acc] sYPad1\n")
+      }
+      when(state === sXPad0) {
+        printf("[TensorLoad] [acc] sXPad0\n")
+      }
+      when(state === sXPad1) {
+        printf("[TensorLoad] [acc] sXPad1\n")
       }
     }
   }
diff --git a/hardware/chisel/src/main/scala/core/TensorStore.scala b/hardware/chisel/src/main/scala/core/TensorStore.scala
index 9b4bf74..f1556ef 100644
--- a/hardware/chisel/src/main/scala/core/TensorStore.scala
+++ b/hardware/chisel/src/main/scala/core/TensorStore.scala
@@ -47,6 +47,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
   val memBlockBits = tp.memBlockBits
   val memDepth = tp.memDepth
   val numMemBlock = tp.numMemBlock
+  require(numMemBlock > 0, s"-F- TensorStore doesnt support pulse width" +
+    s"wider than tensor width. Needed for stride support tensorWidth=${tensorWidth}")
+  require(tp.splitWidth == 1 && tp.splitLength == 1, s"-F- ${tensorType} Cannot do split direct access")
 
   val dec = io.inst.asTypeOf(new MemDecode)
   val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
@@ -54,7 +57,7 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
   val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
   val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
   val xrem = Reg(chiselTypeOf(dec.xsize))
-  val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock)) - 1.U
+  val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock))
   val xmax = (1 << mp.lenBits).U
   val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
   val ycnt = Reg(chiselTypeOf(dec.ysize))
@@ -89,10 +92,13 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
       when (io.start) {
         state := sWriteCmd
         when (xsize < xfer_init_pulses) {
-          xlen := xsize
+          assert(xsize > 0.U, "Idle => WriteCmd, init, without xrem: must have positive xsize")
+          xlen := xsize - 1.U
           xrem := 0.U
         }.otherwise {
           xlen := xfer_init_pulses - 1.U
+          assert(xsize >= xfer_init_pulses,
+            "Idle => WriteCmd, init, with xrem: must have xsize no smaller than xfer_init_pulses")
           xrem := xsize - xfer_init_pulses
         }
       }
@@ -123,10 +129,13 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
             state := sWriteCmd
             xfer_bytes := xfer_stride_bytes
             when(xsize < xfer_stride_pulses) {
-              xlen := xsize
+              assert(xsize > 0.U, "WriteAck => WriteCmd, stride, without xrem: must have positive xsize")
+              xlen := xsize - 1.U
               xrem := 0.U
             }.otherwise {
               xlen := xfer_stride_pulses - 1.U
+              assert(xsize >= xfer_stride_pulses,
+                "WriteAck => WriteCmd, stride, with xrem: must have xsize no smaller than xfer_stride_pulses")
               xrem := xsize - xfer_stride_pulses
             }
           }
@@ -134,13 +143,16 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
         .elsewhen(xrem < xfer_split_pulses) {
           state := sWriteCmd
           xfer_bytes := xfer_split_bytes
-          xlen := xrem
+          assert(xrem > 0.U, "WriteAck => WriteCmd, split, without xrem: must have positive xrem")
+          xlen := xrem - 1.U
           xrem := 0.U
         }
         .otherwise {
           state := sWriteCmd
           xfer_bytes := xfer_split_bytes
           xlen := xfer_split_pulses - 1.U
+          assert(xrem >= xfer_split_pulses,
+            "WriteAck => WriteCmd, split, with xrem: must have xrem no smaller than xfer_split_pulses")
           xrem := xrem - xfer_split_pulses
         }
       }
@@ -160,9 +172,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
   }
 
   for (i <- 0 until tensorLength) {
-    val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
-    when(io.tensor.wr.valid) {
-      tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
+    val inWrData = io.tensor.wr(0).bits.data(i).asUInt.asTypeOf(wdata_t)
+    when(io.tensor.wr(0).valid) {
+      tensorFile(i).write(io.tensor.wr(0).bits.idx, inWrData, no_mask)
     }
   }
 
@@ -186,7 +198,7 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
   }
 
   when(
-    state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
+    state === sWriteCmd || (state =/= sReadMem && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
     set := 0.U
   }.elsewhen(io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
     set := set + 1.U
diff --git a/hardware/chisel/src/main/scala/core/TensorUtil.scala b/hardware/chisel/src/main/scala/core/TensorUtil.scala
index d0a8ba7..dfdecf4 100644
--- a/hardware/chisel/src/main/scala/core/TensorUtil.scala
+++ b/hardware/chisel/src/main/scala/core/TensorUtil.scala
@@ -35,7 +35,7 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
     s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
 
   require(tensorType == "inp" || tensorType == "wgt"
-    || tensorType == "acc" || tensorType == "out",
+    || tensorType == "acc" || tensorType == "out" || tensorType == "fetch" || tensorType == "uop",
     errorMsg)
 
   val (tensorLength, tensorWidth, tensorElemBits) =
@@ -45,6 +45,15 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
       (p(CoreKey).blockOut, p(CoreKey).blockIn, p(CoreKey).wgtBits)
     else if (tensorType == "acc")
       (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).accBits)
+    else if (tensorType == "fetch") {
+      // make fetch a 64 bit data to be able to read
+      // 64 bit aligned address. It works for wide cacheline
+      require(p(ShellKey).memParams.dataBits >= INST_BITS,
+        "-F- Cannot make fetch tensor narrower than data pulse. TODO: narrow fetch with tensors")
+      (1, 1, 64)
+    }
+    else if (tensorType == "uop")
+      (1, 1, p(CoreKey).uopBits)
     else
       (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).outBits)
 
@@ -58,10 +67,118 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
       p(CoreKey).wgtMemDepth
     else if (tensorType == "acc")
       p(CoreKey).accMemDepth
+    else if (tensorType == "fetch") {
+      require(p(ShellKey).memParams.dataBits >= INST_BITS,
+        "-F- Cannot make fetch tensor narrower than data pulse. TODO: narrow fetch with tensors")
+      // still should be one data line
+      (1 << p(ShellKey).memParams.lenBits)*(INST_BITS / 64)
+    }
+    else if (tensorType == "uop") {
+      p(CoreKey).uopMemDepth
+    }
     else
       p(CoreKey).outMemDepth
 
+  // acc/wgt parts are grouped to form
+  // a physically compact compute entity
+
+  val (splitLength, splitWidth) =
+    if (tensorType == "inp") {
+      (1, 1)
+    } else if (tensorType == "wgt") {
+      (p(CoreKey).blockOutFactor, 1)
+    } else if (tensorType == "acc") {
+      // acc scratchpad is batch rows of blockout columns
+      // GEMM/ALU operation group is based on wgt tiling of blockout
+      // means acc out of a group if batch > 1 is not
+      // continous data and may be placed into different memory
+      // modules. But the whole idea of a group to localize
+      // piece of wgt to piece of acc data transformation
+      //
+      (1, p(CoreKey).blockOutFactor)
+    } else if (tensorType == "fetch") {
+      (1, 1)
+    } else if (tensorType == "uop") {
+      (1, 1)
+    } else if (tensorType == "out") {
+      (1, 1) // narrow store doesnt support split
+    } else {
+      (1, 1)
+    }
+  require (splitLength == 1 || splitWidth == 1, "-F- Can split only one dimension.")
+
+  // provide index of a group closes to IO
+  // expect 2 columns of groups io on top and indexing from bottom
+  val closestIOGrpIdx =
+    if (tensorType == "inp") {
+      splitLength - 1
+    } else if (tensorType == "wgt") {
+      if (splitLength < 2) 0 else splitLength / 2 - 1
+    } else if (tensorType == "acc") {
+      if (splitWidth < 2) 0 else splitWidth / 2 - 1
+    } else if (tensorType == "fetch") {
+      0
+    } else if (tensorType == "uop") {
+      0
+    } else if (tensorType == "out") {
+      0
+    } else {
+      0
+    }
+
   val memAddrBits = log2Ceil(memDepth)
+
+  val tensorSizeBits = tensorLength * tensorWidth * tensorElemBits
+  val tsSizeRatio = tensorSizeBits / memBlockBits
+  val clSizeRatio = memBlockBits / tensorSizeBits
+
+  val lenSplit = tensorLength / splitLength // tensor rows in a group
+  val widthSplit = tensorWidth / splitWidth // tensor colums in a group
+  require(lenSplit > 0 && widthSplit > 0, "-F- wrong split")
+
+  // tensor condsiders groups as a continous data, gemm generates a data window
+  // Map data index from a window index to a continous groups index
+  def reindexDataFromGroup (grpIdx : Int, lenIdx: Int, wdtIdx: Int) = {
+
+    val grpLen = lenSplit // tensor rows in a group
+    val grpWdt = widthSplit // tensor colums in a group
+    val srcGrpRow = grpIdx / splitWidth // group row
+    val srcGrpCol = grpIdx % splitWidth // group column
+    val tnzRow = srcGrpRow * grpLen
+    val tnzCol = srcGrpCol * grpWdt
+    val flatIdx = (tnzRow + lenIdx) * tensorWidth + tnzCol + wdtIdx
+
+    val outGroupIdx = flatIdx / (grpLen * grpWdt)
+    val outGroupOffset = flatIdx % (grpLen * grpWdt)
+    val outGroupLenIdx = outGroupOffset / grpWdt
+    val outGroupWdthIdx = outGroupOffset % grpWdt
+    (outGroupIdx, outGroupLenIdx, outGroupWdthIdx)
+  }
+  // map data index form a continous to a window index
+  def reindexDataToGroup (grpIdx : Int, lenIdx: Int, wdtIdx: Int) = {
+    val outGrpLen = tensorLength / splitLength // data rows in a group
+    val outGrpWdt = tensorWidth / splitWidth // data colums in a group
+
+    val outIdx = grpIdx * outGrpLen * outGrpWdt + lenIdx * outGrpWdt + wdtIdx
+    val outRow = outIdx / tensorWidth
+    val outCol = outIdx % tensorWidth
+
+
+    val outGrpRow = outRow / outGrpLen
+    val outLenIdx = outRow % outGrpLen
+
+    val outGrpCol = outCol / outGrpWdt
+    val outColIdx = outCol % outGrpWdt
+
+    val outGrpIdx = outGrpRow * splitWidth + outGrpCol
+
+    (outGrpIdx, outLenIdx, outColIdx)
+
+  }
+  def paramsStr () = {
+    s" ${tensorType} ${tensorSizeBits*memDepth/8} Byte. length:${tensorLength} width:${tensorWidth}" +
+    s" data bits:${tensorElemBits} mem depth:${memDepth} groups split length:${splitLength}"
+  }
 }
 
 /** TensorMaster.
@@ -73,25 +190,29 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
  */
 class TensorMaster(tensorType: String = "none")
   (implicit p: Parameters) extends TensorParams(tensorType) {
-  val rd = new Bundle {
+  val rd = Vec(splitLength * splitWidth, new Bundle {
     val idx = ValidIO(UInt(memAddrBits.W))
     val data = Flipped(
-      ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
-  }
-  val wr = ValidIO(new Bundle {
-    val idx = UInt(memAddrBits.W)
-    val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+      ValidIO(Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W)))))
   })
+  val wr = Vec(splitLength * splitWidth, ValidIO(new Bundle {
+    val idx = UInt(memAddrBits.W)
+    val data = Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W)))
+  }))
   def tieoffRead() {
-    rd.idx.valid := false.B
-    rd.idx.bits := 0.U
+    for (idx <- 0 until splitLength * splitWidth) {
+      rd(idx).idx.valid := false.B
+      rd(idx).idx.bits := 0.U
+    }
   }
   def tieoffWrite() {
-    wr.valid := false.B
-    wr.bits.idx := 0.U
-    wr.bits.data.foreach { b =>
-      b.foreach { c =>
-        c := 0.U
+    for (idx <- 0 until splitLength * splitWidth) {
+      wr(idx).valid := false.B
+      wr(idx).bits.idx := 0.U
+      wr(idx).bits.data.foreach { b =>
+        b.foreach { c =>
+          c := 0.U
+        }
       }
     }
   }
@@ -107,20 +228,33 @@ class TensorMaster(tensorType: String = "none")
  */
 class TensorClient(tensorType: String = "none")
   (implicit p: Parameters) extends TensorParams(tensorType) {
-  val rd = new Bundle {
+  val rd = Vec(splitLength * splitWidth, new Bundle {
     val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
     val data = ValidIO(
-      Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
-  }
-  val wr = Flipped(ValidIO(new Bundle {
+      Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W))))
+  })
+  val wr = Vec(splitLength * splitWidth, Flipped(ValidIO(new Bundle {
     val idx = UInt(memAddrBits.W)
-    val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
-  }))
+    val data = Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W)))
+  })))
   def tieoffRead() {
-    rd.data.valid := false.B
-    rd.data.bits.foreach { b =>
-      b.foreach { c =>
-        c := 0.U
+    for (idx <- 0 until splitLength * splitWidth) {
+      rd(idx).data.valid := false.B
+      rd(idx).data.bits.foreach { b =>
+        b.foreach { c =>
+          c := 0.U
+        }
+      }
+    }
+  }
+  def tieoffWrite() {
+    for (idx <- 0 until splitLength * splitWidth) {
+      wr(idx).valid := false.B
+      wr(idx).bits.idx := 0.U
+      wr(idx).bits.data.foreach { b =>
+        b.foreach { c =>
+          c := 0.U
+        }
       }
     }
   }
@@ -137,7 +271,7 @@ class TensorClient(tensorType: String = "none")
 class TensorMasterData(tensorType: String = "none")
   (implicit p: Parameters) extends TensorParams(tensorType) {
   val data = Flipped(
-    ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+    ValidIO(Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W)))))
   override def cloneType =
     new TensorMasterData(tensorType).asInstanceOf[this.type]
 }
@@ -151,7 +285,7 @@ class TensorMasterData(tensorType: String = "none")
 class TensorClientData(tensorType: String = "none")
   (implicit p: Parameters) extends TensorParams(tensorType) {
   val data = ValidIO(
-    Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+    Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W))))
   override def cloneType =
     new TensorClientData(tensorType).asInstanceOf[this.type]
 }
diff --git a/hardware/chisel/src/test/scala/unittest/AluTest.scala b/hardware/chisel/src/test/scala/unittest/AluTest.scala
index a4274c2..c874b01 100644
--- a/hardware/chisel/src/test/scala/unittest/AluTest.scala
+++ b/hardware/chisel/src/test/scala/unittest/AluTest.scala
@@ -19,61 +19,66 @@
 
 package unittest
 
-import chisel3._
 import chisel3.util._
-import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}
+import chisel3.iotesters.PeekPokeTester
 import scala.util.Random
 import unittest.util._
 import vta.core._
+import vta.util.config._
 
-class TestAluVector(c: AluVector) extends PeekPokeTester(c) {
-
+object Alu_ref {
   /* alu_ref
    *
    * This is a software function used as a reference for the hardware
    */
-  def aluRef(opcode: Int, a: Array[Int], b: Array[Int], width: Int) : Array[Int] = {
+  def alu(opcode: Int, a: Array[Int], b: Array[Int], width: Int) : Array[Int] = {
     val size = a.length
     val mask = Helper.getMask(log2Ceil(width))
     val res = Array.fill(size) {0}
 
-    if (opcode == 1) {
+    if (opcode == 0) {
+      for (i <- 0 until size) { // min
+        res(i) = if (a(i) < b(i)) a(i) else b(i)
+      }
+    } else if (opcode == 1) { // max
       for (i <- 0 until size) {
         res(i) = if (a(i) < b(i)) b(i) else a(i)
       }
-    } else if (opcode == 2) {
+    } else if (opcode == 2) { // add
       for (i <- 0 until size) {
         res(i) = a(i) + b(i)
       }
-    } else if (opcode == 3) {
+    } else if (opcode == 3) { // right shift
       for (i <- 0 until size) {
         res(i) = a(i) >> (b(i) & mask).toInt
       }
-    } else if (opcode == 4) {
+    } else if (opcode == 4) { // left shift
       // HLS shift left by >> negative number
       // b always < 0 when opcode == 4
       for (i <- 0 until size) {
         res(i) = a(i) << ((-1*b(i)) & mask)
       }
-    } else {
-      // default
+    } else { // default
       for (i <- 0 until size) {
-        res(i) = if (a(i) < b(i)) a(i) else b(i)
+        res(i) = 0
       }
     }
     res
   }
+}
+
+class AluVectorTester(c: AluVector, seed: Int = 47) extends PeekPokeTester(c) {
+  val r = new Random(seed)
 
   val num_ops = ALU_OP_NUM
-  for (i <- 0 until num_ops) {
+  for (op <- 0 until num_ops) {
     // generate data based on bits
     val bits = c.io.acc_a.tensorElemBits
-    val dataGen = new RandomArray(c.blockOut, bits)
-    val op = i
+    val dataGen = new RandomArray(c.blockOut, bits, r)
     val in_a = dataGen.any
     val in_b = if (op != 4) dataGen.any else dataGen.negative
     val mask = Helper.getMask(bits)
-    val res = aluRef(op, in_a, in_b, bits)
+    val res = Alu_ref.alu(op, in_a, in_b, bits)
 
     for (i <- 0 until c.blockOut) {
       poke(c.io.acc_a.data.bits(0)(i), in_a(i) & mask)
@@ -83,13 +88,11 @@ class TestAluVector(c: AluVector) extends PeekPokeTester(c) {
 
     poke(c.io.acc_a.data.valid, 1)
     poke(c.io.acc_b.data.valid, 1)
-    poke(c.io.acc_y.data.valid, 1)
 
     step(1)
 
     poke(c.io.acc_a.data.valid, 0)
     poke(c.io.acc_b.data.valid, 0)
-    poke(c.io.acc_y.data.valid, 0)
 
     // wait for valid signal
     while (peek(c.io.acc_y.data.valid) == BigInt(0)) {
@@ -102,3 +105,6 @@ class TestAluVector(c: AluVector) extends PeekPokeTester(c) {
     }
   }
 }
+
+class AluTest extends GenericTest("AluTest", (p:Parameters) =>
+  new AluVector()(p), (c:AluVector) => new AluVectorTester(c, 48))
diff --git a/hardware/chisel/src/main/scala/core/Configs.scala b/hardware/chisel/src/test/scala/unittest/Generic.scala
old mode 100644
new mode 100755
similarity index 54%
copy from hardware/chisel/src/main/scala/core/Configs.scala
copy to hardware/chisel/src/test/scala/unittest/Generic.scala
index 4ab7d85..3dc0b34
--- a/hardware/chisel/src/main/scala/core/Configs.scala
+++ b/hardware/chisel/src/test/scala/unittest/Generic.scala
@@ -17,32 +17,30 @@
  * under the License.
  */
 
-package vta.core
+package unittest
 
+import chisel3._
+import chisel3.util._
 import vta.util.config._
+import chisel3.iotesters._
+import vta.{DefaultPynqConfig}
 
-/** CoreConfig.
- *
- * This is one supported configuration for VTA. This file will
- * be eventually filled out with class configurations that can be
- * mixed/matched with Shell configurations for different backends.
- */
-class CoreConfig extends Config((site, here, up) => {
-  case CoreKey =>
-    CoreParams(
-      batch = 1,
-      blockOut = 16,
-      blockIn = 16,
-      inpBits = 8,
-      wgtBits = 8,
-      uopBits = 32,
-      accBits = 32,
-      outBits = 8,
-      uopMemDepth = 2048,
-      inpMemDepth = 2048,
-      wgtMemDepth = 1024,
-      accMemDepth = 2048,
-      outMemDepth = 2048,
-      instQueueEntries = 512
+import org.scalatest.{Matchers, FlatSpec}
+
+class GenericTest[T <: Module, P <: PeekPokeTester[T], C <: Parameters]
+  (tag : String, dutFactory : (Parameters) => T, testerFactory : (T) => P) extends FlatSpec with Matchers {
+
+  implicit val p: Parameters = new DefaultPynqConfig
+
+  val arguments = Array(
+    "--backend-name", "treadle",
+    // "--backend-name", "vcs",
+    // "--is-verbose",
+    "--test-seed", "0"
     )
-})
+
+  behavior of tag
+  it should "not have expect violations" in {
+    chisel3.iotesters.Driver.execute(arguments, ()=> dutFactory(p))(testerFactory) should be (true)
+  }
+}
diff --git a/hardware/chisel/src/test/scala/unittest/Launcher.scala b/hardware/chisel/src/test/scala/unittest/Launcher.scala
index 2d10c52..1b0d6da 100644
--- a/hardware/chisel/src/test/scala/unittest/Launcher.scala
+++ b/hardware/chisel/src/test/scala/unittest/Launcher.scala
@@ -49,7 +49,7 @@ object Launcher {
     },
     "alu" -> { (manager: TesterOptionsManager) =>
       Driver.execute(() => new AluVector, manager) {
-        (c) => new TestAluVector(c)
+        (c) => new AluVectorTester(c)
       }
     }
   )
diff --git a/hardware/chisel/src/test/scala/unittest/TensorAluTest.scala b/hardware/chisel/src/test/scala/unittest/TensorAluTest.scala
new file mode 100755
index 0000000..21da5f0
--- /dev/null
+++ b/hardware/chisel/src/test/scala/unittest/TensorAluTest.scala
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package unittest
+
+import chisel3._
+import chisel3.util._
+import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}
+import scala.util.Random
+import unittest.util._
+import vta.core._
+import vta.util.config._
+
+class TensorAluIndexGeneratorTester(c: TensorAluIndexGenerator, alu_use_imm : Int = 0) extends PeekPokeTester(c) {
+
+
+  val uop_begin = 0
+  val uop_end = 2
+  assert(uop_begin < uop_end)
+
+  val lp_0 = 2
+  val lp_1 = 3
+  val dst_0 = 1*lp_1
+  val src_0 = 2*lp_1
+  val dst_1 = 1
+  val src_1 = 2
+
+  poke(c.io.dec.reset, 0)
+  poke(c.io.dec.alu_use_imm, alu_use_imm)
+  poke(c.io.dec.uop_begin, uop_begin)
+  poke(c.io.dec.uop_end, uop_end)
+  poke(c.io.dec.lp_0, lp_0)
+  poke(c.io.dec.lp_1, lp_1)
+  poke(c.io.dec.dst_0, dst_0)
+  poke(c.io.dec.dst_1, dst_1)
+  poke(c.io.dec.src_0, src_0)
+  poke(c.io.dec.src_1, src_1)
+  // Don't need empty_0,{push,pop}_{next,prev},op
+
+
+  class Mocks {
+    val uop_indices = new scala.collection.mutable.Queue[BigInt]
+    val dst_indices = new scala.collection.mutable.Queue[BigInt]
+    val src_indices = new scala.collection.mutable.Queue[BigInt]
+
+    def logical_step() {
+      step(1)
+      if (peek(c.io.valid) == 1) {
+        expect(c.io.uop_idx, uop_indices.dequeue())
+        expect(c.io.dst_idx, dst_indices.dequeue())
+      }
+      if (peek(c.io.src_valid) == 1) {
+        expect(c.io.src_idx, src_indices.dequeue())
+      }
+    }
+
+    def test_if_done() {
+      println(s"uop_indices remaining: ${uop_indices.size}")
+      println(s"dst_indices remaining: ${dst_indices.size}")
+      println(s"src_indices remaining: ${src_indices.size}")
+      assert(uop_indices.isEmpty)
+      assert(dst_indices.isEmpty)
+      assert(src_indices.isEmpty)
+    }
+  }
+
+  val mocks = new Mocks
+  for {
+    cnt_o <- 0 until lp_0
+    cnt_i <- 0 until lp_1
+    uop_idx <- uop_begin until uop_end
+  } {
+    mocks.uop_indices.enqueue(uop_idx)
+    mocks.dst_indices.enqueue(dst_0*cnt_o + dst_1*cnt_i)
+    if (alu_use_imm == 0) {
+      mocks.src_indices.enqueue(src_0*cnt_o + src_1*cnt_i)
+    }
+  }
+
+  poke(c.io.start, 1)
+  mocks.logical_step()
+  poke(c.io.start, 0)
+
+  val end = (uop_end-uop_begin)*lp_0*lp_1
+  var count = 0
+  while(peek(c.io.last) == 0 && count < 10*end + 100) {
+    mocks.logical_step()
+    count += 1
+  }
+  mocks.test_if_done()
+  step(1)
+}
+
+class TensorAluIndexGenerator_0_Test extends GenericTest("TensorAluIndexGenerator_0", (p:Parameters) =>
+  new TensorAluIndexGenerator()(p), (c:TensorAluIndexGenerator) => new TensorAluIndexGeneratorTester(c, 0))
+
+class TensorAluIndexGenerator_1_Test extends GenericTest("TensorAluIndexGenerator_1", (p:Parameters) =>
+  new TensorAluIndexGenerator()(p), (c:TensorAluIndexGenerator) => new TensorAluIndexGeneratorTester(c, 1))
+
+class TensorAluPipelinedTester(c: TensorAlu) extends PeekPokeTester(c) {
+  poke(c.io.start, 0)
+
+  val uop_begin = 0
+  val uop_end = 1
+  assert(uop_begin < uop_end)
+  val alu_use_imm = 1
+  val lp_0 = 2
+  val lp_1 = 3
+  val dst_0 = 1*lp_1
+  val src_0 = 2*lp_1
+  val dst_1 = 1
+  val src_1 = 2
+
+  val dst_offset = BigInt("000", 16)
+  val src_offset = BigInt("100", 16)
+
+  val u0 = dst_offset
+  val u1 = src_offset
+  val u2 = 0 // if src_offset is big, some bits go here
+
+  poke(c.io.dec.reset, 0)
+  poke(c.io.dec.alu_op, 2) // ADD or ADDI 1
+  poke(c.io.dec.alu_imm, 1)
+  poke(c.io.dec.alu_use_imm, alu_use_imm)
+  poke(c.io.dec.uop_begin, uop_begin)
+  poke(c.io.dec.uop_end, uop_end)
+  poke(c.io.dec.lp_0, lp_0)
+  poke(c.io.dec.lp_1, lp_1)
+  poke(c.io.dec.dst_0, dst_0)
+  poke(c.io.dec.dst_1, dst_1)
+  poke(c.io.dec.src_0, src_0)
+  poke(c.io.dec.src_1, src_1)
+
+  // Don't need empty_0,{push,pop}_{next,prev},op
+
+  poke(c.io.uop.data.bits.u0, u0)
+  poke(c.io.uop.data.bits.u1, u1)
+  poke(c.io.uop.data.bits.u2, u2)
+
+  require(c.io.acc.splitWidth == 1, "-F- Test doesnt support acc data access split")
+  require(c.io.acc.splitLength == 1, "-F- Test doesnt support acc data access split")
+
+  val acc = IndexedSeq.tabulate(c.io.acc.rd(0).data.bits(0).size){ i => BigInt(i) }
+  for { lhs <- c.io.acc.rd(0).data.bits} {
+    poke(lhs, acc.reverse)
+  }
+
+  class TensorMasterMock(tm: TensorMaster) {
+    poke(tm.rd(0).data.valid, 0)
+    var valid = peek(tm.rd(0).idx.valid)
+    def logical_step(v: Option[BigInt]) {
+      poke(tm.rd(0).data.valid, valid)
+      valid = peek(tm.rd(0).idx.valid)
+      for { x <- v} expect(tm.rd(0).idx.valid, x)
+    }
+  }
+
+  class UopMasterMock(um: UopMaster) {
+    poke(um.data.valid, 0)
+    var valid = peek(um.idx.valid)
+    def logical_step(v: Option[BigInt]) {
+      poke(um.data.valid, valid)
+      valid = peek(um.idx.valid)
+      for { x <- v} expect(um.idx.valid, x)
+    }
+  }
+
+  class Mocks {
+    val uop_mock = new UopMasterMock(c.io.uop)
+    val acc_mock = new TensorMasterMock(c.io.acc)
+
+    val uop_indices = new scala.collection.mutable.Queue[BigInt]
+    val acc_indices = new scala.collection.mutable.Queue[BigInt]
+    val accout_indices = new scala.collection.mutable.Queue[BigInt]
+    val out_indices = new scala.collection.mutable.Queue[BigInt]
+
+    def logical_step() {
+      step(1)
+      uop_mock.logical_step(None)
+      acc_mock.logical_step(None)
+      if (peek(c.io.uop.idx.valid) == 1) {
+        expect(c.io.uop.idx.bits, uop_indices.dequeue())
+      }
+      if (peek(c.io.acc.rd(0).idx.valid) == 1) {
+        expect(c.io.acc.rd(0).idx.bits, acc_indices.dequeue())
+      }
+      if (peek(c.io.acc.wr(0).valid) == 1) {
+        expect(c.io.acc.wr(0).bits.idx, accout_indices.dequeue())
+      }
+      if (peek(c.io.out.wr(0).valid) == 1) {
+        expect(c.io.out.wr(0).bits.idx, out_indices.dequeue())
+      }
+    }
+
+    def test_if_done() {
+      println(s"uop_indices remaining: ${uop_indices.size}")
+      println(s"acc_indices remaining: ${acc_indices.size}")
+      println(s"accout_indices remaining: ${accout_indices.size}")
+      println(s"out_indices remaining: ${out_indices.size}")
+      assert(uop_indices.isEmpty)
+      assert(acc_indices.isEmpty)
+      assert(accout_indices.isEmpty)
+      assert(out_indices.isEmpty)
+    }
+  }
+
+  val mocks = new Mocks
+  for {
+    cnt_o <- 0 until lp_0
+    cnt_i <- 0 until lp_1
+    uop_idx <- uop_begin until uop_end
+  } {
+    mocks.uop_indices.enqueue(uop_idx)
+    mocks.acc_indices.enqueue(src_offset + src_0*cnt_o + src_1*cnt_i)
+    mocks.accout_indices.enqueue(dst_offset + dst_0*cnt_o + dst_1*cnt_i)
+    mocks.out_indices.enqueue(dst_offset + dst_0*cnt_o + dst_1*cnt_i)
+  }
+
+  poke(c.io.start, 0)
+  step(1)
+  poke(c.io.start, 1)
+
+  var count = 0
+  val end = (uop_end-uop_begin)*lp_0*lp_1
+
+  while (peek(c.io.done) == 0 && count < 10*end + 100) {
+    mocks.logical_step()
+    poke(c.io.start, 0)
+    count += 1
+  }
+  expect(c.io.done, 1)
+  mocks.test_if_done()
+}
+
+class TensorAluPipelinedTest extends GenericTest("TensorAluPipelined", (p:Parameters) =>
+  new TensorAlu()(p), (c:TensorAlu) => new TensorAluPipelinedTester(c))
diff --git a/hardware/chisel/src/test/scala/unittest/utils/RandomArray.scala b/hardware/chisel/src/test/scala/unittest/utils/RandomArray.scala
index 2852d4e..55f2f52 100644
--- a/hardware/chisel/src/test/scala/unittest/utils/RandomArray.scala
+++ b/hardware/chisel/src/test/scala/unittest/utils/RandomArray.scala
@@ -22,10 +22,13 @@ package unittest.util
 import scala.util.Random
 import scala.math.pow
 
-class RandomArray(val len: Int, val bits: Int) {
-  val r = new Random
+class RandomArray(val len: Int, val bits: Int, val r: Random) {
   if (bits < 1) throw new IllegalArgumentException ("bits should be greater than 1")
 
+  def this(len: Int, bits: Int) {
+    this(len, bits, new Random)
+  }
+
   def any : Array[Int] = {
     Array.fill(len) { r.nextInt(pow(2, bits).toInt) - pow(2, bits-1).toInt }
   }