You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/04/17 16:44:00 UTC

[tvm-vta] branch main updated: adapt chisel impl to new VTA ISA (#24)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-vta.git


The following commit(s) were added to refs/heads/main by this push:
     new 4319417  adapt chisel impl to new VTA ISA (#24)
4319417 is described below

commit 43194178b4e570a5f1dd4f3f9d37ee16fc1b65be
Author: Luis Vega <ve...@users.noreply.github.com>
AuthorDate: Sat Apr 17 09:43:52 2021 -0700

    adapt chisel impl to new VTA ISA (#24)
    
    * adapt chisel impl to new VTA ISA
    
    * add comment
    
    * rename variable
    
    * update comment
    
    * remove comment
---
 hardware/chisel/src/main/scala/core/Decode.scala    |  9 ++++-----
 hardware/chisel/src/main/scala/core/ISA.scala       | 19 ++++++++++---------
 hardware/chisel/src/main/scala/core/TensorAlu.scala | 10 ++++++----
 3 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/hardware/chisel/src/main/scala/core/Decode.scala b/hardware/chisel/src/main/scala/core/Decode.scala
index 37f6ab4..dc8d3e1 100644
--- a/hardware/chisel/src/main/scala/core/Decode.scala
+++ b/hardware/chisel/src/main/scala/core/Decode.scala
@@ -43,7 +43,7 @@ class MemDecode extends Bundle {
   val xstride = UInt(M_STRIDE_BITS.W)
   val xsize = UInt(M_SIZE_BITS.W)
   val ysize = UInt(M_SIZE_BITS.W)
-  val empty_0 = UInt(7.W) // derive this
+  val empty_0 = UInt(6.W) // derive this
   val dram_offset = UInt(M_DRAM_OFFSET_BITS.W)
   val sram_offset = UInt(M_SRAM_OFFSET_BITS.W)
   val id = UInt(M_ID_BITS.W)
@@ -90,12 +90,11 @@ class GemmDecode extends Bundle {
  *   - VSHX
  */
 class AluDecode extends Bundle {
-  val empty_1 = Bool()
   val alu_imm = UInt(C_ALU_IMM_BITS.W)
   val alu_use_imm = Bool()
-  val alu_op = UInt(C_ALU_DEC_BITS.W)
-  val src_1 = UInt(C_IIDX_BITS.W)
-  val src_0 = UInt(C_IIDX_BITS.W)
+  val alu_op = UInt(C_ALU_OP_BITS.W)
+  val src_1 = UInt(C_AIDX_BITS.W)
+  val src_0 = UInt(C_AIDX_BITS.W)
   val dst_1 = UInt(C_AIDX_BITS.W)
   val dst_0 = UInt(C_AIDX_BITS.W)
   val empty_0 = Bool()
diff --git a/hardware/chisel/src/main/scala/core/ISA.scala b/hardware/chisel/src/main/scala/core/ISA.scala
index bfe89eb..503cc2b 100644
--- a/hardware/chisel/src/main/scala/core/ISA.scala
+++ b/hardware/chisel/src/main/scala/core/ISA.scala
@@ -33,7 +33,7 @@ trait ISAConstants {
   val OP_BITS = 3
 
   val M_DEP_BITS = 4
-  val M_ID_BITS = 2
+  val M_ID_BITS = 3
   val M_SRAM_OFFSET_BITS = 16
   val M_DRAM_OFFSET_BITS = 32
   val M_SIZE_BITS = 16
@@ -46,7 +46,7 @@ trait ISAConstants {
   val C_AIDX_BITS = 11
   val C_IIDX_BITS = 11
   val C_WIDX_BITS = 10
-  val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
+  val C_ALU_DEC_BITS = 2
   val C_ALU_OP_BITS = 3
   val C_ALU_IMM_BITS = 16
 
@@ -67,6 +67,7 @@ trait ISAConstants {
   val M_ID_W = 1.asUInt(M_ID_BITS.W)
   val M_ID_I = 2.asUInt(M_ID_BITS.W)
   val M_ID_A = 3.asUInt(M_ID_BITS.W)
+  val M_ID_O = 4.asUInt(M_ID_BITS.W)
 }
 
 /** ISA.
@@ -82,7 +83,7 @@ object ISA {
   private val depBits = 4
 
   private val idBits: HashMap[String, Int] =
-    HashMap(("task", 3), ("mem", 2), ("alu", 2))
+    HashMap(("task", 3), ("mem", 3), ("alu", 3))
 
   private val taskId: HashMap[String, String] =
     HashMap(("load", "000"),
@@ -92,13 +93,13 @@ object ISA {
       ("alu", "100"))
 
   private val memId: HashMap[String, String] =
-    HashMap(("uop", "00"), ("wgt", "01"), ("inp", "10"), ("acc", "11"))
+    HashMap(("uop", "000"), ("wgt", "001"), ("inp", "010"), ("acc", "011"), ("out", "100"))
 
   private val aluId: HashMap[String, String] =
-    HashMap(("minpool", "00"),
-      ("maxpool", "01"),
-      ("add", "10"),
-      ("shift", "11"))
+    HashMap(("minpool", "000"),
+      ("maxpool", "001"),
+      ("add", "010"),
+      ("shift", "011"))
 
   private def dontCare(bits: Int): String = "?" * bits
 
@@ -124,7 +125,7 @@ object ISA {
 
   private def alu(id: String): BitPat = {
     // TODO: move alu id next to task id
-    val inst = dontCare(18) + aluId(id) + dontCare(105) + taskId("alu")
+    val inst = dontCare(17) + aluId(id) + dontCare(105) + taskId("alu")
     instPat(inst)
   }
 
diff --git a/hardware/chisel/src/main/scala/core/TensorAlu.scala b/hardware/chisel/src/main/scala/core/TensorAlu.scala
index 6af3c83..81abb8e 100644
--- a/hardware/chisel/src/main/scala/core/TensorAlu.scala
+++ b/hardware/chisel/src/main/scala/core/TensorAlu.scala
@@ -39,6 +39,7 @@ class Alu(implicit p: Parameters) extends Module {
   val m = ~ub(width - 1, 0) + 1.U
 
   val n = ub(width - 1, 0)
+  // opcode - min:0, max:1, add:2, shr:3, shl:4
   val fop = Seq(Mux(io.a < io.b, io.a, io.b), Mux(io.a < io.b, io.b, io.a),
     io.a + io.b, io.a >> n, io.a << m)
 
@@ -214,7 +215,7 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
   io.uop.idx.valid := state === sReadUop
   io.uop.idx.bits := uop_idx
 
-  // acc_i
+  // acc (input)
   io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
   io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src)
 
@@ -230,8 +231,9 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
 
   // alu
   val isSHR = dec.alu_op === ALU_OP(3)
-  val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
-  val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
+  val isSHL = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
+  // opcode - min:0, max:1, add:2, shr:3, shl:4
+  val fixme_alu_op = Cat(isSHL, Mux(isSHL, 0.U, dec.alu_op(1, 0)))
   alu.io.opcode := fixme_alu_op
   alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
   alu.io.acc_a.data.bits <> io.acc.rd.data.bits
@@ -242,7 +244,7 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
     tensorImm.data.bits,
     io.acc.rd.data.bits)
 
-  // acc_o
+  // acc (output)
   io.acc.wr.valid := alu.io.acc_y.data.valid
   io.acc.wr.bits.idx := uop_dst
   io.acc.wr.bits.data <> alu.io.acc_y.data.bits