Skip to content

Commit

Permalink
Merge pull request #3714 from richardyrh/master
Browse files Browse the repository at this point in the history
Changes to the Cluster API used by the GPU Project
  • Loading branch information
jerryz123 authored Feb 9, 2025
2 parents c7aabd2 + 4687b06 commit 47d6527
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 35 deletions.
63 changes: 37 additions & 26 deletions src/main/scala/subsystem/Cluster.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,20 @@ import scala.collection.immutable.SortedMap

case class ClustersLocated(loc: HierarchicalLocation) extends Field[Seq[CanAttachCluster]](Nil)

trait BaseClusterParams extends HierarchicalElementParams {
val clusterId: Int
}

abstract class InstantiableClusterParams[ClusterType <: Cluster]
extends HierarchicalElementParams
with BaseClusterParams {
def instantiate(crossing: HierarchicalElementCrossingParamsLike, lookup: LookupByClusterIdImpl)(implicit p: Parameters): ClusterType
}

case class ClusterParams(
val clusterId: Int,
val clockSinkParams: ClockSinkParameters = ClockSinkParameters()
) extends HierarchicalElementParams {
) extends InstantiableClusterParams[Cluster] {
val baseName = "cluster"
val uniqueName = s"${baseName}_$clusterId"
def instantiate(crossing: HierarchicalElementCrossingParamsLike, lookup: LookupByClusterIdImpl)(implicit p: Parameters): Cluster = {
Expand All @@ -31,7 +41,7 @@ case class ClusterParams(
}

class Cluster(
val thisClusterParams: ClusterParams,
val thisClusterParams: BaseClusterParams,
crossing: ClockCrossingType,
lookup: LookupByClusterIdImpl)(implicit p: Parameters) extends BaseHierarchicalElement(crossing)(p)
with Attachable
Expand All @@ -55,8 +65,6 @@ class Cluster(
val slaveNode = ccbus.inwardNode
val masterNode = cmbus.outwardNode



lazy val ibus = LazyModule(new InterruptBusWrapper)
ibus.clockNode := csbus.fixedClockNode

Expand All @@ -66,7 +74,7 @@ class Cluster(
def toPlicDomain = this
lazy val msipNodes = totalTileIdList.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val meipNodes = totalTileIdList.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val seipNodes = totalTileIdList.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val seipNodes = totalTiles.filter(_._2.tileParams.core.useSupervisor).keys.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val tileToPlicNodes = totalTileIdList.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val debugNodes = totalTileIdList.map { i => (i, IntSyncIdentityNode()) }.to(SortedMap)
lazy val nmiNodes = totalTiles.filter { case (i,t) => t.tileParams.core.useNMI }
Expand All @@ -79,7 +87,7 @@ class Cluster(
// TODO fix: shouldn't need to connect dummy notifications
tileHaltXbarNode := NullIntSource()
tileWFIXbarNode := NullIntSource()
tileCeaseXbarNode := NullIntSource()
// tileCeaseXbarNode := NullIntSource()

override lazy val module = new ClusterModuleImp(this)
}
Expand All @@ -88,12 +96,12 @@ class ClusterModuleImp(outer: Cluster) extends BaseHierarchicalElementModuleImp[

case class InCluster(id: Int) extends HierarchicalLocation(s"Cluster$id")

class ClusterPRCIDomain(
abstract class ClusterPRCIDomain[ClusterType <: Cluster](
clockSinkParams: ClockSinkParameters,
crossingParams: HierarchicalElementCrossingParamsLike,
clusterParams: ClusterParams,
clusterParams: InstantiableClusterParams[ClusterType],
lookup: LookupByClusterIdImpl)
(implicit p: Parameters) extends HierarchicalElementPRCIDomain[Cluster](clockSinkParams, crossingParams)
(implicit p: Parameters) extends HierarchicalElementPRCIDomain[ClusterType](clockSinkParams, crossingParams)
{
val element = element_reset_domain {
LazyModule(clusterParams.instantiate(crossingParams, lookup))
Expand All @@ -104,19 +112,19 @@ class ClusterPRCIDomain(


trait CanAttachCluster {
type ClusterType <: Cluster
type ClusterContextType <: DefaultHierarchicalElementContextType

def clusterParams: ClusterParams
def clusterParams: InstantiableClusterParams[ClusterType]
def crossingParams: HierarchicalElementCrossingParamsLike

def instantiate(allClusterParams: Seq[ClusterParams], instantiatedClusters: SortedMap[Int, ClusterPRCIDomain])(implicit p: Parameters): ClusterPRCIDomain = {
def instantiate(allClusterParams: Seq[BaseClusterParams], instantiatedClusters: SortedMap[Int, ClusterPRCIDomain[_]])(implicit p: Parameters): ClusterPRCIDomain[ClusterType] = {
val clockSinkParams = clusterParams.clockSinkParams.copy(name = Some(clusterParams.uniqueName))
val cluster_prci_domain = LazyModule(new ClusterPRCIDomain(
clockSinkParams, crossingParams, clusterParams, PriorityMuxClusterIdFromSeq(allClusterParams)))
val cluster_prci_domain = LazyModule(new ClusterPRCIDomain[ClusterType](
clockSinkParams, crossingParams, clusterParams, PriorityMuxClusterIdFromSeq(allClusterParams)) {})
cluster_prci_domain
}

def connect(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connect(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
connectMasterPorts(domain, context)
connectSlavePorts(domain, context)
connectInterrupts(domain, context)
Expand All @@ -126,21 +134,21 @@ trait CanAttachCluster {
connectTrace(domain, context)
}

def connectMasterPorts(domain: ClusterPRCIDomain, context: Attachable): Unit = {
def connectMasterPorts(domain: ClusterPRCIDomain[ClusterType], context: Attachable): Unit = {
implicit val p = context.p
val dataBus = context.locateTLBusWrapper(crossingParams.master.where)
dataBus.coupleFrom(clusterParams.baseName) { bus =>
bus :=* crossingParams.master.injectNode(context) :=* domain.crossMasterPort(crossingParams.crossingType)
}
}
def connectSlavePorts(domain: ClusterPRCIDomain, context: Attachable): Unit = {
def connectSlavePorts(domain: ClusterPRCIDomain[ClusterType], context: Attachable): Unit = {
implicit val p = context.p
val controlBus = context.locateTLBusWrapper(crossingParams.slave.where)
controlBus.coupleTo(clusterParams.baseName) { bus =>
domain.crossSlavePort(crossingParams.crossingType) :*= crossingParams.slave.injectNode(context) :*= TLWidthWidget(controlBus.beatBytes) :*= bus
}
}
def connectInterrupts(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectInterrupts(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p

domain.element.debugNodes.foreach { case (hartid, node) =>
Expand Down Expand Up @@ -170,23 +178,23 @@ trait CanAttachCluster {
}
}

def connectPRC(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectPRC(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p
domain.element.allClockGroupsNode :*= context.allClockGroupsNode
domain {
domain.element_reset_domain.clockNode := crossingParams.resetCrossingType.injectClockNode := domain.clockNode
}
}

def connectOutputNotifications(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectOutputNotifications(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p
context.tileHaltXbarNode :=* domain.crossIntOut(NoCrossing, domain.element.tileHaltXbarNode)
context.tileWFIXbarNode :=* domain.crossIntOut(NoCrossing, domain.element.tileWFIXbarNode)
context.tileCeaseXbarNode :=* domain.crossIntOut(NoCrossing, domain.element.tileCeaseXbarNode)

}

def connectInputConstants(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectInputConstants(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p
val tlBusToGetPrefixFrom = context.locateTLBusWrapper(crossingParams.mmioBaseAddressPrefixWhere)
domain.element.tileHartIdNodes.foreach { case (hartid, node) =>
Expand All @@ -197,7 +205,7 @@ trait CanAttachCluster {
}
}

def connectTrace(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectTrace(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p
domain.element.traceNodes.foreach { case (hartid, node) =>
val traceNexusNode = BundleBridgeBlockDuringReset[TraceBundle](
Expand All @@ -212,23 +220,26 @@ trait CanAttachCluster {
}
}

case class ClusterAttachParams(
case class ClusterAttachParams (
clusterParams: ClusterParams,
crossingParams: HierarchicalElementCrossingParamsLike
) extends CanAttachCluster
) extends CanAttachCluster {
type ClusterType = Cluster
}

case class CloneClusterAttachParams(
sourceClusterId: Int,
cloneParams: CanAttachCluster
) extends CanAttachCluster {
type ClusterType = cloneParams.ClusterType
def clusterParams = cloneParams.clusterParams
def crossingParams = cloneParams.crossingParams

override def instantiate(allClusterParams: Seq[ClusterParams], instantiatedClusters: SortedMap[Int, ClusterPRCIDomain])(implicit p: Parameters): ClusterPRCIDomain = {
override def instantiate(allClusterParams: Seq[BaseClusterParams], instantiatedClusters: SortedMap[Int, ClusterPRCIDomain[_]])(implicit p: Parameters): ClusterPRCIDomain[ClusterType] = {
require(instantiatedClusters.contains(sourceClusterId))
val clockSinkParams = clusterParams.clockSinkParams.copy(name = Some(clusterParams.uniqueName))
val cluster_prci_domain = CloneLazyModule(
new ClusterPRCIDomain(clockSinkParams, crossingParams, clusterParams, PriorityMuxClusterIdFromSeq(allClusterParams)),
new ClusterPRCIDomain[ClusterType](clockSinkParams, crossingParams, clusterParams, PriorityMuxClusterIdFromSeq(allClusterParams)) {},
instantiatedClusters(sourceClusterId)
)
cluster_prci_domain
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/subsystem/HasHierarchicalElements.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,16 @@ trait InstantiatesHierarchicalElements { this: LazyModule with Attachable =>
}

val clusterAttachParams: Seq[CanAttachCluster] = p(ClustersLocated(location)).sortBy(_.clusterParams.clusterId)
val clusterParams: Seq[ClusterParams] = clusterAttachParams.map(_.clusterParams)
val clusterParams: Seq[BaseClusterParams] = clusterAttachParams.map(_.clusterParams)
val clusterCrossingTypes: Seq[ClockCrossingType] = clusterAttachParams.map(_.crossingParams.crossingType)
val cluster_prci_domains: SortedMap[Int, ClusterPRCIDomain] = clusterAttachParams.foldLeft(SortedMap[Int, ClusterPRCIDomain]()) {
val cluster_prci_domains: SortedMap[Int, ClusterPRCIDomain[_]] = clusterAttachParams.foldLeft(SortedMap[Int, ClusterPRCIDomain[_]]()) {
case (instantiated, params) => instantiated + (params.clusterParams.clusterId -> params.instantiate(clusterParams, instantiated)(p))
}

val element_prci_domains: Seq[HierarchicalElementPRCIDomain[_]] = tile_prci_domains.values.toSeq ++ cluster_prci_domains.values.toSeq

val leafTiles: SortedMap[Int, BaseTile] = SortedMap(tile_prci_domains.mapValues(_.element.asInstanceOf[BaseTile]).toSeq.sortBy(_._1):_*)
val totalTiles: SortedMap[Int, BaseTile] = (leafTiles ++ cluster_prci_domains.values.map(_.element.totalTiles).flatten)
val totalTiles: SortedMap[Int, BaseTile] = (leafTiles ++ cluster_prci_domains.values.map(_.element.asInstanceOf[Cluster].totalTiles).flatten)

// Helper functions for accessing certain parameters that are popular to refer to in subsystem code
def nLeafTiles: Int = leafTiles.size
Expand All @@ -123,7 +123,7 @@ trait HasHierarchicalElements extends DefaultHierarchicalElementContextType
params.connect(tile_prci_domains(params.tileParams.tileId).asInstanceOf[TilePRCIDomain[params.TileType]], this.asInstanceOf[params.TileContextType])
}
clusterAttachParams.foreach { params =>
params.connect(cluster_prci_domains(params.clusterParams.clusterId).asInstanceOf[ClusterPRCIDomain], this.asInstanceOf[params.ClusterContextType])
params.connect(cluster_prci_domains(params.clusterParams.clusterId).asInstanceOf[ClusterPRCIDomain[params.ClusterType]], this.asInstanceOf[params.ClusterContextType])
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/subsystem/LookupByClusterId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import chisel3._
import chisel3.util._

abstract class LookupByClusterIdImpl {
def apply[T <: Data](f: ClusterParams => Option[T], clusterId: UInt): T
def apply[T <: Data](f: BaseClusterParams => Option[T], clusterId: UInt): T
}

case class ClustersWontDeduplicate(t: ClusterParams) extends LookupByClusterIdImpl {
def apply[T <: Data](f: ClusterParams => Option[T], clusterId: UInt): T = f(t).get
case class ClustersWontDeduplicate(t: BaseClusterParams) extends LookupByClusterIdImpl {
def apply[T <: Data](f: BaseClusterParams => Option[T], clusterId: UInt): T = f(t).get
}

case class PriorityMuxClusterIdFromSeq(seq: Seq[ClusterParams]) extends LookupByClusterIdImpl {
def apply[T <: Data](f: ClusterParams => Option[T], clusterId: UInt): T =
case class PriorityMuxClusterIdFromSeq(seq: Seq[BaseClusterParams]) extends LookupByClusterIdImpl {
def apply[T <: Data](f: BaseClusterParams => Option[T], clusterId: UInt): T =
PriorityMux(seq.collect { case t if f(t).isDefined => (t.clusterId.U === clusterId) -> f(t).get })
}

0 comments on commit 47d6527

Please sign in to comment.