You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/06/08 01:26:49 UTC
[tvm-vta] branch main updated: Chisel Pipelined ALU (#27)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-vta.git
The following commit(s) were added to refs/heads/main by this push:
new d5e8117 Chisel Pipelined ALU (#27)
d5e8117 is described below
commit d5e8117ce1535c527c536e115b4e58d53817b82f
Author: Abhijit Davare <ab...@intel.com>
AuthorDate: Mon Jun 7 18:26:41 2021 -0700
Chisel Pipelined ALU (#27)
* Remove parameter values from case class
* Add new blockOutFactor parameter with default value = 1
* Support split access
* Modify to support split interface, minor refactoring
* Use split read/write intefaces
* Pipelined ALU with split interfaces
* Modify instantiation and usage of pipelined ALU and split interfaces
* Don't use internal Random by default
* Change tester name
* Add generic tester class
* Derive from GenericTest, minor refactoring
* Test ALU index generator and pipelined ALU
* Add ASF header
* Bugfix: delay slicing index by a cycle to match SyncReadMem read delay
* Formatting, comment, and minor refactoring changes for clarity
* Fix scalastyle issues for test files
---
hardware/chisel/src/main/scala/core/Compute.scala | 103 +++++-
hardware/chisel/src/main/scala/core/Configs.scala | 1 +
hardware/chisel/src/main/scala/core/Core.scala | 29 +-
hardware/chisel/src/main/scala/core/LoadUop.scala | 3 +-
.../chisel/src/main/scala/core/TensorAlu.scala | 401 ++++++++++++++++++---
.../chisel/src/main/scala/core/TensorGemm.scala | 106 +++---
.../chisel/src/main/scala/core/TensorLoad.scala | 62 +++-
.../chisel/src/main/scala/core/TensorStore.scala | 28 +-
.../chisel/src/main/scala/core/TensorUtil.scala | 186 ++++++++--
.../chisel/src/test/scala/unittest/AluTest.scala | 42 ++-
.../scala/unittest/Generic.scala} | 48 ++-
.../chisel/src/test/scala/unittest/Launcher.scala | 2 +-
.../src/test/scala/unittest/TensorAluTest.scala | 252 +++++++++++++
.../test/scala/unittest/utils/RandomArray.scala | 7 +-
14 files changed, 1042 insertions(+), 228 deletions(-)
diff --git a/hardware/chisel/src/main/scala/core/Compute.scala b/hardware/chisel/src/main/scala/core/Compute.scala
index a1e7fad..0055a25 100644
--- a/hardware/chisel/src/main/scala/core/Compute.scala
+++ b/hardware/chisel/src/main/scala/core/Compute.scala
@@ -19,6 +19,8 @@
package vta.core
+import scala.math.pow
+
import chisel3._
import chisel3.util._
import vta.util.config._
@@ -32,7 +34,7 @@ import vta.shell._
* - Compute ALU instructions (tensorAlu module)
* - Compute GEMM instructions (tensorGemm module)
*/
-class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class Compute(debug: Boolean = false)(implicit val p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val i_post = Vec(2, Input(Bool()))
@@ -58,6 +60,9 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val tensorGemm = Module(new TensorGemm)
val tensorAlu = Module(new TensorAlu)
+ // try to use the acc closest to top IO
+ val topAccGrpIdx = tensorGemm.io.acc.closestIOGrpIdx
+
val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
// decode
@@ -118,44 +123,102 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd
loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
+ assert(!tensorGemm.io.uop.idx.valid || !tensorAlu.io.uop.idx.valid)
// acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr
- tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
- tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
+ require(tensorAcc.io.tensor.lenSplit ==
+ tensorAcc.io.tensor.tensorLength, "-F- Expecting a whole batch in acc group")
+
+ // split factor of isGemm for many groups
+ val splitFactorL0 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
+ val splitFactorL1 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth)
+ - log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
+ require(splitFactorL0 * splitFactorL1 == tensorAcc.io.tensor.splitWidth)
+ val accRdSelectL0 = for (idx <- 0 until splitFactorL1) yield {
+ // can save 1 stage on small design
+ if (splitFactorL1 > 1) RegNext(dec.io.isGemm, init = false.B) else dec.io.isGemm
+ }
+
+ for (idx <- 0 until tensorAcc.io.tensor.splitWidth) {
+ tensorAcc.io.tensor.rd(idx).idx <> Mux(
+ RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
+ tensorGemm.io.acc.rd(idx).idx,
+ tensorAlu.io.acc.rd(idx).idx)
+ tensorAcc.io.tensor.wr(idx) <> Mux(
+ RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
+ tensorGemm.io.acc.wr(idx),
+ tensorAlu.io.acc.wr(idx))
+ }
io.vme_rd(1) <> tensorAcc.io.vme_rd
- io.acc_wr_event := tensorAcc.io.tensor.wr.valid
+ io.acc_wr_event := tensorAcc.io.tensor.wr(topAccGrpIdx).valid
// gemm
- tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
- tensorGemm.io.inst := inst_q.io.deq.bits
+ tensorGemm.io.start := RegNext(state === sIdle & start & dec.io.isGemm, init = false.B)
+ tensorGemm.io.dec := inst_q.io.deq.bits.asTypeOf(new GemmDecode)
tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorGemm.io.inp <> io.inp
tensorGemm.io.wgt <> io.wgt
- tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm
- tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
- tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
- tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
+ for (idx <- 0 until tensorGemm.io.acc.splitWidth) {
+ tensorGemm.io.acc.rd(idx).data.valid :=
+ tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isGemm, init = false.B)
+ tensorGemm.io.acc.rd(idx).data.bits <> tensorAcc.io.tensor.rd(idx).data.bits
+ }
+ for (idx <- 0 until tensorGemm.io.out.splitWidth) {
+ tensorGemm.io.out.rd(idx).data.valid :=
+ io.out.rd(idx).data.valid & RegNext(dec.io.isGemm, init = false.B)
+ tensorGemm.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
+ }
// alu
- tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
- tensorAlu.io.inst := inst_q.io.deq.bits
+ tensorAlu.io.start := RegNext(state === sIdle & start & dec.io.isAlu, init = false.B)
+ tensorAlu.io.dec := inst_q.io.deq.bits.asTypeOf(new AluDecode)
tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
- tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu
- tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
- tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
- tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
+ for (idx <- 0 until tensorAlu.io.acc.splitWidth) {
+ tensorAlu.io.acc.rd(idx).data.valid :=
+ tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isAlu, init = false.B)
+ tensorAlu.io.acc.rd(idx).data.bits <> tensorAcc.io.tensor.rd(idx).data.bits
+ }
+ for (idx <- 0 until tensorAlu.io.out.splitWidth) {
+ tensorAlu.io.out.rd(idx).data.valid :=
+ io.out.rd(idx).data.valid & RegNext(dec.io.isAlu, init = false.B)
+ tensorAlu.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
+ }
// out
- io.out.rd.idx <> Mux(dec.io.isGemm,
- tensorGemm.io.out.rd.idx,
- tensorAlu.io.out.rd.idx)
- io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
+ for (idx <- 0 until tensorGemm.io.out.splitWidth) {
+ io.out.rd(idx).idx <> Mux(dec.io.isGemm,
+ tensorGemm.io.out.rd(idx).idx,
+ tensorAlu.io.out.rd(idx).idx)
+ assert(!tensorGemm.io.out.rd(idx).idx.valid || !tensorAlu.io.out.rd(idx).idx.valid)
+ assert(!tensorGemm.io.out.rd(idx).data.valid || !tensorAlu.io.out.rd(idx).data.valid)
+ assert(!tensorGemm.io.out.wr(idx).valid || !tensorAlu.io.out.wr(idx).valid)
+ }
+ require (tensorGemm.io.out.splitWidth == 1)
+ require (tensorAlu.io.out.splitWidth == 1)
+ io.out.wr(0).valid := Mux(
+ RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).valid, tensorAlu.io.out.wr(0).valid)
+ io.out.wr(0).bits.idx := Mux(
+ RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).bits.idx, tensorAlu.io.out.wr(0).bits.idx)
+ // put mux/Reg into every gemm group to build pipe (for Mux select) tree over distance
+ val chunkWidth = io.out.wr(0).bits.data.getWidth / tensorGemm.io.acc.splitWidth
+ val outDataBits = Wire(Vec(tensorGemm.io.acc.splitWidth, UInt(chunkWidth.W)))
+ io.out.wr(0).bits.data := outDataBits.asTypeOf(io.out.wr(0).bits.data)
+ for (idx <- 0 until tensorGemm.io.acc.splitWidth) {
+ val lowBitIdx = idx * chunkWidth
+ val highBitIdx = lowBitIdx + chunkWidth - 1
+ val srcAluFlat = tensorAlu.io.out.wr(0).bits.data.asUInt
+ val srcGemFlat = tensorGemm.io.out.wr(0).bits.data.asUInt
+ outDataBits(idx) := Mux(
+ RegNext(dec.io.isGemm, init = false.B),
+ srcGemFlat(highBitIdx, lowBitIdx),
+ srcAluFlat(highBitIdx, lowBitIdx))
+ }
// semaphore
s(0).io.spost := io.i_post(0)
s(1).io.spost := io.i_post(1)
diff --git a/hardware/chisel/src/main/scala/core/Configs.scala b/hardware/chisel/src/main/scala/core/Configs.scala
index 4ab7d85..4022dc3 100644
--- a/hardware/chisel/src/main/scala/core/Configs.scala
+++ b/hardware/chisel/src/main/scala/core/Configs.scala
@@ -32,6 +32,7 @@ class CoreConfig extends Config((site, here, up) => {
CoreParams(
batch = 1,
blockOut = 16,
+ blockOutFactor = 1,
blockIn = 16,
inpBits = 8,
wgtBits = 8,
diff --git a/hardware/chisel/src/main/scala/core/Core.scala b/hardware/chisel/src/main/scala/core/Core.scala
index e2ac51a..a01c805 100644
--- a/hardware/chisel/src/main/scala/core/Core.scala
+++ b/hardware/chisel/src/main/scala/core/Core.scala
@@ -25,20 +25,21 @@ import vta.shell._
/** Core parameters */
case class CoreParams(
- batch: Int = 1,
- blockOut: Int = 16,
- blockIn: Int = 16,
- inpBits: Int = 8,
- wgtBits: Int = 8,
- uopBits: Int = 32,
- accBits: Int = 32,
- outBits: Int = 8,
- uopMemDepth: Int = 512,
- inpMemDepth: Int = 512,
- wgtMemDepth: Int = 512,
- accMemDepth: Int = 512,
- outMemDepth: Int = 512,
- instQueueEntries: Int = 32
+ batch: Int,
+ blockOut: Int,
+ blockOutFactor: Int,
+ blockIn: Int,
+ inpBits: Int,
+ wgtBits: Int,
+ uopBits: Int,
+ accBits: Int,
+ outBits: Int,
+ uopMemDepth: Int,
+ inpMemDepth: Int,
+ wgtMemDepth: Int,
+ accMemDepth: Int,
+ outMemDepth: Int,
+ instQueueEntries: Int
) {
require(uopBits % 8 == 0,
s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
diff --git a/hardware/chisel/src/main/scala/core/LoadUop.scala b/hardware/chisel/src/main/scala/core/LoadUop.scala
index 87bd508..e9f5d40 100644
--- a/hardware/chisel/src/main/scala/core/LoadUop.scala
+++ b/hardware/chisel/src/main/scala/core/LoadUop.scala
@@ -205,7 +205,8 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
// read-from-sram
io.uop.data.valid := RegNext(io.uop.idx.valid)
- val sIdx = io.uop.idx.bits % numUop.U
+ // delay LSB of idx by a cycle because of the one-cycle memory read latency
+ val sIdx = RegNext(io.uop.idx.bits % numUop.U)
val rIdx = io.uop.idx.bits >> log2Ceil(numUop)
val memRead = mem.read(rIdx, io.uop.idx.valid)
val sWord = memRead.asUInt.asTypeOf(wdata)
diff --git a/hardware/chisel/src/main/scala/core/TensorAlu.scala b/hardware/chisel/src/main/scala/core/TensorAlu.scala
index 81abb8e..8d7aa31 100644
--- a/hardware/chisel/src/main/scala/core/TensorAlu.scala
+++ b/hardware/chisel/src/main/scala/core/TensorAlu.scala
@@ -97,38 +97,322 @@ class AluVector(implicit p: Parameters) extends Module {
io.out.data.valid := valid.asUInt.andR
}
-/** TensorAlu.
- *
- * This unit instantiate the ALU vector unit (AluVector) and go over the
- * micro-ops (uops) which are used to read the source operands (vectors)
- * from the acc-scratchpad and then they are written back the same
- * acc-scratchpad.
- */
-class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class TensorAluIndexGenerator(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val cnt_o_width = (new AluDecode).lp_0.getWidth
+ val cnt_i_width = (new AluDecode).lp_1.getWidth
+
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val last = Output(Bool())
+ val dec = Input(new AluDecode)
+ val valid = Output(Bool())
+ val src_valid = Output(Bool())
+ val dst_idx = Output(UInt(new TensorParams(tensorType="acc").memAddrBits.W))
+ val src_idx = Output(UInt(new TensorParams(tensorType="acc").memAddrBits.W))
+ val uop_idx = Output(UInt(log2Ceil(p(CoreKey).uopMemDepth).W))
+ val cnt_o = Output(UInt(cnt_o_width.W))
+ val cnt_i = Output(UInt(cnt_i_width.W))
+ })
+
+ io.last := false.B
+
+ val running = RegInit(false.B)
+ val stutter = RegInit(false.B)
+
+ val advance = io.dec.alu_use_imm || stutter
+
+ when(!running && io.start) {
+ running := true.B
+ } .elsewhen(running && !advance) {
+ stutter := true.B
+ } .elsewhen(running && advance) {
+ when (io.last) {
+ running := false.B
+ }
+ stutter := false.B
+ }
+
+ val cnt_i = Reg(chiselTypeOf(io.dec.lp_1))
+ val dst_i = Reg(chiselTypeOf(io.dst_idx))
+ val src_i = Reg(chiselTypeOf(io.src_idx))
+
+ val cnt_o = Reg(chiselTypeOf(io.dec.lp_0))
+ val dst_o = Reg(chiselTypeOf(io.dst_idx))
+ val src_o = Reg(chiselTypeOf(io.src_idx))
+
+ val uop_idx = Reg(chiselTypeOf(io.dec.uop_end))
+
+ io.valid := running && advance
+ io.src_valid := running && !advance
+ io.dst_idx := dst_i
+ io.src_idx := src_i
+ io.uop_idx := uop_idx
+ io.cnt_o := cnt_o
+ io.cnt_i := cnt_i
+
+ when(!running) {
+ cnt_i := 0.U; dst_i := 0.U; src_i := 0.U;
+ cnt_o := 0.U; dst_o := 0.U; src_o := 0.U;
+ uop_idx := io.dec.uop_begin
+ } .elsewhen (advance) {
+ when (uop_idx =/= io.dec.uop_end - 1.U) {
+ uop_idx := uop_idx + 1.U
+ }.otherwise {
+ uop_idx := io.dec.uop_begin
+ when (cnt_i =/= io.dec.lp_1 - 1.U) {
+ cnt_i := cnt_i + 1.U
+ dst_i := dst_i + io.dec.dst_1
+ src_i := src_i + io.dec.src_1
+ }.otherwise {
+ when (cnt_o =/= io.dec.lp_0 - 1.U) {
+ val dst_tmp = dst_o + io.dec.dst_0
+ val src_tmp = src_o + io.dec.src_0
+ cnt_o := cnt_o + 1.U
+ dst_o := dst_tmp
+ src_o := src_tmp
+ cnt_i := 0.U
+ dst_i := dst_tmp
+ src_i := src_tmp
+ } .otherwise {
+ io.last := true.B
+ }
+ }
+ }
+ }
+}
+
+class TensorAluIfc(implicit p: Parameters) extends Module {
val aluBits = p(CoreKey).accBits
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
- val inst = Input(UInt(INST_BITS.W))
+ val dec = Input(new AluDecode)
val uop = new UopMaster
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
})
+}
+
+class TensorAluPipelined(debug: Boolean = false)(implicit p: Parameters) extends TensorAluIfc {
+ val stateBits = 2
+ val inflightBits = 4
+ val dataSplitFactor = p(CoreKey).blockOutFactor
+
+ val sIdle::sRun::sWait::Nil = Enum(3)
+ val state = RegInit(init=sIdle)
+ val inflight = RegInit(0.U(inflightBits.W))
+
+ val index_generator = Module(new TensorAluIndexGenerator)
+ val aluDataReadPipeDelay = 0 // available for pipelining
+
+ // State Machine for compute io.done correctly
+ io.done := false.B
+ when(state === sIdle && io.start) {
+ state := sRun
+ }.elsewhen(state === sRun && index_generator.io.last) {
+ state := sWait
+ }.elsewhen(state === sWait && inflight === 0.U) {
+ state := sIdle
+ io.done := true.B
+ }
+
+ index_generator.io.start := io.start
+ index_generator.io.dec := io.dec
+
+ // second term works around funny clearing in uop register file flopped output
+ io.uop.idx.valid := index_generator.io.valid || index_generator.io.src_valid
+ io.uop.idx.bits := index_generator.io.uop_idx
+
+ val valid_r1 = ShiftRegister(index_generator.io.valid, aluDataReadPipeDelay + 1, resetData=false.B, en = true.B)
+ val valid_r2 = RegNext(valid_r1, init=false.B)
+ val valid_r3 = RegNext(valid_r2, init=false.B)
+ val valid_r4 = RegNext(valid_r3, init=false.B)
+
+ when(index_generator.io.valid && valid_r4) {
+ }.elsewhen(index_generator.io.valid) {
+ assert(inflight =/= ((1<<inflightBits)-1).U)
+ inflight := inflight + 1.U
+ }.elsewhen(valid_r4) {
+ assert(inflight =/= 0.U)
+ inflight := inflight - 1.U
+ }
+ when(state === sIdle) {
+ assert(inflight === 0.U)
+ inflight := 0.U
+ }
+
+ val src_valid_r1 = ShiftRegister(
+ index_generator.io.src_valid,
+ aluDataReadPipeDelay + 1,
+ resetData=false.B, en = true.B)
+ val src_valid_r2 = RegNext(src_valid_r1, init=false.B)
+ val src_valid_r3 = RegNext(src_valid_r2, init=false.B)
+ val src_valid_r4 = RegNext(src_valid_r3, init=false.B)
+
+ val dst_idx_r1 = ShiftRegister(index_generator.io.dst_idx, aluDataReadPipeDelay + 1)
+ val src_idx_r1 = ShiftRegister(index_generator.io.src_idx, aluDataReadPipeDelay + 1)
+
+ val uop_data_r1 = ShiftRegister(io.uop.data, aluDataReadPipeDelay)
+
+ val dst_offset = uop_data_r1.bits.u0
+
+ val w = dst_offset.getWidth
+ val u2 = uop_data_r1.bits.u2.asTypeOf(UInt(w.W))
+ val s = log2Ceil(p(CoreKey).inpMemDepth)
+ val u1 = uop_data_r1.bits.u1.asTypeOf(UInt(w.W))
+ val src_offset = (u2 << s) | u1
+
+ // split registers of stage 2 by data groups
+ val accRdIdxValid = valid_r1 || src_valid_r1
+ for (idx <- 0 until dataSplitFactor) {
+ io.acc.rd(idx).idx.valid := RegNext(accRdIdxValid)
+ }
+
+ val new_src_idx_r1 = src_idx_r1 + src_offset
+ val src_idx_r2 = RegNext(new_src_idx_r1)
+ val src_idx_r3 = RegNext(src_idx_r2)
+
+ val new_dst_idx_r1 = dst_idx_r1 + dst_offset
+ val dst_idx_r2 = RegNext(new_dst_idx_r1)
+ val dst_idx_r3 = RegNext(dst_idx_r2)
+ val dst_idx_r4 = RegNext(dst_idx_r3)
+
+ // split registers of stage 2 by data groups
+ val accRdIdxBits = Mux(src_valid_r1 || io.dec.alu_use_imm, new_src_idx_r1, new_dst_idx_r1)
+ for (idx <- 0 until dataSplitFactor) {
+ io.acc.rd(idx).idx.bits := RegNext(accRdIdxBits)
+ assert(io.acc.rd(idx).data.valid === (valid_r3 || src_valid_r3))
+ }
+
+ require(io.out.splitWidth == 1 && io.out.splitLength == 1, "-F- Out split write is not supported")
+ val numVecUnits = dataSplitFactor
+ val outData = Wire(io.out.wr(0).bits.data.cloneType)
+ val dataRemapB = Wire(Vec(numVecUnits, io.acc.rd(0).data.bits.cloneType))
+ val dataRemapA = Wire(Vec(numVecUnits, io.acc.rd(0).data.bits.cloneType))
+ // numVecUnits is a pow of 2
+ // split dec bits pipe further if there are many vecUnits
+ val decSplitNb0 = if (numVecUnits < 8) 1 else 2
+ val decSplit0 = Wire(Vec(decSplitNb0, io.dec.cloneType))
+ for (idx <- 0 until decSplitNb0) {
+ decSplit0(idx) := ShiftRegister(io.dec, if(aluDataReadPipeDelay < 2) 0 else 1)
+ }
+
+ for (idx <- 0 until numVecUnits) {
+ val alu = Module(new AluVector)
+
+ for(aluLenIdx <- 0 until alu.io.acc_b.lenSplit) {
+ for(aluWdtIdx <- 0 until alu.io.acc_b.widthSplit) {
+ val (accGrpIdx, accLenIdx, accWdtIdx) =
+ alu.io.acc_b.reindexDataFromGroup(idx, aluLenIdx, aluWdtIdx)
+ dataRemapB(idx)(aluLenIdx)(aluWdtIdx) :=
+ io.acc.rd(accGrpIdx).data.bits(accLenIdx)(accWdtIdx)
+ }
+ }
+ val save_src = RegNext(dataRemapB(idx))
+ val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
+ tensorImm.data.valid := valid_r3
+ val tensorImmBits_piped = ShiftRegister(
+ decSplit0(idx/(numVecUnits/decSplitNb0)).alu_imm,
+ if(aluDataReadPipeDelay < 2) aluDataReadPipeDelay else aluDataReadPipeDelay -1)
+ tensorImm.data.bits.foreach { b =>
+ b.foreach { c =>
+ c := Mux(tensorImmBits_piped(C_ALU_IMM_BITS - 1),
+ Cat(-1.S((aluBits - C_ALU_IMM_BITS).W), tensorImmBits_piped), tensorImmBits_piped)
+ }
+ }
+
+ // alu
+ val tensorOpBits_piped = ShiftRegister(
+ decSplit0(idx/(numVecUnits/decSplitNb0)).alu_op,
+ if(aluDataReadPipeDelay < 2) aluDataReadPipeDelay else aluDataReadPipeDelay -1)
+ val isSHR = (tensorOpBits_piped === ALU_OP(3))
+ val neg_shift = isSHR & tensorImmBits_piped(C_ALU_IMM_BITS - 1)
+ val fixme_alu_op = Mux(
+ neg_shift,
+ ALU_OP(4), // use opcode = 4 for left shift
+ tensorOpBits_piped)
+ alu.io.opcode := fixme_alu_op
+
+ assert(!valid_r3 || io.acc.rd(idx).data.valid)
+
+ alu.io.acc_a.data.valid := RegNext(valid_r2) // valid_r3 split
+
+ for(aluLenIdx <- 0 until alu.io.acc_a.lenSplit) {
+ for(aluWdtIdx <- 0 until alu.io.acc_a.widthSplit) {
+ val (accGrpIdx, accLenIdx, accWdtIdx) =
+ alu.io.acc_a.reindexDataFromGroup(idx, aluLenIdx, aluWdtIdx)
+ dataRemapA(idx)(aluLenIdx)(aluWdtIdx) :=
+ io.acc.rd(accGrpIdx).data.bits(accLenIdx)(accWdtIdx)
+ alu.io.acc_a.data.bits := dataRemapA(idx)
+ }
+ }
+ val tensorUseImmBits_piped = ShiftRegister(
+ decSplit0(idx/(numVecUnits/decSplitNb0)).alu_use_imm,
+ if(aluDataReadPipeDelay < 2) aluDataReadPipeDelay else aluDataReadPipeDelay -1)
+ alu.io.acc_b.data.valid := Mux(tensorUseImmBits_piped,
+ tensorImm.data.valid,
+ valid_r3)
+ alu.io.acc_b.data.bits := Mux(tensorUseImmBits_piped,
+ tensorImm.data.bits,
+ save_src)
+
+ assert(alu.io.acc_y.data.valid === valid_r4)
+ io.acc.wr(idx).valid := valid_r4
+ io.acc.wr(idx).bits.idx := dst_idx_r4
+
+ for(aluLenIdx <- 0 until alu.io.acc_y.lenSplit) {
+ for(aluWdtIdx <- 0 until alu.io.acc_y.widthSplit) {
+ val (accGrpIdx, accLenIdx, accWdtIdx) =
+ alu.io.acc_y.reindexDataFromGroup(idx, aluLenIdx, aluWdtIdx)
+ io.acc.wr(accGrpIdx).bits.data(accLenIdx)(accWdtIdx) :=
+ alu.io.acc_y.data.bits(aluLenIdx)(aluWdtIdx)
+ }
+ }
+
+ assert(alu.io.out.data.valid === valid_r4)
+ for (idx1 <- 0 until io.out.tensorLength) {
+ for (idx2 <- 0 until io.out.tensorWidth/numVecUnits) {
+ outData(idx1)(idx*io.out.tensorWidth/numVecUnits + idx2) := alu.io.out.data.bits(idx1)(idx2)
+ }
+ }
+ }
+
+// comment for split write
+ io.out.wr(0).valid := valid_r4
+ io.out.wr(0).bits.idx := dst_idx_r4
+ io.out.wr(0).bits.data := outData
+ io.out.tieoffRead()
+
+ val bypass_dst = valid_r3 && valid_r4 && (dst_idx_r4 === dst_idx_r3)
+ val bypass_src = src_valid_r3 && valid_r4 && (dst_idx_r4 === src_idx_r3)
+
+ // Do we need a bypass
+ assert(!bypass_dst, s"Bypass required on dst_idx read $dst_idx_r3 RAW with write $dst_idx_r4\n")
+ assert(!bypass_src, s"Bypass required on src_idx read $src_idx_r3 RAW with write $dst_idx_r4\n")
+}
+
+/** TensorAluOrig.
+ * This unit instantiate the ALU vector unit (AluVector) and go over the
+ * micro-ops (uops) which are used to read the source operands (vectors)
+ * from the acc-scratchpad and then they are written back the same
+ * acc-scratchpad.
+ */
+class TensorAluOrig(debug: Boolean = false)(implicit p: Parameters) extends TensorAluIfc {
val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil =
Enum(6)
val state = RegInit(sIdle)
val alu = Module(new AluVector)
- val dec = io.inst.asTypeOf(new AluDecode)
+ val dec = io.dec
val uop_idx = Reg(chiselTypeOf(dec.uop_end))
val uop_end = dec.uop_end
- val uop_dst = Reg(chiselTypeOf(dec.uop_end))
- val uop_src = Reg(chiselTypeOf(dec.uop_end))
+ val uop_dst = Reg(chiselTypeOf(io.uop.data.bits.u0)) // width can address entire acc
+ val uop_src = Reg(chiselTypeOf(io.uop.data.bits.u0)) // width can address entire acc
val cnt_o = Reg(chiselTypeOf(dec.lp_0))
- val dst_o = Reg(chiselTypeOf(dec.uop_end))
- val src_o = Reg(chiselTypeOf(dec.uop_end))
+ val dst_o = Reg(chiselTypeOf(io.uop.data.bits.u0))
+ val src_o = Reg(chiselTypeOf(io.uop.data.bits.u0))
val cnt_i = Reg(chiselTypeOf(dec.lp_1))
- val dst_i = Reg(chiselTypeOf(dec.uop_end))
- val src_i = Reg(chiselTypeOf(dec.uop_end))
+ val dst_i = Reg(chiselTypeOf(io.uop.data.bits.u0))
+ val src_i = Reg(chiselTypeOf(io.uop.data.bits.u0))
val done =
state === sExe &
alu.io.out.data.valid &
@@ -208,57 +492,62 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
when(state === sComputeIdx && io.uop.data.valid) {
uop_dst := io.uop.data.bits.u0 + dst_i
- uop_src := io.uop.data.bits.u1 + src_i
+ uop_src := ((io.uop.data.bits.u2.asTypeOf(UInt(width = uop_dst.getWidth.W)) << log2Ceil(p(CoreKey).inpMemDepth))
+ | io.uop.data.bits.u1.asTypeOf(UInt(width = uop_dst.getWidth.W))) + src_i
}
// uop
io.uop.idx.valid := state === sReadUop
io.uop.idx.bits := uop_idx
- // acc (input)
- io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
- io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src)
-
- // imm
- val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
- tensorImm.data.valid := state === sReadTensorB
- tensorImm.data.bits.foreach { b =>
- b.foreach { c =>
- c := Mux(dec.alu_imm(C_ALU_IMM_BITS - 1),
- Cat(-1.S((aluBits - C_ALU_IMM_BITS).W), dec.alu_imm), dec.alu_imm)
+ val dataSplitFactor = p(CoreKey).blockOutFactor
+
+ val accRdValid = state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
+ val accRdIdx = Mux(state === sReadTensorA, uop_dst, uop_src)
+ for (idx <- 0 until dataSplitFactor) {
+ // acc (input)
+ io.acc.rd(idx).idx.valid := accRdValid
+ io.acc.rd(idx).idx.bits := accRdIdx
+
+ // imm
+ val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
+ tensorImm.data.valid := state === sReadTensorB
+ tensorImm.data.bits.foreach { b =>
+ b.foreach { c =>
+ c := Mux(dec.alu_imm(C_ALU_IMM_BITS - 1),
+ Cat(-1.S((aluBits - C_ALU_IMM_BITS).W), dec.alu_imm), dec.alu_imm)
+ }
}
- }
- // alu
- val isSHR = dec.alu_op === ALU_OP(3)
- val isSHL = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
- // opcode - min:0, max:1, add:2, shr:3, shl:4
- val fixme_alu_op = Cat(isSHL, Mux(isSHL, 0.U, dec.alu_op(1, 0)))
- alu.io.opcode := fixme_alu_op
- alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
- alu.io.acc_a.data.bits <> io.acc.rd.data.bits
- alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
- tensorImm.data.valid,
- io.acc.rd.data.valid & state === sExe)
- alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
- tensorImm.data.bits,
- io.acc.rd.data.bits)
-
- // acc (output)
- io.acc.wr.valid := alu.io.acc_y.data.valid
- io.acc.wr.bits.idx := uop_dst
- io.acc.wr.bits.data <> alu.io.acc_y.data.bits
-
- // out
- io.out.wr.valid := alu.io.out.data.valid
- io.out.wr.bits.idx := uop_dst
- io.out.wr.bits.data <> alu.io.out.data.bits
+ // alu
+ val isSHR = (dec.alu_op === ALU_OP(3))
+ val isSHL = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
+ // opcode - min:0, max:1, add:2, shr:3, shl:4
+ val fixme_alu_op = Cat(isSHL, Mux(isSHL, 0.U, dec.alu_op(1, 0)))
+ alu.io.opcode := fixme_alu_op
+ alu.io.acc_a.data.valid := io.acc.rd(idx).data.valid & state === sReadTensorB
+ alu.io.acc_a.data.bits <> io.acc.rd(idx).data.bits
+ alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
+ tensorImm.data.valid,
+ io.acc.rd(idx).data.valid & state === sExe)
+ alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
+ tensorImm.data.bits,
+ io.acc.rd(idx).data.bits)
+
+ // acc (output)
+ io.acc.wr(idx).valid := alu.io.acc_y.data.valid
+ io.acc.wr(idx).bits.idx := uop_dst
+ io.acc.wr(idx).bits.data <> alu.io.acc_y.data.bits
+
+ // out
+ io.out.wr(idx).valid := alu.io.out.data.valid
+ io.out.wr(idx).bits.idx := uop_dst
+ io.out.wr(idx).bits.data <> alu.io.out.data.bits
+ }
io.out.tieoffRead() // write-only
-
io.done := done
if (debug) {
-
when(state === sReadUop) {
printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
}
@@ -308,3 +597,5 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
}
}
}
+
+class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends TensorAluPipelined(debug)
diff --git a/hardware/chisel/src/main/scala/core/TensorGemm.scala b/hardware/chisel/src/main/scala/core/TensorGemm.scala
index f2d295f..e977552 100644
--- a/hardware/chisel/src/main/scala/core/TensorGemm.scala
+++ b/hardware/chisel/src/main/scala/core/TensorGemm.scala
@@ -173,31 +173,41 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
io.out.data.valid := vld.asUInt.andR
}
-/** TensorGemm.
- *
- * This unit instantiate the MatrixVectorMultiplication and go over the
- * micro-ops (uops) which are used to read inputs, weights and biases,
- * and writes results back to the acc and out scratchpads.
- *
- * Also, the TensorGemm uses the reset field in the Gemm instruction to
- * clear or zero-out the acc-scratchpad locations based on the micro-ops.
- */
-class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
+abstract class TensorGemmIfc(implicit p: Parameters) extends Module {
+ val stateBits = 3
+ val inflightBits = 4
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
- val inst = Input(UInt(INST_BITS.W))
+ val dec = Input(new GemmDecode)
val uop = new UopMaster
val inp = new TensorMaster(tensorType = "inp")
val wgt = new TensorMaster(tensorType = "wgt")
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
+ val state = Output(UInt(stateBits.W))
+ val inflight = Output(UInt(inflightBits.W))
})
- val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil =
- Enum(6)
+}
+
+/** TensorGemm.
+ *
+ * This unit instantiate the MatrixVectorMultiplication and go over the
+ * micro-ops (uops) which are used to read inputs, weights and biases,
+ * and writes results back to the acc and out scratchpads.
+ *
+ * Also, the TensorGemm uses the reset field in the Gemm instruction to
+ * clear or zero-out the acc-scratchpad locations based on the micro-ops.
+ */
+class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends TensorGemmIfc {
+
+ require(p(CoreKey).blockOutFactor == 1,
+ "-F- Split GEMM not supported. Use TensorGemmPipelinedSplit or set blockOutFactor to 1")
+ val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
val state = RegInit(sIdle)
+ io.state := state
val mvc = Module(new MatrixVectorMultiplication)
- val dec = io.inst.asTypeOf(new GemmDecode)
+ val dec = io.dec
val uop_idx = Reg(chiselTypeOf(dec.uop_end))
val uop_end = dec.uop_end
val uop_acc = Reg(chiselTypeOf(dec.uop_end))
@@ -211,18 +221,18 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
val acc_i = Reg(chiselTypeOf(dec.uop_end))
val inp_i = Reg(chiselTypeOf(dec.uop_end))
val wgt_i = Reg(chiselTypeOf(dec.uop_end))
- val pBits = log2Ceil(p(CoreKey).blockOut) + 1
- val inflight = Reg(UInt(pBits.W))
+
+ val inflight = Reg(UInt(inflightBits.W))
+ io.inflight := inflight
// Latency is defined as two in the following, because there is one cycle in the MAC module,
// and another cycle in the pipelined adders as the first layer of the accumulator
val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = 2))
+ val cond_last = cnt_o === dec.lp_0 - 1.U &
+ cnt_i === dec.lp_1 - 1.U &
+ uop_idx === uop_end - 1.U
+
val done = inflight === 0.U &
- ((state === sExe &
- cnt_o === dec.lp_0 - 1.U &
- cnt_i === dec.lp_1 - 1.U &
- uop_idx === uop_end - 1.U &
- inflight === 0.U) |
- state === sWait)
+ ((state === sExe) & cond_last | state === sWait)
switch(state) {
is(sIdle) {
@@ -240,10 +250,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
state := sExe
}
is(sExe) {
- when(
- (cnt_o === dec.lp_0 - 1.U) &&
- (cnt_i === dec.lp_1 - 1.U) &&
- (uop_idx === uop_end - 1.U)) {
+ when(cond_last) {
when(inflight =/= 0.U) {
state := sWait
}.otherwise {
@@ -264,10 +271,11 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
inflight := 0.U
}.elsewhen(!dec.reset) {
when((state === sReadTensor) && mvc.io.acc_o.data.valid) { // issue & commit
- inflight := inflight
}.elsewhen(state === sReadTensor) { // issue a tensor
+ assert(inflight =/= ((1<<inflightBits)-1).U)
inflight := inflight + 1.U
}.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
+ assert(inflight =/= 0.U)
inflight := inflight - 1.U
}
}
@@ -327,40 +335,42 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
io.uop.idx.bits := uop_idx
// inp
- io.inp.rd.idx.valid := state === sReadTensor
- io.inp.rd.idx.bits := uop_inp
+ io.inp.rd(0).idx.valid := state === sReadTensor
+ io.inp.rd(0).idx.bits := uop_inp
io.inp.tieoffWrite() // read-only
// wgt
- io.wgt.rd.idx.valid := state === sReadTensor
- io.wgt.rd.idx.bits := uop_wgt
+ io.wgt.rd(0).idx.valid := state === sReadTensor
+ io.wgt.rd(0).idx.bits := uop_wgt
io.wgt.tieoffWrite() // read-only
// acc_i
- io.acc.rd.idx.valid := state === sReadTensor
- io.acc.rd.idx.bits := uop_acc
+ io.acc.rd(0).idx.valid := state === sReadTensor
+ io.acc.rd(0).idx.bits := uop_acc
// mvc
mvc.io.reset := dec.reset & state === sExe
- mvc.io.inp.data <> io.inp.rd.data
- mvc.io.wgt.data <> io.wgt.rd.data
- mvc.io.acc_i.data <> io.acc.rd.data
+ mvc.io.inp.data <> io.inp.rd(0).data
+ mvc.io.wgt.data <> io.wgt.rd(0).data
+ mvc.io.acc_i.data <> io.acc.rd(0).data
// acc_o
- io.acc.wr.valid := mvc.io.acc_o.data.valid &
+ io.acc.wr(0).valid := mvc.io.acc_o.data.valid &
Mux(dec.reset, true.B, wrpipe.io.deq.valid)
- io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
- io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
+ io.acc.wr(0).bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
+ io.acc.wr(0).bits.data <> mvc.io.acc_o.data.bits
// out
- io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid
- io.out.wr.bits.idx := wrpipe.io.deq.bits
- io.out.wr.bits.data <> mvc.io.out.data.bits
+ io.out.wr(0).valid := mvc.io.out.data.valid & wrpipe.io.deq.valid
+ io.out.wr(0).bits.idx := wrpipe.io.deq.bits
+ io.out.wr(0).bits.data <> mvc.io.out.data.bits
io.out.tieoffRead() // write-only
io.done := done
if (debug) {
+ printf("[TensorGemm] [state]:%d [inflight]:%d\n", state, inflight)
+
when(state === sReadUop && ~dec.reset) {
printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
}
@@ -369,24 +379,24 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
}
- io.inp.rd.data.bits.zipWithIndex.foreach {
+ io.inp.rd(0).data.bits.zipWithIndex.foreach {
case (r, i) =>
- when(io.inp.rd.data.valid && ~dec.reset) {
+ when(io.inp.rd(0).data.valid && ~dec.reset) {
printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
}
}
- io.wgt.rd.data.bits.zipWithIndex.foreach {
+ io.wgt.rd(0).data.bits.zipWithIndex.foreach {
case (r, i) =>
- when(io.wgt.rd.data.valid && ~dec.reset) {
+ when(io.wgt.rd(0).data.valid && ~dec.reset) {
printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
}
}
- io.acc.rd.data.bits.foreach { tensor =>
+ io.acc.rd(0).data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach {
case (elem, i) =>
- when(io.acc.rd.data.valid && ~dec.reset) {
+ when(io.acc.rd(0).data.valid && ~dec.reset) {
printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
}
}
diff --git a/hardware/chisel/src/main/scala/core/TensorLoad.scala b/hardware/chisel/src/main/scala/core/TensorLoad.scala
index 5ab690d..8b31253 100644
--- a/hardware/chisel/src/main/scala/core/TensorLoad.scala
+++ b/hardware/chisel/src/main/scala/core/TensorLoad.scala
@@ -23,7 +23,6 @@ import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
-
/** TensorLoad.
*
* Load 1D and 2D tensors from main memory (DRAM) to input/weight
@@ -45,6 +44,10 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
val vme_rd = new VMEReadMaster
val tensor = new TensorClient(tensorType)
})
+
+ require(tp.numMemBlock > 0, s"-F- Unexpected data to tensor bit size ratio. ${tensorType} ${tp.numMemBlock}")
+ require(tp.splitWidth == 1 && tp.splitLength == 1, s"-F- Cannot do split direct access")
+
val sizeFactor = tp.tensorLength * tp.numMemBlock
val strideFactor = tp.tensorLength * tp.tensorWidth
@@ -82,12 +85,14 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
when(dec.xpad_0 =/= 0.U) {
state := sXPad0
}.otherwise {
+ assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
state := sReadCmd
}
}
}
is(sXPad0) {
when(xPadCtrl0.io.done) {
+ assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
state := sReadCmd
}
}
@@ -112,6 +117,7 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
}.elsewhen(dec.xpad_0 =/= 0.U) {
state := sXPad0
}.otherwise {
+ assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
state := sReadCmd
}
}.elsewhen(dataCtrl.io.split) {
@@ -131,6 +137,7 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
when(dec.xpad_0 =/= 0.U) {
state := sXPad0
}.otherwise {
+ assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
state := sReadCmd
}
}
@@ -191,13 +198,16 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
state === sXPad1 |
state === sYPad1
- when(state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
+ when(state === sReadCmd && tag =/= (tp.numMemBlock - 1).U) { // split read inside row of mem blocks
+ tag := tag
+ }.elsewhen(state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
tag := 0.U
}.elsewhen(io.vme_rd.data.fire() || isZeroPad) {
tag := tag + 1.U
}
- when(state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
+ when(state === sIdle || (dataCtrlDone && ~isZeroPad) ||
+ (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
set := 0.U
}.elsewhen((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
set := set + 1.U
@@ -221,6 +231,7 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
val tensorFile = Seq.fill(tp.tensorLength) {
SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
}
+
val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
val wdata = Seq.fill(tp.tensorLength) {
Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
@@ -235,12 +246,12 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
wmask(i)(j) := tag === j.U
wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
}
- val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
+ val tdata = io.tensor.wr(0).bits.data(i).asUInt.asTypeOf(wdata(i))
val muxWen =
Mux(state === sIdle,
- io.tensor.wr.valid,
+ io.tensor.wr(0).valid,
(io.vme_rd.data.fire() | isZeroPad) & set === i.U)
- val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
+ val muxWaddr = Mux(state === sIdle, io.tensor.wr(0).bits.idx, waddr_cur)
val muxWdata = Mux(state === sIdle, tdata, wdata(i))
val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
when(muxWen) {
@@ -249,14 +260,14 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
}
// read-from-sram
- val rvalid = RegNext(io.tensor.rd.idx.valid)
- io.tensor.rd.data.valid := rvalid
+ val rvalid = RegNext(io.tensor.rd(0).idx.valid)
+ io.tensor.rd(0).data.valid := rvalid
val rdata =
- tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
+ tensorFile.map(_.read(io.tensor.rd(0).idx.bits, io.tensor.rd(0).idx.valid))
rdata.zipWithIndex.foreach {
case (r, i) =>
- io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
+ io.tensor.rd(0).data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd(0).data.bits(i))
}
// done
@@ -291,11 +302,42 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
dataCtrl.io.addr,
dataCtrl.io.len)
}
+ when(state === sYPad0) {
+ printf("[TensorLoad] [wgt] sYPad0\n")
+ }
+ when(state === sYPad1) {
+ printf("[TensorLoad] [wgt] sYPad1\n")
+ }
+ when(state === sXPad0) {
+ printf("[TensorLoad] [wgt] sXPad0\n")
+ }
+ when(state === sXPad1) {
+ printf("[TensorLoad] [wgt] sXPad1\n")
+ }
} else if (tensorType == "acc") {
when(io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [acc] cmd addr:%x len:%x\n",
dataCtrl.io.addr,
dataCtrl.io.len)
+ printf("[TensorLoad] [acc info] dec.xsize: %d, dec.ysize: %d, dec.xstride: %d\n",
+ dec.xsize, dec.ysize, dec.xstride)
+ printf("[TensorLoad] [acc i2fo] dec.xpad_1: %d dec.xpad_0: %d dec.ypad_1: %d dec.ypad_0: %d\n",
+ dec.xpad_1, dec.xpad_0, dec.ypad_1, dec.ypad_0)
+
+ printf("tp.tensorLength: %d, tp.numMemBlock: %d, tp.tensorLength: %d, tp.tensorWidth: %d\n",
+ tp.tensorLength.U, tp.numMemBlock.U, tp.tensorLength.U, tp.tensorWidth.U)
+ }
+ when(state === sYPad0) {
+ printf("[TensorLoad] [acc] sYPad0\n")
+ }
+ when(state === sYPad1) {
+ printf("[TensorLoad] [acc] sYPad1\n")
+ }
+ when(state === sXPad0) {
+ printf("[TensorLoad] [acc] sXPad0\n")
+ }
+ when(state === sXPad1) {
+ printf("[TensorLoad] [acc] sXPad1\n")
}
}
}
diff --git a/hardware/chisel/src/main/scala/core/TensorStore.scala b/hardware/chisel/src/main/scala/core/TensorStore.scala
index 9b4bf74..f1556ef 100644
--- a/hardware/chisel/src/main/scala/core/TensorStore.scala
+++ b/hardware/chisel/src/main/scala/core/TensorStore.scala
@@ -47,6 +47,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
val memBlockBits = tp.memBlockBits
val memDepth = tp.memDepth
val numMemBlock = tp.numMemBlock
+ require(numMemBlock > 0, s"-F- TensorStore doesnt support pulse width" +
+ s"wider than tensor width. Needed for stride support tensorWidth=${tensorWidth}")
+ require(tp.splitWidth == 1 && tp.splitLength == 1, s"-F- ${tensorType} Cannot do split direct access")
val dec = io.inst.asTypeOf(new MemDecode)
val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
@@ -54,7 +57,7 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xrem = Reg(chiselTypeOf(dec.xsize))
- val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock)) - 1.U
+ val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock))
val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val ycnt = Reg(chiselTypeOf(dec.ysize))
@@ -89,10 +92,13 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
when (io.start) {
state := sWriteCmd
when (xsize < xfer_init_pulses) {
- xlen := xsize
+ assert(xsize > 0.U, "Idle => WriteCmd, init, without xrem: must have positive xsize")
+ xlen := xsize - 1.U
xrem := 0.U
}.otherwise {
xlen := xfer_init_pulses - 1.U
+ assert(xsize >= xfer_init_pulses,
+ "Idle => WriteCmd, init, with xrem: must have xsize no smaller than xfer_init_pulses")
xrem := xsize - xfer_init_pulses
}
}
@@ -123,10 +129,13 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
state := sWriteCmd
xfer_bytes := xfer_stride_bytes
when(xsize < xfer_stride_pulses) {
- xlen := xsize
+ assert(xsize > 0.U, "WriteAck => WriteCmd, stride, without xrem: must have positive xsize")
+ xlen := xsize - 1.U
xrem := 0.U
}.otherwise {
xlen := xfer_stride_pulses - 1.U
+ assert(xsize >= xfer_stride_pulses,
+ "WriteAck => WriteCmd, stride, with xrem: must have xsize no smaller than xfer_stride_pulses")
xrem := xsize - xfer_stride_pulses
}
}
@@ -134,13 +143,16 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
.elsewhen(xrem < xfer_split_pulses) {
state := sWriteCmd
xfer_bytes := xfer_split_bytes
- xlen := xrem
+ assert(xrem > 0.U, "WriteAck => WriteCmd, split, without xrem: must have positive xrem")
+ xlen := xrem - 1.U
xrem := 0.U
}
.otherwise {
state := sWriteCmd
xfer_bytes := xfer_split_bytes
xlen := xfer_split_pulses - 1.U
+ assert(xrem >= xfer_split_pulses,
+ "WriteAck => WriteCmd, split, with xrem: must have xrem no smaller than xfer_split_pulses")
xrem := xrem - xfer_split_pulses
}
}
@@ -160,9 +172,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
}
for (i <- 0 until tensorLength) {
- val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
- when(io.tensor.wr.valid) {
- tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
+ val inWrData = io.tensor.wr(0).bits.data(i).asUInt.asTypeOf(wdata_t)
+ when(io.tensor.wr(0).valid) {
+ tensorFile(i).write(io.tensor.wr(0).bits.idx, inWrData, no_mask)
}
}
@@ -186,7 +198,7 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
}
when(
- state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
+ state === sWriteCmd || (state =/= sReadMem && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
set := 0.U
}.elsewhen(io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
set := set + 1.U
diff --git a/hardware/chisel/src/main/scala/core/TensorUtil.scala b/hardware/chisel/src/main/scala/core/TensorUtil.scala
index d0a8ba7..dfdecf4 100644
--- a/hardware/chisel/src/main/scala/core/TensorUtil.scala
+++ b/hardware/chisel/src/main/scala/core/TensorUtil.scala
@@ -35,7 +35,7 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
require(tensorType == "inp" || tensorType == "wgt"
- || tensorType == "acc" || tensorType == "out",
+ || tensorType == "acc" || tensorType == "out" || tensorType == "fetch" || tensorType == "uop",
errorMsg)
val (tensorLength, tensorWidth, tensorElemBits) =
@@ -45,6 +45,15 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
(p(CoreKey).blockOut, p(CoreKey).blockIn, p(CoreKey).wgtBits)
else if (tensorType == "acc")
(p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).accBits)
+ else if (tensorType == "fetch") {
+ // make fetch a 64 bit data to be able to read
+ // 64 bit aligned address. It works for wide cacheline
+ require(p(ShellKey).memParams.dataBits >= INST_BITS,
+ "-F- Cannot make fetch tensor narrower than data pulse. TODO: narrow fetch with tensors")
+ (1, 1, 64)
+ }
+ else if (tensorType == "uop")
+ (1, 1, p(CoreKey).uopBits)
else
(p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).outBits)
@@ -58,10 +67,118 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
p(CoreKey).wgtMemDepth
else if (tensorType == "acc")
p(CoreKey).accMemDepth
+ else if (tensorType == "fetch") {
+ require(p(ShellKey).memParams.dataBits >= INST_BITS,
+ "-F- Cannot make fetch tensor narrower than data pulse. TODO: narrow fetch with tensors")
+ // still should be one data line
+ (1 << p(ShellKey).memParams.lenBits)*(INST_BITS / 64)
+ }
+ else if (tensorType == "uop") {
+ p(CoreKey).uopMemDepth
+ }
else
p(CoreKey).outMemDepth
+ // acc/wgt parts are grouped to form
+ // a physically compact compute entity
+
+ val (splitLength, splitWidth) =
+ if (tensorType == "inp") {
+ (1, 1)
+ } else if (tensorType == "wgt") {
+ (p(CoreKey).blockOutFactor, 1)
+ } else if (tensorType == "acc") {
+ // acc scratchpad is batch rows of blockout columns
+ // GEMM/ALU operation group is based on wgt tiling of blockout
+ // means acc out of a group if batch > 1 is not
+ // continous data and may be placed into different memory
+ // modules. But the whole idea of a group to localize
+ // piece of wgt to piece of acc data transformation
+ //
+ (1, p(CoreKey).blockOutFactor)
+ } else if (tensorType == "fetch") {
+ (1, 1)
+ } else if (tensorType == "uop") {
+ (1, 1)
+ } else if (tensorType == "out") {
+ (1, 1) // narrow store doesnt support split
+ } else {
+ (1, 1)
+ }
+ require (splitLength == 1 || splitWidth == 1, "-F- Can split only one dimension.")
+
+ // provide index of a group closes to IO
+ // expect 2 columns of groups io on top and indexing from bottom
+ val closestIOGrpIdx =
+ if (tensorType == "inp") {
+ splitLength - 1
+ } else if (tensorType == "wgt") {
+ if (splitLength < 2) 0 else splitLength / 2 - 1
+ } else if (tensorType == "acc") {
+ if (splitWidth < 2) 0 else splitWidth / 2 - 1
+ } else if (tensorType == "fetch") {
+ 0
+ } else if (tensorType == "uop") {
+ 0
+ } else if (tensorType == "out") {
+ 0
+ } else {
+ 0
+ }
+
val memAddrBits = log2Ceil(memDepth)
+
+ val tensorSizeBits = tensorLength * tensorWidth * tensorElemBits
+ val tsSizeRatio = tensorSizeBits / memBlockBits
+ val clSizeRatio = memBlockBits / tensorSizeBits
+
+ val lenSplit = tensorLength / splitLength // tensor rows in a group
+ val widthSplit = tensorWidth / splitWidth // tensor colums in a group
+ require(lenSplit > 0 && widthSplit > 0, "-F- wrong split")
+
+ // tensor condsiders groups as a continous data, gemm generates a data window
+ // Map data index from a window index to a continous groups index
+ def reindexDataFromGroup (grpIdx : Int, lenIdx: Int, wdtIdx: Int) = {
+
+ val grpLen = lenSplit // tensor rows in a group
+ val grpWdt = widthSplit // tensor colums in a group
+ val srcGrpRow = grpIdx / splitWidth // group row
+ val srcGrpCol = grpIdx % splitWidth // group column
+ val tnzRow = srcGrpRow * grpLen
+ val tnzCol = srcGrpCol * grpWdt
+ val flatIdx = (tnzRow + lenIdx) * tensorWidth + tnzCol + wdtIdx
+
+ val outGroupIdx = flatIdx / (grpLen * grpWdt)
+ val outGroupOffset = flatIdx % (grpLen * grpWdt)
+ val outGroupLenIdx = outGroupOffset / grpWdt
+ val outGroupWdthIdx = outGroupOffset % grpWdt
+ (outGroupIdx, outGroupLenIdx, outGroupWdthIdx)
+ }
+ // map data index form a continous to a window index
+ def reindexDataToGroup (grpIdx : Int, lenIdx: Int, wdtIdx: Int) = {
+ val outGrpLen = tensorLength / splitLength // data rows in a group
+ val outGrpWdt = tensorWidth / splitWidth // data colums in a group
+
+ val outIdx = grpIdx * outGrpLen * outGrpWdt + lenIdx * outGrpWdt + wdtIdx
+ val outRow = outIdx / tensorWidth
+ val outCol = outIdx % tensorWidth
+
+
+ val outGrpRow = outRow / outGrpLen
+ val outLenIdx = outRow % outGrpLen
+
+ val outGrpCol = outCol / outGrpWdt
+ val outColIdx = outCol % outGrpWdt
+
+ val outGrpIdx = outGrpRow * splitWidth + outGrpCol
+
+ (outGrpIdx, outLenIdx, outColIdx)
+
+ }
+ def paramsStr () = {
+ s" ${tensorType} ${tensorSizeBits*memDepth/8} Byte. length:${tensorLength} width:${tensorWidth}" +
+ s" data bits:${tensorElemBits} mem depth:${memDepth} groups split length:${splitLength}"
+ }
}
/** TensorMaster.
@@ -73,25 +190,29 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
*/
class TensorMaster(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
- val rd = new Bundle {
+ val rd = Vec(splitLength * splitWidth, new Bundle {
val idx = ValidIO(UInt(memAddrBits.W))
val data = Flipped(
- ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
- }
- val wr = ValidIO(new Bundle {
- val idx = UInt(memAddrBits.W)
- val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+ ValidIO(Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W)))))
})
+ val wr = Vec(splitLength * splitWidth, ValidIO(new Bundle {
+ val idx = UInt(memAddrBits.W)
+ val data = Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W)))
+ }))
def tieoffRead() {
- rd.idx.valid := false.B
- rd.idx.bits := 0.U
+ for (idx <- 0 until splitLength * splitWidth) {
+ rd(idx).idx.valid := false.B
+ rd(idx).idx.bits := 0.U
+ }
}
def tieoffWrite() {
- wr.valid := false.B
- wr.bits.idx := 0.U
- wr.bits.data.foreach { b =>
- b.foreach { c =>
- c := 0.U
+ for (idx <- 0 until splitLength * splitWidth) {
+ wr(idx).valid := false.B
+ wr(idx).bits.idx := 0.U
+ wr(idx).bits.data.foreach { b =>
+ b.foreach { c =>
+ c := 0.U
+ }
}
}
}
@@ -107,20 +228,33 @@ class TensorMaster(tensorType: String = "none")
*/
class TensorClient(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
- val rd = new Bundle {
+ val rd = Vec(splitLength * splitWidth, new Bundle {
val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
val data = ValidIO(
- Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
- }
- val wr = Flipped(ValidIO(new Bundle {
+ Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W))))
+ })
+ val wr = Vec(splitLength * splitWidth, Flipped(ValidIO(new Bundle {
val idx = UInt(memAddrBits.W)
- val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
- }))
+ val data = Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W)))
+ })))
def tieoffRead() {
- rd.data.valid := false.B
- rd.data.bits.foreach { b =>
- b.foreach { c =>
- c := 0.U
+ for (idx <- 0 until splitLength * splitWidth) {
+ rd(idx).data.valid := false.B
+ rd(idx).data.bits.foreach { b =>
+ b.foreach { c =>
+ c := 0.U
+ }
+ }
+ }
+ }
+ def tieoffWrite() {
+ for (idx <- 0 until splitLength * splitWidth) {
+ wr(idx).valid := false.B
+ wr(idx).bits.idx := 0.U
+ wr(idx).bits.data.foreach { b =>
+ b.foreach { c =>
+ c := 0.U
+ }
}
}
}
@@ -137,7 +271,7 @@ class TensorClient(tensorType: String = "none")
class TensorMasterData(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val data = Flipped(
- ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+ ValidIO(Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W)))))
override def cloneType =
new TensorMasterData(tensorType).asInstanceOf[this.type]
}
@@ -151,7 +285,7 @@ class TensorMasterData(tensorType: String = "none")
class TensorClientData(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val data = ValidIO(
- Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+ Vec(lenSplit, Vec(widthSplit, UInt(tensorElemBits.W))))
override def cloneType =
new TensorClientData(tensorType).asInstanceOf[this.type]
}
diff --git a/hardware/chisel/src/test/scala/unittest/AluTest.scala b/hardware/chisel/src/test/scala/unittest/AluTest.scala
index a4274c2..c874b01 100644
--- a/hardware/chisel/src/test/scala/unittest/AluTest.scala
+++ b/hardware/chisel/src/test/scala/unittest/AluTest.scala
@@ -19,61 +19,66 @@
package unittest
-import chisel3._
import chisel3.util._
-import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}
+import chisel3.iotesters.PeekPokeTester
import scala.util.Random
import unittest.util._
import vta.core._
+import vta.util.config._
-class TestAluVector(c: AluVector) extends PeekPokeTester(c) {
-
+object Alu_ref {
/* alu_ref
*
* This is a software function used as a reference for the hardware
*/
- def aluRef(opcode: Int, a: Array[Int], b: Array[Int], width: Int) : Array[Int] = {
+ def alu(opcode: Int, a: Array[Int], b: Array[Int], width: Int) : Array[Int] = {
val size = a.length
val mask = Helper.getMask(log2Ceil(width))
val res = Array.fill(size) {0}
- if (opcode == 1) {
+ if (opcode == 0) {
+ for (i <- 0 until size) { // min
+ res(i) = if (a(i) < b(i)) a(i) else b(i)
+ }
+ } else if (opcode == 1) { // max
for (i <- 0 until size) {
res(i) = if (a(i) < b(i)) b(i) else a(i)
}
- } else if (opcode == 2) {
+ } else if (opcode == 2) { // add
for (i <- 0 until size) {
res(i) = a(i) + b(i)
}
- } else if (opcode == 3) {
+ } else if (opcode == 3) { // right shift
for (i <- 0 until size) {
res(i) = a(i) >> (b(i) & mask).toInt
}
- } else if (opcode == 4) {
+ } else if (opcode == 4) { // left shift
// HLS shift left by >> negative number
// b always < 0 when opcode == 4
for (i <- 0 until size) {
res(i) = a(i) << ((-1*b(i)) & mask)
}
- } else {
- // default
+ } else { // default
for (i <- 0 until size) {
- res(i) = if (a(i) < b(i)) a(i) else b(i)
+ res(i) = 0
}
}
res
}
+}
+
+class AluVectorTester(c: AluVector, seed: Int = 47) extends PeekPokeTester(c) {
+ val r = new Random(seed)
val num_ops = ALU_OP_NUM
- for (i <- 0 until num_ops) {
+ for (op <- 0 until num_ops) {
// generate data based on bits
val bits = c.io.acc_a.tensorElemBits
- val dataGen = new RandomArray(c.blockOut, bits)
- val op = i
+ val dataGen = new RandomArray(c.blockOut, bits, r)
val in_a = dataGen.any
val in_b = if (op != 4) dataGen.any else dataGen.negative
val mask = Helper.getMask(bits)
- val res = aluRef(op, in_a, in_b, bits)
+ val res = Alu_ref.alu(op, in_a, in_b, bits)
for (i <- 0 until c.blockOut) {
poke(c.io.acc_a.data.bits(0)(i), in_a(i) & mask)
@@ -83,13 +88,11 @@ class TestAluVector(c: AluVector) extends PeekPokeTester(c) {
poke(c.io.acc_a.data.valid, 1)
poke(c.io.acc_b.data.valid, 1)
- poke(c.io.acc_y.data.valid, 1)
step(1)
poke(c.io.acc_a.data.valid, 0)
poke(c.io.acc_b.data.valid, 0)
- poke(c.io.acc_y.data.valid, 0)
// wait for valid signal
while (peek(c.io.acc_y.data.valid) == BigInt(0)) {
@@ -102,3 +105,6 @@ class TestAluVector(c: AluVector) extends PeekPokeTester(c) {
}
}
}
+
+class AluTest extends GenericTest("AluTest", (p:Parameters) =>
+ new AluVector()(p), (c:AluVector) => new AluVectorTester(c, 48))
diff --git a/hardware/chisel/src/main/scala/core/Configs.scala b/hardware/chisel/src/test/scala/unittest/Generic.scala
old mode 100644
new mode 100755
similarity index 54%
copy from hardware/chisel/src/main/scala/core/Configs.scala
copy to hardware/chisel/src/test/scala/unittest/Generic.scala
index 4ab7d85..3dc0b34
--- a/hardware/chisel/src/main/scala/core/Configs.scala
+++ b/hardware/chisel/src/test/scala/unittest/Generic.scala
@@ -17,32 +17,30 @@
* under the License.
*/
-package vta.core
+package unittest
+import chisel3._
+import chisel3.util._
import vta.util.config._
+import chisel3.iotesters._
+import vta.{DefaultPynqConfig}
-/** CoreConfig.
- *
- * This is one supported configuration for VTA. This file will
- * be eventually filled out with class configurations that can be
- * mixed/matched with Shell configurations for different backends.
- */
-class CoreConfig extends Config((site, here, up) => {
- case CoreKey =>
- CoreParams(
- batch = 1,
- blockOut = 16,
- blockIn = 16,
- inpBits = 8,
- wgtBits = 8,
- uopBits = 32,
- accBits = 32,
- outBits = 8,
- uopMemDepth = 2048,
- inpMemDepth = 2048,
- wgtMemDepth = 1024,
- accMemDepth = 2048,
- outMemDepth = 2048,
- instQueueEntries = 512
+import org.scalatest.{Matchers, FlatSpec}
+
+class GenericTest[T <: Module, P <: PeekPokeTester[T], C <: Parameters]
+ (tag : String, dutFactory : (Parameters) => T, testerFactory : (T) => P) extends FlatSpec with Matchers {
+
+ implicit val p: Parameters = new DefaultPynqConfig
+
+ val arguments = Array(
+ "--backend-name", "treadle",
+ // "--backend-name", "vcs",
+ // "--is-verbose",
+ "--test-seed", "0"
)
-})
+
+ behavior of tag
+ it should "not have expect violations" in {
+ chisel3.iotesters.Driver.execute(arguments, ()=> dutFactory(p))(testerFactory) should be (true)
+ }
+}
diff --git a/hardware/chisel/src/test/scala/unittest/Launcher.scala b/hardware/chisel/src/test/scala/unittest/Launcher.scala
index 2d10c52..1b0d6da 100644
--- a/hardware/chisel/src/test/scala/unittest/Launcher.scala
+++ b/hardware/chisel/src/test/scala/unittest/Launcher.scala
@@ -49,7 +49,7 @@ object Launcher {
},
"alu" -> { (manager: TesterOptionsManager) =>
Driver.execute(() => new AluVector, manager) {
- (c) => new TestAluVector(c)
+ (c) => new AluVectorTester(c)
}
}
)
diff --git a/hardware/chisel/src/test/scala/unittest/TensorAluTest.scala b/hardware/chisel/src/test/scala/unittest/TensorAluTest.scala
new file mode 100755
index 0000000..21da5f0
--- /dev/null
+++ b/hardware/chisel/src/test/scala/unittest/TensorAluTest.scala
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package unittest
+
+import chisel3._
+import chisel3.util._
+import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}
+import scala.util.Random
+import unittest.util._
+import vta.core._
+import vta.util.config._
+
+class TensorAluIndexGeneratorTester(c: TensorAluIndexGenerator, alu_use_imm : Int = 0) extends PeekPokeTester(c) {
+
+
+ val uop_begin = 0
+ val uop_end = 2
+ assert(uop_begin < uop_end)
+
+ val lp_0 = 2
+ val lp_1 = 3
+ val dst_0 = 1*lp_1
+ val src_0 = 2*lp_1
+ val dst_1 = 1
+ val src_1 = 2
+
+ poke(c.io.dec.reset, 0)
+ poke(c.io.dec.alu_use_imm, alu_use_imm)
+ poke(c.io.dec.uop_begin, uop_begin)
+ poke(c.io.dec.uop_end, uop_end)
+ poke(c.io.dec.lp_0, lp_0)
+ poke(c.io.dec.lp_1, lp_1)
+ poke(c.io.dec.dst_0, dst_0)
+ poke(c.io.dec.dst_1, dst_1)
+ poke(c.io.dec.src_0, src_0)
+ poke(c.io.dec.src_1, src_1)
+ // Don't need empty_0,{push,pop}_{next,prev},op
+
+
+ class Mocks {
+ val uop_indices = new scala.collection.mutable.Queue[BigInt]
+ val dst_indices = new scala.collection.mutable.Queue[BigInt]
+ val src_indices = new scala.collection.mutable.Queue[BigInt]
+
+ def logical_step() {
+ step(1)
+ if (peek(c.io.valid) == 1) {
+ expect(c.io.uop_idx, uop_indices.dequeue())
+ expect(c.io.dst_idx, dst_indices.dequeue())
+ }
+ if (peek(c.io.src_valid) == 1) {
+ expect(c.io.src_idx, src_indices.dequeue())
+ }
+ }
+
+ def test_if_done() {
+ println(s"uop_indices remaining: ${uop_indices.size}")
+ println(s"dst_indices remaining: ${dst_indices.size}")
+ println(s"src_indices remaining: ${src_indices.size}")
+ assert(uop_indices.isEmpty)
+ assert(dst_indices.isEmpty)
+ assert(src_indices.isEmpty)
+ }
+ }
+
+ val mocks = new Mocks
+ for {
+ cnt_o <- 0 until lp_0
+ cnt_i <- 0 until lp_1
+ uop_idx <- uop_begin until uop_end
+ } {
+ mocks.uop_indices.enqueue(uop_idx)
+ mocks.dst_indices.enqueue(dst_0*cnt_o + dst_1*cnt_i)
+ if (alu_use_imm == 0) {
+ mocks.src_indices.enqueue(src_0*cnt_o + src_1*cnt_i)
+ }
+ }
+
+ poke(c.io.start, 1)
+ mocks.logical_step()
+ poke(c.io.start, 0)
+
+ val end = (uop_end-uop_begin)*lp_0*lp_1
+ var count = 0
+ while(peek(c.io.last) == 0 && count < 10*end + 100) {
+ mocks.logical_step()
+ count += 1
+ }
+ mocks.test_if_done()
+ step(1)
+}
+
+class TensorAluIndexGenerator_0_Test extends GenericTest("TensorAluIndexGenerator_0", (p:Parameters) =>
+ new TensorAluIndexGenerator()(p), (c:TensorAluIndexGenerator) => new TensorAluIndexGeneratorTester(c, 0))
+
+class TensorAluIndexGenerator_1_Test extends GenericTest("TensorAluIndexGenerator_1", (p:Parameters) =>
+ new TensorAluIndexGenerator()(p), (c:TensorAluIndexGenerator) => new TensorAluIndexGeneratorTester(c, 1))
+
+class TensorAluPipelinedTester(c: TensorAlu) extends PeekPokeTester(c) {
+ poke(c.io.start, 0)
+
+ val uop_begin = 0
+ val uop_end = 1
+ assert(uop_begin < uop_end)
+ val alu_use_imm = 1
+ val lp_0 = 2
+ val lp_1 = 3
+ val dst_0 = 1*lp_1
+ val src_0 = 2*lp_1
+ val dst_1 = 1
+ val src_1 = 2
+
+ val dst_offset = BigInt("000", 16)
+ val src_offset = BigInt("100", 16)
+
+ val u0 = dst_offset
+ val u1 = src_offset
+ val u2 = 0 // if src_offset is big, some bits go here
+
+ poke(c.io.dec.reset, 0)
+ poke(c.io.dec.alu_op, 2) // ADD or ADDI 1
+ poke(c.io.dec.alu_imm, 1)
+ poke(c.io.dec.alu_use_imm, alu_use_imm)
+ poke(c.io.dec.uop_begin, uop_begin)
+ poke(c.io.dec.uop_end, uop_end)
+ poke(c.io.dec.lp_0, lp_0)
+ poke(c.io.dec.lp_1, lp_1)
+ poke(c.io.dec.dst_0, dst_0)
+ poke(c.io.dec.dst_1, dst_1)
+ poke(c.io.dec.src_0, src_0)
+ poke(c.io.dec.src_1, src_1)
+
+ // Don't need empty_0,{push,pop}_{next,prev},op
+
+ poke(c.io.uop.data.bits.u0, u0)
+ poke(c.io.uop.data.bits.u1, u1)
+ poke(c.io.uop.data.bits.u2, u2)
+
+ require(c.io.acc.splitWidth == 1, "-F- Test doesnt support acc data access split")
+ require(c.io.acc.splitLength == 1, "-F- Test doesnt support acc data access split")
+
+ val acc = IndexedSeq.tabulate(c.io.acc.rd(0).data.bits(0).size){ i => BigInt(i) }
+ for { lhs <- c.io.acc.rd(0).data.bits} {
+ poke(lhs, acc.reverse)
+ }
+
+ class TensorMasterMock(tm: TensorMaster) {
+ poke(tm.rd(0).data.valid, 0)
+ var valid = peek(tm.rd(0).idx.valid)
+ def logical_step(v: Option[BigInt]) {
+ poke(tm.rd(0).data.valid, valid)
+ valid = peek(tm.rd(0).idx.valid)
+ for { x <- v} expect(tm.rd(0).idx.valid, x)
+ }
+ }
+
+ class UopMasterMock(um: UopMaster) {
+ poke(um.data.valid, 0)
+ var valid = peek(um.idx.valid)
+ def logical_step(v: Option[BigInt]) {
+ poke(um.data.valid, valid)
+ valid = peek(um.idx.valid)
+ for { x <- v} expect(um.idx.valid, x)
+ }
+ }
+
+ class Mocks {
+ val uop_mock = new UopMasterMock(c.io.uop)
+ val acc_mock = new TensorMasterMock(c.io.acc)
+
+ val uop_indices = new scala.collection.mutable.Queue[BigInt]
+ val acc_indices = new scala.collection.mutable.Queue[BigInt]
+ val accout_indices = new scala.collection.mutable.Queue[BigInt]
+ val out_indices = new scala.collection.mutable.Queue[BigInt]
+
+ def logical_step() {
+ step(1)
+ uop_mock.logical_step(None)
+ acc_mock.logical_step(None)
+ if (peek(c.io.uop.idx.valid) == 1) {
+ expect(c.io.uop.idx.bits, uop_indices.dequeue())
+ }
+ if (peek(c.io.acc.rd(0).idx.valid) == 1) {
+ expect(c.io.acc.rd(0).idx.bits, acc_indices.dequeue())
+ }
+ if (peek(c.io.acc.wr(0).valid) == 1) {
+ expect(c.io.acc.wr(0).bits.idx, accout_indices.dequeue())
+ }
+ if (peek(c.io.out.wr(0).valid) == 1) {
+ expect(c.io.out.wr(0).bits.idx, out_indices.dequeue())
+ }
+ }
+
+ def test_if_done() {
+ println(s"uop_indices remaining: ${uop_indices.size}")
+ println(s"acc_indices remaining: ${acc_indices.size}")
+ println(s"accout_indices remaining: ${accout_indices.size}")
+ println(s"out_indices remaining: ${out_indices.size}")
+ assert(uop_indices.isEmpty)
+ assert(acc_indices.isEmpty)
+ assert(accout_indices.isEmpty)
+ assert(out_indices.isEmpty)
+ }
+ }
+
+ val mocks = new Mocks
+ for {
+ cnt_o <- 0 until lp_0
+ cnt_i <- 0 until lp_1
+ uop_idx <- uop_begin until uop_end
+ } {
+ mocks.uop_indices.enqueue(uop_idx)
+ mocks.acc_indices.enqueue(src_offset + src_0*cnt_o + src_1*cnt_i)
+ mocks.accout_indices.enqueue(dst_offset + dst_0*cnt_o + dst_1*cnt_i)
+ mocks.out_indices.enqueue(dst_offset + dst_0*cnt_o + dst_1*cnt_i)
+ }
+
+ poke(c.io.start, 0)
+ step(1)
+ poke(c.io.start, 1)
+
+ var count = 0
+ val end = (uop_end-uop_begin)*lp_0*lp_1
+
+ while (peek(c.io.done) == 0 && count < 10*end + 100) {
+ mocks.logical_step()
+ poke(c.io.start, 0)
+ count += 1
+ }
+ expect(c.io.done, 1)
+ mocks.test_if_done()
+}
+
+class TensorAluPipelinedTest extends GenericTest("TensorAluPipelined", (p:Parameters) =>
+ new TensorAlu()(p), (c:TensorAlu) => new TensorAluPipelinedTester(c))
diff --git a/hardware/chisel/src/test/scala/unittest/utils/RandomArray.scala b/hardware/chisel/src/test/scala/unittest/utils/RandomArray.scala
index 2852d4e..55f2f52 100644
--- a/hardware/chisel/src/test/scala/unittest/utils/RandomArray.scala
+++ b/hardware/chisel/src/test/scala/unittest/utils/RandomArray.scala
@@ -22,10 +22,13 @@ package unittest.util
import scala.util.Random
import scala.math.pow
-class RandomArray(val len: Int, val bits: Int) {
- val r = new Random
+class RandomArray(val len: Int, val bits: Int, val r: Random) {
if (bits < 1) throw new IllegalArgumentException ("bits should be greater than 1")
+ def this(len: Int, bits: Int) {
+ this(len, bits, new Random)
+ }
+
def any : Array[Int] = {
Array.fill(len) { r.nextInt(pow(2, bits).toInt) - pow(2, bits-1).toInt }
}