Skip to content

Commit

Permalink
Batch: use struct memcpy to replace member var assign
Browse files Browse the repository at this point in the history
For possible misalignment of struct member, we add packed attribute
to struct declare so all member will be aligned by byte, which helps
use memcpy for DPIC.
  • Loading branch information
klin02 committed Feb 6, 2024
1 parent 22e224e commit cc45380
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 76 deletions.
14 changes: 6 additions & 8 deletions src/main/scala/Batch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class BatchOutput(dataWidth: Int, infoWidth: Int, config: GatewayConfig) extends

class BatchInfo extends Bundle {
val id = UInt(8.W)
val len = UInt(16.W)
}

object Batch {
Expand All @@ -58,12 +57,12 @@ class BatchEndpoint(template: Seq[DifftestBundle], bundles: Seq[DifftestBundle],

def bundleAlign(bundle: DifftestBundle): UInt = {
def byteAlign(data: Data): UInt = {
val width: Int = data.getWidth + (8 - data.getWidth % 8) % 8
val width: Int = (data.getWidth + 7) / 8 * 8
data.asTypeOf(UInt(width.W))
}
val element = ListBuffer.empty[UInt]
bundle.elements.toSeq.reverse.foreach { case (name, data) =>
if (name != "valid") {
if (!(bundle.isFlatten && name == "valid")) {
data match {
case vec: Vec[_] => element ++= vec.map(byteAlign(_))
case data: Data => element += byteAlign(data)
Expand Down Expand Up @@ -92,12 +91,12 @@ class BatchEndpoint(template: Seq[DifftestBundle], bundles: Seq[DifftestBundle],
}.toSeq)

// Maxixum 4000 byte packed. Now we set maxium of data byte as 3000, info as 900
val MaxDataByteLen = 3600
val MaxDataByteLen = 3900
val MaxDataByteWidth = log2Ceil(MaxDataByteLen)
val MaxDataBitLen = MaxDataByteLen * 8
val infoWidth = (new BatchInfo).getWidth
// Append BatchInterval and BatchFinish Info
val MaxInfoByteLen = math.min((config.batchSize * (bundleNum + 1) + 1) * (infoWidth / 8), 300)
val MaxInfoByteLen = math.min((config.batchSize * (bundleNum + 1) + 1) * (infoWidth / 8), 90)
val MaxInfoByteWidth = log2Ceil(MaxInfoByteLen)
val MaxInfoBitLen = MaxInfoByteLen * 8

Expand All @@ -110,7 +109,6 @@ class BatchEndpoint(template: Seq[DifftestBundle], bundles: Seq[DifftestBundle],
val data_len = (aligned_data(idx).getWidth / 8).U
val info = Wire(new BatchInfo)
info.id := getBundleID(in(idx).desiredModuleName).U
info.len := data_len
if (idx == 0) {
data_vec(idx) := Mux(delayed_valid(idx), delayed_data(idx), 0.U)
info_vec(idx) := Mux(delayed_valid(idx), info.asUInt, 0.U)
Expand All @@ -124,8 +122,8 @@ class BatchEndpoint(template: Seq[DifftestBundle], bundles: Seq[DifftestBundle],
}
}

val BatchInterval = WireInit(0.U.asTypeOf(new BatchInfo))
val BatchFinish = WireInit(0.U.asTypeOf(new BatchInfo))
val BatchInterval = Wire(new BatchInfo)
val BatchFinish = Wire(new BatchInfo)
BatchInterval.id := template.length.U
BatchFinish.id := (template.length + 1).U
val step_data = WireInit(data_vec(bundleNum - 1))
Expand Down
107 changes: 42 additions & 65 deletions src/main/scala/DPIC.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,57 +178,45 @@ class DPICBatch[T <: Seq[DifftestBundle]](template: T, bundle: GatewayBatchBundl

def getDPICBundleUnpack(gen: DifftestBundle): String = {
val unpack = ListBuffer.empty[String]
def byteCnt(data: Data): Int = (data.getWidth + (8 - data.getWidth % 8) % 8) / 8
case class ArgPair(name: String, width: Int, offset: Int, isVec: Boolean)
val argsWithWidthOffset: Seq[ArgPair] = {
def byteCnt(data: Data): Int = (data.getWidth + 7) / 8
case class ArgPair(name: String, width: Int, offset: Int)
def getBundleArgs(gen: DifftestBundle): Seq[ArgPair] = {
val list = ListBuffer.empty[ArgPair]
var offset: Int = 0
for ((name, data) <- gen.elements.toSeq.reverse) {
if (name != "valid") {
if (!(gen.isFlatten && name == "valid")) {
data match {
case vec: Vec[_] => {
for ((v, i) <- vec.zipWithIndex) {
list += ArgPair(s"${name}_$i", byteCnt(v), offset, true)
list += ArgPair(s"${name}_$i", byteCnt(v), offset)
offset += byteCnt(v)
}
}
case d: Data => {
list += ArgPair(name, byteCnt(d), offset, false)
list += ArgPair(name, byteCnt(d), offset)
offset += byteCnt(d)
}
}
}
}
list.toSeq
}
def varAssign(pair: ArgPair, prefix: String): String = {
val rhs = (0 until pair.width).map { i =>
val appendOffset = if (i != 0) s" << ${8 * i}" else ""
s"(uint${pair.width * 8}_t)data[offset + ${pair.offset + i}]$appendOffset"
}.mkString(" |\n ")
val lhs = (pair.name, pair.isVec) match {
case (x, true) => x.slice(0, x.lastIndexOf('_')) + s"[${x.split('_').last}]"
case (x, false) => x
}
s"$prefix$lhs = $rhs;"
}
val filterIn = ListBuffer("coreid")
val index = if (gen.isIndexed) {
filterIn += "index"
"[index]"
} else if (gen.isFlatten) {
filterIn += "address"
"[address]"
} else {
""
}
unpack ++= argsWithWidthOffset.filter(p => filterIn.contains(p.name)).map(p => varAssign(p, ""))

// Note: filterArgs will not in struct defined, but at the beginning or the end of Bundle
val bundleArgs = getBundleArgs(gen)
val filterArgs = Seq("coreid", "index", "address")
unpack ++= bundleArgs.filter(p => filterArgs.contains(p.name)).map(p => s"${p.name} = data[${p.offset}];")
val dut_zone = if (config.hasDutZone) "io_dut_zone" else "0"
val packet = s"DUT_BUF(coreid, $dut_zone, dut_index)->${gen.desiredCppName}"
val index = if (gen.isIndexed) "[index]" else if (gen.isFlatten) "[address]" else ""
unpack += s"auto packet = &($packet$index);"
if (gen.bits.hasValid && !gen.isFlatten) unpack += "packet->valid = true;"
val filterOut = Seq("coreid", "index", "address", "valid")
unpack ++= argsWithWidthOffset.filterNot(p => filterOut.contains(p.name)).map(p => varAssign(p, "packet->"))
val packedArgs = bundleArgs.filterNot(p => filterArgs.contains(p.name))
val ptrOffset: String = packedArgs.head.offset match {
case 0 => ""
case n => s" + $n"
}
unpack += s"memcpy(packet, data$ptrOffset, sizeof(${gen.desiredModuleName}));"
unpack += s"data += ${bundleArgs.map(_.width).sum};"
unpack.toSeq.mkString("\n ")
}

Expand All @@ -244,52 +232,41 @@ class DPICBatch[T <: Seq[DifftestBundle]](template: T, bundle: GatewayBatchBundl
| ${dpicFuncArgs.map(arg => getDPICArgString(arg._1, arg._2, true)).mkString(",\n ")}
|)""".stripMargin
val dpicFunc: String = {
val byteUnpack = dpicFuncArgs
.filter(_._2.getWidth > 8)
.map { case (name, data) =>
val array = name.replace("io_", "")
s"""
| static uint8_t $array[${data.getWidth / 8}];
| for (int i = 0; i < ${data.getWidth / 32}; i++) {
| $array[i * 4] = (uint8_t)($name[i] & 0xFF);
| $array[i * 4 + 1] = (uint8_t)(($name[i] >> 8) & 0xFF);
| $array[i * 4 + 2] = (uint8_t)(($name[i] >> 16) & 0xFF);
| $array[i * 4 + 3] = (uint8_t)(($name[i] >> 24) & 0xFF);
| }
|""".stripMargin
}
.mkString("")
val bundleEnum = template.map(_.desiredModuleName.replace("Difftest", "")) ++ Seq("BatchInterval", "BatchFinish")
val bundleAssign = template.zipWithIndex.map { case (t, idx) =>
s"""
| else if (id == ${bundleEnum(idx)}) {
| ${getDPICBundleUnpack(t)}
| }
val bundleAssign =
(Seq(s"""
| if (id == BatchFinish) {
| break;
| }
| else if (id == BatchInterval && i != 0) {
| dut_index ++;
| continue;
| }
|""".stripMargin)
++ template.zipWithIndex.map { case (t, idx) =>
s"""
| else if (id == ${bundleEnum(idx)}) {
| ${getDPICBundleUnpack(t)}
| }
""".stripMargin
}.mkString("")
}).mkString("")

val infoLen = io.info.getWidth / 8
s"""
|enum DifftestBundle {
|enum DifftestBundleType {
| ${bundleEnum.mkString(",\n ")}
|};
|$dpicFuncProto {
| if (!diffstate_buffer) return;
| uint64_t offset = 0;
| uint32_t dut_index = 0;
|$byteUnpack
| for (int i = 0; i < $infoLen; i+=3) {
| uint8_t id = (uint8_t)info[i+2];
| uint16_t len = (uint16_t)info[i+1] << 8 | (uint16_t)info[i];
| static uint8_t info[$infoLen];
| memcpy(info, io_info, $infoLen * sizeof(uint8_t));
| uint8_t* data = (uint8_t*)io_data;
| for (int i = 0; i < $infoLen; i++) {
| uint8_t id = info[i];
| uint32_t coreid, index, address;
| if (id == BatchFinish) {
| break;
| }
| else if (id == BatchInterval && i != 0) {
| dut_index ++;
| continue;
| }
| $bundleAssign
| offset += len;
| }
|}
|""".stripMargin
Expand Down
9 changes: 6 additions & 3 deletions src/main/scala/Difftest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ sealed trait DifftestBundle extends Bundle with DifftestWithCoreid { this: Difft
val macroName = s"CONFIG_DIFFTEST_${desiredModuleName.toUpperCase.replace("DIFFTEST", "")}"
s"#define $macroName"
}
def toCppDeclaration: String = {
def toCppDeclaration(packed: Boolean): String = {
val cpp = ListBuffer.empty[String]
cpp += "typedef struct {"
val attribute = if (packed) "__attribute__((packed))" else ""
cpp += s"typedef struct $attribute {"
for (((name, elem), size) <- diffElements.zip(diffSizes(8))) {
val isRemoved = isFlatten && Seq("valid", "address").contains(name)
if (!isRemoved) {
Expand Down Expand Up @@ -261,6 +262,7 @@ object DifftestModule {
private val instances = ListBuffer.empty[(DifftestBundle, String)]
private val cppMacros = ListBuffer.empty[String]
private val vMacros = ListBuffer.empty[String]
private var structPacked = false

def apply[T <: DifftestBundle](
gen: T,
Expand Down Expand Up @@ -293,6 +295,7 @@ object DifftestModule {
cppMacros ++= gateway.cppMacros
vMacros ++= gateway.vMacros
instances ++= gateway.instances
structPacked = gateway.structPacked.getOrElse(false)

if (cppHeader.isDefined) {
generateCppHeader(cpu, cppHeader.get)
Expand Down Expand Up @@ -365,7 +368,7 @@ object DifftestModule {
val configWidthName = s"CONFIG_DIFF_${macroName}_WIDTH"
difftestCpp += s"#define $configWidthName ${bundleType.bits.getNumElements}"
}
difftestCpp += bundleType.toCppDeclaration
difftestCpp += bundleType.toCppDeclaration(structPacked)
difftestCpp += ""
})

Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/Gateway.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ case class GatewayResult(
cppMacros: Seq[String] = Seq(),
vMacros: Seq[String] = Seq(),
instances: Seq[(DifftestBundle, String)] = Seq(),
structPacked: Option[Boolean] = None,
step: Option[UInt] = None,
) {
def +(that: GatewayResult): GatewayResult = {
GatewayResult(
cppMacros = cppMacros ++ that.cppMacros,
vMacros = vMacros ++ that.vMacros,
instances = instances ++ that.instances,
structPacked = if (structPacked.isDefined) structPacked else that.structPacked,
step = if (step.isDefined) step else that.step,
)
}
Expand Down Expand Up @@ -114,6 +116,7 @@ object Gateway {
val endpoint = Module(new GatewayEndpoint(instances.toSeq, config))
GatewayResult(
instances = endpoint.instances,
structPacked = Some(config.isBatch),
step = Some(endpoint.step),
)
} else {
Expand Down

0 comments on commit cc45380

Please sign in to comment.