This commit is contained in:
junxiong-ji 2025-04-22 10:21:37 +08:00 committed by GitHub
commit cb45183e52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 118 additions and 72 deletions

View File

@ -186,6 +186,14 @@ class HashModule(implicit p: Parameters) extends XSModule {
}
class BlockCipherModule(implicit p: Parameters) extends XSModule {
def asVecByWidth(x: UInt, width: Int): Vec[UInt] = {
require(x.getWidth % width == 0)
val numElem: Int = x.getWidth / width
VecInit((0 until numElem).map(i => x(width * (i + 1) - 1, width * i)))
}
def asBytes(x: UInt): Vec[UInt] = asVecByWidth(x, 8)
val io = IO(new Bundle() {
val src = Vec(2, Input(UInt(XLEN.W)))
val func = Input(UInt())
@ -195,36 +203,62 @@ class BlockCipherModule(implicit p: Parameters) extends XSModule {
val (src1, src2, func, funcReg) = (io.src(0), io.src(1), io.func, RegEnable(io.func, io.regEnable))
val src1Bytes = VecInit((0 until 8).map(i => src1(i*8+7, i*8)))
val src2Bytes = VecInit((0 until 8).map(i => src2(i*8+7, i*8)))
val src1Bytes = asBytes(src1)
val src2Bytes = asBytes(src2)
val src1_reg = RegEnable(src1, io.regEnable)
val src2_reg = RegEnable(src2, io.regEnable)
// AES
val aesSboxIn = ForwardShiftRows(src1Bytes, src2Bytes)
val aesSboxMid = Reg(Vec(8, Vec(18, Bool())))
val aesSboxIn = Wire(Vec(8, UInt(8.W)))
val aesSboxOut = Wire(Vec(8, UInt(8.W)))
val iaesSboxIn = InverseShiftRows(src1Bytes, src2Bytes)
val iaesSboxMid = Reg(Vec(8, Vec(18, Bool())))
val iaesSboxIn = Wire(Vec(8, UInt(8.W)))
val iaesSboxOut = Wire(Vec(8, UInt(8.W)))
val ksSboxIn = Wire(Vec(4, UInt(8.W)))
val ksSboxOut = Wire(Vec(4, UInt(8.W)))
aesSboxOut.zip(aesSboxMid).zip(aesSboxIn)foreach { case ((out, mid), in) =>
when (io.regEnable) {
mid := SboxInv(SboxAesTop(in))
}
out := SboxAesOut(mid)
}
iaesSboxOut.zip(iaesSboxMid).zip(iaesSboxIn)foreach { case ((out, mid), in) =>
when (io.regEnable) {
mid := SboxInv(SboxIaesTop(in))
}
out := SboxIaesOut(mid)
val isInvAes = func(1) // 0: aes64es/aes64esm/aes64ks1i 1: aes64ds/aes64dsm
val isKsAes = func(2) // 0: aes64es/aes64esm 1: aes64ks1i
val aesEncSboxIn = Wire(Vec(8, UInt(8.W)))
val aesDecSboxIn = Wire(Vec(8, UInt(8.W)))
val aesEncSboxInLow = Wire(Vec(4, UInt(8.W)))
val aesEncSboxInHigh = Wire(Vec(4, UInt(8.W)))
aesEncSboxIn := ForwardShiftRows(src1Bytes, src2Bytes)
aesDecSboxIn := InverseShiftRows(src1Bytes, src2Bytes)
aesEncSboxInLow := aesEncSboxIn.take(4)
aesEncSboxInHigh := aesEncSboxIn.drop(4)
ksSboxIn := Mux(
src2(3, 0) === 0xa.U(4.W),
VecInit(src1Bytes(4), src1Bytes(5), src1Bytes(6), src1Bytes(7)),
VecInit(src1Bytes(5), src1Bytes(6), src1Bytes(7), src1Bytes(4))
)
aesSboxIn := Mux(isKsAes, ksSboxIn, aesEncSboxInLow) ++ aesEncSboxInHigh
iaesSboxIn := aesDecSboxIn
(aesSboxIn lazyZip iaesSboxIn lazyZip aesSboxOut lazyZip iaesSboxOut).foreach {
case (fwdIn, invIn, fwdOut, invOut) =>
val aesSBox1FwdOut = SboxAesTop(fwdIn)
val aesSBox1InvOut = SboxIaesTop(invIn)
val aesSBox2In = Mux(isInvAes, aesSBox1InvOut, aesSBox1FwdOut)
val aesSBox2Out = Reg(UInt(18.W)) // 1 cycle delay
when (io.regEnable) {
aesSBox2Out := SboxInv(aesSBox2In)
}
fwdOut := SboxAesOut(aesSBox2Out)
invOut := SboxIaesOut(aesSBox2Out)
}
ksSboxOut := aesSboxOut.take(4)
// AES encryption/decryption
val aes64es = aesSboxOut.asUInt
val aes64ds = iaesSboxOut.asUInt
val imMinIn = RegEnable(src1Bytes, io.regEnable)
val imMinIn = asBytes(src1_reg)
val aes64esm = Cat(MixFwd(Seq(aesSboxOut(4), aesSboxOut(5), aesSboxOut(6), aesSboxOut(7))),
MixFwd(Seq(aesSboxOut(0), aesSboxOut(1), aesSboxOut(2), aesSboxOut(3))))
@ -233,30 +267,19 @@ class BlockCipherModule(implicit p: Parameters) extends XSModule {
val aes64im = Cat(MixInv(Seq(imMinIn(4), imMinIn(5), imMinIn(6), imMinIn(7))),
MixInv(Seq(imMinIn(0), imMinIn(1), imMinIn(2), imMinIn(3))))
// AES key schedule
val rconSeq: Seq[Int] = Seq(
0x01, 0x02, 0x04, 0x08,
0x10, 0x20, 0x40, 0x80,
0x1b, 0x36, 0x00
)
val rcon = VecInit(rconSeq.map(_.U(8.W)))
val rcon = WireInit(VecInit(Seq("h01".U, "h02".U, "h04".U, "h08".U,
"h10".U, "h20".U, "h40".U, "h80".U,
"h1b".U, "h36".U, "h00".U)))
val ksSboxIn = Wire(Vec(4, UInt(8.W)))
val ksSboxTop = Reg(Vec(4, Vec(21, Bool())))
val ksSboxOut = Wire(Vec(4, UInt(8.W)))
ksSboxIn(0) := Mux(src2(3,0) === "ha".U, src1Bytes(4), src1Bytes(5))
ksSboxIn(1) := Mux(src2(3,0) === "ha".U, src1Bytes(5), src1Bytes(6))
ksSboxIn(2) := Mux(src2(3,0) === "ha".U, src1Bytes(6), src1Bytes(7))
ksSboxIn(3) := Mux(src2(3,0) === "ha".U, src1Bytes(7), src1Bytes(4))
ksSboxOut.zip(ksSboxTop).zip(ksSboxIn).foreach{ case ((out, top), in) =>
when (io.regEnable) {
top := SboxAesTop(in)
}
out := SboxAesOut(SboxInv(top))
}
val ks1Idx = RegEnable(src2(3,0), io.regEnable)
val ks1Idx = src2_reg(3, 0)
val aes64ks1i = Cat(ksSboxOut.asUInt ^ rcon(ks1Idx), ksSboxOut.asUInt ^ rcon(ks1Idx))
val aes64ks2Temp = src1(63,32) ^ src2(31,0)
val aes64ks2 = RegEnable(Cat(aes64ks2Temp ^ src2(63,32), aes64ks2Temp), io.regEnable)
val aes64ks2Temp = src1_reg(63, 32) ^ src2_reg(31, 0)
val aes64ks2 = Cat(aes64ks2Temp ^ src2_reg(63, 32), aes64ks2Temp)
val aesResult = LookupTreeDefault(funcReg, aes64es, List(
BKUOpType.aes64es -> aes64es,
@ -269,26 +292,35 @@ class BlockCipherModule(implicit p: Parameters) extends XSModule {
))
// SM4
val sm4SboxIn = src2Bytes(func(1,0))
val sm4SboxTop = Reg(Vec(21, Bool()))
val sm4SboxIn = src2Bytes(func(1, 0))
val sm4SboxTop = Reg(UInt(21.W))
when (io.regEnable) {
sm4SboxTop := SboxSm4Top(sm4SboxIn)
}
val sm4SboxOut = SboxSm4Out(SboxInv(sm4SboxTop))
// val sm4SboxTop = Reg(Vec(21, Bool()))
// when (io.regEnable) {
// sm4SboxTop := SboxSm4Top(sm4SboxIn)
// }
// val sm4SboxOut = SboxSm4Out(SboxInv(sm4SboxTop))
val sm4ed = sm4SboxOut ^ (sm4SboxOut<<8) ^ (sm4SboxOut<<2) ^ (sm4SboxOut<<18) ^ ((sm4SboxOut&"h3f".U)<<26) ^ ((sm4SboxOut&"hc0".U)<<10)
val sm4ks = sm4SboxOut ^ ((sm4SboxOut&"h07".U)<<29) ^ ((sm4SboxOut&"hfe".U)<<7) ^ ((sm4SboxOut&"h01".U)<<23) ^ ((sm4SboxOut&"hf8".U)<<13)
val sm4ed = sm4SboxOut ^ (sm4SboxOut << 8).asUInt ^ (sm4SboxOut << 2).asUInt ^ (sm4SboxOut << 18).asUInt ^
((sm4SboxOut & 0x3f.U) << 26).asUInt ^ ((sm4SboxOut & 0xc0.U) << 10).asUInt
val sm4ks = sm4SboxOut ^ ((sm4SboxOut & 0x07.U) << 29).asUInt ^ ((sm4SboxOut & 0xfe.U) << 7).asUInt ^
((sm4SboxOut & 0x01.U) << 23).asUInt ^ ((sm4SboxOut & 0xf8.U) << 13).asUInt
val sm4Source = VecInit(Seq(
sm4ed(31,0),
Cat(sm4ed(23,0), sm4ed(31,24)),
Cat(sm4ed(15,0), sm4ed(31,16)),
Cat(sm4ed( 7,0), sm4ed(31,8)),
sm4ks(31,0),
Cat(sm4ks(23,0), sm4ks(31,24)),
Cat(sm4ks(15,0), sm4ks(31,16)),
Cat(sm4ks( 7,0), sm4ks(31,8))
sm4ed(31, 0),
Cat(sm4ed(23, 0), sm4ed(31, 24)),
Cat(sm4ed(15, 0), sm4ed(31, 16)),
Cat(sm4ed( 7, 0), sm4ed(31, 8)),
sm4ks(31, 0),
Cat(sm4ks(23, 0), sm4ks(31, 24)),
Cat(sm4ks(15, 0), sm4ks(31, 16)),
Cat(sm4ks( 7, 0), sm4ks(31, 8))
))
val sm4Result = SignExt((sm4Source(funcReg(2,0)) ^ RegEnable(src1(31,0), io.regEnable))(31,0), XLEN)
val sm4ResultTemp = Wire(UInt(32.W))
sm4ResultTemp := sm4Source(funcReg(2, 0)) ^ src1_reg(31, 0)
val sm4Result = SignExt(sm4ResultTemp, XLEN)
io.out := Mux(funcReg(3), sm4Result, aesResult)
}

View File

@ -61,25 +61,27 @@ object ROR64 {
// AES forward shift rows
object ForwardShiftRows {
def apply(src1: Seq[UInt], src2: Seq[UInt]): Seq[UInt] = {
VecInit(Seq(src1(0), src1(5), src2(2), src2(7),
src1(4), src2(1), src2(6), src1(3)))
def apply(src1: Vec[UInt], src2: Vec[UInt]): Vec[UInt] = {
VecInit(src1(0), src1(5), src2(2), src2(7),
src1(4), src2(1), src2(6), src1(3))
}
}
// AES inverse shift rows
object InverseShiftRows {
def apply(src1: Seq[UInt], src2: Seq[UInt]): Seq[UInt] = {
VecInit(Seq(src1(0), src2(5), src2(2), src1(7),
src1(4), src1(1), src2(6), src2(3)))
def apply(src1: Vec[UInt], src2: Vec[UInt]): Vec[UInt] = {
VecInit(src1(0), src2(5), src2(2), src1(7),
src1(4), src1(1), src2(6), src2(3))
}
}
// AES encode sbox top
object SboxAesTop {
def apply(i: UInt): Seq[Bool] = {
def apply(i: UInt): UInt = {
require(i.getWidth == 8)
val t = Wire(Vec(6, Bool()))
val o = Wire(Vec(21, Bool()))
t( 0) := i( 3) ^ i( 1)
t( 1) := i( 6) ^ i( 5)
t( 2) := i( 6) ^ i( 2)
@ -107,15 +109,17 @@ object SboxAesTop {
o(18) := o( 2) ^ o( 8)
o(19) := o(15) ^ o(13)
o(20) := o( 1) ^ t( 3)
o
o.asUInt
}
}
// AES decode sbox top
object SboxIaesTop {
def apply(i: UInt): Seq[Bool] = {
def apply(i: UInt): UInt = {
require(i.getWidth == 8)
val t = Wire(Vec(5, Bool()))
val o = Wire(Vec(21, Bool()))
t( 0) := i( 1) ^ i( 0)
t( 1) := i( 6) ^ i( 1)
t( 2) := i( 5) ^ ~i( 2)
@ -142,15 +146,17 @@ object SboxIaesTop {
o(18) := i( 3) ^ ~i( 0)
o(19) := i( 5) ^ ~o( 1)
o(20) := o( 1) ^ t( 3)
o
o.asUInt
}
}
// SM4 encode/decode sbox top
object SboxSm4Top {
def apply(i: UInt): Seq[Bool] = {
def apply(i: UInt): UInt = {
require(i.getWidth == 8)
val t = Wire(Vec(7, Bool()))
val o = Wire(Vec(21, Bool()))
t( 0) := i(3) ^ i( 4)
t( 1) := i(2) ^ i( 7)
t( 2) := i(7) ^ o(18)
@ -179,15 +185,17 @@ object SboxSm4Top {
o(18) := i(2) ^ i( 6)
o(19) := i(5) ^ ~o(14)
o(20) := i(0) ^ t( 1)
o
o.asUInt
}
}
// Sbox middle part for AES, AES^-1, SM4
object SboxInv {
def apply(i: Seq[Bool]): Seq[Bool] = {
def apply(i: UInt): UInt = {
require(i.getWidth == 21)
val t = Wire(Vec(46, Bool()))
val o = Wire(Vec(18, Bool()))
t( 0) := i( 3) ^ i(12)
t( 1) := i( 9) & i( 5)
t( 2) := i(17) & i( 6)
@ -252,15 +260,17 @@ object SboxInv {
o(15) := t(40) & i( 6)
o(16) := t(39) & i( 0)
o(17) := t(43) & i(12)
o
o.asUInt
}
}
// AES encode sbox out
object SboxAesOut {
def apply(i: Seq[Bool]): UInt = {
def apply(i: UInt): UInt = {
require(i.getWidth == 18)
val t = Wire(Vec(30, Bool()))
val o = Wire(Vec(8, Bool()))
t( 0) := i(11) ^ i(12)
t( 1) := i( 0) ^ i( 6)
t( 2) := i(14) ^ i(16)
@ -305,9 +315,11 @@ object SboxAesOut {
// AES decode sbox out
object SboxIaesOut {
def apply(i: Seq[Bool]): UInt = {
def apply(i: UInt): UInt = {
require(i.getWidth == 18)
val t = Wire(Vec(30, Bool()))
val o = Wire(Vec(8, Bool()))
t( 0) := i( 2) ^ i(11)
t( 1) := i( 8) ^ i( 9)
t( 2) := i( 4) ^ i(12)
@ -352,9 +364,11 @@ object SboxIaesOut {
// SM4 encode/decode sbox out
object SboxSm4Out {
def apply(i: Seq[Bool]): UInt = {
def apply(i: UInt): UInt = {
require(i.getWidth == 18)
val t = Wire(Vec(30, Bool()))
val o = Wire(Vec(8, Bool()))
t( 0) := i( 4) ^ i( 7)
t( 1) := i(13) ^ i(15)
t( 2) := i( 2) ^ i(16)
@ -417,7 +431,7 @@ object SboxSm4 {
// Mix Column
object XtN {
def Xt2(byte: UInt): UInt = ((byte << 1) ^ Mux(byte(7), "h1b".U, 0.U))(7,0)
def Xt2(byte: UInt): UInt = ((byte << 1).asUInt ^ Mux(byte(7), "h1b".U, 0.U))(7,0)
def apply(byte: UInt, t: UInt): UInt = {
val byte1 = Xt2(byte)