feat(sc): change SCTable dual-port Sram to singlePort

This commit is contained in:
YuanDL 2025-09-30 15:16:56 +08:00
parent 0443023029
commit 00f9c5d2cc
1 changed files with 56 additions and 26 deletions

View File

@ -38,7 +38,10 @@ import utility.sram.SRAMConflictBehavior
import utility.sram.SRAMTemplate
import xiangshan._
trait HasSCParameter extends TageParams {}
trait HasSCParameter extends TageParams {
val nBanks = 2
def bankIdxWidth = log2Ceil(nBanks)
}
class SCReq(implicit p: Parameters) extends TageReq
@ -46,6 +49,7 @@ abstract class SCBundle(implicit p: Parameters) extends TageBundle with HasSCPar
abstract class SCModule(implicit p: Parameters) extends TageModule with HasSCParameter {}
class SCMeta(val ntables: Int)(implicit p: Parameters) extends XSBundle with HasSCParameter {
val valid = Bool()
val scPreds = Vec(numBr, Bool())
// Suppose ctrbits of all tables are identical
val ctrs = Vec(numBr, Vec(ntables, SInt(SCCtrBits.W)))
@ -66,7 +70,7 @@ class SCUpdate(val ctrBits: Int = 6)(implicit p: Parameters) extends SCBundle {
class SCTableIO(val ctrBits: Int = 6)(implicit p: Parameters) extends SCBundle {
val req = Input(Valid(new SCReq))
val resp = Output(new SCResp(ctrBits))
val resp = Output(Valid(new SCResp(ctrBits)))
val update = Input(new SCUpdate(ctrBits))
}
@ -74,19 +78,21 @@ class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int)(implicit p: Pa
extends SCModule with HasFoldedHistory {
val io = IO(new SCTableIO(ctrBits))
def get_bank_mask(idx: UInt): Vec[Bool] = VecInit((0 until nBanks).map(idx(bankIdxWidth - 1, 0) === _.U))
def get_bank_idx(idx: UInt): UInt = idx >> bankIdxWidth
// val table = Module(new SRAMTemplate(SInt(ctrBits.W), set=nRows, way=2*TageBanks, shouldReset=true, holdRead=true, singlePort=false))
val table = Module(new SRAMTemplate(
val table = Seq.fill(nBanks)(Module(new SRAMTemplate(
SInt(ctrBits.W),
set = nRows,
way = 2 * TageBanks,
shouldReset = true,
holdRead = true,
singlePort = false,
conflictBehavior = SRAMConflictBehavior.BufferWriteLossy,
singlePort = true,
withClockGate = true,
hasMbist = hasMbist,
hasSramCtl = hasSramCtl
))
)))
private val mbistPl = MbistPipeline.PlaceMbistPipeline(1, "MbistPipeSc", hasMbist)
// def getIdx(hist: UInt, pc: UInt) = {
// (compute_folded_ghist(hist, log2Ceil(nRows)) ^ (pc >> instOffsetBits))(log2Ceil(nRows)-1,0)
@ -107,16 +113,27 @@ class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int)(implicit p: Pa
def ctrUpdate(ctr: SInt, cond: Bool): SInt = signedSatUpdate(ctr, ctrBits, cond)
val s0_idx = getIdx(io.req.bits.pc, io.req.bits.folded_hist)
val s1_idx = RegEnable(s0_idx, io.req.valid)
val s0_idx = getIdx(io.req.bits.pc, io.req.bits.folded_hist)
val s0_bank_idx = get_bank_idx(s0_idx)
val s0_bank_mask = get_bank_mask(s0_idx)
val s0_invalid_by_conflict = Mux1H(s0_bank_mask, table.map(_.io.w.req.valid))
val s1_idx = RegEnable(s0_idx, io.req.valid)
val s1_bank_idx = get_bank_idx(s1_idx)
val s1_bank_mask = get_bank_mask(s1_idx)
val s1_resp_invalid_by_conflict = RegEnable(s0_invalid_by_conflict, io.req.valid)
val s1_pc = RegEnable(io.req.bits.pc, io.req.fire)
val s1_unhashed_idx = s1_pc >> instOffsetBits
table.io.r.req.valid := io.req.valid
table.io.r.req.bits.setIdx := s0_idx
table.zip(s0_bank_mask).foreach { case (bank, bankEnable) =>
bank.io.r.req.valid := io.req.valid && bankEnable
bank.io.r.req.bits.setIdx := s0_bank_idx
}
val per_br_ctrs_unshuffled = table.io.r.resp.data.sliding(2, 2).toSeq.map(VecInit(_))
val table_resp = Mux1H(s1_bank_mask, table.map(_.io.r.resp))
val per_br_ctrs_unshuffled = table_resp.data.sliding(2, 2).toSeq.map(VecInit(_))
val per_br_ctrs = VecInit((0 until numBr).map(i =>
Mux1H(
UIntToOH(get_phy_br_idx(s1_unhashed_idx, i), numBr),
@ -124,7 +141,8 @@ class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int)(implicit p: Pa
)
))
io.resp.ctrs := per_br_ctrs
io.resp.valid := !s1_resp_invalid_by_conflict
io.resp.bits.ctrs := per_br_ctrs
val update_wdata = Wire(Vec(numBr, SInt(ctrBits.W))) // correspond to physical bridx
val update_wdata_packed = VecInit(update_wdata.map(Seq.fill(2)(_)).reduce(_ ++ _))
@ -144,14 +162,18 @@ class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int)(implicit p: Pa
if (histLen > 0) {
update_folded_hist.getHistWithInfo(idxFhInfo).folded_hist := compute_folded_ghist(io.update.ghist, log2Ceil(nRows))
}
val update_idx = getIdx(io.update.pc, update_folded_hist)
val update_idx = getIdx(io.update.pc, update_folded_hist)
val update_bank_idx = get_bank_idx(update_idx)
val update_bank_mask = get_bank_mask(update_idx)
table.io.w.apply(
valid = io.update.mask.reduce(_ || _),
data = update_wdata_packed,
setIdx = update_idx,
waymask = updateWayMask.asUInt
)
table.zip(update_bank_mask).foreach { case (bank, bankEnable) =>
bank.io.w.apply(
valid = io.update.mask.reduce(_ || _) && bankEnable,
data = update_wdata_packed,
setIdx = update_bank_idx,
waymask = updateWayMask.asUInt
)
}
val wrBypassEntries = 16
@ -192,7 +214,7 @@ class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int)(implicit p: Pa
XSDebug(
RegNext(io.req.valid),
p"scTableResp: s1_idx=${s1_idx}," +
p"ctr:${io.resp.ctrs}\n"
p"ctr:${io.resp.bits.ctrs}\n"
)
XSDebug(
io.update.mask.reduce(_ || _),
@ -297,19 +319,25 @@ trait HasSC extends HasSCParameter with HasPerfEvents { this: Tage =>
// do summation in s2
val s1_scTableSums = VecInit(
(0 to 1) map { i =>
ParallelSingedExpandingAdd(s1_scResps map (r => getCentered(r.ctrs(w)(i)))) // TODO: rewrite with wallace tree
ParallelSingedExpandingAdd(s1_scResps map (r =>
getCentered(r.bits.ctrs(w)(i))
)) // TODO: rewrite with wallace tree
}
)
val s2_scResps = VecInit(s1_scResps.map(RegEnable(_, io.s1_fire(3))))
val s2_scRespValids = VecInit(s2_scResps.map(_.valid))
val s2_scTableSums = RegEnable(s1_scTableSums, io.s1_fire(3))
val s2_tagePrvdCtrCentered = getPvdrCentered(RegEnable(s1_providerResps(w).ctr, io.s1_fire(3)))
val s2_totalSums = s2_scTableSums.map(_ +& s2_tagePrvdCtrCentered)
val s2_sumAboveThresholds =
VecInit((0 to 1).map(i => aboveThreshold(s2_scTableSums(i), s2_tagePrvdCtrCentered, useThresholds(w))))
VecInit((0 to 1).map(i =>
aboveThreshold(s2_scTableSums(i), s2_tagePrvdCtrCentered, useThresholds(w)) && s2_scRespValids(i)
))
val s2_scPreds = VecInit(s2_totalSums.map(_ >= 0.S))
val s2_scResps = VecInit(RegEnable(s1_scResps, io.s1_fire(3)).map(_.ctrs(w)))
val s2_scCtrs = VecInit(s2_scResps.map(_(s2_tageTakens_dup(3)(w).asUInt)))
val s2_chooseBit = s2_tageTakens_dup(3)(w)
val s2_scRespsCtrs = VecInit(s2_scResps.map(_.bits.ctrs(w)))
val s2_scCtrs = VecInit(s2_scRespsCtrs.map(_(s2_tageTakens_dup(3)(w).asUInt)))
val s2_chooseBit = s2_tageTakens_dup(3)(w)
val s2_pred =
Mux(s2_provideds(w) && s2_sumAboveThresholds(s2_chooseBit), s2_scPreds(s2_chooseBit), s2_tageTakens_dup(3)(w))
@ -317,6 +345,7 @@ trait HasSC extends HasSCParameter with HasPerfEvents { this: Tage =>
val s3_disagree = RegEnable(s2_disagree, io.s2_fire(3))
io.out.last_stage_spec_info.sc_disagree.map(_ := s3_disagree)
scMeta.valid := RegEnable(s2_scRespValids(w), io.s2_fire(3))
scMeta.scPreds(w) := RegEnable(s2_scPreds(s2_chooseBit), io.s2_fire(3))
scMeta.ctrs(w) := RegEnable(s2_scCtrs, io.s2_fire(3))
@ -356,6 +385,7 @@ trait HasSC extends HasSCParameter with HasPerfEvents { this: Tage =>
val tagePred = updateTageMeta.takens(w)
val taken = update.br_taken_mask(w)
val scOldCtrs = updateSCMeta.ctrs(w)
val scPredValid = updateSCMeta.valid
val pvdrCtr = updateTageMeta.providerResps(w).ctr
val tableSum = ParallelSingedExpandingAdd(scOldCtrs.map(getCentered))
val totalSumAbs = (tableSum +& getPvdrCentered(pvdrCtr)).abs.asUInt
@ -363,7 +393,7 @@ trait HasSC extends HasSCParameter with HasPerfEvents { this: Tage =>
val sumAboveThreshold = aboveThreshold(tableSum, getPvdrCentered(pvdrCtr), updateThres)
val thres = useThresholds(w)
val newThres = scThresholds(w).update(scPred =/= taken)
when(updateValids(w) && updateTageMeta.providers(w).valid) {
when(scPredValid && updateValids(w) && updateTageMeta.providers(w).valid) {
scUpdateTagePreds(w) := tagePred
scUpdateTakens(w) := taken
(scUpdateOldCtrs(w) zip scOldCtrs).foreach { case (t, c) => t := c }