You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/09/09 01:10:21 UTC
[tvm-vta] branch main updated: VTA Chisel Wide memory interface.
(#32)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-vta.git
The following commit(s) were added to refs/heads/main by this push:
new 36a9157 VTA Chisel Wide memory interface. (#32)
36a9157 is described below
commit 36a91576edf633479c78649e050f18dd2ddc8103
Author: Anton Sorokin <an...@intel.com>
AuthorDate: Wed Sep 8 18:10:14 2021 -0700
VTA Chisel Wide memory interface. (#32)
* VTA Chisel Wide memory interface.
* Added SyncQueue with tests - Implementation uses sync memory to implement larger queues.
* AXI 64/128/256/512 data bits support by AXIParams->dataBits
A wide implementation of load/store is used when AXI interface data width
is larger than number of bits in a tesor.
Instructions are stored as 64bit tensors to allow 64bit address alignment
* TensorLoad is modified to replace all VME load operations.
Multiple simultaneous requests can be generated. Load is pipelined
and separated from request generation.
* TensorStore -> TensorStoreNarrowVME TensorStoreWideVME. The narrow one is the original one
* TensorLoad -> TensorLoadSimple (original) TensorLoadWideVME TensorStoreNarrowVME
* LoadUop -> LoadUopSimple is the original one. The new one is based on TensorLoad
* Fetch -> FetchVME64 FetchWideVME. Reuse communication part from TensorLoad.
* DPI intreface changed to transfer more than 64bit. svOpenArrayHandle is used. tsim library compilation now requires verilator
* Compute is changed to use TensorLoad style of load uop.
* VME changed to generate/queue/respond to multiple simultaneous load requests
* code formatting fix
* Update to Chisel 3.4.3 PR Port to the latest stable Chisel release (#33)
* Fix Makefile to use Chisel -o instead of top name and .sv instead of .v
* fix reset to reset.asBool
* fix SyncQueue to deprecated module.io
* fix toBools to asBools
* include Verialted.cpp verilated_dpi.cpp directly in module.cc to provide verilator array acces fuctionality and avoid compilation warnings
* fix module io warnings
* comments
* Jenkinsfile ci pipeline fix
* Jenkinsfile ci pipeline fix. only for lint,cpu,i386
* Reenable tsim tests
* style fix
* comments cleanup
* AXI constants commented. Moved write id to VME
* comments cleanup
* comments cleanup
---
.../chisel/src/main/resources/verilog/VTAMemDPI.v | 169 +++--
hardware/chisel/src/main/scala/core/Compute.scala | 11 +-
hardware/chisel/src/main/scala/core/Fetch.scala | 145 +---
.../scala/core/{Fetch.scala => FetchVME64.scala} | 17 +-
.../chisel/src/main/scala/core/FetchWideVME.scala | 351 ++++++++++
hardware/chisel/src/main/scala/core/LoadUop.scala | 181 +----
.../chisel/src/main/scala/core/LoadUopSimple.scala | 250 +++++++
.../chisel/src/main/scala/core/TensorLoad.scala | 321 +--------
.../src/main/scala/core/TensorLoadNarrowVME.scala | 740 ++++++++++++++++++++
.../{TensorLoad.scala => TensorLoadSimple.scala} | 21 +-
.../src/main/scala/core/TensorLoadWideVME.scala | 765 +++++++++++++++++++++
.../chisel/src/main/scala/core/TensorStore.scala | 238 +------
...ensorStore.scala => TensorStoreNarrowVME.scala} | 37 +-
.../src/main/scala/core/TensorStoreWideVME.scala | 289 ++++++++
.../chisel/src/main/scala/core/TensorUtil.scala | 73 +-
hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala | 254 ++++---
.../chisel/src/main/scala/interface/axi/AXI.scala | 9 +-
hardware/chisel/src/main/scala/shell/VME.scala | 328 ++++++---
.../scala/shell/{VME.scala => VMESimple.scala} | 121 +---
.../chisel/src/main/scala/util/SyncQueue.scala | 508 ++++++++++++++
.../scala/unittest/SyncQueue2PortMemTest.scala | 207 ++++++
.../src/test/scala/unittest/SyncQueueTest.scala | 264 +++++++
hardware/dpi/tsim_device.cc | 26 +-
include/vta/dpi/tsim.h | 18 +-
src/dpi/module.cc | 196 ++++--
tests/scripts/docker_bash.sh | 6 -
26 files changed, 4283 insertions(+), 1262 deletions(-)
diff --git a/hardware/chisel/src/main/resources/verilog/VTAMemDPI.v b/hardware/chisel/src/main/resources/verilog/VTAMemDPI.v
index e0ed949..9823550 100644
--- a/hardware/chisel/src/main/resources/verilog/VTAMemDPI.v
+++ b/hardware/chisel/src/main/resources/verilog/VTAMemDPI.v
@@ -18,51 +18,80 @@
*/
module VTAMemDPI #
-( parameter LEN_BITS = 8,
- parameter ADDR_BITS = 64,
- parameter DATA_BITS = 64
-)
-(
- input clock,
- input reset,
- input dpi_req_valid,
- input dpi_req_opcode,
- input [LEN_BITS-1:0] dpi_req_len,
- input [ADDR_BITS-1:0] dpi_req_addr,
- input dpi_wr_valid,
- input [DATA_BITS-1:0] dpi_wr_bits,
- output logic dpi_rd_valid,
- output logic [DATA_BITS-1:0] dpi_rd_bits,
- input dpi_rd_ready
-);
-
- import "DPI-C" function void VTAMemDPI
+ ( parameter LEN_BITS = 8,
+ parameter ADDR_BITS = 64,
+ parameter DATA_BITS = 64,
+ parameter STRB_BITS = DATA_BITS/8
+ )
(
- input byte unsigned req_valid,
- input byte unsigned req_opcode,
- input byte unsigned req_len,
- input longint unsigned req_addr,
- input byte unsigned wr_valid,
- input longint unsigned wr_value,
- output byte unsigned rd_valid,
- output longint unsigned rd_value,
- input byte unsigned rd_ready
+ input clock,
+ input reset,
+ input dpi_req_ar_valid,
+ input [LEN_BITS-1:0] dpi_req_ar_len,
+ input [7:0] dpi_req_ar_id,
+ input [ADDR_BITS-1:0] dpi_req_ar_addr,
+ input dpi_req_aw_valid,
+ input [LEN_BITS-1:0] dpi_req_aw_len,
+ input [ADDR_BITS-1:0] dpi_req_aw_addr,
+ input dpi_wr_valid,
+ input [DATA_BITS-1:0] dpi_wr_bits_data,
+ input [STRB_BITS-1:0] dpi_wr_bits_strb,
+ output logic dpi_rd_valid,
+ output logic [7:0] dpi_rd_bits_id,
+ output logic [DATA_BITS-1:0] dpi_rd_bits_data,
+ input dpi_rd_ready
);
- typedef logic dpi1_t;
- typedef logic [7:0] dpi8_t;
- typedef logic [31:0] dpi32_t;
- typedef logic [63:0] dpi64_t;
+ import "DPI-C" function void VTAMemDPI
+ (
+ input byte unsigned rd_req_valid,
+ input byte unsigned rd_req_len,
+ input byte unsigned rd_req_id,
+ input longint unsigned rd_req_addr,
+ input byte unsigned wr_req_valid,
+ input byte unsigned wr_req_len,
+ input longint unsigned wr_req_addr,
+ input byte unsigned wr_valid,
+ input longint unsigned wr_value[],
+ input longint unsigned wr_strb,
+ output byte unsigned rd_valid,
+ output byte unsigned rd_id,
+ output longint unsigned rd_value[],
+ input byte unsigned rd_ready
+ );
+ parameter blockNb = DATA_BITS/64;
+
+ generate
+ if (blockNb*64 != DATA_BITS) begin
+ $error("-F- 64 bit data blocks expected.");
+ end
+ endgenerate
+ generate
+ if (STRB_BITS > 64) begin
+ $error("-F- Strb bits should not exceed 64. Fix strb transfer");
+ end
+ endgenerate
+
+ typedef logic dpi1_t;
+ typedef logic [7:0] dpi8_t;
+ typedef logic [31:0] dpi32_t;
+ typedef logic [63:0] dpi64_t;
+ typedef longint dpi_data_t [blockNb-1:0];
dpi1_t __reset;
- dpi8_t __req_valid;
- dpi8_t __req_opcode;
- dpi8_t __req_len;
- dpi64_t __req_addr;
+ dpi8_t __rd_req_valid;
+ dpi8_t __rd_req_len;
+ dpi8_t __rd_req_id;
+ dpi64_t __rd_req_addr;
+ dpi8_t __wr_req_valid;
+ dpi8_t __wr_req_len;
+ dpi64_t __wr_req_addr;
dpi8_t __wr_valid;
- dpi64_t __wr_value;
+ dpi_data_t __wr_value;
+ dpi64_t __wr_strb;
dpi8_t __rd_valid;
- dpi64_t __rd_value;
+ dpi_data_t __rd_value;
+ dpi8_t __rd_id;
dpi8_t __rd_ready;
always_ff @(posedge clock) begin
@@ -71,36 +100,70 @@ module VTAMemDPI #
// delaying outputs by one-cycle
// since verilator does not support delays
+ integer i;
always_ff @(posedge clock) begin
dpi_rd_valid <= dpi1_t ' (__rd_valid);
- dpi_rd_bits <= __rd_value;
+ for (i = 0; i < blockNb; i = i +1) begin
+ dpi_rd_bits_data[64 * i +: 64] <= __rd_value[i];
+ end
+ dpi_rd_bits_id <= __rd_id;
end
- assign __req_valid = dpi8_t ' (dpi_req_valid);
- assign __req_opcode = dpi8_t ' (dpi_req_opcode);
- assign __req_len = dpi_req_len;
- assign __req_addr = dpi_req_addr;
+ assign __rd_req_valid = dpi8_t ' (dpi_req_ar_valid);
+ assign __rd_req_len = dpi8_t ' (dpi_req_ar_len);
+ assign __rd_req_id = dpi_req_ar_id;
+ assign __rd_req_addr = dpi64_t ' (dpi_req_ar_addr);
+ assign __wr_req_valid = dpi8_t ' (dpi_req_aw_valid);
+ assign __wr_req_len = dpi8_t ' (dpi_req_aw_len);
+ assign __wr_req_addr = dpi64_t ' (dpi_req_aw_addr);
+
+ generate
+ if (STRB_BITS != 64) begin
+ localparam [63 - STRB_BITS:0] strbfill = 0;
+ assign __wr_strb = {strbfill, dpi_wr_bits_strb};
+ end
+ else begin
+ assign __wr_strb = dpi_wr_bits_strb;
+ end
+ endgenerate
assign __wr_valid = dpi8_t ' (dpi_wr_valid);
- assign __wr_value = dpi_wr_bits;
+ genvar j;
+ generate
+ for (j = 0; j < blockNb; j = j +1) begin
+ assign __wr_value[j] = dpi_wr_bits_data[64 * j +: 64];
+ end
+ endgenerate
assign __rd_ready = dpi8_t ' (dpi_rd_ready);
// evaluate DPI function
always_ff @(posedge clock) begin
- if (reset | __reset) begin
- __rd_valid = 0;
- __rd_value = 0;
+ if(reset) begin
+ dpi_rd_valid <= 0;
+ dpi_rd_bits_data <= 0;
+ dpi_rd_bits_id <= 0;
end
else begin
VTAMemDPI(
- __req_valid,
- __req_opcode,
- __req_len,
- __req_addr,
+ __rd_req_valid,
+ __rd_req_len,
+ __rd_req_id,
+ __rd_req_addr,
+ __wr_req_valid,
+ __wr_req_len,
+ __wr_req_addr,
__wr_valid,
__wr_value,
+ __wr_strb,
__rd_valid,
+ __rd_id,
__rd_value,
__rd_ready);
- end
- end
+ dpi_rd_valid <= dpi1_t ' (__rd_valid);
+ for (i = 0; i < blockNb; i = i +1) begin
+ dpi_rd_bits_data[64 * i +: 64] <= __rd_value[i];
+ end
+ dpi_rd_bits_id <= __rd_id;
+ end // else: !if(reset | __reset)
+ end // always_ff @
+
endmodule
diff --git a/hardware/chisel/src/main/scala/core/Compute.scala b/hardware/chisel/src/main/scala/core/Compute.scala
index 0055a25..94669bc 100644
--- a/hardware/chisel/src/main/scala/core/Compute.scala
+++ b/hardware/chisel/src/main/scala/core/Compute.scala
@@ -20,10 +20,12 @@
package vta.core
import scala.math.pow
+import scala.math.sqrt
import chisel3._
import chisel3.util._
import vta.util.config._
+import vta.util._
import vta.shell._
/** Compute.
@@ -55,15 +57,16 @@ class Compute(debug: Boolean = false)(implicit val p: Parameters) extends Module
val s = Seq.tabulate(2)(_ =>
Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
- val loadUop = Module(new LoadUop)
+
+ val loadUop = Module(new LoadUopTop)
val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
val tensorGemm = Module(new TensorGemm)
val tensorAlu = Module(new TensorAlu)
- // try to use the acc closest to top IO
+ //try to use the acc closest to top IO
val topAccGrpIdx = tensorGemm.io.acc.closestIOGrpIdx
- val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
+ val inst_q = Module(new SyncQueue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
// decode
val dec = Module(new ComputeDecode)
@@ -205,7 +208,7 @@ class Compute(debug: Boolean = false)(implicit val p: Parameters) extends Module
RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).valid, tensorAlu.io.out.wr(0).valid)
io.out.wr(0).bits.idx := Mux(
RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).bits.idx, tensorAlu.io.out.wr(0).bits.idx)
- // put mux/Reg into every gemm group to build pipe (for Mux select) tree over distance
+ //put mux/Reg into every gemm group to build pipe (for Mux select) tree over distance
val chunkWidth = io.out.wr(0).bits.data.getWidth / tensorGemm.io.acc.splitWidth
val outDataBits = Wire(Vec(tensorGemm.io.acc.splitWidth, UInt(chunkWidth.W)))
io.out.wr(0).bits.data := outDataBits.asTypeOf(io.out.wr(0).bits.data)
diff --git a/hardware/chisel/src/main/scala/core/Fetch.scala b/hardware/chisel/src/main/scala/core/Fetch.scala
index 0ea35a3..66eaa43 100644
--- a/hardware/chisel/src/main/scala/core/Fetch.scala
+++ b/hardware/chisel/src/main/scala/core/Fetch.scala
@@ -23,6 +23,7 @@ import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
+import vta.util._
/** Fetch.
*
@@ -53,143 +54,21 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
val st = Decoupled(UInt(INST_BITS.W))
}
})
- val entries_q = 1 << (mp.lenBits - 1) // one-instr-every-two-vme-word
- val inst_q = Module(new Queue(UInt(INST_BITS.W), entries_q))
- val dec = Module(new FetchDecode)
- val s1_launch = RegNext(io.launch)
- val pulse = io.launch & ~s1_launch
+ val forceSimpleFetch = false // Force use original implementation of fetch
- val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
- val rlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
- val ilen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
-
- val xrem = Reg(chiselTypeOf(io.ins_count))
- val xsize = (io.ins_count << 1.U) - 1.U
- val xmax = (1 << mp.lenBits).U
- val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
-
- val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5)
- val state = RegInit(sIdle)
-
- // control
- switch(state) {
- is(sIdle) {
- when(pulse) {
- state := sReadCmd
- when(xsize < xmax) {
- rlen := xsize
- ilen := xsize >> 1.U
- xrem := 0.U
- }.otherwise {
- rlen := xmax - 1.U
- ilen := (xmax >> 1.U) - 1.U
- xrem := xsize - xmax
- }
- }
- }
- is(sReadCmd) {
- when(io.vme_rd.cmd.ready) {
- state := sReadLSB
- }
- }
- is(sReadLSB) {
- when(io.vme_rd.data.valid) {
- state := sReadMSB
- }
- }
- is(sReadMSB) {
- when(io.vme_rd.data.valid) {
- when(inst_q.io.count === ilen) {
- state := sDrain
- }.otherwise {
- state := sReadLSB
- }
- }
- }
- is(sDrain) {
- when(inst_q.io.count === 0.U) {
- when(xrem === 0.U) {
- state := sIdle
- }.elsewhen(xrem < xmax) {
- state := sReadCmd
- rlen := xrem
- ilen := xrem >> 1.U
- xrem := 0.U
- }.otherwise {
- state := sReadCmd
- rlen := xmax - 1.U
- ilen := (xmax >> 1.U) - 1.U
- xrem := xrem - xmax
- }
- }
- }
+ if (forceSimpleFetch) {
+ require (mp.dataBits <= 128, "-F- Simple VME data transfer doesnt support fetch data wider than instruction.")
}
- // read instructions from dram
- when(state === sIdle) {
- raddr := io.ins_baddr
- }.elsewhen(state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
- raddr := raddr + xmax_bytes
+ if (mp.dataBits >= 128 && !forceSimpleFetch) {
+ // wide cacheline
+ val fetch = Module(new FetchWideVME(debug))
+ io <> fetch.io
+ } else {
+ require(mp.dataBits == 64, "-F- Cannot make simple Fetch for more than 64 bit data read")
+ val fetch = Module(new Fetch64Bit(debug)) // Simple
+ io <> fetch.io
}
- io.vme_rd.cmd.valid := state === sReadCmd
- io.vme_rd.cmd.bits.addr := raddr
- io.vme_rd.cmd.bits.len := rlen
-
- io.vme_rd.data.ready := inst_q.io.enq.ready
-
- val lsb = Reg(chiselTypeOf(io.vme_rd.data.bits))
- val msb = io.vme_rd.data.bits
- val inst = Cat(msb, lsb)
-
- when(state === sReadLSB) { lsb := io.vme_rd.data.bits }
-
- inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
- inst_q.io.enq.bits := inst
-
- // decode
- dec.io.inst := inst_q.io.deq.bits
-
- // instruction queues
- io.inst.ld.valid := dec.io.isLoad & inst_q.io.deq.valid & state === sDrain
- io.inst.co.valid := dec.io.isCompute & inst_q.io.deq.valid & state === sDrain
- io.inst.st.valid := dec.io.isStore & inst_q.io.deq.valid & state === sDrain
-
- io.inst.ld.bits := inst_q.io.deq.bits
- io.inst.co.bits := inst_q.io.deq.bits
- io.inst.st.bits := inst_q.io.deq.bits
-
- // check if selected queue is ready
- val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
- val deq_ready =
- MuxLookup(deq_sel,
- false.B, // default
- Array(
- "h_01".U -> io.inst.ld.ready,
- "h_02".U -> io.inst.st.ready,
- "h_04".U -> io.inst.co.ready
- ))
-
- // dequeue instruction
- inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
-
- // debug
- if (debug) {
- when(state === sIdle && pulse) {
- printf("[Fetch] Launch\n")
- }
- // instruction
- when(inst_q.io.deq.fire()) {
- when(dec.io.isLoad) {
- printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits)
- }
- when(dec.io.isCompute) {
- printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits)
- }
- when(dec.io.isStore) {
- printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits)
- }
- }
- }
}
diff --git a/hardware/chisel/src/main/scala/core/Fetch.scala b/hardware/chisel/src/main/scala/core/FetchVME64.scala
similarity index 90%
copy from hardware/chisel/src/main/scala/core/Fetch.scala
copy to hardware/chisel/src/main/scala/core/FetchVME64.scala
index 0ea35a3..2f51472 100644
--- a/hardware/chisel/src/main/scala/core/Fetch.scala
+++ b/hardware/chisel/src/main/scala/core/FetchVME64.scala
@@ -23,6 +23,7 @@ import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
+import vta.util._
/** Fetch.
*
@@ -39,7 +40,7 @@ import vta.shell._
* more than one instruction at the time. Finally, the instruction queue is
* sized (entries_q), depending on the maximum burst allowed in the memory.
*/
-class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class Fetch64Bit(debug: Boolean = false)(implicit p: Parameters) extends Module {
val vp = p(ShellKey).vcrParams
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
@@ -54,10 +55,10 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
}
})
val entries_q = 1 << (mp.lenBits - 1) // one-instr-every-two-vme-word
- val inst_q = Module(new Queue(UInt(INST_BITS.W), entries_q))
+ val inst_q = Module(new SyncQueue(UInt(INST_BITS.W), entries_q))
val dec = Module(new FetchDecode)
- val s1_launch = RegNext(io.launch)
+ val s1_launch = RegNext(io.launch, init = false.B)
val pulse = io.launch & ~s1_launch
val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
@@ -136,14 +137,15 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
io.vme_rd.cmd.valid := state === sReadCmd
io.vme_rd.cmd.bits.addr := raddr
io.vme_rd.cmd.bits.len := rlen
+ io.vme_rd.cmd.bits.tag := 0.U // Cannot reorder requests as a queue is used
io.vme_rd.data.ready := inst_q.io.enq.ready
- val lsb = Reg(chiselTypeOf(io.vme_rd.data.bits))
- val msb = io.vme_rd.data.bits
+ val lsb = Reg(chiselTypeOf(io.vme_rd.data.bits.data))
+ val msb = io.vme_rd.data.bits.data
val inst = Cat(msb, lsb)
- when(state === sReadLSB) { lsb := io.vme_rd.data.bits }
+ when(state === sReadLSB) { lsb := io.vme_rd.data.bits.data }
inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
inst_q.io.enq.bits := inst
@@ -156,6 +158,9 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
io.inst.co.valid := dec.io.isCompute & inst_q.io.deq.valid & state === sDrain
io.inst.st.valid := dec.io.isStore & inst_q.io.deq.valid & state === sDrain
+ assert(!(inst_q.io.deq.valid & state === sDrain) || dec.io.isLoad || dec.io.isCompute || dec.io.isStore,
+ "-F- Fetch: Unknown instruction type")
+
io.inst.ld.bits := inst_q.io.deq.bits
io.inst.co.bits := inst_q.io.deq.bits
io.inst.st.bits := inst_q.io.deq.bits
diff --git a/hardware/chisel/src/main/scala/core/FetchWideVME.scala b/hardware/chisel/src/main/scala/core/FetchWideVME.scala
new file mode 100644
index 0000000..1f55b91
--- /dev/null
+++ b/hardware/chisel/src/main/scala/core/FetchWideVME.scala
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** Fetch.
+ *
+ * The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the
+ * VTA Memory Engine (VME), and push them into an instruction queue called
+ * inst_q. Once the instruction queue is full, instructions are dispatched to
+ * the Load, Compute and Store module queues based on the instruction opcode.
+ * After draining the queue, the fetch unit checks if there are more instructions
+ * via the ins_count register which is written by the host.
+ *
+ * Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB)
+ * because we are using a DRAM payload of 8-bytes or half of a VTA instruction.
+ * This should be configurable for larger payloads, i.e. 64-bytes, which can load
+ * more than one instruction at the time. Finally, the instruction queue is
+ * sized (entries_q), depending on the maximum burst allowed in the memory.
+ */
+class FetchWideVME(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val vp = p(ShellKey).vcrParams
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val launch = Input(Bool())
+ val ins_baddr = Input(UInt(mp.addrBits.W))
+ val ins_count = Input(UInt(vp.regBits.W))
+ val vme_rd = new VMEReadMaster
+ val inst = new Bundle {
+ val ld = Decoupled(UInt(INST_BITS.W))
+ val co = Decoupled(UInt(INST_BITS.W))
+ val st = Decoupled(UInt(INST_BITS.W))
+ }
+ })
+
+ val tp = new TensorParams("fetch")
+ val tensorsInClNb = tp.clSizeRatio
+ val tensorsInClNbWidth = log2Ceil(tensorsInClNb)
+ val inst_q = Seq.fill(tensorsInClNb) {
+ require((tp.memDepth/tensorsInClNb) * tensorsInClNb == tp.memDepth,
+ "-F- Unexpected queue depth to instructions in cacheline ratio")
+ SyncReadMem(tp.memDepth/tensorsInClNb, UInt(tp.tensorSizeBits.W))
+ }
+
+ // sample start
+ val s1_launch = RegNext(io.launch, init = false.B)
+ val start = io.launch & ~s1_launch
+
+
+ val xrem = Reg(chiselTypeOf(io.ins_count))
+ // fit instruction into 64bit chunks
+ val elemsInInstr = INST_BITS/64
+ val xsize = io.ins_count << log2Ceil(elemsInInstr)
+ // max size of transfer is limited by a buffer size
+ val xmax = (((1 << mp.lenBits) << log2Ceil(tp.clSizeRatio)).min(tp.memDepth)).U
+ val elemNb = Reg(xsize.cloneType)
+
+ val sIdle :: sRead :: sDrain :: Nil = Enum(3)
+ val state = RegInit(sIdle)
+ val isBusy = state === sRead
+
+ val vmeStart = start || (state === sRead && RegNext(state, init = sIdle) === sDrain)
+ val dramOffset = RegInit(UInt(mp.addrBits.W), init = 0.U)
+ val vmeCmd = Module (new GenVMECmdWideFetch(debug))
+ vmeCmd.io.start := vmeStart
+ vmeCmd.io.isBusy := isBusy & ~vmeStart
+ vmeCmd.io.ins_baddr := Mux(start, io.ins_baddr, io.ins_baddr + (dramOffset << log2Ceil(tp.tensorSizeBits / 8)))
+ vmeCmd.io.vmeCmd <> io.vme_rd.cmd
+ val readLen = vmeCmd.io.readLen
+ val vmeCmdDone = vmeCmd.io.done & ~vmeStart
+
+ vmeCmd.io.xsize := elemNb
+ vmeCmd.io.sram_offset := 0.U // this is a queue we reload
+
+ io.vme_rd.data.ready := true.B
+ val pipeDelayQueueDeqV = RegNext(io.vme_rd.data.valid, init = false.B)
+ val pipeDelayQueueDeqF = pipeDelayQueueDeqV // fire()
+ val pipeDelayQueueDeqB = RegNext(io.vme_rd.data.bits)
+
+ // Nb of CLs requestd, not received.
+ val clCntIdxWdth = log2Ceil(tp.memDepth/tensorsInClNb) + 1
+ val clInFlight = Reg(UInt(clCntIdxWdth.W))
+ when(start) {
+ clInFlight := 0.U
+ }.elsewhen(isBusy && io.vme_rd.cmd.fire() && !pipeDelayQueueDeqF) {
+ clInFlight := clInFlight + readLen
+ }.elsewhen(isBusy && io.vme_rd.cmd.fire() && pipeDelayQueueDeqF) {
+ clInFlight := clInFlight + readLen - 1.U
+ }.elsewhen(isBusy && !io.vme_rd.cmd.fire() && pipeDelayQueueDeqF) {
+ assert(clInFlight > 0.U)
+ clInFlight := clInFlight - 1.U
+ }.otherwise {
+ clInFlight := clInFlight
+ }
+
+ // number of entries in a queue
+ val queueCount = Reg(UInt((tp.memAddrBits + 1).W))
+ val queueHead = Wire(UInt(tp.memAddrBits.W))
+ val queueHeadNext = Reg(UInt(tp.memAddrBits.W))
+ val forceRead = Wire(Bool())
+ forceRead := false.B
+ // control
+ switch(state) {
+ is(sIdle) {
+ when(start) {
+ state := sRead
+ dramOffset := 0.U
+ when(xsize < xmax) {
+ elemNb := xsize
+ xrem := 0.U
+ }.otherwise {
+ elemNb := xmax
+ xrem := xsize - xmax
+ }
+ }
+ }
+ is(sRead) {
+ when(vmeCmdDone && clInFlight === 0.U) {
+ forceRead := true.B
+ state := sDrain
+ }
+ }
+ is(sDrain) {
+ when(queueCount === 0.U) {
+ dramOffset := dramOffset + elemNb
+ when(xrem === 0.U) {
+ state := sIdle
+ }.elsewhen(xrem < xmax) {
+ state := sRead
+ elemNb := xrem
+ xrem := 0.U
+ }.otherwise {
+ state := sRead
+ elemNb := xmax
+ xrem := xrem - xmax
+ }
+ }
+ }
+ }
+
+
+ //---------------------
+ //--- Read VME data ---
+ //---------------------
+
+ val readData = Module(new ReadVMEDataWide("fetch", debug))
+ readData.io.start := vmeStart
+ readData.io.vmeData.valid := pipeDelayQueueDeqV
+ readData.io.vmeData.bits := pipeDelayQueueDeqB
+ assert(readData.io.vmeData.ready === true.B)
+
+ //--------------------
+ //--- Write memory ---
+ //--------------------
+
+ val wmask = readData.io.destMask
+ val wdata = readData.io.destData
+ val widx = readData.io.destIdx
+
+ for (i <- 0 until tensorsInClNb) {
+ when(wmask(i) && pipeDelayQueueDeqF) {
+ inst_q(i).write(widx(i), wdata(i))
+ }
+ }
+ if (debug) {
+ when (io.vme_rd.data.fire()) {
+ printf(s"[TensorLoad] fetch data rdDataDestIdx:%x rdDataDestMask:%b\n",
+ widx.asUInt,
+ wmask.asUInt)
+ }
+ }
+
+ // read-from-sram
+ // queue head points to the first elem of instruction
+ val rIdx = queueHead >> tensorsInClNbWidth // SP idx
+ // rMask selects the first elem of instruction
+ val rMask = if (tensorsInClNbWidth > 0) {
+ UIntToOH(queueHead(tensorsInClNbWidth - 1, 0))
+ } else {
+ 1.U
+ }
+
+ val deqElem = Wire(Bool())
+ val rdataVec = for (i <- 0 until tensorsInClNb) yield {
+ // expand mask to select all elems of instruction
+ val maskShift = i%elemsInInstr
+ inst_q(i).read(rIdx, VecInit((rMask<<maskShift).asTypeOf(rMask).asBools)(i) && (deqElem || forceRead))
+
+ }
+
+ // instruction is a elemsInInstr number of elements
+ // combine them into one instruction
+ val rdata = Wire(Vec(elemsInInstr, UInt((tp.tensorSizeBits).W)))
+ for (i <- 0 until elemsInInstr) {
+ // expand mask to select all elems of instruction
+ rdata(i) := Mux1H(RegNext((rMask << i).asTypeOf(rMask)), rdataVec)
+ }
+
+
+ val canRead = queueCount >= elemsInInstr.U && state === sDrain
+ // instruction queues
+
+ // use 2-enty queue to create one pipe stage for valid-ready interface
+ val readInstrPipe = Module(new Queue(UInt(INST_BITS.W), 2))
+
+ // decode
+ val dec = Module(new FetchDecode)
+ dec.io.inst := readInstrPipe.io.deq.bits
+ readInstrPipe.io.enq.valid := canRead
+ readInstrPipe.io.enq.bits := rdata.asTypeOf(UInt(INST_BITS.W))
+ deqElem := readInstrPipe.io.enq.fire()
+ readInstrPipe.io.deq.ready := (
+ (dec.io.isLoad & io.inst.ld.ready) ||
+ (dec.io.isCompute & io.inst.co.ready) ||
+ (dec.io.isStore & io.inst.st.ready))
+ io.inst.ld.valid := dec.io.isLoad & readInstrPipe.io.deq.valid
+ io.inst.co.valid := dec.io.isCompute & readInstrPipe.io.deq.valid
+ io.inst.st.valid := dec.io.isStore & readInstrPipe.io.deq.valid
+
+ io.inst.ld.bits := readInstrPipe.io.deq.bits
+ io.inst.co.bits := readInstrPipe.io.deq.bits
+ io.inst.st.bits := readInstrPipe.io.deq.bits
+
+ when(start) {
+ queueCount := 0.U
+ }.elsewhen(deqElem && pipeDelayQueueDeqF) {
+ assert(queueCount > 0.U, "-F- Decrement zero counter")
+ val readCount = PopCount(wmask)
+ assert(readCount > 0.U, "-F- Must push something")
+ queueCount := queueCount + readCount - elemsInInstr.U
+ }.elsewhen(deqElem) {
+ assert(queueCount > 0.U, "-F- Decrement zero counter")
+ queueCount := queueCount - elemsInInstr.U
+ }.elsewhen (pipeDelayQueueDeqF) {
+ val numLoaded = PopCount(wmask)
+ assert(tp.memDepth.U - numLoaded >= queueCount, "-F- Counter overflow")
+ queueCount := queueCount + PopCount(wmask)
+ }.otherwise {
+ queueCount := queueCount
+ }
+ when(start) {
+ queueHead := 0.U
+ queueHeadNext := 0.U
+ }.elsewhen(deqElem) {
+ queueHead := queueHeadNext + elemsInInstr.U // read ahead
+ when (queueCount - elemsInInstr.U === 0.U) {
+ queueHeadNext := 0.U
+ }.otherwise {
+ queueHeadNext := queueHeadNext + elemsInInstr.U
+ }
+ }.otherwise {
+ assert(reset.asBool || state === sIdle || queueCount =/= 0.U ||
+ (queueCount === 0.U && queueHeadNext === 0.U))
+ queueHead := queueHeadNext
+ queueHeadNext := queueHeadNext
+ }
+
+ // check if selected queue is ready
+ val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
+ val deq_ready =
+ MuxLookup(deq_sel,
+ false.B, // default
+ Array(
+ "h_01".U -> io.inst.ld.ready,
+ "h_02".U -> io.inst.st.ready,
+ "h_04".U -> io.inst.co.ready
+ ))
+
+
+ // debug
+ if (debug) {
+ when(start) {
+ printf("[Fetch] Launch\n")
+ }
+ when(io.inst.ld.fire()) {
+ printf("[Fetch] [instruction decode] [L] %x\n", dec.io.inst)
+ }
+ when(io.inst.co.fire()) {
+ printf("[Fetch] [instruction decode] [C] %x\n", dec.io.inst)
+ }
+ when(io.inst.st.fire()) {
+ printf("[Fetch] [instruction decode] [S] %x\n", dec.io.inst)
+ }
+ }
+}
+class GenVMECmdWideFetch(debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val isBusy = Input(Bool())
+ val ins_baddr = Input(UInt(mp.addrBits.W))
+ val vmeCmd = Decoupled(new VMECmd)
+ val readLen = Output(UInt((mp.lenBits + 1).W))
+ val done = Output(Bool())
+
+ val xsize = Input(UInt(M_SIZE_BITS.W))
+ val sram_offset = Input(UInt(M_SRAM_OFFSET_BITS.W))
+
+ })
+
+
+ val cmdGen = Module (new GenVMECmdWide(tensorType = "fetch", debug))
+
+ cmdGen.io.start := io.start
+ cmdGen.io.isBusy := io.isBusy
+ cmdGen.io.baddr := io.ins_baddr
+ io.vmeCmd <> cmdGen.io.vmeCmd
+ io.readLen := cmdGen.io.readLen
+ io.done := cmdGen.io.done
+
+ cmdGen.io.ysize := 1.U
+ cmdGen.io.xsize := io.xsize
+ cmdGen.io.xstride := io.xsize
+ cmdGen.io.dram_offset := 0.U
+ cmdGen.io.sram_offset := io.sram_offset
+ cmdGen.io.xpad_0 := 0.U
+ cmdGen.io.xpad_1 := 0.U
+ cmdGen.io.ypad_0 := 0.U
+ cmdGen.io.updateState := io.vmeCmd.fire()
+ cmdGen.io.canSendCmd := true.B
+
+ when(io.start) {
+ val tp = new TensorParams("fetch")
+ assert(io.ins_baddr%(tp.tensorSizeBits/8).U === 0.U, "-F- Fetch DRAM address expected to be tensor size aligned.")
+ }
+
+}
+
diff --git a/hardware/chisel/src/main/scala/core/LoadUop.scala b/hardware/chisel/src/main/scala/core/LoadUop.scala
index e9f5d40..74e08cd 100644
--- a/hardware/chisel/src/main/scala/core/LoadUop.scala
+++ b/hardware/chisel/src/main/scala/core/LoadUop.scala
@@ -50,14 +50,11 @@ class UopClient(implicit p: Parameters) extends Bundle {
override def cloneType = new UopClient().asInstanceOf[this.type]
}
-/** LoadUop.
+/** LoadUopTop.
*
- * Load micro-ops (uops) from memory, i.e. DRAM, and store them in the
- * uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in
- * group of 2 given the fact that the DRAM payload is 8-bytes. This module
- * should be modified later on to support different DRAM sizes efficiently.
+ * Top wrapper of load uop implementations.
*/
-class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class LoadUopTop(debug: Boolean = false)(implicit val p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
@@ -67,160 +64,40 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
val vme_rd = new VMEReadMaster
val uop = new UopClient
})
- val numUop = 2 // store two uops per sram word
- val uopBits = p(CoreKey).uopBits
- val uopBytes = uopBits / 8
- val uopDepth = p(CoreKey).uopMemDepth / numUop
- val dec = io.inst.asTypeOf(new MemDecode)
- val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
- val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
- val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
- val xrem = Reg(chiselTypeOf(dec.xsize))
- val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U
- val xmax = (1 << mp.lenBits).U
- val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
+ // force simple load uop implementation be careful if
+ // define simple tensor load
+ val forceSimpleLoadUop = false;
- val dram_even = (dec.dram_offset % 2.U) === 0.U
- val sram_even = (dec.sram_offset % 2.U) === 0.U
- val sizeIsEven = (dec.xsize % 2.U) === 0.U
+ if (forceSimpleLoadUop) {
+ require(mp.dataBits == 64, "-F- Original LoadUop supports only 64 bit memory data transfer")
- val sIdle :: sReadCmd :: sReadData :: Nil = Enum(3)
- val state = RegInit(sIdle)
+ val loadUop = Module(new LoadUopSimple(debug))
- // control
- switch(state) {
- is(sIdle) {
- when(io.start) {
- state := sReadCmd
- when(xsize < xmax) {
- xlen := xsize
- xrem := 0.U
- }.otherwise {
- xlen := xmax - 1.U
- xrem := xsize - xmax
- }
- }
- }
- is(sReadCmd) {
- when(io.vme_rd.cmd.ready) {
- state := sReadData
- }
- }
- is(sReadData) {
- when(io.vme_rd.data.valid) {
- when(xcnt === xlen) {
- when(xrem === 0.U) {
- state := sIdle
- }.otherwise {
- raddr := raddr + xmax_bytes
- when(xrem < xmax) {
- state := sReadCmd
- xlen := xrem
- xrem := 0.U
- }
- .otherwise {
- state := sReadCmd
- xlen := xmax - 1.U
- xrem := xrem - xmax
- }
- }
- }
- }
- }
- }
+ loadUop.io.start := io.start
+ io.done := loadUop.io.done
+ loadUop.io.baddr := io.baddr
+ loadUop.io.vme_rd <> io.vme_rd
- // read-from-dram
- val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
- when(state === sIdle) {
- when(dram_even) {
- raddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))
- }.otherwise {
- raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))) - uopBytes.U
- }
- }
+ loadUop.io.dec := io.inst.asTypeOf(new MemDecode)
+ loadUop.io.uop.idx <> io.uop.idx
+ io.uop <> loadUop.io.uop
- io.vme_rd.cmd.valid := state === sReadCmd
- io.vme_rd.cmd.bits.addr := raddr
- io.vme_rd.cmd.bits.len := xlen
+ } else {
+ val loadUop = Module(new TensorLoad(tensorType = "uop"))
+ loadUop.io.tensor.tieoffWrite()
- io.vme_rd.data.ready := state === sReadData
+ loadUop.io.start := io.start
+ io.done := loadUop.io.done
+ loadUop.io.baddr := io.baddr
+ loadUop.io.vme_rd <> io.vme_rd
- when(state =/= sReadData) {
- xcnt := 0.U
- }.elsewhen(io.vme_rd.data.fire()) {
- xcnt := xcnt + 1.U
- }
+ loadUop.io.inst := io.inst
+ require(loadUop.tp.splitWidth == 1 && loadUop.tp.splitLength == 1, "-F- UOP tensor split read is not expected")
+ loadUop.io.tensor.rd(0).idx <> io.uop.idx
+ io.uop.data.valid := loadUop.io.tensor.rd(0).data.valid
+ io.uop.data.bits <> loadUop.io.tensor.rd(0).data.bits.asTypeOf(new UopDecode)
- val waddr = Reg(UInt(log2Ceil(uopDepth).W))
- when(state === sIdle) {
- waddr := dec.sram_offset >> log2Ceil(numUop)
- }.elsewhen(io.vme_rd.data.fire()) {
- waddr := waddr + 1.U
- }
-
- val wdata = Wire(Vec(numUop, UInt(uopBits.W)))
- val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata))
- val wmask = Reg(Vec(numUop, Bool()))
-
- when(sram_even) {
- when(sizeIsEven) {
- wmask := "b_11".U.asTypeOf(wmask)
- }.elsewhen(io.vme_rd.cmd.fire()) {
- when(dec.xsize === 1.U) {
- wmask := "b_01".U.asTypeOf(wmask)
- }.otherwise {
- wmask := "b_11".U.asTypeOf(wmask)
- }
- }.elsewhen(io.vme_rd.data.fire()) {
- when((xcnt === xlen - 1.U) && (xrem === 0.U)) {
- wmask := "b_01".U.asTypeOf(wmask)
- }.otherwise {
- wmask := "b_11".U.asTypeOf(wmask)
- }
- }
- }.otherwise {
- when(io.vme_rd.cmd.fire()) {
- wmask := "b_10".U.asTypeOf(wmask)
- }.elsewhen(io.vme_rd.data.fire()) {
- when(sizeIsEven && (xcnt === xlen - 1.U) && (xrem === 0.U)) {
- wmask := "b_01".U.asTypeOf(wmask)
- }.otherwise {
- wmask := "b_11".U.asTypeOf(wmask)
- }
- }
- }
-
- wdata := io.vme_rd.data.bits.asTypeOf(wdata)
- when(dram_even === false.B && sram_even) {
- wdata(0) := io.vme_rd.data.bits.asTypeOf(wdata)(1)
- }.elsewhen(sram_even === false.B && dram_even) {
- wdata(1) := io.vme_rd.data.bits.asTypeOf(wdata)(0)
- }
-
- when(io.vme_rd.data.fire()) {
- mem.write(waddr, wdata, wmask)
- }
-
- // read-from-sram
- io.uop.data.valid := RegNext(io.uop.idx.valid)
-
- // delay LSB of idx by a cycle because of the one-cycle memory read latency
- val sIdx = RegNext(io.uop.idx.bits % numUop.U)
- val rIdx = io.uop.idx.bits >> log2Ceil(numUop)
- val memRead = mem.read(rIdx, io.uop.idx.valid)
- val sWord = memRead.asUInt.asTypeOf(wdata)
- val sUop = sWord(sIdx).asTypeOf(io.uop.data.bits)
-
- io.uop.data.bits <> sUop
-
- // done
- io.done := state === sReadData & io.vme_rd.data.valid & xcnt === xlen & xrem === 0.U
-
- // debug
- if (debug) {
- when(io.vme_rd.cmd.fire()) {
- printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
- }
}
}
+
diff --git a/hardware/chisel/src/main/scala/core/LoadUopSimple.scala b/hardware/chisel/src/main/scala/core/LoadUopSimple.scala
new file mode 100644
index 0000000..7de7b1b
--- /dev/null
+++ b/hardware/chisel/src/main/scala/core/LoadUopSimple.scala
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+class LoadUopSimple(debug: Boolean = false)(implicit val p: Parameters) extends Module {
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val dec = Input(new MemDecode)
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vme_rd = new VMEReadMaster
+ val uop = new UopClient
+ })
+ val uopsPerMemXfer = p(ShellKey).memParams.dataBits / p(CoreKey).uopBits
+ require(p(ShellKey).memParams.dataBits % p(CoreKey).uopBits == 0)
+
+ val uopBits = p(CoreKey).uopBits
+ val uopBytes = uopBits / 8
+ val uopDepth = p(CoreKey).uopMemDepth / uopsPerMemXfer
+ val dataBytes = mp.dataBits / 8
+
+ val dec = io.dec
+ val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
+ val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+ val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+ val xrem = Reg(chiselTypeOf(dec.xsize))
+ val xmax = (1 << mp.lenBits).U
+ val xmax_bytes = ((1 << mp.lenBits) * dataBytes).U
+ // Align DRAM address to data and do not cross page boundary.
+ val data_align_bits = WireInit(UInt(raddr.getWidth.W), dataBytes.U - 1.U)
+ val beat_bytes_bits = log2Ceil(mp.dataBits >> 3)
+ val xfer_bytes = Reg(chiselTypeOf(xmax_bytes))
+ // DRAM address width must be the same as AXI araddr since anything above will
+ // be silently truncated, enforce this here.
+ val dram_byte_addr = WireInit(UInt(raddr.getWidth.W), dec.dram_offset << log2Ceil(uopBytes))
+ // Here we are assuming io.baddr | dram_byte_addr === io.baddr + dram_byte_addr.
+ val unaligned_addr = io.baddr | dram_byte_addr
+ val xfer_init_addr = unaligned_addr & ~data_align_bits
+ val xfer_next_addr = raddr + xfer_bytes
+ val xfer_init_bytes = xmax_bytes - xfer_init_addr % xmax_bytes
+ val xfer_init_beats = xfer_init_bytes >> beat_bytes_bits
+ val xfer_next_bytes = xmax_bytes - xfer_next_addr % xmax_bytes
+ val xfer_next_beats = xfer_next_bytes >> beat_bytes_bits
+
+ val dram_even = (dec.dram_offset % 2.U) === 0.U
+ val sram_even = (dec.sram_offset % 2.U) === 0.U
+ val sizeIsEven = (dec.xsize % 2.U) === 0.U
+
+ val sIdle :: sReadCmd :: sReadData :: Nil = Enum(3)
+ val state = RegInit(sIdle)
+ val first = RegInit(init=false.B)
+
+ // control
+ switch(state) {
+ is(sIdle) {
+ xfer_bytes := xfer_init_bytes
+ when(io.start) {
+ state := sReadCmd
+ first := true.B
+ raddr := xfer_init_addr
+ // Number of total beats in the load transfer.
+ val xsize = if (uopsPerMemXfer == 1) {
+ dec.xsize
+ } else {
+ ((dec.xsize +& 1.U + dec.dram_offset(0)) >> 1)
+ }
+
+ when(xsize <= xfer_init_beats) {
+ xlen := xsize - 1.U
+ xrem := 0.U
+ }.otherwise {
+ xlen := xfer_init_beats - 1.U
+ xrem := xsize - xfer_init_beats
+ }
+ }
+ }
+ is(sReadCmd) {
+ when(io.vme_rd.cmd.ready) {
+ state := sReadData
+ }
+ }
+ is(sReadData) {
+ when(io.vme_rd.data.valid) {
+ when(xcnt === xlen) {
+ when(xrem === 0.U) {
+ state := sIdle
+ }.otherwise {
+ state := sReadCmd
+ raddr := xfer_next_addr
+ xfer_bytes := xfer_next_bytes
+ when(xrem <= xfer_next_beats) {
+ xlen := xrem - 1.U
+ xrem := 0.U
+ }.otherwise {
+ xlen := xfer_next_beats - 1.U
+ xrem := xrem - xfer_next_beats
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // read-from-dram
+ io.vme_rd.cmd.valid := state === sReadCmd
+ io.vme_rd.cmd.bits.addr := raddr
+ io.vme_rd.cmd.bits.len := xlen
+ io.vme_rd.cmd.bits.tag := dec.sram_offset
+
+ io.vme_rd.data.ready := state === sReadData
+
+ when(state =/= sReadData) {
+ xcnt := 0.U
+ }.elsewhen(io.vme_rd.data.fire()) {
+ xcnt := xcnt + 1.U
+ }
+
+ val waddr = IndexedSeq.fill(uopsPerMemXfer) { Reg(UInt(log2Ceil(uopDepth).W))}
+ when(state === sIdle) {
+ val so = dec.sram_offset >> log2Ceil(uopsPerMemXfer)
+ if (uopsPerMemXfer == 1) {
+ waddr(0) := so
+ } else {
+ when (!sram_even && dram_even) { // 10
+ waddr(0) := so + 1.U
+ waddr(1) := so
+ }.elsewhen (sram_even && !dram_even) { // 01
+ waddr(0) := so
+ waddr(1) := so - 1.U
+ }.otherwise {
+ waddr(0) := so
+ waddr(1) := so
+ }
+ }
+ }.elsewhen(io.vme_rd.data.fire()) {
+ for (i <- 0 until uopsPerMemXfer) {
+ waddr(i) := waddr(i) + 1.U
+ }
+ }
+
+ val mems = IndexedSeq.fill(uopsPerMemXfer) { SyncReadMem(uopDepth, UInt(uopBits.W))}
+ val last = (xcnt === xlen) && (xrem === 0.U)
+
+ val wmask = Wire(Vec(uopsPerMemXfer, Bool()))
+ for (i <- 0 until uopsPerMemXfer) {
+ wmask(i) := true.B
+ }
+
+ when (io.vme_rd.data.fire()) {
+ when (first) {
+ first := false.B
+
+ if (uopsPerMemXfer == 2) {
+ when(!sram_even && !dram_even) {
+ wmask(0) := false.B
+ }
+ }
+ }
+ when(last) {
+ if (uopsPerMemXfer == 2) {
+ when(dram_even ^ sizeIsEven) {
+ when (sram_even ^ sizeIsEven) {
+ wmask(1) := false.B
+ }.otherwise{
+ wmask(0) := false.B
+ }
+ }
+ }
+ }
+ }
+
+ val wdata = Wire(Vec(uopsPerMemXfer, UInt(uopBits.W)))
+ wdata := io.vme_rd.data.bits.data.asTypeOf(wdata)
+ if (uopsPerMemXfer == 2) {
+ when(dram_even =/= sram_even) { // swap
+ wdata(0) := io.vme_rd.data.bits.data.asTypeOf(wdata)(1)
+ wdata(1) := io.vme_rd.data.bits.data.asTypeOf(wdata)(0)
+ }
+ }
+
+ when(io.vme_rd.data.fire()) {
+ for { i <- 0 until mems.size} {
+ when (wmask(i)) {
+ mems(i).write(waddr(i), wdata(i))
+ }
+ }
+ }
+
+ io.done := io.vme_rd.data.fire() & last
+
+ // ----------- read-from-sram -------------
+
+ io.uop.data.valid := RegNext(io.uop.idx.valid)
+
+ // delay LSB of idx by a cycle because of the one-cycle memory read latency
+ val rIdx = io.uop.idx.bits >> log2Ceil(uopsPerMemXfer)
+ val m0 = mems(0).read(rIdx, io.uop.idx.valid)
+
+ if (uopsPerMemXfer == 2) {
+ val m1 = mems(1).read(rIdx, io.uop.idx.valid)
+ val sIdx = RegNext(io.uop.idx.bits % uopsPerMemXfer.U)
+ io.uop.data.bits <> Mux(sIdx =/= 0.U, m1, m0).asTypeOf(io.uop.data.bits)
+ } else {
+ io.uop.data.bits <> m0.asTypeOf(io.uop.data.bits)
+ }
+
+ if (false) {
+ // Report initial part of the uop state after
+ // the clock transition where io.done is high
+ val memDumpGuard = RegNext(io.done,init=false.B)
+ when (memDumpGuard) {
+ for {
+ idx <- 0 until scala.math.min(8,uopDepth)
+ i <- 0 until uopsPerMemXfer} {
+ val s = mems(i)(idx).asTypeOf(io.uop.data.bits)
+ printf(s"uop: $idx $i u0: %x u1: %x u2: %x\n", s.u0, s.u1, s.u2)
+ }
+ }
+ }
+
+ // debug
+ if (debug) {
+ when(io.vme_rd.cmd.fire()) {
+ printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
+ }
+ }
+}
diff --git a/hardware/chisel/src/main/scala/core/TensorLoad.scala b/hardware/chisel/src/main/scala/core/TensorLoad.scala
index 8b31253..24fb841 100644
--- a/hardware/chisel/src/main/scala/core/TensorLoad.scala
+++ b/hardware/chisel/src/main/scala/core/TensorLoad.scala
@@ -19,10 +19,15 @@
package vta.core
+import scala.math.pow
+
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
+
+
+
/** TensorLoad.
*
* Load 1D and 2D tensors from main memory (DRAM) to input/weight
@@ -45,300 +50,26 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
val tensor = new TensorClient(tensorType)
})
- require(tp.numMemBlock > 0, s"-F- Unexpected data to tensor bit size ratio. ${tensorType} ${tp.numMemBlock}")
- require(tp.splitWidth == 1 && tp.splitLength == 1, s"-F- Cannot do split direct access")
-
- val sizeFactor = tp.tensorLength * tp.numMemBlock
- val strideFactor = tp.tensorLength * tp.tensorWidth
-
- val dec = io.inst.asTypeOf(new MemDecode)
- val dataCtrl = Module(
- new TensorDataCtrl(tensorType, sizeFactor, strideFactor))
- val dataCtrlDone = RegInit(false.B)
- val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor))
- val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor))
- val xPadCtrl0 = Module(new TensorPadCtrl(padType = "XPad0", sizeFactor))
- val xPadCtrl1 = Module(new TensorPadCtrl(padType = "XPad1", sizeFactor))
-
- val tag = Reg(UInt(log2Ceil(tp.numMemBlock).W))
- val set = Reg(UInt(log2Ceil(tp.tensorLength).W))
-
- val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil =
- Enum(7)
- val state = RegInit(sIdle)
-
- // control
- switch(state) {
- is(sIdle) {
- when(io.start) {
- when(dec.ypad_0 =/= 0.U) {
- state := sYPad0
- }.elsewhen(dec.xpad_0 =/= 0.U) {
- state := sXPad0
- }.otherwise {
- state := sReadCmd
- }
- }
- }
- is(sYPad0) {
- when(yPadCtrl0.io.done) {
- when(dec.xpad_0 =/= 0.U) {
- state := sXPad0
- }.otherwise {
- assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
- state := sReadCmd
- }
- }
- }
- is(sXPad0) {
- when(xPadCtrl0.io.done) {
- assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
- state := sReadCmd
- }
- }
- is(sReadCmd) {
- when(io.vme_rd.cmd.ready) {
- state := sReadData
- }
- }
- is(sReadData) {
- when(io.vme_rd.data.valid) {
- when(dataCtrl.io.done) {
- when(dec.xpad_1 =/= 0.U) {
- state := sXPad1
- }.elsewhen(dec.ypad_1 =/= 0.U) {
- state := sYPad1
- }.otherwise {
- state := sIdle
- }
- }.elsewhen(dataCtrl.io.stride) {
- when(dec.xpad_1 =/= 0.U) {
- state := sXPad1
- }.elsewhen(dec.xpad_0 =/= 0.U) {
- state := sXPad0
- }.otherwise {
- assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
- state := sReadCmd
- }
- }.elsewhen(dataCtrl.io.split) {
- state := sReadCmd
- }
- }
- }
- is(sXPad1) {
- when(xPadCtrl1.io.done) {
- when(dataCtrlDone) {
- when(dec.ypad_1 =/= 0.U) {
- state := sYPad1
- }.otherwise {
- state := sIdle
- }
- }.otherwise {
- when(dec.xpad_0 =/= 0.U) {
- state := sXPad0
- }.otherwise {
- assert(tag === (tp.numMemBlock - 1).U, "-F- Should not happen mid tensor row read")
- state := sReadCmd
- }
- }
- }
- }
- is(sYPad1) {
- when(yPadCtrl1.io.done && dataCtrlDone) {
- state := sIdle
- }
- }
- }
-
- // data controller
- dataCtrl.io.start := state === sIdle & io.start
- dataCtrl.io.inst := io.inst
- dataCtrl.io.baddr := io.baddr
- dataCtrl.io.xinit := io.vme_rd.cmd.fire()
- dataCtrl.io.xupdate := io.vme_rd.data.fire()
- dataCtrl.io.yupdate := io.vme_rd.data.fire()
-
- when(state === sIdle) {
- dataCtrlDone := false.B
- }.elsewhen(io.vme_rd.data.fire() && dataCtrl.io.done) {
- dataCtrlDone := true.B
- }
-
- // pad
- yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start
-
- yPadCtrl1.io.start := dec.ypad_1 =/= 0.U &
- ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
- (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
-
- xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
- ((state === sIdle & io.start) |
- (state === sYPad0 & yPadCtrl0.io.done) |
- (io.vme_rd.data.fire() & ~dataCtrlDone & dataCtrl.io.stride & dec.xpad_1 === 0.U) |
- (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
-
- xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
- ((dataCtrl.io.done) | (~dataCtrl.io.done & dataCtrl.io.stride & dec.xpad_1 =/= 0.U))
-
- yPadCtrl0.io.inst := io.inst
- yPadCtrl1.io.inst := io.inst
- xPadCtrl0.io.inst := io.inst
- xPadCtrl1.io.inst := io.inst
-
- // read-from-dram
- io.vme_rd.cmd.valid := state === sReadCmd
- io.vme_rd.cmd.bits.addr := dataCtrl.io.addr
- io.vme_rd.cmd.bits.len := dataCtrl.io.len
-
- io.vme_rd.data.ready := state === sReadData
-
- // write-to-sram
- val isZeroPad = state === sYPad0 |
- state === sXPad0 |
- state === sXPad1 |
- state === sYPad1
-
- when(state === sReadCmd && tag =/= (tp.numMemBlock - 1).U) { // split read inside row of mem blocks
- tag := tag
- }.elsewhen(state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
- tag := 0.U
- }.elsewhen(io.vme_rd.data.fire() || isZeroPad) {
- tag := tag + 1.U
- }
-
- when(state === sIdle || (dataCtrlDone && ~isZeroPad) ||
- (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
- set := 0.U
- }.elsewhen((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
- set := set + 1.U
- }
-
- val waddr_cur = Reg(UInt(tp.memAddrBits.W))
- val waddr_nxt = Reg(UInt(tp.memAddrBits.W))
- when(state === sIdle) {
- waddr_cur := dec.sram_offset
- waddr_nxt := dec.sram_offset
- }.elsewhen((io.vme_rd.data.fire() || isZeroPad)
- && set === (tp.tensorLength - 1).U
- && tag === (tp.numMemBlock - 1).U)
- {
- waddr_cur := waddr_cur + 1.U
- }.elsewhen(dataCtrl.io.stride && io.vme_rd.data.fire()) {
- waddr_cur := waddr_nxt + dec.xsize
- waddr_nxt := waddr_nxt + dec.xsize
- }
-
- val tensorFile = Seq.fill(tp.tensorLength) {
- SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
- }
-
- val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
- val wdata = Seq.fill(tp.tensorLength) {
- Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
- }
- val no_mask = Wire(Vec(tp.numMemBlock, Bool()))
- no_mask.foreach { m =>
- m := true.B
- }
-
- for (i <- 0 until tp.tensorLength) {
- for (j <- 0 until tp.numMemBlock) {
- wmask(i)(j) := tag === j.U
- wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
- }
- val tdata = io.tensor.wr(0).bits.data(i).asUInt.asTypeOf(wdata(i))
- val muxWen =
- Mux(state === sIdle,
- io.tensor.wr(0).valid,
- (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
- val muxWaddr = Mux(state === sIdle, io.tensor.wr(0).bits.idx, waddr_cur)
- val muxWdata = Mux(state === sIdle, tdata, wdata(i))
- val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
- when(muxWen) {
- tensorFile(i).write(muxWaddr, muxWdata, muxWmask)
- }
- }
-
- // read-from-sram
- val rvalid = RegNext(io.tensor.rd(0).idx.valid)
- io.tensor.rd(0).data.valid := rvalid
-
- val rdata =
- tensorFile.map(_.read(io.tensor.rd(0).idx.bits, io.tensor.rd(0).idx.valid))
- rdata.zipWithIndex.foreach {
- case (r, i) =>
- io.tensor.rd(0).data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd(0).data.bits(i))
- }
-
- // done
- val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
- val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U
- val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done
- io.done := done_no_pad | done_x_pad | done_y_pad
-
- // debug
- if (debug) {
- if (tensorType == "inp") {
- when(io.vme_rd.cmd.fire()) {
- printf("[TensorLoad] [inp] cmd addr:%x len:%x\n",
- dataCtrl.io.addr,
- dataCtrl.io.len)
- }
- when(state === sYPad0) {
- printf("[TensorLoad] [inp] sYPad0\n")
- }
- when(state === sYPad1) {
- printf("[TensorLoad] [inp] sYPad1\n")
- }
- when(state === sXPad0) {
- printf("[TensorLoad] [inp] sXPad0\n")
- }
- when(state === sXPad1) {
- printf("[TensorLoad] [inp] sXPad1\n")
- }
- } else if (tensorType == "wgt") {
- when(io.vme_rd.cmd.fire()) {
- printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n",
- dataCtrl.io.addr,
- dataCtrl.io.len)
- }
- when(state === sYPad0) {
- printf("[TensorLoad] [wgt] sYPad0\n")
- }
- when(state === sYPad1) {
- printf("[TensorLoad] [wgt] sYPad1\n")
- }
- when(state === sXPad0) {
- printf("[TensorLoad] [wgt] sXPad0\n")
- }
- when(state === sXPad1) {
- printf("[TensorLoad] [wgt] sXPad1\n")
- }
- } else if (tensorType == "acc") {
- when(io.vme_rd.cmd.fire()) {
- printf("[TensorLoad] [acc] cmd addr:%x len:%x\n",
- dataCtrl.io.addr,
- dataCtrl.io.len)
- printf("[TensorLoad] [acc info] dec.xsize: %d, dec.ysize: %d, dec.xstride: %d\n",
- dec.xsize, dec.ysize, dec.xstride)
- printf("[TensorLoad] [acc i2fo] dec.xpad_1: %d dec.xpad_0: %d dec.ypad_1: %d dec.ypad_0: %d\n",
- dec.xpad_1, dec.xpad_0, dec.ypad_1, dec.ypad_0)
-
- printf("tp.tensorLength: %d, tp.numMemBlock: %d, tp.tensorLength: %d, tp.tensorWidth: %d\n",
- tp.tensorLength.U, tp.numMemBlock.U, tp.tensorLength.U, tp.tensorWidth.U)
- }
- when(state === sYPad0) {
- printf("[TensorLoad] [acc] sYPad0\n")
- }
- when(state === sYPad1) {
- printf("[TensorLoad] [acc] sYPad1\n")
- }
- when(state === sXPad0) {
- printf("[TensorLoad] [acc] sXPad0\n")
- }
- when(state === sXPad1) {
- printf("[TensorLoad] [acc] sXPad1\n")
- }
- }
+ override def desiredName = "TensorLoad" + tensorType.capitalize
+
+ val forceSimpleTensorLoad = false // force a simple implemetation of TL
+
+ if (forceSimpleTensorLoad) {
+ // use
+ val tensorLoad = Module(new TensorLoadSimple(tensorType, debug))
+ io <> tensorLoad.io
+ } else if (mp.dataBits >= tp.tensorSizeBits) {
+ // cacheline is wider than tensor size,
+ // macro memory bitwidth by cache size
+ // bank by tansor size
+ val tensorLoad = Module(new TensorLoadWideVME(tensorType, debug))
+ io <> tensorLoad.io
+ } else {
+ // tensor is wider than cacheline, bank by
+ // macro memory bitwidth by tansor size
+ // bank by cacheline size
+ val tensorLoad = Module(new TensorLoadNarrowVME(tensorType, debug))
+ io <> tensorLoad.io
}
}
+
diff --git a/hardware/chisel/src/main/scala/core/TensorLoadNarrowVME.scala b/hardware/chisel/src/main/scala/core/TensorLoadNarrowVME.scala
new file mode 100644
index 0000000..746e54f
--- /dev/null
+++ b/hardware/chisel/src/main/scala/core/TensorLoadNarrowVME.scala
@@ -0,0 +1,740 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import scala.math.pow
+import scala.math.sqrt
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+
+/** TensorLoad.
+ *
+ * Load 1D and 2D tensors from main memory (DRAM) to input/weight
+ * scratchpads (SRAM). Also, there is support for zero padding, while
+ * doing the load.
+ */
+class TensorLoadNarrowVME(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vme_rd = new VMEReadMaster
+ val tensor = new TensorClient(tensorType)
+ })
+ val writePipeLatency = tp.writePipeLatency
+
+ val sIdle :: sBusy :: Nil =
+ Enum(2)
+ val state = RegInit(sIdle)
+
+ val isBusy = state === sBusy
+
+ val localDone = Wire(Bool())
+ when(io.start) {
+ state := sBusy
+ }.elsewhen(localDone) {
+ state := sIdle
+ }
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+
+ val vmeDataBitsPipe = RegNext(io.vme_rd.data.bits)
+ val vmeDataValidPipe = RegNext(io.vme_rd.data.valid, init = false.B)
+ val vmeDataReadyPipe = RegNext(io.vme_rd.data.ready, init = false.B)
+ val vmeDataFirePipe = vmeDataValidPipe & vmeDataReadyPipe
+
+ //--------------------------------------
+ //--- Generate data load VME command ---
+ //--------------------------------------
+ val vmeCmd = Module (new GenVMECmd(tensorType, debug))
+ vmeCmd.io.start := io.start
+ vmeCmd.io.isBusy := isBusy
+ vmeCmd.io.inst := io.inst
+ vmeCmd.io.baddr := io.baddr
+ vmeCmd.io.vmeCmd <> io.vme_rd.cmd
+ val readLen = vmeCmd.io.readLen
+ val commandsDone = vmeCmd.io.done
+
+ // count how many blocks not received
+ val blkIdxWdth = log2Ceil(tp.tsSizeRatio * tp.memDepth) // the size of scratchpad in blocks
+ // Nb of data blocks requestd, not received. TODO: smaller width parameter
+ val blocksInFlight = Reg(UInt(blkIdxWdth.W))
+ when(io.start) {
+ blocksInFlight := 0.U
+ }.elsewhen(isBusy && io.vme_rd.cmd.fire() && !vmeDataFirePipe) {
+ blocksInFlight := blocksInFlight + readLen
+ }.elsewhen(isBusy && io.vme_rd.cmd.fire() && vmeDataFirePipe) {
+ blocksInFlight := blocksInFlight + readLen - 1.U
+ }.elsewhen(isBusy && !io.vme_rd.cmd.fire() && vmeDataFirePipe) {
+ assert(blocksInFlight > 0.U)
+ blocksInFlight := blocksInFlight - 1.U
+ }.otherwise {
+ blocksInFlight := blocksInFlight
+ }
+
+ //---------------------
+ //--- Read VME data ---
+ //---------------------
+
+ val readData = Module(new ReadVMEData(tensorType, debug))
+ readData.io.start := io.start
+ readData.io.vmeData.valid := vmeDataValidPipe
+ readData.io.vmeData.bits := vmeDataBitsPipe
+ assert(!readData.io.vmeData.valid || readData.io.vmeData.ready,
+ "-F- Expecting const ready. Fix ReadVMEData to receive data 1 cyce after ready")
+ io.vme_rd.data.ready := readData.io.vmeData.ready
+ val rdDataDestCol = readData.io.col // this is an index of a col in tensor
+ val rdDataDestIdx = readData.io.idx // this is an index of a tensor
+
+ //-------------------------
+ //--- Fill zero padding ---
+ //-------------------------
+
+ val fillPadding = Module(new ZeroPadding(tensorType, debug))
+ fillPadding.io.canWriteMem := !vmeDataFirePipe
+ fillPadding.io.inst := RegNext(io.inst) // stage it to move from instr queue
+ fillPadding.io.start := RegNext(io.start, init = false.B)// stage it to move from instr que
+
+ val isZeroPadWrite = fillPadding.io.tensorIdx.valid // Store zero filled tensor, zpDestIdx is valid
+ val zpDestIdx = fillPadding.io.tensorIdx.bits // Tensor index
+ val paddingDone = fillPadding.io.done
+
+ //--------------------
+ //--- Write memory ---
+ //--------------------
+
+ val memSizeRatio = tp.tsSizeRatio
+ val splitDataFactor = tp.splitWidth * tp.splitLength
+ val splitMemBlockFactor = if (splitDataFactor > memSizeRatio) {
+ require((splitDataFactor/memSizeRatio) * memSizeRatio == splitDataFactor,
+ "-F- Cannot split tensor data memBlockBits further.")
+ splitDataFactor/memSizeRatio
+ }else {
+ 1
+ }
+ val groupMemBlockFactor = if (splitDataFactor > memSizeRatio) {
+ 1
+ }else {
+ require((memSizeRatio/splitDataFactor) * splitDataFactor == memSizeRatio,
+ "-F- Cannot group tensor data memBlockBits into groups.")
+ memSizeRatio/splitDataFactor
+ }
+ // one macro has a VME memory read bit width or read/write group bit width
+ //different groups can read/write scratchpad separately
+ val tensorFile = Seq.fill(memSizeRatio * splitMemBlockFactor
+ ) {
+ SyncReadMem(tp.memDepth, UInt((tp.memBlockBits/splitMemBlockFactor).W))
+ }
+
+
+ require(splitDataFactor * groupMemBlockFactor == memSizeRatio * splitMemBlockFactor,
+ "-F- Wrong split of data")
+ //-------------------------------
+ //--- Write address vector ------
+ //-------------------------------
+ // split data to build pipe tree
+ val splitFactorL0 = pow(2,log2Ceil(memSizeRatio) / 2).toInt
+ val splitFactorL1 = pow(2,log2Ceil(memSizeRatio) - log2Ceil(memSizeRatio) / 2).toInt
+ require(splitFactorL0 * splitFactorL1 == memSizeRatio)
+ // tensor load instruction writes a VME data block or a whole tensor
+ val waddrTensInstrTmp = Mux(isZeroPadWrite, zpDestIdx, rdDataDestIdx)
+ val waddrTensInstrPipe = VecInit((for (j <- 0 until splitFactorL1) yield {
+ ShiftRegister(waddrTensInstrTmp, if (writePipeLatency > 0) 1 else 0)
+ }).flatMap(elem => for (k <- 0 until splitFactorL0) yield {
+ elem
+ }).flatMap(elem => for (k <- 0 until splitMemBlockFactor) yield {
+ ShiftRegister(elem, if (writePipeLatency < 2) 0 else writePipeLatency - 1)
+ }))
+ require(waddrTensInstrPipe.size == memSizeRatio * splitMemBlockFactor)
+
+ val waddrDirect = (VecInit((for (grIdx <- 0 until splitDataFactor) yield {
+ io.tensor.wr(grIdx).bits.idx
+ }).flatMap(elem => for (k <- 0 until groupMemBlockFactor) yield {elem}))).asTypeOf(
+ Vec(memSizeRatio * splitMemBlockFactor, io.tensor.wr(0).bits.idx.cloneType)
+ )
+
+
+ val waddr = Wire(Vec(memSizeRatio * splitMemBlockFactor, waddrTensInstrTmp.cloneType))
+ for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+ waddr(j) := Mux(
+ ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en = true.B),
+ waddrDirect(j),
+ waddrTensInstrPipe(j))
+ }
+
+ //-------------------------------
+ //--- Write enable vector -------
+ //-------------------------------
+ val dataOffset = rdDataDestCol
+ // get en sygnal and duplicate
+ val wenTensInstr = VecInit((for (j <- 0 until memSizeRatio) yield {
+ Mux(isZeroPadWrite, true.B, dataOffset === j.U && vmeDataFirePipe)
+ }).flatMap(elem => for (k <- 0 until splitMemBlockFactor) yield {elem}))
+
+ val wenDirect = VecInit((for (grIdx <- 0 until splitDataFactor) yield {
+ io.tensor.wr(grIdx).valid
+ }).flatMap(elem => for (k <- 0 until groupMemBlockFactor) yield {elem}))
+
+ val wen = Wire(Vec(memSizeRatio * splitMemBlockFactor, Bool()))
+ for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+ wen(j) := Mux(
+ ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en = true.B),
+ wenDirect(j),
+ ShiftRegister(wenTensInstr(j), writePipeLatency))
+ }
+
+ require(tp.memBlockBits % tp.tensorElemBits == 0)
+
+
+ //-------------------------------
+ //--- Write data vector ---------
+ //-------------------------------
+ val wdataTensInstrDataPipe = VecInit((for (j <- 0 until splitFactorL0) yield {
+ ShiftRegister(vmeDataBitsPipe.data, if (writePipeLatency > 0) 1 else 0)
+ }).flatMap(elem => for (k <- 0 until splitFactorL1) yield {
+ elem
+ }).flatMap(elem => for (k <- 0 until splitMemBlockFactor) yield {
+ require(elem.getWidth == tp.memBlockBits)
+ ShiftRegister(
+ elem.asTypeOf(Vec(splitMemBlockFactor, UInt((tp.memBlockBits/splitMemBlockFactor).W)))(k),
+ if (writePipeLatency < 2) 0 else writePipeLatency - 1)
+ }))
+ require(wdataTensInstrDataPipe.size == memSizeRatio * splitMemBlockFactor)
+ val wdataTensInstr = Wire(Vec(memSizeRatio * splitMemBlockFactor, UInt((tp.memBlockBits/splitMemBlockFactor).W)))
+ for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+ // pipe 1 stage paddingControl per group
+ val padValue = 0.U
+
+ wdataTensInstr(j) := Mux(
+ ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = false.B, en = true.B),
+ ShiftRegister(padValue /* a single group total data bits */, writePipeLatency),
+ wdataTensInstrDataPipe(j))
+ }
+
+ // THIS wdataDirect writes continous scratchpad data space
+ // It is WRONG for ACC batch > 1
+ // maps group data bits to continous sequence of mem blocks
+ // but wr(x).bits.data is a window in a tensor
+ val wdataDirect = VecInit((for (grIdx <- 0 until splitDataFactor) yield {
+ io.tensor.wr(grIdx).bits.data
+ }).flatMap(elem => for (k <- 0 until groupMemBlockFactor) yield {
+ elem.asTypeOf(Vec(groupMemBlockFactor, UInt((tp.memBlockBits/splitMemBlockFactor).W)))(k)
+ }))
+ val wdata = Wire(Vec(memSizeRatio * splitMemBlockFactor, UInt((tp.memBlockBits/splitMemBlockFactor).W)))
+ for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+ wdata(j) := Mux(
+ ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en = true.B),
+ wdataDirect(j),
+ wdataTensInstr(j))
+ }
+
+ for (j <- 0 until memSizeRatio * splitMemBlockFactor) {
+ when(wen(j)) {
+ tensorFile(j).write(waddr(j), wdata(j))
+ }
+ }
+ if (debug) {
+ when(isZeroPadWrite) {
+ printf(s"[TensorLoad] $tensorType isZeroPadWrite data zpDestIdx:%d\n",
+ zpDestIdx)
+ }
+ when (vmeDataFirePipe) {
+ printf(s"[TensorLoad] $tensorType data rdDataDestCol:%d rdDataDestIdx:%d\n",
+ rdDataDestCol,
+ rdDataDestIdx)
+ }
+ }
+
+ // read-from-sram
+ for (grIdx <- 0 until splitDataFactor) {
+ val rvalid = ShiftRegister(
+ io.tensor.rd(grIdx).idx.valid, tp.readTensorLatency + 1, resetData = false.B, en = true.B)
+ io.tensor.rd(grIdx).data.valid := rvalid
+ }
+
+ val memsInGroup = memSizeRatio * splitMemBlockFactor / splitDataFactor
+ for (grIdx <- 0 until splitDataFactor) {
+ io.tensor.rd(grIdx).data.bits :=
+ VecInit(for (memBlkIdx <- 0 until memsInGroup) yield {
+ tensorFile(grIdx * memsInGroup + memBlkIdx).read(
+ ShiftRegister(io.tensor.rd(grIdx).idx.bits, tp.readTensorLatency),
+ ShiftRegister(io.tensor.rd(grIdx).idx.valid, tp.readTensorLatency, resetData = false.B, en = true.B))
+ }).asTypeOf(io.tensor.rd(grIdx).data.bits)
+ }
+
+ // done
+ val loadDone = blocksInFlight === 0.U && commandsDone && state === sBusy
+ localDone := loadDone && paddingDone
+ io.done := ShiftRegister(localDone, writePipeLatency, resetData = false.B, en = true.B)
+
+}
+
+//-------------------------
+//--- Fill zero padding ---
+//-------------------------
+
+//----------------------------------------------------------------------------
+// Fill tensors with zeros if padding is defined
+// stride must be used (xstride and ysize) if xpad_0 or xpad_1
+// are not zero and matrix has more than one row of tensors
+// zp states enumerate different types of padding blocks
+// TOP - width = dec.xpad_0 + dec.xstride + dec.xpad_1; height = dec.ypad_0
+// LEFT - width = dec.xpad_0; height = dec.ysize
+// RIGHT - width = dec.xpad_1; height = dec.ysize
+// BOT - width = dec.xpad_0 + dec.xstride + dec.xpad_1; height = dec.ypad_1
+// BOTH - LEFT+RIGHT
+// SKIP - dec.xpad_0 == 0 && dec.xpad_1
+
+//Fill algorithm fills row by row from TOP then sides, then BOT
+//----------------------------------------------------------------------------
+class ZeroPadding(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val canWriteMem = Input(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val tensorIdx = Output(ValidIO(UInt(tp.memAddrBits.W)))
+ val start = Input(Bool())
+ val done = Output(Bool())
+ })
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+
+ val isZeroPadWrite = Wire(Bool()) // Store zero filled tensor, zpDestIdx is valid
+ val zpDestIdx = Wire(dec.sram_offset.cloneType) // Tensor index
+ val sZpIdle :: sZpTop :: sZpSideLeft :: sZpSideRight :: sZpSideBoth :: sZpSideSkip :: sZpBot :: Nil =
+ Enum(7)
+ val zpState = RegInit(sZpIdle)
+ val paddingDone = zpState === sZpIdle // Done filling zero tensors
+ val zpColIdx = Reg(UInt((dec.xpad_0.getWidth + dec.xsize.getWidth + dec.xpad_1.getWidth).W))
+ val zpNewFillBlock = Wire(Bool()) // separate new fill block <-> inside block row change and column idx calculation
+ // Define padding area iterators
+ val zpRowIdx = Reg(UInt((dec.ypad_0.getWidth + dec.ysize.getWidth + dec.ypad_1.getWidth).W)) // current padding row
+ // current padding column
+ val zpDestRowOffset = Reg(dec.sram_offset.cloneType) // one-dimentional offset for zpRowIdx
+ zpRowIdx := zpRowIdx
+ zpColIdx := zpColIdx
+ zpDestRowOffset := zpDestRowOffset
+ zpNewFillBlock := false.B
+
+ //state change val
+ val zpLastDataRow = zpRowIdx === dec.ypad_0 + dec.ysize - 1.U
+ val zpTopLastIdx = dec.xpad_0 + dec.xsize + dec.xpad_1 - 1.U // last index of total width
+ val zpWideLineEnd = (zpState === sZpSideBoth || zpState === sZpSideRight) && zpColIdx === zpTopLastIdx
+ val zpNarwLineEnd = zpState === sZpSideLeft && zpColIdx === dec.xpad_0 - 1.U
+ val zpFillLineEnd = zpWideLineEnd || zpNarwLineEnd
+
+ when(io.start) {
+ zpRowIdx := 0.U
+ zpDestRowOffset := dec.sram_offset
+
+ zpColIdx := 0.U
+ when(dec.xpad_0 === 0.U && dec.xpad_1 =/= 0.U && dec.ypad_0 === 0.U) {
+ zpColIdx := dec.xpad_0 + dec.xsize
+ }
+ when(dec.ypad_0 =/= 0.U) {
+ zpState := sZpTop
+ }.elsewhen(dec.xpad_0 =/= 0.U && dec.xpad_1 === 0.U) {
+ zpState := sZpSideLeft
+ }.elsewhen(dec.xpad_0 === 0.U && dec.xpad_1 =/= 0.U) {
+ zpState := sZpSideRight
+ }.elsewhen(dec.xpad_0 =/= 0.U && dec.xpad_1 =/= 0.U) {
+ zpState := sZpSideBoth
+ }.elsewhen(dec.ypad_1 =/= 0.U) {
+ zpState := sZpSideSkip
+ }.otherwise {
+ zpState := sZpIdle // nothing to fill
+ }
+ }.elsewhen(
+ io.canWriteMem &&
+ zpState === sZpTop &&
+ zpRowIdx === dec.ypad_0 - 1.U && /*we know ypad_0 > 0 */
+ zpColIdx === zpTopLastIdx) {
+ zpNewFillBlock := true.B
+
+ zpColIdx := 0.U
+ when(dec.xpad_0 === 0.U && dec.xpad_1 =/= 0.U) {
+ zpColIdx := dec.xpad_0 + dec.xsize
+ }
+ when(dec.xpad_0 =/= 0.U && dec.xpad_1 === 0.U) {
+ zpState := sZpSideLeft
+ }.elsewhen(dec.xpad_0 === 0.U && dec.xpad_1 =/= 0.U) {
+ zpState := sZpSideRight
+ }.elsewhen(dec.xpad_0 =/= 0.U && dec.xpad_1 =/= 0.U) {
+ zpState := sZpSideBoth
+ }.elsewhen(dec.ypad_1 =/= 0.U) {
+ zpState := sZpSideSkip
+ }.otherwise {
+ zpState := sZpIdle // nothing to fill
+ }
+ }.elsewhen(
+ zpLastDataRow && // last row before ypad_1
+ ((zpFillLineEnd && io.canWriteMem) || // last zero tensor in xpad_0 or xpad_1
+ zpState === sZpSideSkip)) /* no padding in data rows */ {
+
+ zpNewFillBlock := true.B
+
+ when(dec.ypad_1 =/= 0.U) { // also no dec.xpad_1 no xpad_0
+ zpColIdx := 0.U // first index for ypad_1 area
+ zpState := sZpBot // if more padding is needed go to count data rows
+ }.otherwise {
+ zpState := sZpIdle // nothing to fill
+ }
+ }.elsewhen(
+ io.canWriteMem &&
+ zpState === sZpBot &&
+ zpRowIdx === dec.ypad_0 + dec.ysize + dec.ypad_1 - 1.U && /*we know ypad_1 > 0 */
+ zpColIdx === zpTopLastIdx) {
+ zpNewFillBlock := true.B
+
+ zpColIdx := 0.U
+ zpState := sZpIdle
+ }.otherwise {
+ zpState := zpState
+ }
+ // allowed to write memory when data reader is inactive
+ isZeroPadWrite := zpState =/= sZpIdle && zpState =/= sZpSideSkip && io.canWriteMem
+ zpDestIdx := zpDestRowOffset + zpColIdx
+
+ //increment row
+ // and set zpColIdx on a row change
+ val incrementRow = Wire(Bool())
+ incrementRow := false.B
+ when(
+ ((((zpState === sZpTop || zpState === sZpSideBoth || zpState === sZpSideRight || zpState === sZpBot) &&
+ zpColIdx === zpTopLastIdx) ||
+ (zpState === sZpSideLeft && zpColIdx === dec.xpad_0 - 1.U))&& io.canWriteMem) ||
+ zpState === sZpSideSkip) {
+
+ zpDestRowOffset := zpDestRowOffset + zpTopLastIdx + 1.U // count rows in one-dimentional destination matrix
+ zpRowIdx := zpRowIdx + 1.U
+ incrementRow := true.B
+ when(!zpNewFillBlock) { // column may be reset on block type change
+ when(zpState === sZpSideRight) {
+ zpColIdx := dec.xpad_0 + dec.xsize
+ }.otherwise {
+ zpColIdx := 0.U
+ }
+ }
+ }
+
+ //increment column if it is not done on block change or row in block change
+ when(isZeroPadWrite && !zpNewFillBlock && !incrementRow) {
+ when(zpState === sZpSideBoth && zpColIdx === dec.xpad_0 - 1.U) {
+ zpColIdx := zpColIdx + dec.xsize + 1.U// skip data tensors
+
+ }.otherwise {
+ zpColIdx := zpColIdx + 1.U
+ }
+ }
+ io.done := zpState === sZpIdle
+ io.tensorIdx.valid := isZeroPadWrite
+ io.tensorIdx.bits := zpDestIdx
+}
+
+//---------------------
+//--- Read VME data ---
+//---------------------
+//----------------------------------------------------------------------------
+// Read VME data. Generate Memory index and data
+// transaction TAG is a data block offset in scratchpad
+// Different transactions are identified by tag change
+// SAME DESTINATION SUBSEQUENT REQUESTS IN ONE INSTRUCTION LEADS TO UNDEFINED BEHAVIOR
+//----------------------------------------------------------------------------
+class ReadVMEData(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val vmeData = Flipped(Decoupled(new VMEData))
+ val idx = Output(UInt(tp.memAddrBits.W))
+ val col = Output(UInt(log2Ceil(tp.tsSizeRatio).W))
+ })
+
+ io.vmeData.ready := true.B // always ready to read VME data
+
+ require(pow(2, log2Ceil(tp.tensorSizeBits)) == tp.tensorSizeBits,
+ "-F- Tensor bit size must be 2^. Using shift and bits to divide.")
+ require(pow(2, log2Ceil(tp.memBlockBits)) == tp.memBlockBits,
+ "-F- Tensor bit size must be 2^. Using shift and bits to divide.")
+ require(tp.tsSizeRatio >= 1,
+ "-F- Tensor bit size must equal or greater than read puls width.")
+
+ val blkOffsetWidth = log2Ceil(tp.tsSizeRatio)
+
+
+ val rdDataDestCol = Wire(UInt(blkOffsetWidth.W)) // this is an index of a cl in a tensor
+ val rdDataDestIdx = Wire(UInt(M_SRAM_OFFSET_BITS.W)) // this is an index of a tensor
+ io.vmeData.ready := true.B // always ready to read VME data
+
+ //decode data destination
+ val vmeTagDecode = io.vmeData.bits.tag
+ val vmeTagDecodeLast = Reg(vmeTagDecode.cloneType) // store tag to identify a new burst
+ val rdDataIdx = vmeTagDecode(vmeTagDecode.getWidth - 1, blkOffsetWidth)
+ val rdDataCol = if (tp.tsSizeRatio == 1) 0.U else vmeTagDecode(blkOffsetWidth - 1, 0)
+ val rdDataDestColNext = Reg(rdDataDestCol.cloneType) // this is an index in a col in tensor
+ val rdDataDestIdxNext = Reg(UInt(M_SRAM_OFFSET_BITS.W)) // this is an index of a tensor
+
+ val vmeTagDecodeLastValid = Wire(Bool())
+ val vmeTagDecodeLastValidNext = RegNext(
+ next = vmeTagDecodeLastValid,
+ init = false.B)
+ when(io.start) {
+ vmeTagDecodeLastValid :=false.B // reset tag valid
+ }.elsewhen(io.vmeData.fire()) {
+ vmeTagDecodeLastValid := true.B // set tag valid on a new read
+ }.otherwise {
+ vmeTagDecodeLastValid := vmeTagDecodeLastValidNext // keep value
+ }
+ rdDataDestCol := DontCare
+ rdDataDestIdx := DontCare
+ when(io.vmeData.fire()) {
+ when (
+ !vmeTagDecodeLastValidNext ||
+ (vmeTagDecodeLastValidNext &&
+ vmeTagDecode.asUInt =/= vmeTagDecodeLast.asUInt)) {
+
+ vmeTagDecodeLast := vmeTagDecode // a new burst
+ rdDataDestCol := rdDataCol
+ rdDataDestIdx := rdDataIdx
+ rdDataDestColNext := rdDataCol + 1.U //increment col in tensor
+ rdDataDestIdxNext := rdDataIdx
+ }.otherwise {
+ rdDataDestCol := rdDataDestColNext //continue burst read
+ rdDataDestColNext := rdDataDestColNext + 1.U //increment col in tensor
+ rdDataDestIdx := rdDataDestIdxNext
+ when(rdDataDestCol === (tp.tsSizeRatio - 1).U) {
+ rdDataDestIdxNext := rdDataDestIdxNext + 1.U //increment tensor index
+ }
+ }
+ }
+
+ io.idx := rdDataDestIdx
+ io.col := rdDataDestCol
+}
+
+// transaction TAG is a data block offset in scratchpad
+// Different transactions are identified by tag change
+// SAME DESTINATION SUBSEQUENT REQUESTS IN ONE INSTRUCTION LEADS TO UNDEFINED BEHAVIOR
+class GenVMECmd(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val isBusy = Input(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vmeCmd = Decoupled(new VMECmd)
+ val readLen = Output(UInt((mp.lenBits + 1).W))
+ val done = Output(Bool())
+ })
+ val sizeFactor = tp.tsSizeRatio
+
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+
+ val rdCmdExtAddr = Reg(UInt(mp.addrBits.W)) // current address in the row
+ val maxTransfer = (1 << mp.lenBits).U // max number of blocks in transfer
+ // from old data ctrl
+ val elemBytes = tp.tensorLength * tp.tensorWidth * tp.tensorElemBits / 8 // bytes in tensor
+ val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
+ val xfer_init_addr = io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
+ val maxTrBytes = maxTransfer << (log2Ceil(mp.dataBits) - 3)
+ //Align first transfer to maxTrBytes boundary. It occures on every dec.xsize transfer
+ //all other transfers in the row will end at maxTrBytes boundary
+ val firstMaxTransfer = (maxTrBytes - rdCmdExtAddr % maxTrBytes) >> (log2Ceil(mp.dataBits) - 3)
+
+
+ //--------------------------------------
+ //--- Generate data load VME command ---
+ //--------------------------------------
+
+ val rdCmdStartIdxValid = Wire(Bool()) // Command is valid
+ val startIssueCmdRead = Wire(Bool()) // First transaction in dec.xsize transfer
+ val rdCmdStartIdx = Reg(UInt(log2Ceil(tp.memDepth).W)) // Scratchpad data block index for the first transaction
+ val readLen = Wire(UInt((mp.lenBits + 1).W)) // read cmd transaction length. It is <= maxTransfer
+ val commandsDone = RegInit(true.B) // Done generating VME commands
+ val stride = Wire(Bool()) // flags change to the next row to read
+ val blocksReadSize = (dec.xsize << log2Ceil(sizeFactor)) // how many blocks to read in a singl src row
+ val blocksReadNb = Reg(blocksReadSize.cloneType)
+ val rdCmdExtAddrRowBegin = Reg(UInt(mp.addrBits.W)) // starting address in the row
+ val newReadRow = Reg(Bool()) // flags the first read of dec.xsize
+
+ // set which source row of data to read. dec.ysize defines the number of rows
+ val srcRowIdx = Reg(UInt(dec.ysize.getWidth.W)) // current row of stride read
+ when (io.start) {
+ srcRowIdx := 0.U // 1st row
+ }.elsewhen (stride) {
+ srcRowIdx := srcRowIdx + 1.U // increment row
+ }.otherwise {
+ srcRowIdx := srcRowIdx // stay in the row
+ }
+
+ // set how many blocks of data being loaded
+ commandsDone := commandsDone
+ when (io.start || stride) {
+ blocksReadNb := 0.U
+ commandsDone := false.B
+ }.elsewhen (io.vmeCmd.fire()) {
+ val nextBlRNb = blocksReadNb + readLen
+ blocksReadNb := nextBlRNb // THIS IS WHEN A NEW VME CMD HAPPENS
+ when (nextBlRNb === blocksReadSize && srcRowIdx === dec.ysize - 1.U) {
+ commandsDone := true.B
+ }
+ }.otherwise {
+ blocksReadNb := blocksReadNb
+ }
+
+ //when the whole xsize row read commands send, go for the next src row
+ when((blocksReadNb === blocksReadSize - readLen) && (srcRowIdx =/= dec.ysize - 1.U) && io.vmeCmd.fire()) {
+ stride := true.B
+ }.otherwise {
+ stride := false.B
+ }
+
+ assert(!io.isBusy || blocksReadSize >= blocksReadNb)// define how many block to read at this cycle
+ val blocksRemained = blocksReadSize - blocksReadNb
+ when (newReadRow) {
+ when(blocksRemained < firstMaxTransfer) {
+ readLen := blocksRemained
+ }.otherwise {
+ readLen := firstMaxTransfer
+ }
+ }.otherwise {
+ when(blocksRemained < maxTransfer) {
+ readLen := blocksRemained
+ }.otherwise {
+ readLen := maxTransfer
+ }
+ }
+ // block index of the read data row (xsize). Modified by zero padding
+ val totalWidth = dec.xsize + dec.xpad_0 + dec.xpad_1 // width of scratchpad matrix in tensors
+ // instead of multiplying total width by ypad_0 do incremental addition.
+ //Should cost ypad_0 cycles to issue 1st read cmd
+ // counts src matrix with y padding rows of tensors
+ val currentRowIdx = Reg(UInt((dec.ysize.getWidth + dec.ypad_0.getWidth).W))
+ // start to issue read cmd
+ rdCmdStartIdxValid := currentRowIdx >= dec.ypad_0 &&
+ currentRowIdx < (dec.ysize + dec.ypad_0) &&
+ io.isBusy &&
+ !commandsDone
+ when (io.start) {
+ currentRowIdx := 0.U
+ rdCmdStartIdx := dec.sram_offset + dec.xpad_0 // this index is in tensors
+ }.elsewhen (io.isBusy && (currentRowIdx < dec.ypad_0 || stride)) {
+ rdCmdStartIdx := rdCmdStartIdx + totalWidth
+ currentRowIdx := currentRowIdx + 1.U
+ }
+ startIssueCmdRead := false.B
+ when(blocksReadNb === 0.U && rdCmdStartIdxValid) {
+ startIssueCmdRead := true.B
+ }
+ rdCmdExtAddrRowBegin := rdCmdExtAddrRowBegin
+
+ when (io.start) {
+ rdCmdExtAddr := xfer_init_addr
+ rdCmdExtAddrRowBegin := xfer_init_addr
+ newReadRow := true.B
+ }.elsewhen (io.vmeCmd.fire()) {
+ when(stride) {
+ val memRow = rdCmdExtAddrRowBegin + (dec.xstride << log2Ceil(elemBytes))
+ rdCmdExtAddr := memRow // go to the next source matrix row with xstride tensors offset
+ rdCmdExtAddrRowBegin := memRow
+ newReadRow := true.B
+ }.otherwise {
+ newReadRow := false.B
+ // go to the next tranaction same continous data block
+ rdCmdExtAddr := rdCmdExtAddr + (readLen << (log2Ceil(mp.dataBits) - 3))
+ }
+ }.otherwise {
+ rdCmdExtAddr := rdCmdExtAddr
+ newReadRow := newReadRow
+ }
+
+ //-------------------------------------
+ //--- execute VME data load command ---
+ //-------------------------------------
+
+ require(pow(2, log2Ceil(tp.tensorSizeBits)) == tp.tensorSizeBits,
+ "-F- Tensor size must be 2^. Using shift and bits to divide.")
+ require(pow(2, log2Ceil(tp.memBlockBits)) == tp.memBlockBits,
+ "-F- Read pulsewidth must be 2^ . Using shift and bits to divide.")
+ //first log2Ceil(tp.numMemBlock) bits encode block offset in a row,
+ //then log2Ceil(tp.tensorLength) bits for a row in a tensor, then tensor index
+ val blkOffset = log2Ceil(tp.tsSizeRatio)
+ val blkIdxWdth = log2Ceil(tp.tsSizeRatio * tp.memDepth) // the size of scratchpad in blocks
+
+ val rdCmdDestBlockIdx = Wire(UInt(blkIdxWdth.W)) // dataBits size block index in a scratchpad
+ val rdCmdDestBlockIdxNext = Reg(rdCmdDestBlockIdx.cloneType) // dataBits size block index in a scratchpad
+ rdCmdDestBlockIdxNext := rdCmdDestBlockIdxNext
+ rdCmdDestBlockIdx := rdCmdDestBlockIdxNext
+
+ // block position in a scratchpad
+ val rdCmdValid = Wire(Bool())
+ //increment scratch pad destination index
+ when(rdCmdStartIdxValid) {
+ rdCmdValid := true.B
+ when(startIssueCmdRead) {
+ rdCmdDestBlockIdx := rdCmdStartIdx << blkOffset // it is aligned by tensor size
+ rdCmdDestBlockIdxNext:= rdCmdDestBlockIdx + readLen
+ }.elsewhen (io.vmeCmd.fire()) {
+ // increment block position by transaction length
+ rdCmdDestBlockIdxNext:= rdCmdDestBlockIdxNext + readLen
+ }
+ }.otherwise {
+ rdCmdValid := false.B
+ }
+ if(debug) {
+ when (io.vmeCmd.fire()) {
+ printf(s"[GenVMECmd] $tensorType cmd data rdCmdDestBlockIdx:%b " +
+ s" length:%d \n",
+ rdCmdDestBlockIdx,
+ readLen)
+ }
+ }
+ // read-from-dram
+ require(io.vmeCmd.bits.tag.getWidth >= rdCmdDestBlockIdx.getWidth,
+ "-F- Not enough VME tag bits to store transaction tag.")
+ io.vmeCmd.valid := rdCmdValid
+ io.vmeCmd.bits.addr := rdCmdExtAddr
+ io.vmeCmd.bits.len := readLen - 1.U
+ assert(!io.vmeCmd.valid || ((readLen << log2Ceil(mp.dataBits/8)) <= (maxTrBytes - rdCmdExtAddr % maxTrBytes)),
+ s"-F- ${tensorType} DRAM page alignment failure. DRAM " +
+ s"address + len overlaps mp.lenBits*memBlockSize alignment %x %x",
+ rdCmdExtAddr, readLen)
+ io.vmeCmd.bits.tag := rdCmdDestBlockIdx
+ io.readLen := readLen
+ io.done := commandsDone
+}
diff --git a/hardware/chisel/src/main/scala/core/TensorLoad.scala b/hardware/chisel/src/main/scala/core/TensorLoadSimple.scala
similarity index 94%
copy from hardware/chisel/src/main/scala/core/TensorLoad.scala
copy to hardware/chisel/src/main/scala/core/TensorLoadSimple.scala
index 8b31253..c2ef007 100644
--- a/hardware/chisel/src/main/scala/core/TensorLoad.scala
+++ b/hardware/chisel/src/main/scala/core/TensorLoadSimple.scala
@@ -21,8 +21,10 @@ package vta.core
import chisel3._
import chisel3.util._
+import chisel3.util.experimental._
import vta.util.config._
import vta.shell._
+
/** TensorLoad.
*
* Load 1D and 2D tensors from main memory (DRAM) to input/weight
@@ -31,7 +33,7 @@ import vta.shell._
* managed by TensorPadCtrl. The TensorDataCtrl is in charge of
* handling the way tensors are stored on the scratchpads.
*/
-class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
+class TensorLoadSimple(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType)
@@ -189,6 +191,7 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
io.vme_rd.cmd.valid := state === sReadCmd
io.vme_rd.cmd.bits.addr := dataCtrl.io.addr
io.vme_rd.cmd.bits.len := dataCtrl.io.len
+ io.vme_rd.cmd.bits.tag := dec.sram_offset
io.vme_rd.data.ready := state === sReadData
@@ -232,6 +235,20 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
}
+ if (false) {
+ val memDumpGuard = WireInit(false.B)
+ when (memDumpGuard) {
+ for {
+ idx <- 0 until scala.math.min(64,tp.memDepth)
+ i <- 0 until tp.tensorLength} {
+ val f = (Seq.fill(tp.numMemBlock){ "%x"}).mkString(" ")
+ val s = tensorFile(i)(idx)
+ val d = Seq.tabulate(tp.numMemBlock){ j => s(j)}
+ printf(s"$tensorType: $idx $i $f\n", d:_*)
+ }
+ }
+ }
+
val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
val wdata = Seq.fill(tp.tensorLength) {
Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
@@ -244,7 +261,7 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
for (i <- 0 until tp.tensorLength) {
for (j <- 0 until tp.numMemBlock) {
wmask(i)(j) := tag === j.U
- wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
+ wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits.data)
}
val tdata = io.tensor.wr(0).bits.data(i).asUInt.asTypeOf(wdata(i))
val muxWen =
diff --git a/hardware/chisel/src/main/scala/core/TensorLoadWideVME.scala b/hardware/chisel/src/main/scala/core/TensorLoadWideVME.scala
new file mode 100644
index 0000000..f50530a
--- /dev/null
+++ b/hardware/chisel/src/main/scala/core/TensorLoadWideVME.scala
@@ -0,0 +1,765 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import scala.math.pow
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+
+/** TensorLoad.
+ *
+ * Load Cachelines from main memory (DRAM) into SRAM
+ * Mux Cachelines to tensor size memory blocks in
+ * scratchpads (SRAM). Also, there is support for zero padding, while
+ * doing the load. Zero-padding works on the y and x axis, and it is
+ * managed by ZeroPadding.
+ * Read tensors from SRAM.
+
+ * banks number (BN) = CachLineSize (CS) / Tensor bit size (TS)
+ * the number of banks is pow of 2
+ * Scratchpad: Seq(BN) {Mem(TensorsNb/BN, TS)}
+ * Cacheline: Vec(BN,CS/BN)
+
+ * Load:
+ * Scratchpad
+ * bank1 bank2
+ * | |
+ * --- ---
+ * wmask-/ \ -/ \
+ * ----- -----
+ * | | | |
+ * c | | | |
+ * a -----|-------- |
+ * c | |
+ * h | |
+ * e | |
+ * l | |
+ * i ------------------
+ * n
+ * e
+
+
+
+
+ */
+class TensorLoadWideVME(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vme_rd = new VMEReadMaster
+ val tensor = new TensorClient(tensorType)
+ })
+ // the delay cycles of write pipe. Needed to deliver singal over physical distance
+ val writePipeLatency = tp.writePipeLatency
+
+ val sIdle :: sBusy :: Nil =
+ Enum(2)
+ val state = RegInit(sIdle)
+
+ val isBusy = state === sBusy
+ val localDone = Wire(Bool())
+ when(io.start) {
+ state := sBusy
+ }.elsewhen(localDone) {
+ state := sIdle
+ }
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+
+ val readVMEDataLatency = tp.readVMEDataLatency
+ val vmeDataBitsPipe = ShiftRegister(io.vme_rd.data.bits, readVMEDataLatency, en = true.B)
+ val vmeDataValidPipe = ShiftRegister(io.vme_rd.data.valid, readVMEDataLatency, resetData = false.B, en = true.B)
+ val vmeDataReadyPipe = ShiftRegister(io.vme_rd.data.ready, readVMEDataLatency, resetData = true.B, en = true.B)
+ val vmeDataFirePipe = vmeDataValidPipe & vmeDataReadyPipe
+
+ //--------------------------------------
+ //--- Generate data load VME command ---
+ //--------------------------------------
+ val vmeCmd = Module (new GenVMECmdWideTL(tensorType, debug))
+ vmeCmd.io.start := io.start
+ vmeCmd.io.isBusy := isBusy
+ vmeCmd.io.inst := io.inst
+ vmeCmd.io.baddr := io.baddr
+ vmeCmd.io.vmeCmd <> io.vme_rd.cmd
+ val readLen = vmeCmd.io.readLen
+ val commandsDone = vmeCmd.io.done
+
+ require (mp.dataBits >= tp.tensorSizeBits,
+ "-F- Chacheline width must be larger than tensor bit size")
+ require(pow(2, log2Ceil(mp.dataBits)) == mp.dataBits,
+ "-F- Chacheline width must be pow of 2")
+ require(pow(2, log2Ceil(tp.tensorSizeBits)) == tp.tensorSizeBits,
+ "-F- Tensor size bits must be pow of 2")
+
+ // me mux puts tensors in a single memory line of Cacheline (CL) bits
+ val tensorsInClNb = tp.clSizeRatio
+ val tensorsInClNbWidth = log2Ceil(tensorsInClNb)
+
+ //--------------------------------------
+ //--- count how many CLs not receved ---
+ //--------------------------------------
+
+ // the address size of scratchpad memory
+ val clCntIdxWdth = log2Ceil(tp.memDepth/tensorsInClNb) + 1
+ // Nb of CLs requestd, not received.
+ val clInFlight = Reg(UInt(clCntIdxWdth.W))
+ when(io.start) {
+ clInFlight := 0.U
+ }.elsewhen(isBusy && io.vme_rd.cmd.fire() && !vmeDataFirePipe) {
+ clInFlight := clInFlight + readLen
+ }.elsewhen(isBusy && io.vme_rd.cmd.fire() && vmeDataFirePipe) {
+ clInFlight := clInFlight + readLen - 1.U
+ }.elsewhen(isBusy && !io.vme_rd.cmd.fire() && vmeDataFirePipe) {
+ assert(clInFlight > 0.U)
+ clInFlight := clInFlight - 1.U
+ }.otherwise {
+ clInFlight := clInFlight
+ }
+
+ //---------------------
+ //--- Read VME data ---
+ //---------------------
+
+ val readData = Module(new ReadVMEDataWide(tensorType, debug))
+ readData.io.start := io.start
+ readData.io.vmeData.valid := vmeDataValidPipe
+ readData.io.vmeData.bits := vmeDataBitsPipe
+ assert(!readData.io.vmeData.valid || readData.io.vmeData.ready,
+ "-F- Expecting const ready. Fix ReadVMEData to receive data piped after ready")
+ io.vme_rd.data.ready := readData.io.vmeData.ready
+ // write mask defined number of elems strating with offset in SRAM line
+ val rdDataWrIdx = readData.io.destIdx // SP index vector
+ val rdDataWrData = readData.io.destData // SP data vector
+ val rdDataWrEn = readData.io.destMask // write enable vector
+
+ //-------------------------
+ //--- Fill zero padding ---
+ //-------------------------
+
+ val fillPadding = Module(new ZeroPadding(tensorType, debug))
+ fillPadding.io.canWriteMem := !vmeDataFirePipe
+ fillPadding.io.inst := io.inst
+ fillPadding.io.start := io.start
+
+ val isZeroPadWrite = fillPadding.io.tensorIdx.valid // Store zero filled tensor, zpDestIdx is valid
+ val zpDestIdx = fillPadding.io.tensorIdx.bits >> tensorsInClNbWidth // SP idx
+ val zpDestMask =
+ if (tensorsInClNb == 1) 1.U
+ else UIntToOH(fillPadding.io.tensorIdx.bits (tensorsInClNbWidth - 1, 0)) // tensor in a memory line
+ val paddingDone = fillPadding.io.done
+
+ //--------------------
+ //--- Write memory ---
+ //--------------------
+
+ // depth is reduced by dataBlock/tensorSize ratio
+ // width is dataBlock bits split into tensor bits
+ // each tensor is split into group bits
+ // group bits can be read/written independently
+
+
+ val splitDataFactor = tp.splitWidth * tp.splitLength
+ val splitMemFactor = tp.splitMemsFactor
+ val groupSizeBits = tp.tensorSizeBits/splitDataFactor
+ val memSizeBits = groupSizeBits/splitMemFactor
+ val tensorFile = Seq.fill(tensorsInClNb * splitDataFactor*splitMemFactor) {
+ SyncReadMem(tp.memDepth/tensorsInClNb, UInt(memSizeBits.W))
+ }
+
+ // direct write
+ val directWrIdx = for (grpIdx <- 0 until splitDataFactor) yield {
+ io.tensor.wr(grpIdx).bits.idx >> tensorsInClNbWidth // SP idx
+ }
+ val directWrMask = for (grpIdx <- 0 until splitDataFactor) yield {
+ Mux(
+ io.tensor.wr(grpIdx).valid,
+ if(tensorsInClNb == 1) 1.U
+ else UIntToOH(io.tensor.wr(grpIdx).bits.idx(tensorsInClNbWidth - 1, 0)),// tensor in a memory line
+ 0.U)
+ }
+
+ // THIS directWrData writes continous scratchpad data space
+ // It is WRONG for ACC is batch is > 1
+ // maps group data bits to continous sequence of mem blocks
+ // but wr(x).bits.data is a window in a tensor
+ val directWrData = VecInit(for (grpIdx <- 0 until splitDataFactor) yield {
+ io.tensor.wr(grpIdx).bits.data
+ }).asTypeOf(UInt(tp.tensorSizeBits.W))
+
+
+ val wmask = Wire(Vec(tensorsInClNb*splitDataFactor*splitMemFactor, Bool()))
+ for (i <- 0 until tensorsInClNb) {
+ for (grpIdx <- 0 until splitDataFactor) {
+ for (memIdx <- 0 until splitMemFactor) { // duplicate control
+ wmask(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx) :=
+ Mux(
+ ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en = true.B),
+ directWrMask(grpIdx)(i),
+ Mux(
+ ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = false.B, en = true.B),
+ ShiftRegister(zpDestMask(i), writePipeLatency),
+ Mux(
+ ShiftRegister(vmeDataFirePipe, writePipeLatency, resetData = false.B, en = true.B),
+ ShiftRegister(rdDataWrEn(i), writePipeLatency),
+ false.B)))
+ }
+ }
+ }
+
+ val wdata = Wire(Vec(tensorsInClNb*splitDataFactor, UInt(groupSizeBits.W)))
+ for (i <- 0 until tensorsInClNb){
+ for (grpIdx <- 0 until splitDataFactor) {
+ val zpDestData = 0.U
+ wdata(i*splitDataFactor + grpIdx) := Mux(
+ ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en = true.B),
+ io.tensor.wr(grpIdx).bits.data.asTypeOf(UInt(groupSizeBits.W)),
+ Mux(
+ ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = false.B, en = true.B),
+ ShiftRegister(zpDestData /* group size zero */, writePipeLatency),
+ ShiftRegister(
+ (rdDataWrData(i).asTypeOf(Vec(splitDataFactor, UInt(groupSizeBits.W))))(grpIdx), writePipeLatency)))
+ }
+ }
+
+ val widx = Wire(Vec(tensorsInClNb*splitDataFactor*splitMemFactor, UInt(tp.memAddrBits.W)))
+ for (i <- 0 until tensorsInClNb) {
+ for (grpIdx <- 0 until splitDataFactor) {
+ for (memIdx <- 0 until splitMemFactor) { // duplicate control
+ widx(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx) :=
+ Mux(
+ ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en = true.B),
+ directWrIdx(grpIdx),
+ Mux(
+ ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = false.B, en = true.B),
+ ShiftRegister(zpDestIdx, writePipeLatency),
+ ShiftRegister(rdDataWrIdx(i), writePipeLatency)))
+ }
+ }
+ }
+
+ for (i <- 0 until tensorsInClNb) {
+ for (grpIdx <- 0 until splitDataFactor) {
+ for (memIdx <- 0 until splitMemFactor) { // duplicate control
+ when(wmask(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx)) {
+ tensorFile(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx).write(
+ widx(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx),
+ wdata(i*splitDataFactor + grpIdx).asTypeOf(
+ Vec(splitMemFactor, UInt(memSizeBits.W)))(memIdx))
+ }
+ }
+ }
+ }
+ if (debug) {
+ when(isZeroPadWrite) {
+ printf(s"[TensorLoad] $tensorType isZeroPadWrite data zpDestIdx:%d\n",
+ zpDestIdx)
+ }
+ }
+
+ // read-from-sram
+ for (grpIdx <- 0 until splitDataFactor) {
+ val rIdx = io.tensor.rd(grpIdx).idx.bits >> tensorsInClNbWidth // SP idx
+ val rMask =
+ Mux(
+ io.tensor.rd(grpIdx).idx.valid,
+ if(tensorsInClNb == 1) 1.U
+ else UIntToOH(io.tensor.rd(grpIdx).idx.bits(tensorsInClNbWidth - 1, 0)),// tensor in a memory line
+ 0.U)
+
+ val rdataVec = for (i <- 0 until tensorsInClNb) yield {
+ VecInit(for (memIdx <- 0 until splitMemFactor) yield {
+ tensorFile(
+ i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx).read(
+ ShiftRegister(rIdx, tp.readTensorLatency),
+ ShiftRegister(VecInit(rMask.asBools)(i), tp.readTensorLatency, resetData = false.B, en = true.B))
+ }).asUInt
+ }
+
+ val rdata = Wire(UInt(tp.tensorSizeBits.W))
+ rdata := Mux1H(ShiftRegister(rMask, tp.readTensorLatency + 1), rdataVec)
+ io.tensor.rd(grpIdx).data.bits := rdata.asTypeOf(io.tensor.rd(grpIdx).data.bits.cloneType)
+
+ val rvalid = ShiftRegister(
+ io.tensor.rd(grpIdx).idx.valid, tp.readTensorLatency + 1, resetData = false.B, en = true.B)
+ io.tensor.rd(grpIdx).data.valid := rvalid
+ }
+
+ // done
+ val loadDone = clInFlight === 0.U && commandsDone && state === sBusy
+ localDone := loadDone && paddingDone
+ io.done := ShiftRegister(localDone, writePipeLatency, resetData = false.B, en = true.B)
+}
+
+//---------------------
+//--- Read VME data ---
+//---------------------
+//----------------------------------------------------------------------------
+// Read VME data. Generate Memory index and data
+// transaction TAG is a data block offset in scratchpad
+// Different transactions are identified by atag change
+// SAME DESTINATION SUBSEQUENT REQUESTS IN ONE INSTRUCTION LEADS TO UNDEFINED BEHAVIOR
+//----------------------------------------------------------------------------
+class ReadVMEDataWide(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val wmaskWidth = mp.dataBits/tp.tensorSizeBits
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val vmeData = Flipped(Decoupled(new VMEData))
+
+ val destIdx = Output(Vec(tp.clSizeRatio, UInt(tp.memAddrBits.W)))
+ val destData = Output(Vec(tp.clSizeRatio, UInt(tp.tensorSizeBits.W)))
+ val destMask = Output(Vec(tp.clSizeRatio, Bool()))
+ })
+
+ io.vmeData.ready := true.B // always ready to read VME data
+
+ require(pow(2, log2Ceil(tp.tensorLength)) == tp.tensorLength,
+ "-F- Tensor length must be 2^. Using shift and bits to divide.")
+ val blkIdxWdth = log2Ceil(tp.memDepth) // the size of scratchpad in cache lines
+
+ //decode data destination
+ val vmeTagDecode = io.vmeData.bits.tag
+ val vmeTagDecodeLast = Reg(vmeTagDecode.cloneType) // store tag to identify a new burst
+ val clBytes = mp.dataBits / 8 // cacheline bytes
+ val elemBytes = tp.tensorLength * tp.tensorWidth * tp.tensorElemBits / 8 // bytes in tensor
+ val rdDataMaskDecodeWidth = if (wmaskWidth == 1) 1 else (log2Ceil(wmaskWidth) + 1)
+ val rdDataElemIdx = vmeTagDecode(vmeTagDecode.getWidth - 1, 2 * rdDataMaskDecodeWidth)
+ val rdFstOffsetNb = if (rdDataMaskDecodeWidth == 0) {
+ 0.U
+ } else {
+ val readOffset = vmeTagDecode(2 * rdDataMaskDecodeWidth - 1, rdDataMaskDecodeWidth)
+ readOffset
+ }
+ val rdLstNb = if (rdDataMaskDecodeWidth == 0) {
+ 1.U
+ } else {
+ val readNb = vmeTagDecode(rdDataMaskDecodeWidth - 1, 0)
+ assert(!io.vmeData.valid || readNb > 0.U,"-F- Expecting some elements to read")
+ readNb
+ }
+ val wrMask1st = if (rdDataMaskDecodeWidth == 0) {
+ 1.U
+ } else {
+ Reverse(VecInit(for(idx <- 0 until wmaskWidth) yield {
+ idx.U < tp.clSizeRatio.U - rdFstOffsetNb
+ }).asUInt)
+ }
+ val wrMaskLast = if (rdDataMaskDecodeWidth == 0) {
+ 1.U
+ } else {
+ VecInit(for(idx <- 0 until wmaskWidth) yield {
+ idx.U < rdLstNb
+ }).asUInt
+ }
+ val rdDataElemDestIdx = Wire(UInt(tp.memAddrBits.W)) // this is an idx of a tensor
+ val rdDataElemDestIdxNext = Reg(UInt(tp.memAddrBits.W))
+ val rdDataClDestIdx = rdDataElemDestIdx >> log2Ceil(tp.clSizeRatio)
+ val rdDataDestElemOffset = rdDataElemDestIdx % tp.clSizeRatio.U
+
+ val vmeTagDecodeLastValid = Wire(Bool())
+ val vmeTagDecodeLastValidNext = RegNext(
+ next = vmeTagDecodeLastValid,
+ init = false.B)
+ when(io.start) {
+ vmeTagDecodeLastValid :=false.B // reset tag valid
+ }.elsewhen(io.vmeData.fire()) {
+ vmeTagDecodeLastValid := true.B // set tag valid on a new read
+ }.otherwise {
+ vmeTagDecodeLastValid := vmeTagDecodeLastValidNext // keep value
+ }
+
+ val isFirstPulse = Wire(Bool())
+ val isLastPulse = io.vmeData.bits.last
+ val wmaskSel =
+ Mux(
+ isFirstPulse && isLastPulse,
+ wrMask1st & wrMaskLast,
+ Mux(
+ isFirstPulse,
+ wrMask1st,
+ Mux(
+ isLastPulse,
+ wrMaskLast,
+ ((1 << wmaskWidth) - 1).U)))
+ val wmask = Mux(io.vmeData.fire(), wmaskSel, 0.U)
+ rdDataElemDestIdx := DontCare
+ isFirstPulse := false.B
+ when(io.vmeData.fire()) {
+ when (
+ !vmeTagDecodeLastValidNext ||
+ (vmeTagDecodeLastValidNext &&
+ vmeTagDecode.asUInt =/= vmeTagDecodeLast.asUInt)) {
+
+ vmeTagDecodeLast := vmeTagDecode // a new burst
+ isFirstPulse := true.B
+ rdDataElemDestIdx := rdDataElemIdx
+ // dont incrememt first partial read pulse
+ rdDataElemDestIdxNext := rdDataElemIdx + PopCount(wmask)
+ }.otherwise {
+ rdDataElemDestIdxNext := rdDataElemDestIdxNext + PopCount(wmask)
+ rdDataElemDestIdx := rdDataElemDestIdxNext
+ }
+ }
+
+
+ val srcData = io.vmeData.bits.data.asTypeOf(Vec(tp.clSizeRatio, UInt(tp.tensorSizeBits.W)))
+ val srcOffset = Wire(Vec(tp.clSizeRatio, UInt((log2Ceil(tp.clSizeRatio) + 1).W)))
+ val srcIdx = Wire(Vec(tp.clSizeRatio, UInt(log2Ceil(tp.clSizeRatio).W)))
+
+ // D(j+d) = S(j+s) replace i=j+d --> D(i) = S(i-d+s)
+ for (i <- 0 until tp.clSizeRatio) {
+ srcOffset(i) := i.U + Mux(isFirstPulse, rdFstOffsetNb, 0.U)
+ srcIdx(i) := srcOffset(i) -% rdDataDestElemOffset
+ val srcIdxOH = UIntToOH(srcIdx(i))
+ io.destData(i) := Mux1H(srcIdxOH,srcData)
+ io.destMask(i) := Mux1H(srcIdxOH, wmask)
+
+ //if dest offset overflow, incr that dest idx
+ val incrIdx = if (tp.clSizeRatio == 1 ) {
+ 0.U
+ } else {
+ Mux(srcOffset(i) >= rdDataDestElemOffset, 0.U, 1.U)
+ }
+ io.destIdx(i) := rdDataClDestIdx + incrIdx
+
+
+ }
+
+
+}
+
+// transaction TAG is a data block offset in scratchpad
+// Different transactions are identified by atag change
+// SAME DESTINATION SUBSEQUENT REQUESTS IN ONE INSTRUCTION LEADS TO UNDEFINED BEHAVIOR
+class GenVMECmdWide(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val isBusy = Input(Bool())
+ val updateState = Input(Bool())
+ val canSendCmd = Input(Bool())
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vmeCmd = Decoupled(new VMECmd)
+ val readLen = Output(UInt((mp.lenBits + 1).W))
+ val done = Output(Bool())
+ val fstPulseDataStart = Output(UInt((log2Ceil(tp.clSizeRatio) + 1).W))
+ val lstPulseDataEnd = Output(UInt((log2Ceil(tp.clSizeRatio) + 1).W))
+ val spElemIdx = Output(UInt(tp.memAddrBits.W))
+
+ val ysize = Input(UInt(M_SIZE_BITS.W))
+ val xsize = Input(UInt(M_SIZE_BITS.W))
+ val xstride = Input(UInt(M_STRIDE_BITS.W))
+ val dram_offset = Input(UInt(M_DRAM_OFFSET_BITS.W))
+ val sram_offset = Input(UInt(M_SRAM_OFFSET_BITS.W))
+ val xpad_0 = Input(UInt(M_PAD_BITS.W))
+ val xpad_1 = Input(UInt(M_PAD_BITS.W))
+ val ypad_0 = Input(UInt(M_PAD_BITS.W))
+ })
+
+ val clBytes = mp.dataBits / 8 // cacheline bytes
+ val elemBytes = tp.tensorLength * tp.tensorWidth * tp.tensorElemBits / 8 // bytes in tensor
+ val stride = Wire(Bool()) // flags change to the next row to read
+
+ //----------------------------------------
+ //--- Count lines of DRAM memory lines ---
+ //----------------------------------------
+
+ // set which source row of data to read. io.ysize defines the number of rows
+ val dramLineIdx = Reg(UInt(io.ysize.getWidth.W)) // current row of stride read
+ when (io.start) {
+ dramLineIdx := 0.U // 1st row
+ }.elsewhen (stride) {
+ dramLineIdx := dramLineIdx + 1.U // increment row
+ }.otherwise {
+ dramLineIdx := dramLineIdx // stay in the row
+ }
+
+ // calculate address of DRAM memory line begin (initial/stride)
+ val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
+ val dramInitialAddr = (io.dram_offset << log2Ceil(elemBytes)).asTypeOf(UInt(mp.addrBits.W))
+ val xferElemInitAddr = io.baddr | dramInitialAddr // SHOULD have + here?
+ //aling address to CL size
+ // lower bits - elem offset in a cachline
+ val dramClAddrAlignNotMask = ((BigInt(1) << log2Ceil(clBytes)) - 1).U.asTypeOf(xferElemInitAddr)
+ // upper bits - cacheline alinement
+ val dramClAddrAlignMask = ~dramClAddrAlignNotMask
+ val xferClInitAddr = xferElemInitAddr & dramClAddrAlignMask
+ val rdLineElemBeginAddr = Reg(UInt(mp.addrBits.W)) // DRAM address of xsize tensors memory line
+ val rdLineClBeginAddr = rdLineElemBeginAddr & dramClAddrAlignMask
+ // begin of the next DRAM memory line
+ val nextLineBeginElemAddr = rdLineElemBeginAddr + (io.xstride << log2Ceil(elemBytes))
+ val nextLineBeginClAddr = nextLineBeginElemAddr & dramClAddrAlignMask
+ when (io.start) {
+ rdLineElemBeginAddr := xferElemInitAddr
+ }.elsewhen (stride) {
+ rdLineElemBeginAddr := nextLineBeginElemAddr
+ }.otherwise {
+ rdLineElemBeginAddr := rdLineElemBeginAddr
+ }
+
+ //-----------------------------------------------------
+ //--- Calculate current DRAM address of transaction ---
+ //-----------------------------------------------------
+
+ val rdLen = Wire(UInt((mp.lenBits + 1).W)) // read cmd transaction length. It is <= maxTransfer
+ val rdLineAddr = Reg(UInt(mp.addrBits.W)) // current DRAM address of command
+ when (io.start) {
+ rdLineAddr := xferClInitAddr
+ }.elsewhen (io.updateState) {
+ when(stride) {
+ rdLineAddr := nextLineBeginClAddr
+ }.otherwise {
+ rdLineAddr := rdLineAddr + (rdLen << log2Ceil(clBytes))
+ }
+ }.otherwise {
+ rdLineAddr := rdLineAddr
+ }
+
+ //total load length in cachelines
+ val rdLineBytes = io.xsize << log2Ceil(elemBytes)
+
+ //First transaction in a line length (1st or stride)
+ val maxTransfer = (1 << mp.lenBits).U // max number of pulses in transfer
+ val maxTrBytes = maxTransfer << log2Ceil(clBytes)
+ val rdLen1stMaxTransBytes = maxTrBytes - rdLineClBeginAddr % maxTrBytes
+ // get the number of cachelines till maxTrBytes aligned address
+ val rdLen1stMaxTransClNb = rdLen1stMaxTransBytes >> log2Ceil(clBytes)
+
+ //Transaction begin mask. Number of tensors to read from right
+ val rd1stPulseOffsetBytes = rdLineElemBeginAddr % clBytes.U
+ assert(rd1stPulseOffsetBytes >> log2Ceil(elemBytes) <= tp.clSizeRatio.U,
+ "-F- Expecting the number of tensors to skip in CL")
+ val rd1stPulseOffsetTensNb = Wire(UInt((log2Ceil(tp.clSizeRatio) + 1).W))
+ rd1stPulseOffsetTensNb := rd1stPulseOffsetBytes >> log2Ceil(elemBytes)
+
+ val rdLineClNbTmp = (rdLineBytes + rd1stPulseOffsetBytes) >> log2Ceil(clBytes)
+ val rdLineClNb =
+ Mux((rdLineBytes + rd1stPulseOffsetBytes) % clBytes.U === 0.U, rdLineClNbTmp, rdLineClNbTmp + 1.U)
+
+ //Transaction end mask. Number of tensors to read from left
+ val rdLastPulseBytes = (rdLineElemBeginAddr + rdLineBytes) % clBytes.U
+ assert(rdLastPulseBytes >> log2Ceil(elemBytes) <= (clBytes/elemBytes).U,
+ "-F- Expecting the number of active tensors in CL")
+ val rdLastPulseTensNb = Wire(UInt((log2Ceil(clBytes/elemBytes) + 1).W))
+ val rdLastPulseTensNbTmp = rdLastPulseBytes >> log2Ceil(elemBytes)
+ rdLastPulseTensNb := Mux(rdLastPulseTensNbTmp === 0.U, (clBytes/elemBytes).U, rdLastPulseTensNbTmp)
+
+
+
+ //--------------------------------------
+ //--- Generate data load VME command ---
+ //--------------------------------------
+
+ val rdCmdStartIdxValid = Wire(Bool()) // Command is valid
+ val startIssueCmdRead = Wire(Bool()) // First transaction in io.xsize transfer
+ val rdCmdStartIdx = Reg(UInt(log2Ceil(tp.memDepth).W)) // Scratchpad data block index for the first transaction
+ val commandsDone = RegInit(true.B) // Done generating VME commands
+ // counts the number of CLs read in a xsize line
+ val clReadIdx = Reg(UInt((io.xsize.getWidth + log2Ceil(elemBytes) - log2Ceil(clBytes)).W))
+ val newReadRow = clReadIdx === 0.U // flags the first read of io.xsize
+
+ // set how many blocks of data being loaded
+ commandsDone := commandsDone
+ when (io.start || stride) {
+ clReadIdx := 0.U
+ commandsDone := false.B
+ }.elsewhen (io.updateState) {
+ val nextClIdx = clReadIdx + rdLen
+ clReadIdx := nextClIdx // THIS IS WHEN A NEW VME CMD HAPPENS
+ when (nextClIdx === rdLineClNb && dramLineIdx === io.ysize - 1.U) {
+ commandsDone := true.B
+ }
+ }.otherwise {
+ clReadIdx := clReadIdx
+ }
+
+ //when the whole xsize row read commands are sent, go for the next src row
+ when((clReadIdx === rdLineClNb - rdLen) && (dramLineIdx =/= io.ysize - 1.U) && io.updateState) {
+ stride := true.B
+ }.otherwise {
+ stride := false.B
+ }
+
+ // current transaction tensors to read nb in 1st and last pulses
+ val rdCmd1stPluseOffsetTensNb = Wire(rd1stPulseOffsetTensNb.cloneType)
+ val rdCmdLastPluseTensNb = Wire(rdLastPulseTensNb.cloneType)
+ when(newReadRow) {
+ // first read in line
+ rdCmd1stPluseOffsetTensNb := rd1stPulseOffsetTensNb
+ }.otherwise {
+ // any other read
+ rdCmd1stPluseOffsetTensNb := 0.U
+ }
+ when (clReadIdx === rdLineClNb - rdLen) {
+ // last read in line
+ rdCmdLastPluseTensNb := rdLastPulseTensNb
+ }.otherwise {
+ // any other read
+ rdCmdLastPluseTensNb := (clBytes/elemBytes).U
+ }
+
+ //when the whole xsize row read commands are sent, go for the next src row
+ when((clReadIdx === rdLineClNb - rdLen) && (dramLineIdx =/= io.ysize - 1.U) && io.updateState) {
+ stride := true.B
+ }.otherwise {
+ stride := false.B
+ }
+
+ assert(!io.isBusy || rdLineClNb >= clReadIdx)// define how many cachelines to read at this cycle
+ val clRemained = rdLineClNb - clReadIdx
+ when (newReadRow) {
+ when(clRemained < rdLen1stMaxTransClNb) {
+ rdLen := clRemained
+ }.otherwise {
+ rdLen := rdLen1stMaxTransClNb
+ }
+ }.otherwise {
+ when(clRemained < maxTransfer) {
+ rdLen := clRemained
+ }.otherwise {
+ rdLen := maxTransfer
+ }
+ }
+ // block index of the read data row (xsize). Modified by zero padding
+ val totalWidth = io.xsize + io.xpad_0 + io.xpad_1 // width of scratchpad matrix in tensors
+ // instead of multiplying total width by ypad_0 do incremental addition.
+ //Should cost ypad_0 cycles to issue 1st read cmd
+ // counts src matrix with y padding rows of tensors
+ val currentRowIdx = Reg(UInt((io.ysize.getWidth + io.ypad_0.getWidth).W))
+ // start to issue read cmd
+ rdCmdStartIdxValid := currentRowIdx >= io.ypad_0 &&
+ currentRowIdx < (io.ysize + io.ypad_0) &&
+ io.isBusy &&
+ !commandsDone
+ when (io.start) {
+ currentRowIdx := 0.U
+ rdCmdStartIdx := io.sram_offset + io.xpad_0 // this index is in tensors
+ }.elsewhen (io.isBusy && (currentRowIdx < io.ypad_0 || stride)) {
+ rdCmdStartIdx := rdCmdStartIdx + totalWidth
+ currentRowIdx := currentRowIdx + 1.U
+ }
+ startIssueCmdRead := false.B
+ when(newReadRow && rdCmdStartIdxValid) {
+ startIssueCmdRead := true.B
+ }
+
+ //-------------------------------------
+ //--- execute VME data load command ---
+ //-------------------------------------
+
+ require(pow(2, log2Ceil(tp.tensorLength)) == tp.tensorLength,
+ "-F- Tensor length must be 2^. Using shift and bits to divide.")
+ val blkIdxWdth = log2Ceil(tp.memDepth) // the size of scratchpad
+
+ val rdCmdDestElemIdx = Wire(UInt(tp.memAddrBits.W)) // element(tensor) size block index in a scratchpad
+ val rdCmdDestElemIdxNext = Reg(rdCmdDestElemIdx.cloneType)
+ rdCmdDestElemIdxNext := rdCmdDestElemIdxNext
+ rdCmdDestElemIdx := rdCmdDestElemIdxNext
+
+ val rdCmdValid = Wire(Bool())
+ // the number of tensors read in transaction
+ val rdCmdTransactionTensNb = (rdLen << log2Ceil(clBytes/elemBytes)) - rdCmd1stPluseOffsetTensNb
+ //increment scratch pad destination index
+ when(rdCmdStartIdxValid) {
+ rdCmdValid := true.B
+ when(startIssueCmdRead) {
+ rdCmdDestElemIdx := rdCmdStartIdx
+ rdCmdDestElemIdxNext:= rdCmdStartIdx + rdCmdTransactionTensNb
+ }.elsewhen (io.updateState) {
+ // increment block position by transaction length
+ rdCmdDestElemIdxNext:= rdCmdDestElemIdxNext + rdCmdTransactionTensNb
+ }
+ }.otherwise {
+ rdCmdValid := false.B
+ }
+
+ // read-from-dram
+ require(io.vmeCmd.bits.tag.getWidth >= rdCmdDestElemIdx.getWidth +
+ rdCmdLastPluseTensNb.getWidth + rdCmd1stPluseOffsetTensNb.getWidth,
+ s"-F- Tensor ${tensorType} Not enough VME tag bits to store transaction" +
+ s" tag. need:${rdCmdDestElemIdx.getWidth + rdCmdLastPluseTensNb.getWidth + rdCmd1stPluseOffsetTensNb.getWidth}")
+ io.vmeCmd.valid := rdCmdValid && io.canSendCmd
+ io.vmeCmd.bits.addr := rdLineAddr
+ io.vmeCmd.bits.len := rdLen - 1.U
+ assert(!io.vmeCmd.valid || ((rdLen << log2Ceil(clBytes)) <= maxTrBytes - rdLineAddr % maxTrBytes),
+ s"-F- ${tensorType} DRAM page alignment failure. DRAM " +
+ s"address + len overlaps mp.lenBits*memBlockSize alignment %x %x",
+ rdLineAddr, rdLen)
+ io.vmeCmd.bits.tag := Cat(rdCmdDestElemIdx, Cat(rdCmd1stPluseOffsetTensNb, rdCmdLastPluseTensNb))
+ io.readLen := rdLen
+ io.spElemIdx := rdCmdDestElemIdx // scratchpad tensor idx
+ io.fstPulseDataStart := rdCmd1stPluseOffsetTensNb // first pulse data start
+ io.lstPulseDataEnd := rdCmdLastPluseTensNb // last pulse data end
+ io.done := commandsDone
+}
+
+class GenVMECmdWideTL(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val isBusy = Input(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vmeCmd = Decoupled(new VMECmd)
+ val readLen = Output(UInt((mp.lenBits + 1).W))
+ val done = Output(Bool())
+ })
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+
+ val cmdGen = Module (new GenVMECmdWide(tensorType, debug))
+
+ cmdGen.io.start := io.start
+ cmdGen.io.isBusy := io.isBusy
+ cmdGen.io.baddr := io.baddr
+ io.vmeCmd <> cmdGen.io.vmeCmd
+ io.readLen := cmdGen.io.readLen
+ io.done := cmdGen.io.done
+
+ cmdGen.io.ysize := dec.ysize
+ cmdGen.io.xsize := dec.xsize
+ cmdGen.io.xstride := dec.xstride
+ cmdGen.io.dram_offset := dec.dram_offset
+ cmdGen.io.sram_offset := dec.sram_offset
+ cmdGen.io.xpad_0 := dec.xpad_0
+ cmdGen.io.xpad_1 := dec.xpad_1
+ cmdGen.io.ypad_0 := dec.ypad_0
+ cmdGen.io.updateState := io.vmeCmd.fire()
+ cmdGen.io.canSendCmd := true.B
+}
diff --git a/hardware/chisel/src/main/scala/core/TensorStore.scala b/hardware/chisel/src/main/scala/core/TensorStore.scala
index f1556ef..99bffa6 100644
--- a/hardware/chisel/src/main/scala/core/TensorStore.scala
+++ b/hardware/chisel/src/main/scala/core/TensorStore.scala
@@ -41,228 +41,22 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
val vme_wr = new VMEWriteMaster
val tensor = new TensorClient(tensorType)
})
- val tensorLength = tp.tensorLength
- val tensorWidth = tp.tensorWidth
- val tensorElemBits = tp.tensorElemBits
- val memBlockBits = tp.memBlockBits
- val memDepth = tp.memDepth
- val numMemBlock = tp.numMemBlock
- require(numMemBlock > 0, s"-F- TensorStore doesnt support pulse width" +
- s"wider than tensor width. Needed for stride support tensorWidth=${tensorWidth}")
- require(tp.splitWidth == 1 && tp.splitLength == 1, s"-F- ${tensorType} Cannot do split direct access")
- val dec = io.inst.asTypeOf(new MemDecode)
- val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
- val waddr_nxt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
- val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
- val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
- val xrem = Reg(chiselTypeOf(dec.xsize))
- val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock))
- val xmax = (1 << mp.lenBits).U
- val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
- val ycnt = Reg(chiselTypeOf(dec.ysize))
- val ysize = dec.ysize
- val tag = Reg(UInt(8.W))
- val set = Reg(UInt(8.W))
-
- val xfer_bytes = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
- val xstride_bytes = dec.xstride << log2Ceil(tensorLength * tensorWidth)
- val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
- val elemBytes = (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).outBits) / 8
- val pulse_bytes_bits = log2Ceil(mp.dataBits >> 3)
-
- val xfer_init_addr = io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
- val xfer_split_addr = waddr_cur + xfer_bytes
- val xfer_stride_addr = waddr_nxt + xstride_bytes
-
- val xfer_init_bytes = xmax_bytes - xfer_init_addr % xmax_bytes
- val xfer_init_pulses = xfer_init_bytes >> pulse_bytes_bits
- val xfer_split_bytes = xmax_bytes - xfer_split_addr % xmax_bytes
- val xfer_split_pulses = xfer_split_bytes >> pulse_bytes_bits
- val xfer_stride_bytes = xmax_bytes - xfer_stride_addr % xmax_bytes
- val xfer_stride_pulses= xfer_stride_bytes >> pulse_bytes_bits
-
- val sIdle :: sWriteCmd :: sWriteData :: sReadMem :: sWriteAck :: Nil = Enum(5)
- val state = RegInit(sIdle)
-
- // control
- switch(state) {
- is(sIdle) {
- xfer_bytes := xfer_init_bytes
- when (io.start) {
- state := sWriteCmd
- when (xsize < xfer_init_pulses) {
- assert(xsize > 0.U, "Idle => WriteCmd, init, without xrem: must have positive xsize")
- xlen := xsize - 1.U
- xrem := 0.U
- }.otherwise {
- xlen := xfer_init_pulses - 1.U
- assert(xsize >= xfer_init_pulses,
- "Idle => WriteCmd, init, with xrem: must have xsize no smaller than xfer_init_pulses")
- xrem := xsize - xfer_init_pulses
- }
- }
- }
- is(sWriteCmd) {
- when(io.vme_wr.cmd.ready) {
- state := sWriteData
- }
- }
- is(sWriteData) {
- when(io.vme_wr.data.ready) {
- when(xcnt === xlen) {
- state := sWriteAck
- }.elsewhen(tag === (numMemBlock - 1).U) {
- state := sReadMem
- }
- }
- }
- is(sReadMem) {
- state := sWriteData
- }
- is(sWriteAck) {
- when(io.vme_wr.ack) {
- when(xrem === 0.U) {
- when(ycnt === ysize - 1.U) {
- state := sIdle
- }.otherwise { // stride
- state := sWriteCmd
- xfer_bytes := xfer_stride_bytes
- when(xsize < xfer_stride_pulses) {
- assert(xsize > 0.U, "WriteAck => WriteCmd, stride, without xrem: must have positive xsize")
- xlen := xsize - 1.U
- xrem := 0.U
- }.otherwise {
- xlen := xfer_stride_pulses - 1.U
- assert(xsize >= xfer_stride_pulses,
- "WriteAck => WriteCmd, stride, with xrem: must have xsize no smaller than xfer_stride_pulses")
- xrem := xsize - xfer_stride_pulses
- }
- }
- } // split
- .elsewhen(xrem < xfer_split_pulses) {
- state := sWriteCmd
- xfer_bytes := xfer_split_bytes
- assert(xrem > 0.U, "WriteAck => WriteCmd, split, without xrem: must have positive xrem")
- xlen := xrem - 1.U
- xrem := 0.U
- }
- .otherwise {
- state := sWriteCmd
- xfer_bytes := xfer_split_bytes
- xlen := xfer_split_pulses - 1.U
- assert(xrem >= xfer_split_pulses,
- "WriteAck => WriteCmd, split, with xrem: must have xrem no smaller than xfer_split_pulses")
- xrem := xrem - xfer_split_pulses
- }
- }
- }
- }
-
- // write-to-sram
- val tensorFile = Seq.fill(tensorLength) {
- SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W)))
- }
- val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W)))
- val no_mask = Wire(Vec(numMemBlock, Bool()))
-
- wdata_t := DontCare
- no_mask.foreach { m =>
- m := true.B
- }
-
- for (i <- 0 until tensorLength) {
- val inWrData = io.tensor.wr(0).bits.data(i).asUInt.asTypeOf(wdata_t)
- when(io.tensor.wr(0).valid) {
- tensorFile(i).write(io.tensor.wr(0).bits.idx, inWrData, no_mask)
- }
- }
-
- // read-from-sram
- val stride = state === sWriteAck &
- io.vme_wr.ack &
- xcnt === xlen + 1.U &
- xrem === 0.U &
- ycnt =/= ysize - 1.U
-
- when(state === sIdle) {
- ycnt := 0.U
- }.elsewhen(stride) {
- ycnt := ycnt + 1.U
- }
-
- when(state === sWriteCmd || tag === (numMemBlock - 1).U) {
- tag := 0.U
- }.elsewhen(io.vme_wr.data.fire()) {
- tag := tag + 1.U
- }
-
- when(
- state === sWriteCmd || (state =/= sReadMem && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
- set := 0.U
- }.elsewhen(io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
- set := set + 1.U
- }
-
- val raddr_cur = Reg(UInt(tp.memAddrBits.W))
- val raddr_nxt = Reg(UInt(tp.memAddrBits.W))
- when(state === sIdle) {
- raddr_cur := dec.sram_offset
- raddr_nxt := dec.sram_offset
- }.elsewhen(io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
- raddr_cur := raddr_cur + 1.U
- }.elsewhen(stride) {
- raddr_cur := raddr_nxt + dec.xsize
- raddr_nxt := raddr_nxt + dec.xsize
- }
-
- val tread = Seq.tabulate(tensorLength) { i =>
- i.U ->
- tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem)
- }
- val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)
-
- // write-to-dram
- when(state === sIdle) {
- waddr_cur := xfer_init_addr
- waddr_nxt := xfer_init_addr
- }.elsewhen(state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
- waddr_cur := xfer_split_addr
- }.elsewhen(stride) {
- waddr_cur := xfer_stride_addr
- waddr_nxt := xfer_stride_addr
- }
-
- io.vme_wr.cmd.valid := state === sWriteCmd
- io.vme_wr.cmd.bits.addr := waddr_cur
- io.vme_wr.cmd.bits.len := xlen
-
- io.vme_wr.data.valid := state === sWriteData
- io.vme_wr.data.bits := mdata(tag)
-
- when(state === sWriteCmd) {
- xcnt := 0.U
- }.elsewhen(io.vme_wr.data.fire()) {
- xcnt := xcnt + 1.U
- }
-
- // disable external read-from-sram requests
- io.tensor.tieoffRead()
-
- // done
- io.done := state === sWriteAck & io.vme_wr.ack & xrem === 0.U & ycnt === ysize - 1.U
-
- // debug
- if (debug) {
- when(io.vme_wr.cmd.fire()) {
- printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n",
- ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
- }
- when(io.vme_wr.data.fire()) {
- printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
- }
- when(io.vme_wr.ack) {
- printf("[TensorStore] ack\n")
- }
+ override def desiredName = "TensorStore" + tensorType.capitalize
+
+ val forceSimpleStore = false // force original store flow. Narrow it is
+
+ if (mp.dataBits >= tp.tensorSizeBits && !forceSimpleStore) {
+ // cacheline is wider than tensor size,
+ // macro memory bitwidth by cache size
+ // bank by tansor size
+ val tensorStore = Module(new TensorStoreWideVME(tensorType, debug))
+ io <> tensorStore.io
+ } else {
+ // tensor is wider than cacheline, bank by
+ // macro memory bitwidth by tensor size
+ // bank by cacheline size
+ val tensorStore = Module(new TensorStoreNarrowVME(tensorType, debug))
+ io <> tensorStore.io
}
}
diff --git a/hardware/chisel/src/main/scala/core/TensorStore.scala b/hardware/chisel/src/main/scala/core/TensorStoreNarrowVME.scala
similarity index 87%
copy from hardware/chisel/src/main/scala/core/TensorStore.scala
copy to hardware/chisel/src/main/scala/core/TensorStoreNarrowVME.scala
index f1556ef..bbef36d 100644
--- a/hardware/chisel/src/main/scala/core/TensorStore.scala
+++ b/hardware/chisel/src/main/scala/core/TensorStoreNarrowVME.scala
@@ -28,7 +28,7 @@ import vta.shell._
*
* Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
*/
-class TensorStore(tensorType: String = "none", debug: Boolean = false)(
+class TensorStoreNarrowVME(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType)
@@ -51,6 +51,12 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
s"wider than tensor width. Needed for stride support tensorWidth=${tensorWidth}")
require(tp.splitWidth == 1 && tp.splitLength == 1, s"-F- ${tensorType} Cannot do split direct access")
+ val writePipeLatency = tp.writePipeLatency
+ // Store write is delayed by writePipeLatency
+ // postpone start by the same number of cycles
+ // expects instr and baddr are valid from start till done
+ val localStart = ShiftRegister(io.start, writePipeLatency, resetData = false.B, en = true.B)
+
val dec = io.inst.asTypeOf(new MemDecode)
val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
val waddr_nxt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
@@ -89,16 +95,15 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
switch(state) {
is(sIdle) {
xfer_bytes := xfer_init_bytes
- when (io.start) {
+ when (localStart) {
state := sWriteCmd
when (xsize < xfer_init_pulses) {
- assert(xsize > 0.U, "Idle => WriteCmd, init, without xrem: must have positive xsize")
+ assert(xsize > 0.U)
xlen := xsize - 1.U
xrem := 0.U
}.otherwise {
xlen := xfer_init_pulses - 1.U
- assert(xsize >= xfer_init_pulses,
- "Idle => WriteCmd, init, with xrem: must have xsize no smaller than xfer_init_pulses")
+ assert(xsize >= xfer_init_pulses)
xrem := xsize - xfer_init_pulses
}
}
@@ -129,13 +134,12 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
state := sWriteCmd
xfer_bytes := xfer_stride_bytes
when(xsize < xfer_stride_pulses) {
- assert(xsize > 0.U, "WriteAck => WriteCmd, stride, without xrem: must have positive xsize")
+ assert(xsize > 0.U)
xlen := xsize - 1.U
xrem := 0.U
}.otherwise {
xlen := xfer_stride_pulses - 1.U
- assert(xsize >= xfer_stride_pulses,
- "WriteAck => WriteCmd, stride, with xrem: must have xsize no smaller than xfer_stride_pulses")
+ assert(xsize >= xfer_stride_pulses)
xrem := xsize - xfer_stride_pulses
}
}
@@ -143,7 +147,7 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
.elsewhen(xrem < xfer_split_pulses) {
state := sWriteCmd
xfer_bytes := xfer_split_bytes
- assert(xrem > 0.U, "WriteAck => WriteCmd, split, without xrem: must have positive xrem")
+ assert(xrem > 0.U)
xlen := xrem - 1.U
xrem := 0.U
}
@@ -151,8 +155,7 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
state := sWriteCmd
xfer_bytes := xfer_split_bytes
xlen := xfer_split_pulses - 1.U
- assert(xrem >= xfer_split_pulses,
- "WriteAck => WriteCmd, split, with xrem: must have xrem no smaller than xfer_split_pulses")
+ assert(xrem >= xfer_split_pulses)
xrem := xrem - xfer_split_pulses
}
}
@@ -173,8 +176,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
for (i <- 0 until tensorLength) {
val inWrData = io.tensor.wr(0).bits.data(i).asUInt.asTypeOf(wdata_t)
- when(io.tensor.wr(0).valid) {
- tensorFile(i).write(io.tensor.wr(0).bits.idx, inWrData, no_mask)
+ when(ShiftRegister(io.tensor.wr(0).valid, writePipeLatency, resetData = false.B, en = true.B)) {
+ tensorFile(i).write(ShiftRegister(io.tensor.wr(0).bits.idx, writePipeLatency),
+ ShiftRegister(inWrData, writePipeLatency), no_mask)
}
}
@@ -236,9 +240,11 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
io.vme_wr.cmd.valid := state === sWriteCmd
io.vme_wr.cmd.bits.addr := waddr_cur
io.vme_wr.cmd.bits.len := xlen
+ io.vme_wr.cmd.bits.tag := dec.sram_offset
io.vme_wr.data.valid := state === sWriteData
- io.vme_wr.data.bits := mdata(tag)
+ io.vme_wr.data.bits.data := mdata(tag)
+ io.vme_wr.data.bits.strb := Fill(io.vme_wr.data.bits.strb.getWidth, true.B)
when(state === sWriteCmd) {
xcnt := 0.U
@@ -259,7 +265,8 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
}
when(io.vme_wr.data.fire()) {
- printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
+ printf("[TensorStore] data:%x\n", io.vme_wr.data.bits.data)
+ printf("[TensorStore] strb:%x\n", io.vme_wr.data.bits.strb)
}
when(io.vme_wr.ack) {
printf("[TensorStore] ack\n")
diff --git a/hardware/chisel/src/main/scala/core/TensorStoreWideVME.scala b/hardware/chisel/src/main/scala/core/TensorStoreWideVME.scala
new file mode 100644
index 0000000..8a3cea3
--- /dev/null
+++ b/hardware/chisel/src/main/scala/core/TensorStoreWideVME.scala
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+package vta.core
+
+import scala.math.pow
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** TensorStore.
+ *
+ * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
+ */
+class TensorStoreWideVME(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vme_wr = new VMEWriteMaster
+ val tensor = new TensorClient(tensorType)
+ })
+ val writePipeLatency = tp.writePipeLatency
+ // Store write is delayed by writePipeLatency
+ // postpone start by the same number of cycles
+ // expects instr and baddr are valid from start till done
+ val localStart = ShiftRegister(io.start, writePipeLatency, resetData = false.B, en = true.B)
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+
+ val sIdle :: sWriteCmd :: sWriteData :: sWriteAck :: Nil = Enum(4)
+ val state = RegInit(sIdle)
+
+ val cmdGen = Module (new GenVMECmdWide(tensorType, debug))
+
+ cmdGen.io.ysize := dec.ysize
+ cmdGen.io.xsize := dec.xsize
+ cmdGen.io.xstride := dec.xstride
+ cmdGen.io.dram_offset := dec.dram_offset
+ cmdGen.io.sram_offset := dec.sram_offset
+ cmdGen.io.xpad_0 := 0.U
+ cmdGen.io.xpad_1 := 0.U
+ cmdGen.io.ypad_0 := 0.U
+
+ cmdGen.io.start := localStart
+ cmdGen.io.isBusy := state =/= sIdle
+ cmdGen.io.baddr := io.baddr
+ cmdGen.io.updateState := state === sWriteCmd
+ cmdGen.io.canSendCmd := cmdGen.io.updateState
+ io.vme_wr.cmd <> cmdGen.io.vmeCmd
+ val commandsDone = cmdGen.io.done
+
+ // latch cmd parameters
+ val readLenReg = Reg(cmdGen.io.readLen.cloneType)
+ val readLen = Wire(readLenReg.cloneType)
+ val fstPulseDataStartReg = Reg(cmdGen.io.fstPulseDataStart.cloneType)
+ val fstPulseDataStart = Wire(fstPulseDataStartReg.cloneType)
+ val lstPulseDataEndReg = Reg(cmdGen.io.lstPulseDataEnd.cloneType)
+ val lstPulseDataEnd = Wire(lstPulseDataEndReg.cloneType)
+ val spElemIdxReg = Reg(cmdGen.io.spElemIdx.cloneType)
+ val spElemIdx = Wire(spElemIdxReg.cloneType)
+ when (cmdGen.io.updateState) {
+ readLen := cmdGen.io.readLen
+ readLenReg := readLen
+ fstPulseDataStart := cmdGen.io.fstPulseDataStart
+ fstPulseDataStartReg := fstPulseDataStart
+ lstPulseDataEnd := cmdGen.io.lstPulseDataEnd
+ lstPulseDataEndReg := lstPulseDataEnd
+ spElemIdx := cmdGen.io.spElemIdx
+ spElemIdxReg := spElemIdx
+ }.otherwise {
+ readLenReg := readLenReg
+ readLen := readLenReg
+ fstPulseDataStartReg := fstPulseDataStartReg
+ fstPulseDataStart := fstPulseDataStartReg
+ lstPulseDataEndReg := lstPulseDataEndReg
+ lstPulseDataEnd := lstPulseDataEndReg
+ spElemIdxReg := spElemIdxReg
+ spElemIdx := spElemIdxReg
+ }
+
+ val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
+ xcnt := xcnt
+ // control
+ val updateState = Wire(Bool())
+ updateState := false.B
+ switch(state) {
+ is(sIdle) {
+ when (localStart) {
+ state := sWriteCmd
+ }
+ }
+ is(sWriteCmd) {
+ when(io.vme_wr.cmd.fire()) {
+ state := sWriteData
+ updateState := true.B
+ xcnt := 0.U
+ }
+ }
+ is(sWriteData) {
+ when(io.vme_wr.data.fire()) {
+ when(xcnt === readLen - 1.U) {
+ state := sWriteAck
+ }.otherwise {
+ xcnt := xcnt + 1.U
+ }
+ }
+ }
+ is(sWriteAck) {
+ when(io.vme_wr.ack) {
+ when(commandsDone) {
+ state := sIdle
+ }.otherwise { // stride
+ state := sWriteCmd
+ }
+ }
+ }
+ }
+
+
+ //--------------------
+ //--- Write memory ---
+ //--------------------
+
+ val splitDataFactor = tp.splitWidth * tp.splitLength
+ val groupSizeBits = tp.tensorSizeBits/splitDataFactor
+ val tensorFile = Seq.fill(tp.clSizeRatio * splitDataFactor) {
+ SyncReadMem(tp.memDepth/tp.clSizeRatio, UInt(groupSizeBits.W))
+ }
+
+ // direct write
+ for (grpIdx <- 0 until splitDataFactor) {
+ val directWrIdx = io.tensor.wr(grpIdx).bits.idx >> log2Ceil(tp.clSizeRatio) // SP idx
+ val directWrTensorIdx =
+ if(tp.clSizeRatio == 1) 0.U
+ else io.tensor.wr(grpIdx).bits.idx(log2Ceil(tp.clSizeRatio) - 1, 0)
+ for (i <- 0 until tp.clSizeRatio) {
+ when(ShiftRegister(io.tensor.wr(grpIdx).valid && directWrTensorIdx === i.U, writePipeLatency,
+ resetData = false.B, en = true.B)) {
+
+ tensorFile(i*splitDataFactor + grpIdx).write(ShiftRegister(directWrIdx, writePipeLatency),
+ ShiftRegister(io.tensor.wr(grpIdx).bits.data.asUInt, writePipeLatency))
+ }
+ }
+ }
+
+
+ //--------------------
+ //--- Read memory ---
+ //--------------------
+ // first pulse doesnt reead whole data size, it is bounded by DRAM data alignment
+ // ! - data pulse boundary
+ // . - tenzor boundary
+ // tz - not used
+ // TZ - tensor to store
+ // =TZ= - first pulse tensor
+ // DRAM !-tz-.-tz-.=TZ=!-TZ-.-TZ-.-TZ-!
+ //
+ // SRAM !-tz-.=TZ=.-TZ-!-TZ-.-TZ-.-tz-!
+
+ val isFirstPulse = io.vme_wr.data.fire() && xcnt === 0.U
+ assert(state =/= sWriteData || readLen > 0.U)
+ val firstPulseTenzorsNb = tp.clSizeRatio.U - fstPulseDataStart
+ val isLastPulse = io.vme_wr.data.fire() && xcnt === readLen - 1.U
+ val spReadAddrReg = Reg(UInt(M_SRAM_OFFSET_BITS.W))
+ val spReadAddr = Wire(spReadAddrReg.cloneType)
+ val srcElemOffsetReg = Reg(UInt(log2Ceil(tp.clSizeRatio).W))
+ val srcElemOffset = Wire(srcElemOffsetReg.cloneType)
+ val incrFstIdx = Mux((spElemIdx % tp.clSizeRatio.U) < fstPulseDataStart.asTypeOf(UInt(width = 8.W)), 0.U , 1.U)
+ spReadAddr := DontCare
+ when(state === sWriteCmd) {
+ // init by data block index
+ spReadAddr := spElemIdx >> log2Ceil(tp.clSizeRatio)
+ spReadAddrReg := spReadAddr + incrFstIdx
+ srcElemOffset := spElemIdx % tp.clSizeRatio.U
+ srcElemOffsetReg := (spElemIdx + firstPulseTenzorsNb) % tp.clSizeRatio.U
+ }.elsewhen(io.vme_wr.data.fire()) {
+ spReadAddrReg := spReadAddrReg + 1.U
+ spReadAddr := spReadAddrReg
+ srcElemOffset := (spElemIdx + firstPulseTenzorsNb) % tp.clSizeRatio.U
+ srcElemOffsetReg := srcElemOffset
+ }.otherwise {
+ spReadAddrReg := spReadAddrReg
+ srcElemOffsetReg := srcElemOffsetReg
+ srcElemOffset := srcElemOffsetReg
+ }
+
+
+ val dstData = Wire(Vec(tp.clSizeRatio, UInt(tp.tensorSizeBits.W)))
+ val srcData = Wire(Vec(tp.clSizeRatio, UInt(tp.tensorSizeBits.W)))
+ val srcMemIdx = Wire(Vec(tp.clSizeRatio, spReadAddr.cloneType))
+ val dstOffset = Wire(Vec(tp.clSizeRatio, UInt((log2Ceil(tp.clSizeRatio) + 1).W)))
+ val dstIdx = Wire(Vec(tp.clSizeRatio, UInt(log2Ceil(tp.clSizeRatio).W)))
+
+
+ // D(j+d) = S(j+s) replace i=j+d --> D(i) = S(i-d+s)
+ for (i <- 0 until tp.clSizeRatio) {
+
+ //if src offset overflow, incr that dest idx, read next memory row
+ val incrIdx = if (tp.clSizeRatio == 1 ) {
+ 0.U
+ } else {
+ Mux(i.U >= srcElemOffset, 0.U, 1.U)
+ }
+ srcMemIdx(i) := spReadAddr + incrIdx
+
+ //read memory
+ srcData(i) := VecInit(for (grpIdx <- 0 until splitDataFactor) yield {
+ tensorFile(i*splitDataFactor + grpIdx).read(
+ srcMemIdx(i),
+ state === sWriteCmd | (state === sWriteData && io.vme_wr.data.fire()))
+ }).asTypeOf(UInt(tp.tensorSizeBits.W))
+
+ // crossbar src to dst
+ dstOffset(i) := i.U + spElemIdx % tp.clSizeRatio.U
+ dstIdx(i) := dstOffset(i) -% fstPulseDataStart
+ dstData(i) := Mux1H(UIntToOH(dstIdx(i)), srcData)
+
+ }
+
+ // build valid bytes strb
+ val tensorSizeBytes = tp.tensorSizeBits/8
+ val validBytes = Wire(Vec(tp.clSizeRatio, UInt(tensorSizeBytes.W)))
+ val tensorBytesOnes = (BigInt(1) << tensorSizeBytes) - 1
+ when(isFirstPulse && !isLastPulse) {
+ for (i <- 0 until tp.clSizeRatio) {
+ validBytes(i) := Mux(i.U < fstPulseDataStart && fstPulseDataStart =/= 0.U, 0.U, tensorBytesOnes.U)
+ }
+ }.elsewhen (!isFirstPulse && isLastPulse) {
+ for (i <- 0 until tp.clSizeRatio) {
+ validBytes(i) := Mux(i.U >= lstPulseDataEnd && lstPulseDataEnd =/= 0.U, 0.U, tensorBytesOnes.U)
+ }
+ }.elsewhen (isFirstPulse && isLastPulse) {
+ for (i <- 0 until tp.clSizeRatio) {
+ validBytes(i) := Mux((i.U < fstPulseDataStart && fstPulseDataStart =/= 0.U)
+ || (i.U >= lstPulseDataEnd && lstPulseDataEnd =/= 0.U), 0.U, tensorBytesOnes.U)
+ }
+ }.otherwise {
+ for (i <- 0 until tp.clSizeRatio) {
+ validBytes(i) := tensorBytesOnes.U
+ }
+ }
+
+
+ io.vme_wr.data.valid := state === sWriteData
+ io.vme_wr.data.bits.data := dstData.asUInt
+ io.vme_wr.data.bits.strb := validBytes.asUInt
+
+
+ // disable external read-from-sram requests
+ io.tensor.tieoffRead()
+
+ // done
+ io.done := state === sWriteAck & commandsDone & io.vme_wr.ack
+
+ // debug
+ if (debug) {
+ when(io.vme_wr.data.fire()) {
+ printf("[TensorStore] data:%x\n", io.vme_wr.data.bits.data)
+ printf("[TensorStore] strb:%x\n", io.vme_wr.data.bits.strb)
+ }
+ when(io.vme_wr.ack) {
+ printf("[TensorStore] ack\n")
+ }
+ }
+}
diff --git a/hardware/chisel/src/main/scala/core/TensorUtil.scala b/hardware/chisel/src/main/scala/core/TensorUtil.scala
index dfdecf4..6e79548 100644
--- a/hardware/chisel/src/main/scala/core/TensorUtil.scala
+++ b/hardware/chisel/src/main/scala/core/TensorUtil.scala
@@ -48,6 +48,7 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
else if (tensorType == "fetch") {
// make fetch a 64 bit data to be able to read
// 64 bit aligned address. It works for wide cacheline
+ // fetch tensorload is not used for narrow data load
require(p(ShellKey).memParams.dataBits >= INST_BITS,
"-F- Cannot make fetch tensor narrower than data pulse. TODO: narrow fetch with tensors")
(1, 1, 64)
@@ -79,9 +80,66 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
else
p(CoreKey).outMemDepth
+ // the number of cycles Instruction write is delayed
+ // Idle state writes are not delayed
+ // inserted regs are used to physically deliver signal to memories
+ val writePipeLatency =
+ if (tensorType == "inp") {
+ 0 // VME data load cmd write (per group)
+ } else if (tensorType == "wgt") {
+ 0 // VME data load cmd write (per group)
+ } else if (tensorType == "acc") {
+ 0 // VME data load cmd write (per group)
+ } else if (tensorType == "fetch") {
+ 0
+ } else if (tensorType == "uop") {
+ 0
+ } else if (tensorType == "out") {
+ 0 // direct write from core
+ } else {
+ 0
+ }
+
+ // the number of cycles Idle state data read is delayed
+ // inserted regs are used to physically deliver signal to memories
+ val readTensorLatency =
+ if (tensorType == "inp") {
+ 0 // GEMM inp data read (per memsplit)
+ } else if (tensorType == "wgt") {
+ 0
+ } else if (tensorType == "acc") {
+ 0
+ } else if (tensorType == "fetch") {
+ 0
+ } else if (tensorType == "uop") {
+ 0
+ } else if (tensorType == "out") {
+ 0
+ } else {
+ 0
+ }
+ // the number of cycles vme data signals are delayed
+ // This is a global delay of VME data signals. One for all groups
+ val readVMEDataLatency =
+ if (tensorType == "inp") {
+ 0 // VME data signals delay
+ } else if (tensorType == "wgt") {
+ 0 // VME data signals delay
+ } else if (tensorType == "acc") {
+ 0 // VME data signals delay
+ } else if (tensorType == "fetch") {
+ 0
+ } else if (tensorType == "uop") {
+ 0 // VME data signals delay
+ } else if (tensorType == "out") {
+ 0
+ } else {
+ 0
+ }
+
+
// acc/wgt parts are grouped to form
// a physically compact compute entity
-
val (splitLength, splitWidth) =
if (tensorType == "inp") {
(1, 1)
@@ -94,7 +152,6 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
// continous data and may be placed into different memory
// modules. But the whole idea of a group to localize
// piece of wgt to piece of acc data transformation
- //
(1, p(CoreKey).blockOutFactor)
} else if (tensorType == "fetch") {
(1, 1)
@@ -126,6 +183,15 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
0
}
+ // this split doesnt change tensorLoad interface, but
+ // allows pipe VME write control signals per group of memory modules
+ val splitMemsFactor =
+ if (tensorType == "inp") {
+ 1
+ } else {
+ 1
+ }
+
val memAddrBits = log2Ceil(memDepth)
val tensorSizeBits = tensorLength * tensorWidth * tensorElemBits
@@ -177,7 +243,8 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
}
def paramsStr () = {
s" ${tensorType} ${tensorSizeBits*memDepth/8} Byte. length:${tensorLength} width:${tensorWidth}" +
- s" data bits:${tensorElemBits} mem depth:${memDepth} groups split length:${splitLength}"
+ s" data bits:${tensorElemBits} mem depth:${memDepth} groups split length:${splitLength}" +
+ s" split width:${splitWidth} pipe write:${writePipeLatency}"
}
}
diff --git a/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala b/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala
index c52260d..48a2e9a 100644
--- a/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala
+++ b/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala
@@ -21,52 +21,89 @@ package vta.dpi
import chisel3._
import chisel3.util._
+import chisel3.experimental.IntParam
import vta.util.config._
import vta.interface.axi._
import vta.shell._
/** Memory DPI parameters */
-trait VTAMemDPIParams {
- val dpiLenBits = 8
- val dpiAddrBits = 64
- val dpiDataBits = 64
-}
+case class VTAMemDPIParams(
+ dpiDelay : Int,
+ dpiLenBits: Int,
+ dpiAddrBits: Int,
+ dpiDataBits: Int,
+ dpiTagBits: Int
+) {}
+case object DpiKey extends Field[VTAMemDPIParams]
/** Memory master interface.
*
* This interface is tipically used by the Accelerator
*/
-class VTAMemDPIMaster extends Bundle with VTAMemDPIParams {
+
+class MemRequest(implicit val p: Parameters) extends Bundle {
+ val len = (UInt(p(ShellKey).memParams.dataBits.W))
+ val addr = (UInt(p(ShellKey).memParams.addrBits.W))
+ val id = (UInt(p(ShellKey).memParams.idBits.W))
+}
+
+class VTAMemDPIData(implicit val p: Parameters) extends Bundle {
+ val data = UInt(p(ShellKey).memParams.dataBits.W)
+ val id = UInt(p(ShellKey).memParams.idBits.W)
+ override def cloneType =
+ new VTAMemDPIData().asInstanceOf[this.type]
+}
+
+class VTAMemDPIWrData(implicit val p: Parameters) extends Bundle {
+ val data = UInt(p(ShellKey).memParams.dataBits.W)
+ val strb = UInt((p(ShellKey).memParams.dataBits/8).W)
+ override def cloneType =
+ new VTAMemDPIWrData().asInstanceOf[this.type]
+}
+
+
+class VTAMemDPIMaster(implicit val p: Parameters) extends Bundle {
val req = new Bundle {
- val valid = Output(Bool())
- val opcode = Output(Bool())
- val len = Output(UInt(dpiLenBits.W))
- val addr = Output(UInt(dpiAddrBits.W))
+ val ar_valid = Output(Bool())
+ val ar_len = Output(UInt(p(ShellKey).memParams.lenBits.W))
+ val ar_addr = Output(UInt(p(ShellKey).memParams.addrBits.W))
+ val ar_id = Output(UInt(p(ShellKey).memParams.idBits.W))
+ val aw_valid = Output(Bool())
+ val aw_addr = Output(UInt(p(ShellKey).memParams.addrBits.W))
+ val aw_len = Output(UInt(p(ShellKey).memParams.lenBits.W))
}
- val wr = ValidIO(UInt(dpiDataBits.W))
- val rd = Flipped(Decoupled(UInt(dpiDataBits.W)))
+ val wr = ValidIO(new VTAMemDPIWrData)
+ val rd = Flipped(Decoupled(new VTAMemDPIData))
}
/** Memory client interface.
*
* This interface is tipically used by the Host
*/
-class VTAMemDPIClient extends Bundle with VTAMemDPIParams {
+class VTAMemDPIClient(implicit val p: Parameters) extends Bundle {
val req = new Bundle {
- val valid = Input(Bool())
- val opcode = Input(Bool())
- val len = Input(UInt(dpiLenBits.W))
- val addr = Input(UInt(dpiAddrBits.W))
+ val ar_valid = Input(Bool())
+ val ar_len = Input(UInt(p(ShellKey).memParams.lenBits.W))
+ val ar_addr = Input(UInt(p(ShellKey).memParams.addrBits.W))
+ val ar_id = Input(UInt(p(ShellKey).memParams.idBits.W))
+ val aw_valid = Input(Bool())
+ val aw_addr = Input(UInt(p(ShellKey).memParams.addrBits.W))
+ val aw_len = Input(UInt(p(ShellKey).memParams.lenBits.W))
}
- val wr = Flipped(ValidIO(UInt(dpiDataBits.W)))
- val rd = Decoupled(UInt(dpiDataBits.W))
+ val wr = Flipped(ValidIO(new VTAMemDPIWrData))
+ val rd = (Decoupled(new VTAMemDPIData))
}
/** Memory DPI module.
*
* Wrapper for Memory Verilog DPI module.
*/
-class VTAMemDPI extends BlackBox with HasBlackBoxResource {
+class VTAMemDPI(implicit val p: Parameters) extends BlackBox(
+ Map(
+ "LEN_BITS" -> IntParam(p(ShellKey).memParams.lenBits),
+ "ADDR_BITS" -> IntParam(p(ShellKey).memParams.addrBits),
+ "DATA_BITS" -> IntParam(p(ShellKey).memParams.dataBits))) with HasBlackBoxResource {
+
val io = IO(new Bundle {
val clock = Input(Clock())
val reset = Input(Reset())
@@ -75,110 +112,127 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource {
addResource("/verilog/VTAMemDPI.v")
}
-class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class VTAMemDPIToAXI(debug: Boolean = true)(implicit val p: Parameters) extends Module {
val io = IO(new Bundle {
val dpi = new VTAMemDPIMaster
val axi = new AXIClient(p(ShellKey).memParams)
})
- val opcode = RegInit(false.B)
- val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len)))
- val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
- val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil =
- Enum(6)
- val state = RegInit(sIdle)
-
- switch(state) {
- is(sIdle) {
- when(io.axi.ar.valid) {
- state := sReadAddress
- }.elsewhen(io.axi.aw.valid) {
- state := sWriteAddress
+ //Read request interface for sw memory manager
+ val ar_valid = RegInit(false.B)
+ val ar_len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.ar_len)))
+ val ar_addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.ar_addr)))
+ val ar_id = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.ar_len)))
+ val rd_data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.rd.bits.data)))
+ val rIdle :: readData :: Nil = Enum(2)
+ val rstate = RegInit(rIdle)
+ //Write request interface for sw memomry manager
+ val aw_valid = RegInit(false.B)
+ val aw_len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.aw_len)))
+ val aw_addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.aw_addr)))
+ val wIdle :: writeAddress :: writeData :: writeResponse :: Nil = Enum(4)
+ val wstate = RegInit(wIdle)
+ //Read Interface to Memory Manager
+ val counter = RegInit(0.U(32.W))
+ val dpiDelay = 16.U
+ val dpiReqQueue = Module(new Queue(new MemRequest, 256))
+ dpiReqQueue.io.enq.valid := io.axi.ar.valid & dpiReqQueue.io.enq.ready
+ dpiReqQueue.io.enq.bits.addr := io.axi.ar.bits.addr
+ dpiReqQueue.io.enq.bits.len := io.axi.ar.bits.len
+ dpiReqQueue.io.enq.bits.id := io.axi.ar.bits.id
+
+ when(dpiReqQueue.io.deq.valid && counter < dpiDelay){
+ counter := counter + 1.U
+ }.elsewhen(dpiReqQueue.io.deq.valid && counter === dpiDelay){
+ counter := counter
+ }.otherwise{
+ counter := 0.U
+ }
+
+ switch(rstate){
+ is(rIdle){
+ when(dpiReqQueue.io.deq.valid && dpiReqQueue.io.deq.bits.len =/=0.U && counter === dpiDelay){
+ rstate := readData
}
}
- is(sReadAddress) {
- when(io.axi.ar.valid) {
- state := sReadData
+ is(readData) {
+ when(io.axi.r.ready && io.dpi.rd.valid && ar_len === 1.U) {
+ rstate := rIdle
}
}
- is(sReadData) {
- when(io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
- state := sIdle
+ }
+ when(rstate === rIdle) {
+ when(dpiReqQueue.io.deq.ready){
+ ar_len := dpiReqQueue.io.deq.bits.len
+ ar_addr := dpiReqQueue.io.deq.bits.addr
+ ar_id := dpiReqQueue.io.deq.bits.id
+ }
+ }
+ .elsewhen(rstate === readData){
+ when(io.axi.r.ready && io.dpi.rd.valid && ar_len =/= 0.U){
+ ar_len := ar_len - 1.U
+ }
+ }
+dpiReqQueue.io.deq.ready := ((dpiReqQueue.io.deq.valid && (rstate === rIdle)) && (counter === dpiDelay))
+when(rstate === rIdle && dpiReqQueue.io.deq.valid){
+ io.dpi.req.ar_len := dpiReqQueue.io.deq.bits.len
+ io.dpi.req.ar_addr := dpiReqQueue.io.deq.bits.addr
+ io.dpi.req.ar_id := dpiReqQueue.io.deq.bits.id
+ io.dpi.req.ar_valid := dpiReqQueue.io.deq.ready
+ }.otherwise{
+ io.dpi.req.ar_len := ar_len
+ io.dpi.req.ar_addr := ar_addr
+ io.dpi.req.ar_id := ar_id
+ io.dpi.req.ar_valid := (dpiReqQueue.io.deq.ready)
+ }
+ io.axi.ar.ready := dpiReqQueue.io.enq.ready
+ io.axi.r.valid := io.dpi.rd.valid
+ io.axi.r.bits.data := io.dpi.rd.bits.data
+ io.axi.r.bits.last := (ar_len === 0.U && io.dpi.rd.valid)
+ io.axi.r.bits.resp := 0.U
+ io.axi.r.bits.user := 0.U
+ io.axi.r.bits.id := io.dpi.rd.bits.id
+ io.dpi.rd.ready := io.axi.r.ready
+
+ //Write Request
+ switch(wstate){
+ is(wIdle){
+ when(io.axi.aw.valid){
+ wstate := writeAddress
}
}
- is(sWriteAddress) {
+ is(writeAddress) {
when(io.axi.aw.valid) {
- state := sWriteData
+ wstate := writeData
}
}
- is(sWriteData) {
+ is(writeData) {
when(io.axi.w.valid && io.axi.w.bits.last) {
- state := sWriteResponse
+ wstate := writeResponse
}
}
- is(sWriteResponse) {
+ is(writeResponse) {
when(io.axi.b.ready) {
- state := sIdle
+ wstate := wIdle
}
}
}
-
- when(state === sIdle) {
- when(io.axi.ar.valid) {
- opcode := false.B
- len := io.axi.ar.bits.len
- addr := io.axi.ar.bits.addr
- }.elsewhen(io.axi.aw.valid) {
- opcode := true.B
- len := io.axi.aw.bits.len
- addr := io.axi.aw.bits.addr
- }
- }.elsewhen(state === sReadData) {
- when(io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
- len := len - 1.U
+ when(wstate === wIdle){
+ when(io.axi.aw.valid){
+ aw_len := io.axi.aw.bits.len
+ aw_addr := io.axi.aw.bits.addr
}
}
+ io.dpi.req.aw_addr := aw_addr
+ io.dpi.req.aw_len := aw_len
+ io.dpi.req.aw_valid := RegNext(io.axi.aw.valid) & (wstate === writeAddress)
+ io.axi.aw.ready := wstate === writeAddress
+ io.dpi.wr.valid := wstate === writeData & io.axi.w.valid
+ io.dpi.wr.bits.data := io.axi.w.bits.data
+ io.dpi.wr.bits.strb := io.axi.w.bits.strb
+ io.axi.w.ready := wstate === writeData
- io.dpi.req.valid := (state === sReadAddress & io.axi.ar.valid) | (state === sWriteAddress & io.axi.aw.valid)
- io.dpi.req.opcode := opcode
- io.dpi.req.len := len
- io.dpi.req.addr := addr
-
- io.axi.ar.ready := state === sReadAddress
- io.axi.aw.ready := state === sWriteAddress
-
- io.axi.r.valid := state === sReadData & io.dpi.rd.valid
- io.axi.r.bits.data := io.dpi.rd.bits
- io.axi.r.bits.last := len === 0.U
- io.axi.r.bits.resp := 0.U
- io.axi.r.bits.user := 0.U
- io.axi.r.bits.id := 0.U
- io.dpi.rd.ready := state === sReadData & io.axi.r.ready
-
- io.dpi.wr.valid := state === sWriteData & io.axi.w.valid
- io.dpi.wr.bits := io.axi.w.bits.data
- io.axi.w.ready := state === sWriteData
-
- io.axi.b.valid := state === sWriteResponse
+ io.axi.b.valid := wstate === writeResponse
io.axi.b.bits.resp := 0.U
io.axi.b.bits.user := 0.U
io.axi.b.bits.id := 0.U
-
- if (debug) {
- when(state === sReadAddress && io.axi.ar.valid) {
- printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len)
- }
- when(state === sWriteAddress && io.axi.aw.valid) {
- printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len)
- }
- when(io.axi.r.fire()) {
- printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n",
- io.axi.r.bits.last,
- io.axi.r.bits.data)
- }
- when(io.axi.w.fire()) {
- printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n",
- io.axi.w.bits.last,
- io.axi.w.bits.data)
- }
- }
}
diff --git a/hardware/chisel/src/main/scala/interface/axi/AXI.scala b/hardware/chisel/src/main/scala/interface/axi/AXI.scala
index 5151590..1370006 100644
--- a/hardware/chisel/src/main/scala/interface/axi/AXI.scala
+++ b/hardware/chisel/src/main/scala/interface/axi/AXI.scala
@@ -25,7 +25,7 @@ import vta.util.genericbundle._
case class AXIParams(
coherent: Boolean = false,
- idBits: Int = 1,
+ idBits: Int = 8,
addrBits: Int = 32,
dataBits: Int = 64,
lenBits: Int = 8,
@@ -190,6 +190,9 @@ class AXIMaster(params: AXIParams) extends AXIBase(params) {
r.ready := false.B
}
+ // These values are not changed in VTA
+ // Usually means that there is no implementation for
+ // alternative behavior
def setConst() {
aw.bits.user := params.userConst.U
aw.bits.burst := params.burstConst.U
@@ -199,10 +202,7 @@ class AXIMaster(params: AXIParams) extends AXIBase(params) {
aw.bits.qos := params.qosConst.U
aw.bits.region := params.regionConst.U
aw.bits.size := params.sizeConst.U
- aw.bits.id := params.idConst.U
- w.bits.id := params.idConst.U
w.bits.user := params.userConst.U
- w.bits.strb := Fill(params.strbBits, true.B)
ar.bits.user := params.userConst.U
ar.bits.burst := params.burstConst.U
ar.bits.lock := params.lockConst.U
@@ -211,7 +211,6 @@ class AXIMaster(params: AXIParams) extends AXIBase(params) {
ar.bits.qos := params.qosConst.U
ar.bits.region := params.regionConst.U
ar.bits.size := params.sizeConst.U
- ar.bits.id := params.idConst.U
}
}
diff --git a/hardware/chisel/src/main/scala/shell/VME.scala b/hardware/chisel/src/main/scala/shell/VME.scala
index 41b24d1..77dc069 100644
--- a/hardware/chisel/src/main/scala/shell/VME.scala
+++ b/hardware/chisel/src/main/scala/shell/VME.scala
@@ -18,22 +18,31 @@
*/
package vta.shell
-
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.util.genericbundle._
import vta.interface.axi._
+
/** VME parameters.
*
* These parameters are used on VME interfaces and modules.
*/
-case class VMEParams() {
- val nReadClients: Int = 5
- val nWriteClients: Int = 1
+case class VMEParams
+ (val nReadClients: Int = 5,
+ val nWriteClients: Int = 1,
+ val clientBits : Int = 3,
+ val RequestQueueDepth : Int = 16,
+ val vmeParams : Int = 18,
+ val clientCmdQueueDepth : Int = 1,
+ val clientTagBitWidth : Int = 21,
+ val clientDataQueueDepth : Int = 16) {
+
+ val RequestQueueMaskBits : Int = RequestQueueDepth.toInt
+
require(nReadClients > 0,
- s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
+ s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
require(
nWriteClients == 1,
s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
@@ -46,11 +55,37 @@ abstract class VMEBase(implicit p: Parameters) extends GenericParameterizedBundl
*
* This interface is used for creating write and read requests to memory.
*/
+class clientTag(implicit p:Parameters) extends Bundle{
+ val clientBits = p(ShellKey).vmeParams.clientBits
+ val RequestQueueDepth = p(ShellKey).vmeParams.RequestQueueDepth
+ val RequestQueueMaskBits = p(ShellKey).vmeParams.RequestQueueMaskBits
+ val client_id = UInt(clientBits.W)
+ val client_tag = UInt(p(ShellKey).vmeParams.clientTagBitWidth.W)
+ val client_mask = UInt(RequestQueueMaskBits.W)
+ override def cloneType =
+ new clientTag().asInstanceOf[this.type]
+}
+
class VMECmd(implicit p: Parameters) extends VMEBase {
val addrBits = p(ShellKey).memParams.addrBits
val lenBits = p(ShellKey).memParams.lenBits
+ val tagBits = p(ShellKey).vmeParams.clientTagBitWidth
val addr = UInt(addrBits.W)
val len = UInt(lenBits.W)
+ val tag = UInt(tagBits.W)
+}
+class VMECmdData(implicit p: Parameters) extends VMEBase {
+ val data = UInt(p(ShellKey).memParams.dataBits.W)
+ val last = Bool()
+}
+
+class VMEData(implicit p: Parameters) extends VMEBase {
+ val dataBits = p(ShellKey).memParams.dataBits
+ val data = UInt(dataBits.W)
+ val tag = UInt(p(ShellKey).vmeParams.clientTagBitWidth.W)
+ val last = Bool()
+ override def cloneType =
+ new VMEData().asInstanceOf[this.type]
}
/** VMEReadMaster.
@@ -61,9 +96,9 @@ class VMECmd(implicit p: Parameters) extends VMEBase {
class VMEReadMaster(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits
val cmd = Decoupled(new VMECmd)
- val data = Flipped(Decoupled(UInt(dataBits.W)))
+ val data = Flipped(Decoupled(new VMEData))
override def cloneType =
- new VMEReadMaster().asInstanceOf[this.type]
+ new VMEReadMaster().asInstanceOf[this.type]
}
/** VMEReadClient.
@@ -74,9 +109,25 @@ class VMEReadMaster(implicit p: Parameters) extends Bundle {
class VMEReadClient(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits
val cmd = Flipped(Decoupled(new VMECmd))
- val data = Decoupled(UInt(dataBits.W))
+ val data = Decoupled(new VMEData)
+ override def cloneType =
+ new VMEReadClient().asInstanceOf[this.type]
+}
+
+/** VMEWriteData.
+ *
+ * This interface is used by the VME to handle write requests from modules inside
+ * the core.
+ */
+class VMEWriteData(implicit p: Parameters) extends Bundle {
+ val dataBits = p(ShellKey).memParams.dataBits
+ val strbBits = dataBits/8
+
+ val data = UInt(dataBits.W)
+ val strb = UInt(strbBits.W)
+
override def cloneType =
- new VMEReadClient().asInstanceOf[this.type]
+ new VMEWriteData().asInstanceOf[this.type]
}
/** VMEWriteMaster.
@@ -87,10 +138,10 @@ class VMEReadClient(implicit p: Parameters) extends Bundle {
class VMEWriteMaster(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits
val cmd = Decoupled(new VMECmd)
- val data = Decoupled(UInt(dataBits.W))
+ val data = Decoupled(new VMEWriteData)
val ack = Input(Bool())
override def cloneType =
- new VMEWriteMaster().asInstanceOf[this.type]
+ new VMEWriteMaster().asInstanceOf[this.type]
}
/** VMEWriteClient.
@@ -101,10 +152,10 @@ class VMEWriteMaster(implicit p: Parameters) extends Bundle {
class VMEWriteClient(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits
val cmd = Flipped(Decoupled(new VMECmd))
- val data = Flipped(Decoupled(UInt(dataBits.W)))
+ val data = Flipped(Decoupled(new VMEWriteData))
val ack = Output(Bool())
override def cloneType =
- new VMEWriteClient().asInstanceOf[this.type]
+ new VMEWriteClient().asInstanceOf[this.type]
}
/** VMEMaster.
@@ -141,63 +192,174 @@ class VME(implicit p: Parameters) extends Module {
val mem = new AXIMaster(p(ShellKey).memParams)
val vme = new VMEClient
})
-
+ val clientCmdQueueDepth = p(ShellKey).vmeParams.clientCmdQueueDepth
+ val clientDataQueueDepth = p(ShellKey).vmeParams.clientDataQueueDepth
+ val RequestQueueDepth = p(ShellKey).vmeParams.RequestQueueDepth
+ val RequestQueueAddrWidth = log2Ceil(RequestQueueDepth.toInt)
+ val dataBits = p(ShellKey).memParams.dataBits
val nReadClients = p(ShellKey).vmeParams.nReadClients
- val rd_arb = Module(new Arbiter(new VMECmd, nReadClients))
- val rd_arb_chosen = RegEnable(rd_arb.io.chosen, rd_arb.io.out.fire())
-
- for (i <- 0 until nReadClients) { rd_arb.io.in(i) <> io.vme.rd(i).cmd }
-
- val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3)
- val rstate = RegInit(sReadIdle)
-
- switch(rstate) {
- is(sReadIdle) {
- when(rd_arb.io.out.valid) {
- rstate := sReadAddr
+ val addrBits = p(ShellKey).memParams.addrBits
+ val lenBits = p(ShellKey).memParams.lenBits
+ val idBits = p(ShellKey).memParams.idBits
+ val vmeTag_array = SyncReadMem(RequestQueueDepth,(new(clientTag)))
+ val vmeTag_array_wr_data = Wire(new(clientTag))
+ val vmeTag_array_wr_addr = Wire(UInt(RequestQueueAddrWidth.W))
+ val vmeTag_array_rd_addr = Wire(UInt(RequestQueueAddrWidth.W))
+ val vmeTag_array_wr_en = Wire(Bool())
+ val localTag_out = Wire(new(clientTag))
+ val availableEntriesEn = Wire(Bool())
+ val availableEntriesNext = Wire(UInt(RequestQueueDepth.W))
+ val availableEntries = Reg(availableEntriesNext.cloneType)
+ val freeTagLocation = Wire(UInt(RequestQueueDepth.W))
+ val (resetEntry,newEntry,firstPostn) = firstOneOH(availableEntries.asUInt)
+ val updateEntry = Wire(UInt(RequestQueueDepth.W))
+ when(io.mem.r.bits.last & io.mem.r.valid){
+ availableEntriesNext := updateEntry | availableEntries
+ }.elsewhen(availableEntriesEn && availableEntries =/= 0.U && !(io.mem.r.bits.last & io.mem.r.valid)){
+ availableEntriesNext:= newEntry
+ }.otherwise{
+ availableEntriesNext:= availableEntries
+ }
+ when(reset.toBool){
+ availableEntries := VecInit(Seq.fill(RequestQueueDepth)(true.B)).asUInt
+ updateEntry := 0.U
+ }.otherwise{
+ availableEntries := availableEntriesNext
+ updateEntry := VecInit(IndexedSeq.tabulate(RequestQueueDepth){ i => i.U === (io.mem.r.bits.id).asUInt }).asUInt
+ }
+ // Cmd Queues for eaach VME client
+ val VMEcmd_Qs = IndexedSeq.fill(5){ Module(new Queue(new VMECmd, clientCmdQueueDepth))}
+
+ //---------------------------------------
+ //--- Find available buffer entries -----
+ //---------------------------------------
+ def firstOneOH (in: UInt) = {
+ val oneHotIdx = for(bitIdx <- 0 until in.getWidth) yield {
+ if (bitIdx == 0){
+ in(0)
}
- }
- is(sReadAddr) {
- when(io.mem.ar.ready) {
- rstate := sReadData
+ else{
+ in(bitIdx) && ~in(bitIdx-1,0).orR
}
}
- is(sReadData) {
- when(io.mem.r.fire() && io.mem.r.bits.last) {
- rstate := sReadIdle
- }
+ val oHot = VecInit(oneHotIdx).asUInt
+ val newVec = in&(~oHot) // turn bit to 0
+ val bitPostn = PriorityEncoder(oneHotIdx)
+ (oHot, newVec,bitPostn)
+ }
+ val default_tag = Wire(new(clientTag))
+ default_tag.client_tag := 0.U
+ default_tag.client_id := 0.U
+ default_tag.client_mask := 0.U
+
+ val cmd_valids = for { q <- VMEcmd_Qs } yield q.io.deq.valid
+
+ val vme_select = PriorityEncoder(cmd_valids :+ true.B)
+ val any_cmd_valid = cmd_valids.foldLeft(false.B){ case (x,y) => x || y}
+ availableEntriesEn := io.mem.ar.ready & any_cmd_valid
+
+ for { i <- 0 until 5} {
+ VMEcmd_Qs(i).io.enq.valid := io.vme.rd(i).cmd.valid & VMEcmd_Qs(i).io.enq.ready
+ VMEcmd_Qs(i).io.enq.bits := io.vme.rd(i).cmd.bits
+ VMEcmd_Qs(i).io.deq.ready := io.mem.ar.ready &
+ (vme_select === i.U) & (availableEntries.asUInt =/= 0.U) &
+ !(io.mem.r.bits.last & io.mem.r.valid)
+ io.vme.rd(i).cmd.ready := VMEcmd_Qs(i).io.enq.ready
+ }
+
+ vmeTag_array_wr_addr := firstPostn.asUInt
+
+
+ val cmd_readys = for { q <- VMEcmd_Qs} yield q.io.deq.ready
+ val any_cmd_ready = cmd_readys.foldLeft(false.B){ case (x,y) => x || y}
+
+ vmeTag_array_wr_en := any_cmd_ready
+
+ when(vmeTag_array_wr_en){
+ val rdwrPort = vmeTag_array(vmeTag_array_wr_addr)
+ rdwrPort := vmeTag_array_wr_data
+ }
+
+ io.mem.ar.bits.addr := 0.U
+ io.mem.ar.bits.len := 0.U
+ io.mem.ar.valid := 0.U
+ io.mem.ar.bits.id := 0.U
+ vmeTag_array_wr_data := default_tag
+
+ // Last assign wins so do this in reverse order
+ for { i <- 4 to 0 by -1} {
+ when(VMEcmd_Qs(i).io.deq.ready){
+ io.mem.ar.bits.addr := VMEcmd_Qs(i).io.deq.bits.addr
+ io.mem.ar.bits.len := VMEcmd_Qs(i).io.deq.bits.len
+ io.mem.ar.valid := VMEcmd_Qs(i).io.deq.valid
+ io.mem.ar.bits.id := vmeTag_array_wr_addr
+ vmeTag_array_wr_data.client_id := i.U
+ vmeTag_array_wr_data.client_tag := VMEcmd_Qs(i).io.deq.bits.tag
+ vmeTag_array_wr_data.client_mask := resetEntry
}
}
+ // We need one clock cycle to look up the local tag from the
+ // centralized tag buffer vmeTag_array
+ // Adding a flop stage for mem.r.data, mem.r.last, mem.r.valid
+ // till local tag lookup is performed.
+ io.mem.r.ready := true.B
+ vmeTag_array_rd_addr := io.mem.r.bits.id
+ localTag_out := vmeTag_array(vmeTag_array_rd_addr)
+ freeTagLocation := localTag_out.client_mask
+
+ for (i <- 0 until nReadClients) {
+ io.vme.rd(i).data.valid := ((RegNext(io.mem.r.valid, init = false.B)) && ((localTag_out.client_id) === i.U)
+ && io.vme.rd(i).data.ready)
+ //VME doesnt stop on not ready
+ assert(io.vme.rd(i).data.ready || ~io.vme.rd(i).data.valid)
+ io.vme.rd(i).data.bits.data := RegNext(io.mem.r.bits.data, init = false.B)
+ io.vme.rd(i).data.bits.last := RegNext(io.mem.r.bits.last, init = false.B)
+ io.vme.rd(i).data.bits.tag := localTag_out.client_tag
+ }
+
+ // VME <-> AXI write interface
+ val wr_len = RegInit(0.U(lenBits.W))
+ val wr_addr = RegInit(0.U(addrBits.W))
val sWriteIdle :: sWriteAddr :: sWriteData :: sWriteResp :: Nil = Enum(4)
val wstate = RegInit(sWriteIdle)
- val addrBits = p(ShellKey).memParams.addrBits
- val lenBits = p(ShellKey).memParams.lenBits
val wr_cnt = RegInit(0.U(lenBits.W))
-
+ io.vme.wr(0).cmd.ready := wstate === sWriteIdle
+ io.vme.wr(0).ack := io.mem.b.fire()
+ io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
+ io.mem.aw.valid := wstate === sWriteAddr
+ io.mem.aw.bits.addr := wr_addr
+ io.mem.aw.bits.len := wr_len
+ io.mem.aw.bits.id := p(ShellKey).memParams.idConst.U // no support for multiple writes
+ io.mem.w.valid := wstate === sWriteData & io.vme.wr(0).data.valid
+ io.mem.w.bits.data := io.vme.wr(0).data.bits.data
+ io.mem.w.bits.strb := io.vme.wr(0).data.bits.strb
+ io.mem.w.bits.last := wr_cnt === wr_len
+ io.mem.w.bits.id := p(ShellKey).memParams.idConst.U // no support for multiple writes
+ io.mem.b.ready := wstate === sWriteResp
+ when(io.vme.wr(0).cmd.fire()) {
+ wr_len := io.vme.wr(0).cmd.bits.len
+ wr_addr := io.vme.wr(0).cmd.bits.addr
+ }
when(wstate === sWriteIdle) {
wr_cnt := 0.U
- }.elsewhen(io.mem.w.fire()) {
+ }
+ .elsewhen(io.mem.w.fire()){
wr_cnt := wr_cnt + 1.U
}
-
- switch(wstate) {
- is(sWriteIdle) {
- when(io.vme.wr(0).cmd.valid) {
+ switch(wstate){
+ is(sWriteIdle){
+ when(io.vme.wr(0).cmd.valid){
wstate := sWriteAddr
}
}
- is(sWriteAddr) {
- when(io.mem.aw.ready) {
+ is(sWriteAddr){
+ when(io.mem.aw.ready){
wstate := sWriteData
}
}
- is(sWriteData) {
- when(
- io.vme
- .wr(0)
- .data
- .valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
+ is(sWriteData){
+ when(io.vme.wr(0).data.valid && io.mem.w.ready && wr_cnt === wr_len) {
wstate := sWriteResp
}
}
@@ -207,54 +369,28 @@ class VME(implicit p: Parameters) extends Module {
}
}
}
+ // AXI constants - statically define
+ io.mem.setConst()
+}
- // registers storing read/write cmds
-
- val rd_len = RegInit(0.U(lenBits.W))
- val wr_len = RegInit(0.U(lenBits.W))
- val rd_addr = RegInit(0.U(addrBits.W))
- val wr_addr = RegInit(0.U(addrBits.W))
-
- when(rd_arb.io.out.fire()) {
- rd_len := rd_arb.io.out.bits.len
- rd_addr := rd_arb.io.out.bits.addr
- }
-
- when(io.vme.wr(0).cmd.fire()) {
- wr_len := io.vme.wr(0).cmd.bits.len
- wr_addr := io.vme.wr(0).cmd.bits.addr
- }
+/** VTA Memory Engine (VME).
+ *
+ * This unit multiplexes the memory controller interface for the Core. Currently,
+ * it supports single-writer and multiple-reader mode and it is also based on AXI.
+ */
+class VMETop(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val mem = new AXIMaster(p(ShellKey).memParams)
+ val vme = new VMEClient
+ })
- // rd arb
- rd_arb.io.out.ready := rstate === sReadIdle
+ val forceSimpleVME = false // force simple vme for simple tensor load/uop/fetch
- // vme
- for (i <- 0 until nReadClients) {
- io.vme.rd(i).data.valid := rd_arb_chosen === i.asUInt & io.mem.r.valid
- io.vme.rd(i).data.bits := io.mem.r.bits.data
+ if (forceSimpleVME) {
+ val vme = Module(new VMESimple)
+ io <> vme.io
+ } else {
+ val vme = Module(new VME)
+ io <> vme.io
}
-
- io.vme.wr(0).cmd.ready := wstate === sWriteIdle
- io.vme.wr(0).ack := io.mem.b.fire()
- io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
-
- // mem
- io.mem.aw.valid := wstate === sWriteAddr
- io.mem.aw.bits.addr := wr_addr
- io.mem.aw.bits.len := wr_len
-
- io.mem.w.valid := wstate === sWriteData & io.vme.wr(0).data.valid
- io.mem.w.bits.data := io.vme.wr(0).data.bits
- io.mem.w.bits.last := wr_cnt === io.vme.wr(0).cmd.bits.len
-
- io.mem.b.ready := wstate === sWriteResp
-
- io.mem.ar.valid := rstate === sReadAddr
- io.mem.ar.bits.addr := rd_addr
- io.mem.ar.bits.len := rd_len
-
- io.mem.r.ready := rstate === sReadData & io.vme.rd(rd_arb_chosen).data.ready
-
- // AXI constants - statically defined
- io.mem.setConst()
}
diff --git a/hardware/chisel/src/main/scala/shell/VME.scala b/hardware/chisel/src/main/scala/shell/VMESimple.scala
similarity index 55%
copy from hardware/chisel/src/main/scala/shell/VME.scala
copy to hardware/chisel/src/main/scala/shell/VMESimple.scala
index 41b24d1..e430d81 100644
--- a/hardware/chisel/src/main/scala/shell/VME.scala
+++ b/hardware/chisel/src/main/scala/shell/VMESimple.scala
@@ -18,125 +18,19 @@
*/
package vta.shell
-
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.util.genericbundle._
import vta.interface.axi._
-/** VME parameters.
- *
- * These parameters are used on VME interfaces and modules.
- */
-case class VMEParams() {
- val nReadClients: Int = 5
- val nWriteClients: Int = 1
- require(nReadClients > 0,
- s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
- require(
- nWriteClients == 1,
- s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
-}
-
-/** VMEBase. Parametrize base class. */
-abstract class VMEBase(implicit p: Parameters) extends GenericParameterizedBundle(p)
-
-/** VMECmd.
- *
- * This interface is used for creating write and read requests to memory.
- */
-class VMECmd(implicit p: Parameters) extends VMEBase {
- val addrBits = p(ShellKey).memParams.addrBits
- val lenBits = p(ShellKey).memParams.lenBits
- val addr = UInt(addrBits.W)
- val len = UInt(lenBits.W)
-}
-
-/** VMEReadMaster.
- *
- * This interface is used by modules inside the core to generate read requests
- * and receive responses from VME.
- */
-class VMEReadMaster(implicit p: Parameters) extends Bundle {
- val dataBits = p(ShellKey).memParams.dataBits
- val cmd = Decoupled(new VMECmd)
- val data = Flipped(Decoupled(UInt(dataBits.W)))
- override def cloneType =
- new VMEReadMaster().asInstanceOf[this.type]
-}
-
-/** VMEReadClient.
- *
- * This interface is used by the VME to receive read requests and generate
- * responses to modules inside the core.
- */
-class VMEReadClient(implicit p: Parameters) extends Bundle {
- val dataBits = p(ShellKey).memParams.dataBits
- val cmd = Flipped(Decoupled(new VMECmd))
- val data = Decoupled(UInt(dataBits.W))
- override def cloneType =
- new VMEReadClient().asInstanceOf[this.type]
-}
-
-/** VMEWriteMaster.
- *
- * This interface is used by modules inside the core to generate write requests
- * to the VME.
- */
-class VMEWriteMaster(implicit p: Parameters) extends Bundle {
- val dataBits = p(ShellKey).memParams.dataBits
- val cmd = Decoupled(new VMECmd)
- val data = Decoupled(UInt(dataBits.W))
- val ack = Input(Bool())
- override def cloneType =
- new VMEWriteMaster().asInstanceOf[this.type]
-}
-
-/** VMEWriteClient.
- *
- * This interface is used by the VME to handle write requests from modules inside
- * the core.
- */
-class VMEWriteClient(implicit p: Parameters) extends Bundle {
- val dataBits = p(ShellKey).memParams.dataBits
- val cmd = Flipped(Decoupled(new VMECmd))
- val data = Flipped(Decoupled(UInt(dataBits.W)))
- val ack = Output(Bool())
- override def cloneType =
- new VMEWriteClient().asInstanceOf[this.type]
-}
-
-/** VMEMaster.
- *
- * Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster
- * interfaces.
- */
-class VMEMaster(implicit p: Parameters) extends Bundle {
- val nRd = p(ShellKey).vmeParams.nReadClients
- val nWr = p(ShellKey).vmeParams.nWriteClients
- val rd = Vec(nRd, new VMEReadMaster)
- val wr = Vec(nWr, new VMEWriteMaster)
-}
-
-/** VMEClient.
- *
- * Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient
- * interfaces.
- */
-class VMEClient(implicit p: Parameters) extends Bundle {
- val nRd = p(ShellKey).vmeParams.nReadClients
- val nWr = p(ShellKey).vmeParams.nWriteClients
- val rd = Vec(nRd, new VMEReadClient)
- val wr = Vec(nWr, new VMEWriteClient)
-}
/** VTA Memory Engine (VME).
*
* This unit multiplexes the memory controller interface for the Core. Currently,
* it supports single-writer and multiple-reader mode and it is also based on AXI.
*/
-class VME(implicit p: Parameters) extends Module {
+class VMESimple(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val mem = new AXIMaster(p(ShellKey).memParams)
val vme = new VMEClient
@@ -228,10 +122,17 @@ class VME(implicit p: Parameters) extends Module {
// rd arb
rd_arb.io.out.ready := rstate === sReadIdle
+ val localTag = Reg(Vec(nReadClients, UInt(p(ShellKey).vmeParams.clientTagBitWidth.W)))
// vme
for (i <- 0 until nReadClients) {
io.vme.rd(i).data.valid := rd_arb_chosen === i.asUInt & io.mem.r.valid
- io.vme.rd(i).data.bits := io.mem.r.bits.data
+ io.vme.rd(i).data.bits.data := io.mem.r.bits.data
+ io.vme.rd(i).data.bits.last := io.mem.r.bits.last
+ io.vme.rd(i).data.bits.tag := localTag(i)
+
+ when (io.vme.rd(i).cmd.fire()) {
+ localTag(i) := io.vme.rd(i).cmd.bits.tag
+ }
}
io.vme.wr(0).cmd.ready := wstate === sWriteIdle
@@ -244,14 +145,16 @@ class VME(implicit p: Parameters) extends Module {
io.mem.aw.bits.len := wr_len
io.mem.w.valid := wstate === sWriteData & io.vme.wr(0).data.valid
- io.mem.w.bits.data := io.vme.wr(0).data.bits
+ io.mem.w.bits.data := io.vme.wr(0).data.bits.data
io.mem.w.bits.last := wr_cnt === io.vme.wr(0).cmd.bits.len
+ io.mem.w.bits.strb := Fill(p(ShellKey).memParams.strbBits, true.B)
io.mem.b.ready := wstate === sWriteResp
io.mem.ar.valid := rstate === sReadAddr
io.mem.ar.bits.addr := rd_addr
io.mem.ar.bits.len := rd_len
+ io.mem.ar.bits.id := 0.U
io.mem.r.ready := rstate === sReadData & io.vme.rd(rd_arb_chosen).data.ready
diff --git a/hardware/chisel/src/main/scala/util/SyncQueue.scala b/hardware/chisel/src/main/scala/util/SyncQueue.scala
new file mode 100644
index 0000000..1e77c99
--- /dev/null
+++ b/hardware/chisel/src/main/scala/util/SyncQueue.scala
@@ -0,0 +1,508 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.util
+
+import chisel3._
+import chisel3.util._
+
+import vta.util.config._
+
+//! Queue with SRAM one port or 1r1W
+class SyncQueue[T <: Data](
+ gen: T,
+ val entries: Int,
+ pipe: Boolean = false,
+ flow: Boolean = false)
+ extends Module() {
+
+ val genType = gen
+ val forceSimpleQueue = false // Force usage of Queue
+
+ val io = IO(new QueueIO(genType, entries))
+
+ require (!pipe, "-F- Not supported")
+ require (!flow, "-F- Not supported")
+
+ if (forceSimpleQueue) {
+ val queue = Module(new Queue(genType.asUInt, entries))
+ io <> queue.io
+ } else {
+ val queue = Module(new SyncQueue2PortMem(genType.asUInt, entries))
+ io <> queue.io
+ }
+
+
+}
+
+// Implement a Queue on a single-port memory
+// pipe/flow not supported
+// combine DoubleQueue with a 3-entry Queue
+// Queue is required to buffer DoubleQueue latency
+class SyncQueue1PortMem[T <: Data](
+ gen: T,
+ val entries: Int,
+ pipe: Boolean = false,
+ flow: Boolean = false)
+ extends Module() {
+
+ val genType = gen
+
+ val io = IO(new QueueIO(genType, entries))
+
+ require (!pipe, "-F- Not supported")
+ require (!flow, "-F- Not supported")
+
+ if (entries < 4 ) {
+ val queue = Module(new Queue(genType.asUInt, entries))
+ io <> queue.io
+ } else {
+ val queue = Module(new SyncQueue1PortMemImpl(genType.asUInt, entries))
+ io <> queue.io
+ }
+
+
+}
+class SyncQueue1PortMemImpl[T <: Data](
+ gen: T,
+ val entries: Int,
+ pipe: Boolean = false,
+ flow: Boolean = false)
+ extends Module() {
+
+ require (!pipe, "-F- Not supported")
+ require (!flow, "-F- Not supported")
+
+ val genType = gen
+
+ val io = IO(new QueueIO(genType, entries))
+
+ require (entries > 3, "-F- TODO: small queue implemetation")
+ val doubleQueue = Module(new DoubleQueue(genType.asUInt, entries))
+ val buffer = Module(new Queue(genType.asUInt, 3))
+
+ val doubleQueueHasValues = doubleQueue.io.count =/= 0.U
+ val bufferInValid = Mux(doubleQueueHasValues, doubleQueue.io.deq.valid, io.enq.valid)
+ val bufferInBits = Mux(doubleQueueHasValues, doubleQueue.io.deq.bits, io.enq.bits)
+ buffer.io.enq.valid := bufferInValid
+ buffer.io.enq.bits := bufferInBits
+
+ io.deq <> buffer.io.deq
+ doubleQueue.io.enq.bits := io.enq.bits
+ doubleQueue.io.enq.valid := io.enq.fire() && (!buffer.io.enq.ready || doubleQueueHasValues)
+ doubleQueue.io.deq.ready := buffer.io.enq.ready
+
+ val count = Wire(UInt(log2Up(entries + 1).W))
+ val countNext = RegEnable(
+ next = count,
+ init = 0.U,
+ enable = io.enq.fire() || io.deq.fire())
+ when (io.enq.fire() && !io.deq.fire()) {
+ assert(countNext < entries.U)
+ count := countNext + 1.U
+ }.elsewhen (!io.enq.fire() && io.deq.fire()) {
+ assert(countNext > 0.U)
+ count := countNext - 1.U
+ }.otherwise {
+ count := countNext
+ }
+
+ io.count := countNext
+ io.enq.ready := countNext =/= entries.U
+ io.deq.valid := countNext =/= 0.U
+ assert(io.deq.valid === buffer.io.deq.valid)
+ assert(io.enq.ready === buffer.io.enq.ready || doubleQueue.io.enq.ready)
+}
+
+class SyncQueue2PortMem[T <: Data](
+ gen: T,
+ val entries: Int,
+ pipe: Boolean = false,
+ flow: Boolean = false)
+ extends Module() {
+
+ val genType = gen
+
+ val io = IO(new QueueIO(genType, entries))
+
+ require (!pipe, "-F- Not supported")
+ require (!flow, "-F- Not supported")
+
+ if (entries < 4 ) {
+ val queue = Module(new Queue(genType.asUInt, entries))
+ io <> queue.io
+ } else {
+ val queue = Module(new SyncQueue2PortMemImpl(genType.asUInt, entries))
+ io <> queue.io
+ }
+
+}
+
+class SyncQueue2PortMemImpl[T <: Data](
+ gen: T,
+ val entries: Int,
+ pipe: Boolean = false,
+ flow: Boolean = false)
+ extends Module() {
+
+ require (!pipe, "-F- Not supported")
+ require (!flow, "-F- Not supported")
+
+ val genType = gen
+
+ val io = IO(new QueueIO(genType, entries))
+
+ require (entries > 3, "-F- TODO: small queue implemetation")
+ val memoryQueue = Module(new OneCycleQueue(genType.asUInt, entries, ""))
+ val buffer = Module(new Queue(genType.asUInt, 3))
+
+ val memoryQueueHasValues = memoryQueue.io.count =/= 0.U
+ val bufferInValid = Mux(memoryQueueHasValues, memoryQueue.io.deq.valid, io.enq.valid)
+ val bufferInBits = Mux(memoryQueueHasValues, memoryQueue.io.deq.bits, io.enq.bits)
+ buffer.io.enq.valid := bufferInValid
+ buffer.io.enq.bits := bufferInBits
+
+ io.deq <> buffer.io.deq
+ memoryQueue.io.enq.bits := io.enq.bits
+ memoryQueue.io.enq.valid := io.enq.fire() && (!buffer.io.enq.ready || memoryQueueHasValues)
+ memoryQueue.io.deq.ready := buffer.io.enq.ready
+
+ val count = Wire(UInt(log2Up(entries + 1).W))
+ val countNext = RegEnable(
+ next = count,
+ init = 0.U,
+ enable = io.enq.fire() || io.deq.fire())
+ when (io.enq.fire() && !io.deq.fire()) {
+ assert(countNext < entries.U)
+ count := countNext + 1.U
+ }.elsewhen (!io.enq.fire() && io.deq.fire()) {
+ assert(countNext > 0.U)
+ count := countNext - 1.U
+ }.otherwise {
+ count := countNext
+ }
+
+ io.count := countNext
+ io.enq.ready := countNext =/= entries.U
+ io.deq.valid := countNext =/= 0.U
+ assert(io.deq.valid === buffer.io.deq.valid)
+ assert(io.enq.ready === buffer.io.enq.ready || memoryQueue.io.enq.ready)
+}
+
+//combines two TwoCycle one-port memory queues into a queue
+// with a latency 3
+class DoubleQueue[T <: Data](
+ gen: T,
+ val entries: Int)
+ extends Module() {
+
+ val genType = gen
+
+ val io = IO(new QueueIO(genType, entries))
+
+ require(entries > 1, "Zero size not tested")
+ val entriesRam0 = entries/2
+ val entriesRam1 = entries - entriesRam0
+ val queue0 = Module(new TwoCycleQueue(genType.asUInt, entriesRam0, "q0"))
+ val queue1 = Module(new TwoCycleQueue(genType.asUInt, entriesRam1, "q1"))
+
+ val enqRR = Wire(Bool())
+ enqRR := RegEnable(
+ next = ~enqRR,
+ init = 1.U,
+ enable = io.enq.fire())
+ val deqRR = Wire(Bool())
+ deqRR := RegEnable(
+ next = ~deqRR,
+ init = 1.U,
+ enable = io.deq.fire())
+
+
+ val do_enq0 = WireInit(io.enq.fire() && enqRR)
+ val do_enq1 = WireInit(io.enq.fire() && ~enqRR)
+ val deq0 = WireInit(io.deq.fire() && deqRR)
+ val deq1 = WireInit(io.deq.fire() && ~deqRR)
+ val do_deq0_next = RegNext(deq0 && do_enq0)
+ val do_deq1_next = RegNext(deq1 && do_enq1)
+ val do_deq0 = (deq0 && ~do_enq0) || do_deq0_next
+ val do_deq1 = (deq1 && ~do_enq1) || do_deq1_next
+
+ val do_deq = WireInit(io.deq.fire())
+ val full = !queue0.io.enq.ready && !queue1.io.enq.ready
+ val empty = !queue0.io.deq.valid && !queue1.io.deq.valid
+
+ queue0.io.enq.bits := io.enq.bits
+ queue0.io.enq.valid := do_enq0
+ queue1.io.enq.bits := io.enq.bits
+ queue1.io.enq.valid := do_enq1
+ queue0.io.deq.ready := do_deq0
+ queue1.io.deq.ready := do_deq1
+
+ io.deq.valid := !empty
+ io.enq.ready := !full
+
+
+ when(do_deq0) {
+ assert(queue0.io.deq.valid, "-F- Deq empty queue 0")
+ }
+ when(do_deq1) {
+ assert(queue1.io.deq.valid, "-F- Deq empty queue 1")
+ }
+
+ when(do_enq0) {
+ assert(queue0.io.enq.ready, "-F- Enq full queue 0")
+ }
+ when(do_enq1) {
+ assert(queue1.io.enq.ready, "-F- Enq full queue 1")
+ }
+
+
+ io.deq.bits := Mux(deqRR, queue0.io.deq.bits, queue1.io.deq.bits)
+ io.count := queue0.io.count +& queue1.io.count
+}
+
+// one-port memory queue
+// enq and deq should not overlap
+// two subsequent enq should be cycle separated
+// two subsequent deq can be next cycle
+class TwoCycleQueue[T <: Data](
+ gen: T,
+ val entries: Int,
+ val qname: String)
+ extends Module() {
+
+ val genType = gen
+
+ val io = IO(new QueueIO(genType, entries))
+
+ val ram0 = Module(new OnePortMem(genType.asUInt, entries, qname))
+ val enq_ptr = Counter(entries)
+ val deq_ptr = Counter(entries)
+ val maybe_full = RegInit(false.B)
+
+
+ val ptr_match = enq_ptr.value === deq_ptr.value
+ val empty = ptr_match && !maybe_full
+ val full = ptr_match && maybe_full
+
+ val do_enq = WireInit(io.enq.fire())
+ val do_deq = WireInit(io.deq.fire())
+
+ // check protocol
+ val enq_next = RegNext(do_enq)
+ assert(!(enq_next && do_enq), "-F- Expecting two cycle behavior on enq")
+ assert(!do_enq || !do_deq, "-F- No simultaneous R/W")
+
+ when(do_deq) {
+ deq_ptr.inc()
+ }
+
+ when(do_enq =/= do_deq) {
+ maybe_full := do_enq
+ }
+
+ val firstRead = RegEnable(next = do_enq && io.count === 0.U, init = false.B, enable = true.B)
+ io.deq.valid := !empty && !firstRead
+ io.enq.ready := !full
+
+ when (do_enq) {
+ enq_ptr.inc()
+ }
+
+ val memAddr = Wire(enq_ptr.value.cloneType)
+ memAddr := enq_ptr.value
+ when(!do_enq) {
+ when(firstRead) {// output the 1st written data
+ memAddr := deq_ptr.value
+ }.elsewhen (do_deq) {
+ val wrap = deq_ptr.value === (entries - 1).U
+ when (wrap) {
+ memAddr := 0.U // initiate read of the next entry
+ }.otherwise {
+ memAddr := (deq_ptr.value + 1.U) // initiate read of the next entry
+ }
+ }.otherwise {
+ memAddr := deq_ptr.value
+ }
+ }
+ ram0.io.wr_en := do_enq
+ ram0.io.wr_data := io.enq.bits.asUInt
+ ram0.io.ch_en := do_deq || firstRead || do_enq
+ io.deq.bits := ram0.io.rd_data
+ ram0.io.addr := memAddr
+
+
+
+ val ptr_diff = enq_ptr.value - deq_ptr.value
+ if (isPow2(entries)) {
+ io.count := Mux(maybe_full && ptr_match, entries.U, 0.U) | ptr_diff
+ } else {
+ io.count := Mux(
+ ptr_match,
+ Mux(
+ maybe_full,
+ entries.asUInt, 0.U),
+ Mux(
+ deq_ptr.value > enq_ptr.value,
+ entries.asUInt + ptr_diff, ptr_diff))
+ }
+}
+
+class OneCycleQueue[T <: Data](
+ gen: T,
+ val entries: Int,
+ val qname: String)
+ extends Module() {
+
+ val genType = gen
+
+ val io = IO(new QueueIO(genType, entries))
+
+ val ram0 = Module(new TwoPortMem(genType.asUInt, entries, qname))
+ val enq_ptr = Counter(entries)
+ val deq_ptr = Counter(entries)
+ val maybe_full = RegInit(false.B)
+
+
+ val ptr_match = enq_ptr.value === deq_ptr.value
+ val empty = ptr_match && !maybe_full
+ val full = ptr_match && maybe_full
+
+ val do_enq = WireInit(io.enq.fire())
+ val do_deq = WireInit(io.deq.fire())
+
+
+ when(do_deq) {
+ deq_ptr.inc()
+ }
+
+ when(do_enq =/= do_deq) {
+ maybe_full := do_enq
+ }
+
+ when (do_enq) {
+ enq_ptr.inc()
+ }
+
+ val firstRead = RegEnable(next = do_enq && io.count === 0.U, init = false.B, enable = true.B)
+ io.deq.valid := !empty && !firstRead
+ io.enq.ready := !full
+ assert(!firstRead || !do_deq, "-F- Cannot have deq with first read as queue output is not valid yet")
+
+ val rdAddr = Wire(enq_ptr.value.cloneType)
+ when(firstRead) {// output the 1st written data
+ rdAddr := deq_ptr.value
+ }.elsewhen (do_deq) {
+ val wrap = deq_ptr.value === (entries - 1).U
+ when (wrap) {
+ rdAddr := 0.U // initiate read of the next entry
+ }.otherwise {
+ rdAddr := (deq_ptr.value + 1.U) // initiate read of the next entry
+ }
+ }.otherwise {
+ rdAddr := deq_ptr.value
+ }
+ ram0.io.wr_en := do_enq
+ ram0.io.wr_data := io.enq.bits.asUInt
+ ram0.io.wr_addr := enq_ptr.value
+ ram0.io.rd_en := do_deq || firstRead
+ ram0.io.rd_addr := rdAddr
+ io.deq.bits := ram0.io.rd_data
+
+
+
+ val ptr_diff = enq_ptr.value - deq_ptr.value
+ if (isPow2(entries)) {
+ io.count := Mux(maybe_full && ptr_match, entries.U, 0.U) | ptr_diff
+ } else {
+ io.count := Mux(
+ ptr_match,
+ Mux(
+ maybe_full,
+ entries.asUInt, 0.U),
+ Mux(
+ deq_ptr.value > enq_ptr.value,
+ entries.asUInt + ptr_diff, ptr_diff))
+ }
+}
+
+// one-port memory implementation
+class MemIO[T <: Data](gen: T, entries: Int) extends Bundle
+{
+ val wr_en = Input(Bool())
+ val wr_data = Input(gen.cloneType)
+ val ch_en = Input(Bool())
+ val rd_data = Output(gen.cloneType)
+ val addr = Input(UInt(16.W)) // i dont care
+ override def cloneType: this.type = new MemIO(gen, entries).asInstanceOf[this.type]
+}
+class OnePortMem[T <: Data](
+ gen: T,
+ val entries: Int,
+ val qname: String)
+ extends Module() {
+
+ val genType = gen
+
+ val io = IO(new MemIO(genType, entries))
+
+ val mem = SyncReadMem(entries, genType.asUInt)
+
+ // detected as one-port sync mem interface
+ io.rd_data := DontCare
+ when(io.ch_en) {
+ val rdwrPort = mem(io.addr)
+ when (io.wr_en) { rdwrPort := io.wr_data }
+ .otherwise { io.rd_data := rdwrPort }
+ }
+}
+
+// two-port memory implementation
+class MemIO2P[T <: Data](gen: T, entries: Int) extends Bundle
+{
+ val wr_en = Input(Bool())
+ val wr_addr = Input(UInt(16.W)) // i dont care
+ val wr_data = Input(gen.cloneType)
+ val rd_en = Input(Bool())
+ val rd_addr = Input(UInt(16.W)) // i dont care
+ val rd_data = Output(gen.cloneType)
+ override def cloneType: this.type = new MemIO2P(gen, entries).asInstanceOf[this.type]
+}
+
+class TwoPortMem[T <: Data](
+ gen: T,
+ val entries: Int,
+ val qname: String)
+ extends Module() {
+
+ val genType = gen
+
+ val io = IO(new MemIO2P(genType, entries))
+
+ val mem = SyncReadMem(entries, genType.asUInt)
+
+ when (io.wr_en ) {
+ mem.write(io.wr_addr, io.wr_data.asUInt)
+ }
+ io.rd_data := DontCare
+ when (io.rd_en) {
+ io.rd_data := mem.read(io.rd_addr, io.rd_en)
+ }
+
+}
diff --git a/hardware/chisel/src/test/scala/unittest/SyncQueue2PortMemTest.scala b/hardware/chisel/src/test/scala/unittest/SyncQueue2PortMemTest.scala
new file mode 100644
index 0000000..c8916ad
--- /dev/null
+++ b/hardware/chisel/src/test/scala/unittest/SyncQueue2PortMemTest.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package unittest
+
+import chisel3._
+import chisel3.util._
+import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}
+import scala.util.Random
+import unittest.util._
+import vta.util._
+import vta.util.config._
+
+class Checker2P(c: SyncQueue2PTestWrapper[UInt], t: PeekPokeTester[SyncQueue2PTestWrapper[UInt]]) {
+
+ def bits (bits: Int) = {
+ t.expect(c.io.tq.deq.bits, bits)
+ t.expect(c.io.rq.deq.bits, bits)
+
+ }
+ def ready (bits: Int) = {
+ t.expect(c.io.tq.enq.ready, bits)
+ t.expect(c.io.rq.enq.ready, bits)
+
+ }
+ def valid (bits: Int) = {
+ t.expect(c.io.tq.deq.valid, bits)
+ t.expect(c.io.rq.deq.valid, bits)
+
+ }
+ def status () = {
+ val rv = t.peek(c.io.rq.enq.ready)
+ t.expect(c.io.tq.enq.ready, rv)
+ val rc = t.peek(c.io.rq.count)
+ t.expect(c.io.tq.count, rc)
+ val vv = t.peek(c.io.rq.deq.valid)
+ t.expect(c.io.tq.deq.valid, vv)
+ if (vv != 0) {
+ val bv = t.peek(c.io.rq.deq.bits)
+ t.expect(c.io.tq.deq.bits, bv)
+ }
+ t.peek(c.io.rq.count)
+ t.peek(c.io.tq.count)
+ }
+}
+class TestSyncQueue2PLongRead(c: SyncQueue2PTestWrapper[UInt]) extends PeekPokeTester(c) {
+
+ val chr = new Checker2P (c, this)
+
+ def testFillRW(depth: Int) = {
+ val qsize = peek(c.io.tq.count)
+ require(qsize == 0, s"-F- An empty queue is expected ${qsize}")
+
+ poke (c.io.tq.deq.ready, 0)
+ poke (c.io.tq.enq.valid, 0)
+ chr.ready(1)
+ step(1)
+
+ // fill up to depth
+ for (i <- 10 until 10 + depth) {
+ poke (c.io.tq.enq.bits, i)
+ poke (c.io.tq.enq.valid, 1)
+ chr.status()
+ step(1)
+
+ }
+ // read and write same cycle
+ for (i <- 30 + depth until 30 + depth * 2) {
+ poke (c.io.tq.enq.valid, 1)
+ poke (c.io.tq.deq.ready, 1)
+ poke (c.io.tq.enq.bits, i)
+ chr.status()
+ step(1)
+ }
+ // read out
+ for (i <- 0 until depth + 1) {
+ poke (c.io.tq.enq.valid, 0)
+ poke (c.io.tq.deq.ready, 1)
+ poke (c.io.tq.enq.bits, 99)
+ chr.status()
+ step(1)
+ }
+ }
+ for (i <- 1 until 28) {
+ testFillRW(i)
+ }
+}
+class TestSyncQueue2PWaveRead(c: SyncQueue2PTestWrapper[UInt]) extends PeekPokeTester(c) {
+
+ val chr = new Checker2P (c, this)
+
+ def testFillRW(depth: Int) = {
+ val qsize = peek(c.io.tq.count)
+ require(qsize == 0, s"-F- An empty queue is expected ${qsize}")
+
+ poke (c.io.tq.deq.ready, 0)
+ poke (c.io.tq.enq.valid, 0)
+ chr.ready(1)
+ step(1)
+
+ // fill up to depth
+ for (i <- 10 until 10 + depth) {
+ poke (c.io.tq.enq.bits, i)
+ poke (c.io.tq.enq.valid, 1)
+ chr.status()
+ step(1)
+
+ }
+ // read out, no write
+ poke (c.io.tq.enq.valid, 0)
+ poke (c.io.tq.deq.ready, 1)
+ for (i <- 0 until 7) {
+ chr.status()
+ step(1)
+ }
+ // fill more
+ poke (c.io.tq.deq.ready, 0)
+ poke (c.io.tq.enq.valid, 1)
+ for (i <- 0 until 13) {
+ poke (c.io.tq.enq.bits, 99 + i)
+ chr.status()
+ step(1)
+ }
+ // read out, no write
+ poke (c.io.tq.enq.valid, 0)
+ poke (c.io.tq.deq.ready, 1)
+ for (i <- 1 until 14 + depth) {
+ chr.status()
+ step(1)
+ }
+ }
+ // read
+ for (i <- 1 until 28) {
+ testFillRW(i)
+ }
+}
+class SyncQueue2PTestWrapper[T <: Data](
+ gen: T,
+ val entries: Int)
+ extends Module() {
+
+
+ val genType = gen
+
+ val io = IO(new Bundle {
+ val tq = new QueueIO(genType, entries)
+ val rq = new QueueIO(genType, entries)
+
+ })
+
+ val tq = Module(new SyncQueue2PortMem(genType.asUInt, entries))
+ val rq = Module(new Queue(genType.asUInt, entries))
+ io.tq <> tq.io
+ io.rq <> rq.io
+ tq.io.enq.valid := RegNext(io.tq.enq.valid)
+ tq.io.enq.bits := RegNext(io.tq.enq.bits)
+ tq.io.deq.ready := RegNext(io.tq.deq.ready)
+ //connect reference queue inport to test input
+ rq.io.enq.valid := RegNext(io.tq.enq.valid)
+ rq.io.enq.bits := RegNext(io.tq.enq.bits)
+ rq.io.deq.ready := RegNext(io.tq.deq.ready)
+}
+
+class SyncQueue2PTestLongRead24 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueue2PTestWrapper(UInt(16.W), 24),
+ (c:SyncQueue2PTestWrapper[UInt]) => new TestSyncQueue2PLongRead(c))
+class SyncQueue2PTestLongRead13 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueue2PTestWrapper(UInt(16.W), 13),
+ (c:SyncQueue2PTestWrapper[UInt]) => new TestSyncQueue2PLongRead(c))
+class SyncQueue2PTestWaveRead24 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueue2PTestWrapper(UInt(16.W), 24),
+ (c:SyncQueue2PTestWrapper[UInt]) => new TestSyncQueue2PWaveRead(c))
+class SyncQueue2PTestWaveRead1 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueue2PTestWrapper(UInt(16.W), 1),
+ (c:SyncQueue2PTestWrapper[UInt]) => new TestSyncQueue2PWaveRead(c))
+class SyncQueue2PTestWaveRead2 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueue2PTestWrapper(UInt(16.W), 2),
+ (c:SyncQueue2PTestWrapper[UInt]) => new TestSyncQueue2PWaveRead(c))
+class SyncQueue2PTestWaveRead3 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueue2PTestWrapper(UInt(16.W), 3),
+ (c:SyncQueue2PTestWrapper[UInt]) => new TestSyncQueue2PWaveRead(c))
+class SyncQueue2PTestWaveRead4 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueue2PTestWrapper(UInt(16.W), 4),
+ (c:SyncQueue2PTestWrapper[UInt]) => new TestSyncQueue2PWaveRead(c))
diff --git a/hardware/chisel/src/test/scala/unittest/SyncQueueTest.scala b/hardware/chisel/src/test/scala/unittest/SyncQueueTest.scala
new file mode 100644
index 0000000..105d0df
--- /dev/null
+++ b/hardware/chisel/src/test/scala/unittest/SyncQueueTest.scala
@@ -0,0 +1,264 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package unittest
+
+import chisel3._
+import chisel3.util._
+import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}
+import scala.util.Random
+import unittest.util._
+import vta.util._
+import vta.util.config._
+
+class TestOnePortMem(c: OnePortMem[UInt]) extends PeekPokeTester(c) {
+
+ // write a:0 d:24
+ println("-----------------------------")
+ println("Cycle 0 write 24 to address 0")
+ poke (c.io.wr_en, 1)
+ poke (c.io.wr_data, 24)
+ poke (c.io.ch_en, 1)
+ poke (c.io.addr, 0)
+ step(1)
+ println("-----------------------------")
+ // read a:0
+ println("Cycle 1 read address 0")
+ poke (c.io.wr_en, 0)
+ poke (c.io.addr, 0)
+ poke (c.io.ch_en, 1)
+ step(1)
+ println("-----------------------------")
+ // write a:1 d:99
+ println("Cycle 2 write 99 to address 1")
+ poke (c.io.wr_en, 1)
+ poke (c.io.wr_data, 99)
+ poke (c.io.ch_en, 1)
+ poke (c.io.addr, 1)
+ // read d:24
+ println("Cycle 2 read expect data 24")
+ expect (c.io.rd_data, 24)
+ step(1)
+ println("-----------------------------")
+ println("Cycle 3 should still read data 24")
+ poke (c.io.ch_en, 0)
+ // read d:24
+ expect (c.io.rd_data, 24)
+ step(1)
+ println("-----------------------------")
+ println("Cycle 4 read address 0")
+ poke (c.io.wr_en, 0)
+ poke (c.io.addr, 0)
+ poke (c.io.ch_en, 1)
+ step(1)
+ println("-----------------------------")
+ // write a:1 d:99
+ poke (c.io.wr_en, 0)
+ poke (c.io.wr_data, 99)
+ poke (c.io.ch_en, 0)
+ poke (c.io.addr, 1)
+ // read d:24
+ println("Cycle 5 read expect data 24")
+ expect (c.io.rd_data, 24)
+ step(1)
+ println("-----------------------------")
+}
+class Checker(c: SyncQueueTestWrapper[UInt], t: PeekPokeTester[SyncQueueTestWrapper[UInt]]) {
+
+ def bits (bits: Int) = {
+ t.expect(c.io.tq.deq.bits, bits)
+ t.expect(c.io.rq.deq.bits, bits)
+
+ }
+ def ready (bits: Int) = {
+ t.expect(c.io.tq.enq.ready, bits)
+ t.expect(c.io.rq.enq.ready, bits)
+
+ }
+ def valid (bits: Int) = {
+ t.expect(c.io.tq.deq.valid, bits)
+ t.expect(c.io.rq.deq.valid, bits)
+
+ }
+ def status () = {
+ val rv = t.peek(c.io.rq.enq.ready)
+ t.expect(c.io.tq.enq.ready, rv)
+ val rc = t.peek(c.io.rq.count)
+ t.expect(c.io.tq.count, rc)
+ val vv = t.peek(c.io.rq.deq.valid)
+ t.expect(c.io.tq.deq.valid, vv)
+ if (vv != 0) {
+ val bv = t.peek(c.io.rq.deq.bits)
+ t.expect(c.io.tq.deq.bits, bv)
+ }
+ t.peek(c.io.rq.count)
+ t.peek(c.io.tq.count)
+ }
+}
+class TestSyncQueueLongRead(c: SyncQueueTestWrapper[UInt]) extends PeekPokeTester(c) {
+
+ val chr = new Checker (c, this)
+
+ def testFillRW(depth: Int) = {
+ val qsize = peek(c.io.tq.count)
+ require(qsize == 0, s"-F- An empty queue is expected ${qsize}")
+
+ poke (c.io.tq.deq.ready, 0)
+ poke (c.io.tq.enq.valid, 0)
+ chr.ready(1)
+ step(1)
+
+ // fill up to depth
+ for (i <- 10 until 10 + depth) {
+ poke (c.io.tq.enq.bits, i)
+ poke (c.io.tq.enq.valid, 1)
+ chr.status()
+ step(1)
+
+ }
+ // read and write same cycle
+ for (i <- 30 + depth until 30 + depth * 2) {
+ poke (c.io.tq.enq.valid, 1)
+ poke (c.io.tq.deq.ready, 1)
+ poke (c.io.tq.enq.bits, i)
+ chr.status()
+ step(1)
+ }
+ // read out
+ for (i <- 0 until depth + 1) {
+ poke (c.io.tq.enq.valid, 0)
+ poke (c.io.tq.deq.ready, 1)
+ poke (c.io.tq.enq.bits, 99)
+ chr.status()
+ step(1)
+ }
+ }
+ for (i <- 1 until 28) {
+ testFillRW(i)
+ }
+}
+class TestSyncQueueWaveRead(c: SyncQueueTestWrapper[UInt]) extends PeekPokeTester(c) {
+
+ val chr = new Checker (c, this)
+
+ def testFillRW(depth: Int) = {
+ val qsize = peek(c.io.tq.count)
+ require(qsize == 0, s"-F- An empty queue is expected ${qsize}")
+
+ poke (c.io.tq.deq.ready, 0)
+ poke (c.io.tq.enq.valid, 0)
+ chr.ready(1)
+ step(1)
+
+ // fill up to depth
+ for (i <- 10 until 10 + depth) {
+ poke (c.io.tq.enq.bits, i)
+ poke (c.io.tq.enq.valid, 1)
+ chr.status()
+ step(1)
+
+ }
+ // read out, no write
+ poke (c.io.tq.enq.valid, 0)
+ poke (c.io.tq.deq.ready, 1)
+ for (i <- 0 until 7) {
+ chr.status()
+ step(1)
+ }
+ // fill more
+ poke (c.io.tq.deq.ready, 0)
+ poke (c.io.tq.enq.valid, 1)
+ for (i <- 0 until 13) {
+ poke (c.io.tq.enq.bits, 99 + i)
+ chr.status()
+ step(1)
+ }
+ // read out, no write
+ poke (c.io.tq.enq.valid, 0)
+ poke (c.io.tq.deq.ready, 1)
+ for (i <- 1 until 14 + depth) {
+ chr.status()
+ step(1)
+ }
+ }
+ // read
+ for (i <- 1 until 28) {
+ testFillRW(i)
+ }
+}
+
+class SyncQueueTestWrapper[T <: Data](
+ gen: T,
+ val entries: Int)
+ extends Module() {
+
+
+ val genType = gen
+
+ val io = IO(new Bundle {
+ val tq = new QueueIO(genType, entries)
+ val rq = new QueueIO(genType, entries)
+
+ })
+
+ val tq = Module(new SyncQueue1PortMem(genType.asUInt, entries))
+ val rq = Module(new Queue(genType.asUInt, entries))
+ io.tq <> tq.io
+ io.rq <> rq.io
+ tq.io.enq.valid := RegNext(io.tq.enq.valid)
+ tq.io.enq.bits := RegNext(io.tq.enq.bits)
+ tq.io.deq.ready := RegNext(io.tq.deq.ready)
+ // connect reference queue inport to test input
+ rq.io.enq.valid := RegNext(io.tq.enq.valid)
+ rq.io.enq.bits := RegNext(io.tq.enq.bits)
+ rq.io.deq.ready := RegNext(io.tq.deq.ready)
+}
+
+class SyncQueueTestLongRead24 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueueTestWrapper(UInt(16.W), 24),
+ (c:SyncQueueTestWrapper[UInt]) => new TestSyncQueueLongRead(c))
+class SyncQueueTestLongRead13 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueueTestWrapper(UInt(16.W), 13),
+ (c:SyncQueueTestWrapper[UInt]) => new TestSyncQueueLongRead(c))
+class SyncQueueTestWaveRead24 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueueTestWrapper(UInt(16.W), 24),
+ (c:SyncQueueTestWrapper[UInt]) => new TestSyncQueueWaveRead(c))
+class OnePorMemTest extends GenericTest(
+ "Queue",
+ (p:Parameters) => new OnePortMem(UInt(16.W), 16, ""),
+ (c:OnePortMem[UInt]) => new TestOnePortMem(c))
+class SyncQueueTestWaveRead1 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueueTestWrapper(UInt(16.W), 1),
+ (c:SyncQueueTestWrapper[UInt]) => new TestSyncQueueWaveRead(c))
+class SyncQueueTestWaveRead2 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueueTestWrapper(UInt(16.W), 2),
+ (c:SyncQueueTestWrapper[UInt]) => new TestSyncQueueWaveRead(c))
+class SyncQueueTestWaveRead3 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueueTestWrapper(UInt(16.W), 3),
+ (c:SyncQueueTestWrapper[UInt]) => new TestSyncQueueWaveRead(c))
+class SyncQueueTestWaveRead4 extends GenericTest(
+ "Queue",
+ (p:Parameters) => new SyncQueueTestWrapper(UInt(16.W), 4),
+ (c:SyncQueueTestWrapper[UInt]) => new TestSyncQueueWaveRead(c))
diff --git a/hardware/dpi/tsim_device.cc b/hardware/dpi/tsim_device.cc
index ffa192b..be19108 100644
--- a/hardware/dpi/tsim_device.cc
+++ b/hardware/dpi/tsim_device.cc
@@ -17,9 +17,11 @@
* under the License.
*/
+#include <cassert>
#include <chrono>
#include <thread>
#include <vta/dpi/tsim.h>
+#include <verilated.h>
#if VM_TRACE
#ifdef VM_TRACE_FST
@@ -58,19 +60,25 @@ void VTAHostDPI(dpi8_t* req_valid,
resp_valid, resp_value);
}
-void VTAMemDPI(dpi8_t req_valid,
- dpi8_t req_opcode,
- dpi8_t req_len,
- dpi64_t req_addr,
+void VTAMemDPI(dpi8_t rd_req_valid,
+ dpi8_t rd_req_len,
+ dpi8_t rd_req_id,
+ dpi64_t rd_req_addr,
+ dpi8_t wr_req_valid,
+ dpi8_t wr_req_len,
+ dpi64_t wr_req_addr,
dpi8_t wr_valid,
- dpi64_t wr_value,
+ const svOpenArrayHandle wr_value,
+ dpi64_t wr_strb,
dpi8_t* rd_valid,
- dpi64_t* rd_value,
+ dpi8_t* rd_id,
+ const svOpenArrayHandle rd_value,
dpi8_t rd_ready) {
assert(_mem_dpi != nullptr);
- (*_mem_dpi)(_ctx, req_valid, req_opcode, req_len,
- req_addr, wr_valid, wr_value,
- rd_valid, rd_value, rd_ready);
+ (*_mem_dpi)(_ctx, rd_req_valid, rd_req_len, rd_req_id,
+ rd_req_addr, wr_req_valid, wr_req_len, wr_req_addr,
+ wr_valid, wr_value, wr_strb,
+ rd_valid, rd_id,rd_value, rd_ready);
}
diff --git a/include/vta/dpi/tsim.h b/include/vta/dpi/tsim.h
index 8e13def..750e3f2 100644
--- a/include/vta/dpi/tsim.h
+++ b/include/vta/dpi/tsim.h
@@ -22,6 +22,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <stdint.h>
+#include <svdpi.h>
#ifdef __cplusplus
extern "C" {
@@ -75,14 +76,19 @@ typedef void (*VTAHostDPIFunc)(
*/
typedef void (*VTAMemDPIFunc)(
VTAContextHandle self,
- dpi8_t req_valid,
- dpi8_t req_opcode,
- dpi8_t req_len,
- dpi64_t req_addr,
+ dpi8_t rd_req_valid,
+ dpi8_t rd_req_len,
+ dpi8_t rd_req_id,
+ dpi64_t rd_req_addr,
+ dpi8_t wr_req_valid,
+ dpi8_t wr_req_len,
+ dpi64_t wr_req_addr,
dpi8_t wr_valid,
- dpi64_t wr_value,
+ const svOpenArrayHandle wr_value,
+ dpi64_t wr_strb,
dpi8_t* rd_valid,
- dpi64_t* rd_value,
+ dpi8_t* rd_id,
+ const svOpenArrayHandle rd_value,
dpi8_t rd_ready);
/*! \brief The type of VTADPIInit function pointer */
diff --git a/src/dpi/module.cc b/src/dpi/module.cc
index bb8284c..def1305 100644
--- a/src/dpi/module.cc
+++ b/src/dpi/module.cc
@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <assert.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
@@ -37,6 +38,10 @@
#include "../vmem/virtual_memory.h"
+// Include verilator array access functions code
+#include "verilated.cpp"
+#include "verilated_dpi.cpp"
+
namespace vta {
namespace dpi {
@@ -56,7 +61,8 @@ struct HostResponse {
struct MemResponse {
uint8_t valid;
- uint64_t value;
+ uint8_t id;
+ uint64_t* value;
};
template <typename T>
@@ -118,16 +124,29 @@ class HostDevice {
class MemDevice {
public:
- void SetRequest(uint8_t opcode, uint64_t addr, uint32_t len);
- MemResponse ReadData(uint8_t ready);
- void WriteData(uint64_t value);
+ void SetRequest(
+ uint8_t rd_req_valid,
+ uint64_t rd_req_addr,
+ uint32_t rd_req_len,
+ uint32_t rd_req_id,
+ uint64_t wr_req_addr,
+ uint32_t wr_req_len,
+ uint8_t wr_req_valid);
+ MemResponse ReadData(uint8_t ready, int blkNb);
+ void WriteData(svOpenArrayHandle value, uint64_t wr_strb);
private:
uint64_t* raddr_{0};
uint64_t* waddr_{0};
uint32_t rlen_{0};
+ uint32_t rid_{0};
uint32_t wlen_{0};
std::mutex mutex_;
+ uint64_t dead_beef_ [8] = {0xdeadbeefdeadbeef,0xdeadbeefdeadbeef,
+ 0xdeadbeefdeadbeef,0xdeadbeefdeadbeef,
+ 0xdeadbeefdeadbeef,0xdeadbeefdeadbeef,
+ 0xdeadbeefdeadbeef,0xdeadbeefdeadbeef };
+
};
void SimDevice::Wait() {
@@ -180,36 +199,74 @@ void HostDevice::WaitPopResponse(HostResponse* r) {
resp_.WaitPop(r);
}
-void MemDevice::SetRequest(uint8_t opcode, uint64_t addr, uint32_t len) {
+void MemDevice::SetRequest(
+ uint8_t rd_req_valid,
+ uint64_t rd_req_addr,
+ uint32_t rd_req_len,
+ uint32_t rd_req_id,
+ uint64_t wr_req_addr,
+ uint32_t wr_req_len,
+ uint8_t wr_req_valid) {
+
std::lock_guard<std::mutex> lock(mutex_);
- void * vaddr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(addr);
-
- if (opcode == 1) {
- wlen_ = len + 1;
- waddr_ = reinterpret_cast<uint64_t*>(vaddr);
- } else {
- rlen_ = len + 1;
- raddr_ = reinterpret_cast<uint64_t*>(vaddr);
+ if(rd_req_addr !=0 ){
+ void * rd_vaddr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(rd_req_addr);
+ if(rd_req_valid == 1) {
+ rlen_ = rd_req_len + 1;
+ rid_ = rd_req_id;
+ raddr_ = reinterpret_cast<uint64_t*>(rd_vaddr);
+ }
+ }
+
+ if(wr_req_addr != 0){
+ void * wr_vaddr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(wr_req_addr);
+ if (wr_req_valid == 1) {
+ wlen_ = wr_req_len + 1;
+ waddr_ = reinterpret_cast<uint64_t*>(wr_vaddr);
+ }
}
}
-MemResponse MemDevice::ReadData(uint8_t ready) {
+MemResponse MemDevice::ReadData(uint8_t ready, int blkNb) {
std::lock_guard<std::mutex> lock(mutex_);
MemResponse r;
r.valid = rlen_ > 0;
- r.value = rlen_ > 0 ? *raddr_ : 0xdeadbeefdeadbeef;
+ r.value = rlen_ > 0 ? raddr_ : dead_beef_;
+ r.id = rid_;
if (ready == 1 && rlen_ > 0) {
- raddr_++;
+ raddr_ += blkNb;
rlen_ -= 1;
}
return r;
}
-void MemDevice::WriteData(uint64_t value) {
+void MemDevice::WriteData(svOpenArrayHandle value, uint64_t wr_strb) {
+
+ int lftIdx = svLeft(value, 1);
+ int rgtIdx = svRight(value, 1);
+ int blkNb = lftIdx - rgtIdx + 1;
+ assert(lftIdx >= 0);
+ assert(rgtIdx >= 0);
+ assert(lftIdx >= rgtIdx);
+ assert(blkNb > 0);
+ // supported up to 64bit strb
+ assert(blkNb <= 8);
+
std::lock_guard<std::mutex> lock(mutex_);
+ int strbMask = 0xff;
if (wlen_ > 0) {
- *waddr_ = value;
- waddr_++;
+ for (int idx = 0 ; idx < blkNb; ++idx) {
+ int strbFlags = (wr_strb >> (idx * 8)) & strbMask;
+ if (!(strbFlags == 0 || strbFlags == strbMask)) {
+ LOG(FATAL) << "Unexpected strb data " << (void*)wr_strb;
+ }
+ if (strbFlags != 0) {
+ uint64_t* elemPtr = (uint64_t*)svGetArrElemPtr1(value, rgtIdx + idx);
+ assert(elemPtr != NULL);
+ waddr_[idx] = (*elemPtr);
+ }
+ }
+ waddr_ += blkNb;
wlen_ -= 1;
}
}
@@ -229,9 +286,9 @@ class DPIModule final : public DPIModuleNode {
const ObjectPtr<Object>& sptr_to_self) final {
if (name == "WriteReg") {
return TypedPackedFunc<void(int, int)>(
- [this](int addr, int value){
- this->WriteReg(addr, value);
- });
+ [this](int addr, int value){
+ this->WriteReg(addr, value);
+ });
} else {
LOG(FATAL) << "Member " << name << "does not exists";
return nullptr;
@@ -241,7 +298,7 @@ class DPIModule final : public DPIModuleNode {
void Init(const std::string& name) {
Load(name);
VTADPIInitFunc finit = reinterpret_cast<VTADPIInitFunc>(
- GetSymbol("VTADPIInit"));
+ GetSymbol("VTADPIInit"));
CHECK(finit != nullptr);
finit(this, VTASimDPI, VTAHostDPI, VTAMemDPI);
ftsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim"));
@@ -314,24 +371,65 @@ class DPIModule final : public DPIModuleNode {
}
void MemDPI(
- dpi8_t req_valid,
- dpi8_t req_opcode,
- dpi8_t req_len,
- dpi64_t req_addr,
+ dpi8_t rd_req_valid,
+ dpi8_t rd_req_len,
+ dpi8_t rd_req_id,
+ dpi64_t rd_req_addr,
+ dpi8_t wr_req_valid,
+ dpi8_t wr_req_len,
+ dpi64_t wr_req_addr,
dpi8_t wr_valid,
- dpi64_t wr_value,
+ const svOpenArrayHandle wr_value,
+ dpi64_t wr_strb,
dpi8_t* rd_valid,
- dpi64_t* rd_value,
+ dpi8_t* rd_id,
+ const svOpenArrayHandle rd_value,
dpi8_t rd_ready) {
- MemResponse r = mem_device_.ReadData(rd_ready);
- *rd_valid = r.valid;
- *rd_value = r.value;
+
+ // check data pointers
+ // data is expected to come in 64bit chunks
+ // up to 512 bits total
+ // more bits require wider strb data
+ assert(wr_value != NULL);
+ assert(svDimensions(wr_value) == 1);
+ assert(svSize(wr_value, 1) <= 8);
+ assert(svSize(wr_value, 0) == 64);
+ assert(rd_value != NULL);
+ assert(svDimensions(rd_value) == 1);
+ assert(svSize(rd_value, 1) <= 8);
+ assert(svSize(rd_value, 0) == 64);
+
+ int lftIdx = svLeft(rd_value, 1);
+ int rgtIdx = svRight(rd_value, 1);
+ int blkNb = lftIdx - rgtIdx + 1;
+ assert(lftIdx >= 0);
+ assert(rgtIdx >= 0);
+ assert(lftIdx >= rgtIdx);
+ assert(blkNb > 0);
+
if (wr_valid) {
- mem_device_.WriteData(wr_value);
+ mem_device_.WriteData(wr_value, wr_strb);
}
- if (req_valid) {
- mem_device_.SetRequest(req_opcode, req_addr, req_len);
+ if (rd_req_valid || wr_req_valid) {
+ mem_device_.SetRequest(
+ rd_req_valid,
+ rd_req_addr,
+ rd_req_len,
+ rd_req_id,
+ wr_req_addr,
+ wr_req_len,
+ wr_req_valid);
+ }
+
+
+ MemResponse r = mem_device_.ReadData(rd_ready, blkNb);
+ *rd_valid = r.valid;
+ for (int idx = 0; idx < blkNb; idx ++) {
+ uint64_t* dataPtr = (uint64_t*)svGetArrElemPtr1(rd_value, rgtIdx + idx);
+ assert(dataPtr != NULL);
+ (*dataPtr) = r.value[idx];
}
+ *rd_id = r.id;
}
static void VTASimDPI(
@@ -352,25 +450,31 @@ class DPIModule final : public DPIModuleNode {
dpi8_t resp_valid,
dpi32_t resp_value) {
static_cast<DPIModule*>(self)->HostDPI(
- req_valid, req_opcode, req_addr,
- req_value, req_deq, resp_valid, resp_value);
+ req_valid, req_opcode, req_addr,
+ req_value, req_deq, resp_valid, resp_value);
}
static void VTAMemDPI(
VTAContextHandle self,
- dpi8_t req_valid,
- dpi8_t req_opcode,
- dpi8_t req_len,
- dpi64_t req_addr,
+ dpi8_t rd_req_valid,
+ dpi8_t rd_req_len,
+ dpi8_t rd_req_id,
+ dpi64_t rd_req_addr,
+ dpi8_t wr_req_valid,
+ dpi8_t wr_req_len,
+ dpi64_t wr_req_addr,
dpi8_t wr_valid,
- dpi64_t wr_value,
+ const svOpenArrayHandle wr_value,
+ dpi64_t wr_strb,
dpi8_t* rd_valid,
- dpi64_t* rd_value,
+ dpi8_t* rd_id,
+ const svOpenArrayHandle rd_value,
dpi8_t rd_ready) {
static_cast<DPIModule*>(self)->MemDPI(
- req_valid, req_opcode, req_len,
- req_addr, wr_valid, wr_value,
- rd_valid, rd_value, rd_ready);
+ rd_req_valid, rd_req_len, rd_req_id,
+ rd_req_addr, wr_req_valid, wr_req_len, wr_req_addr,
+ wr_valid, wr_value, wr_strb,
+ rd_valid, rd_id, rd_value, rd_ready);
}
private:
diff --git a/tests/scripts/docker_bash.sh b/tests/scripts/docker_bash.sh
index cdda5d4..e68d6af 100755
--- a/tests/scripts/docker_bash.sh
+++ b/tests/scripts/docker_bash.sh
@@ -67,12 +67,6 @@ echo "WORKSPACE: ${WORKSPACE}"
echo "DOCKER CONTAINER NAME: ${DOCKER_IMAGE_NAME}"
echo ""
-# FIXME(zhanghao): re-enable the tsim test after ISA is updated
-if [[ ${COMMAND[@]} == "./tests/scripts/task_python_vta_tsim.sh" ]]; then
- echo "Skip '${COMMAND[@]}'"
- exit
-fi
-
echo "Running '${COMMAND[@]}' inside ${DOCKER_IMAGE_NAME}..."
# By default we cleanup - remove the container once it finish running (--rm)