You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/04/17 16:44:00 UTC
[tvm-vta] branch main updated: adapt chisel impl to new VTA ISA
(#24)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-vta.git
The following commit(s) were added to refs/heads/main by this push:
new 4319417 adapt chisel impl to new VTA ISA (#24)
4319417 is described below
commit 43194178b4e570a5f1dd4f3f9d37ee16fc1b65be
Author: Luis Vega <ve...@users.noreply.github.com>
AuthorDate: Sat Apr 17 09:43:52 2021 -0700
adapt chisel impl to new VTA ISA (#24)
* adapt chisel impl to new VTA ISA
* add comment
* rename variable
* update comment
* remove comment
---
hardware/chisel/src/main/scala/core/Decode.scala | 9 ++++-----
hardware/chisel/src/main/scala/core/ISA.scala | 19 ++++++++++---------
hardware/chisel/src/main/scala/core/TensorAlu.scala | 10 ++++++----
3 files changed, 20 insertions(+), 18 deletions(-)
diff --git a/hardware/chisel/src/main/scala/core/Decode.scala b/hardware/chisel/src/main/scala/core/Decode.scala
index 37f6ab4..dc8d3e1 100644
--- a/hardware/chisel/src/main/scala/core/Decode.scala
+++ b/hardware/chisel/src/main/scala/core/Decode.scala
@@ -43,7 +43,7 @@ class MemDecode extends Bundle {
val xstride = UInt(M_STRIDE_BITS.W)
val xsize = UInt(M_SIZE_BITS.W)
val ysize = UInt(M_SIZE_BITS.W)
- val empty_0 = UInt(7.W) // derive this
+ val empty_0 = UInt(6.W) // derive this
val dram_offset = UInt(M_DRAM_OFFSET_BITS.W)
val sram_offset = UInt(M_SRAM_OFFSET_BITS.W)
val id = UInt(M_ID_BITS.W)
@@ -90,12 +90,11 @@ class GemmDecode extends Bundle {
* - VSHX
*/
class AluDecode extends Bundle {
- val empty_1 = Bool()
val alu_imm = UInt(C_ALU_IMM_BITS.W)
val alu_use_imm = Bool()
- val alu_op = UInt(C_ALU_DEC_BITS.W)
- val src_1 = UInt(C_IIDX_BITS.W)
- val src_0 = UInt(C_IIDX_BITS.W)
+ val alu_op = UInt(C_ALU_OP_BITS.W)
+ val src_1 = UInt(C_AIDX_BITS.W)
+ val src_0 = UInt(C_AIDX_BITS.W)
val dst_1 = UInt(C_AIDX_BITS.W)
val dst_0 = UInt(C_AIDX_BITS.W)
val empty_0 = Bool()
diff --git a/hardware/chisel/src/main/scala/core/ISA.scala b/hardware/chisel/src/main/scala/core/ISA.scala
index bfe89eb..503cc2b 100644
--- a/hardware/chisel/src/main/scala/core/ISA.scala
+++ b/hardware/chisel/src/main/scala/core/ISA.scala
@@ -33,7 +33,7 @@ trait ISAConstants {
val OP_BITS = 3
val M_DEP_BITS = 4
- val M_ID_BITS = 2
+ val M_ID_BITS = 3
val M_SRAM_OFFSET_BITS = 16
val M_DRAM_OFFSET_BITS = 32
val M_SIZE_BITS = 16
@@ -46,7 +46,7 @@ trait ISAConstants {
val C_AIDX_BITS = 11
val C_IIDX_BITS = 11
val C_WIDX_BITS = 10
- val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
+ val C_ALU_DEC_BITS = 2
val C_ALU_OP_BITS = 3
val C_ALU_IMM_BITS = 16
@@ -67,6 +67,7 @@ trait ISAConstants {
val M_ID_W = 1.asUInt(M_ID_BITS.W)
val M_ID_I = 2.asUInt(M_ID_BITS.W)
val M_ID_A = 3.asUInt(M_ID_BITS.W)
+ val M_ID_O = 4.asUInt(M_ID_BITS.W)
}
/** ISA.
@@ -82,7 +83,7 @@ object ISA {
private val depBits = 4
private val idBits: HashMap[String, Int] =
- HashMap(("task", 3), ("mem", 2), ("alu", 2))
+ HashMap(("task", 3), ("mem", 3), ("alu", 3))
private val taskId: HashMap[String, String] =
HashMap(("load", "000"),
@@ -92,13 +93,13 @@ object ISA {
("alu", "100"))
private val memId: HashMap[String, String] =
- HashMap(("uop", "00"), ("wgt", "01"), ("inp", "10"), ("acc", "11"))
+ HashMap(("uop", "000"), ("wgt", "001"), ("inp", "010"), ("acc", "011"), ("out", "100"))
private val aluId: HashMap[String, String] =
- HashMap(("minpool", "00"),
- ("maxpool", "01"),
- ("add", "10"),
- ("shift", "11"))
+ HashMap(("minpool", "000"),
+ ("maxpool", "001"),
+ ("add", "010"),
+ ("shift", "011"))
private def dontCare(bits: Int): String = "?" * bits
@@ -124,7 +125,7 @@ object ISA {
private def alu(id: String): BitPat = {
// TODO: move alu id next to task id
- val inst = dontCare(18) + aluId(id) + dontCare(105) + taskId("alu")
+ val inst = dontCare(17) + aluId(id) + dontCare(105) + taskId("alu")
instPat(inst)
}
diff --git a/hardware/chisel/src/main/scala/core/TensorAlu.scala b/hardware/chisel/src/main/scala/core/TensorAlu.scala
index 6af3c83..81abb8e 100644
--- a/hardware/chisel/src/main/scala/core/TensorAlu.scala
+++ b/hardware/chisel/src/main/scala/core/TensorAlu.scala
@@ -39,6 +39,7 @@ class Alu(implicit p: Parameters) extends Module {
val m = ~ub(width - 1, 0) + 1.U
val n = ub(width - 1, 0)
+ // opcode - min:0, max:1, add:2, shr:3, shl:4
val fop = Seq(Mux(io.a < io.b, io.a, io.b), Mux(io.a < io.b, io.b, io.a),
io.a + io.b, io.a >> n, io.a << m)
@@ -214,7 +215,7 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
io.uop.idx.valid := state === sReadUop
io.uop.idx.bits := uop_idx
- // acc_i
+ // acc (input)
io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src)
@@ -230,8 +231,9 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
// alu
val isSHR = dec.alu_op === ALU_OP(3)
- val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
- val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
+ val isSHL = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
+ // opcode - min:0, max:1, add:2, shr:3, shl:4
+ val fixme_alu_op = Cat(isSHL, Mux(isSHL, 0.U, dec.alu_op(1, 0)))
alu.io.opcode := fixme_alu_op
alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
alu.io.acc_a.data.bits <> io.acc.rd.data.bits
@@ -242,7 +244,7 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
tensorImm.data.bits,
io.acc.rd.data.bits)
- // acc_o
+ // acc (output)
io.acc.wr.valid := alu.io.acc_y.data.valid
io.acc.wr.bits.idx := uop_dst
io.acc.wr.bits.data <> alu.io.acc_y.data.bits