Skip to content

Commit c3012ef

Browse files
committed
*: make Disco emit nb of participants
Include nb of participants in more server-client messages
1 parent e7aeeea commit c3012ef

17 files changed

+218
-174
lines changed

discojs/src/client/client.ts

+28-9
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ const debug = createDebug("discojs:client");
2020
* Main, abstract, class representing a Disco client in a network, which handles
2121
* communication with other nodes, be it peers or a server.
2222
*/
23-
export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
23+
export abstract class Client extends EventEmitter<{
24+
'status': RoundStatus,
25+
'participants': number
26+
}>{
2427
// Own ID provided by the network's server.
2528
protected _ownId?: NodeID
2629
// The network's server.
@@ -40,7 +43,10 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
4043
* we were doing before waiting (training locally or updating our model).
4144
* We use this attribute to store the status to rollback to when we stop waiting
4245
*/
43-
private previousStatus: RoundStatus | undefined;
46+
#previousStatus: RoundStatus | undefined;
47+
48+
// Current number of participants including this client in the training session
49+
#nbOfParticipants: number = 1;
4450

4551
constructor (
4652
public readonly url: URL, // The network server's URL to connect to
@@ -82,7 +88,7 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
8288
* the waiting status and once enough participants join, it can display the previous status again
8389
*/
8490
protected saveAndEmit(status: RoundStatus) {
85-
this.previousStatus = status
91+
this.#previousStatus = status
8692
this.emit("status", status)
8793
}
8894

@@ -111,12 +117,13 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
111117
protected setupServerCallbacks(setMessageInversionFlag: () => void) {
112118
// Setup an event callback if the server signals that we should
113119
// wait for more participants
114-
this.server.on(type.WaitingForMoreParticipants, () => {
120+
this.server.on(type.WaitingForMoreParticipants, (event) => {
115121
if (this.promiseForMoreParticipants !== undefined)
116122
throw new Error("Server sent multiple WaitingForMoreParticipants messages")
117123
debug(`[${shortenId(this.ownId)}] received WaitingForMoreParticipants message from server`)
118124
// Display the waiting status right away
119125
this.emit("status", "not enough participants")
126+
this.nbOfParticipants = event.nbOfParticipants // emits the `participants` event
120127
// Upon receiving a WaitingForMoreParticipants message,
121128
// the client will await for this promise to resolve before sending its
122129
// local weight update
@@ -129,10 +136,10 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
129136
// and directly follows with an EnoughParticipants message when the 2nd participant joins
130137
// However, the EnoughParticipants can arrive before the NewNodeInfo (which can be much bigger)
131138
// so we check whether we received the EnoughParticipants before being assigned a node ID
132-
this.server.once(type.EnoughParticipants, () => {
139+
this.server.once(type.EnoughParticipants, (event) => {
133140
if (this._ownId === undefined) {
134-
debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`)
135141
setMessageInversionFlag()
142+
this.nbOfParticipants = event.nbOfParticipants
136143
}
137144
})
138145
}
@@ -146,10 +153,11 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
146153
protected async createPromiseForMoreParticipants(): Promise<void> {
147154
return new Promise<void>((resolve) => {
148155
// "once" is important because we can't resolve the same promise multiple times
149-
this.server.once(type.EnoughParticipants, () => {
156+
this.server.once(type.EnoughParticipants, (event) => {
150157
debug(`[${shortenId(this.ownId)}] received EnoughParticipants message from server`)
151158
// Emit the last status emitted before waiting if defined
152-
if (this.previousStatus !== undefined) this.emit("status", this.previousStatus)
159+
if (this.#previousStatus !== undefined) this.emit("status", this.#previousStatus)
160+
this.nbOfParticipants = event.nbOfParticipants
153161
resolve()
154162
})
155163
})
@@ -190,7 +198,18 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
190198
* If federated, it should the number of participants excluding the server
191199
* If local it should be 1
192200
*/
193-
abstract getNbOfParticipants(): number;
201+
public get nbOfParticipants(): number {
202+
return this.#nbOfParticipants
203+
}
204+
205+
/**
206+
* Setter for the number of participants
207+
* It emits the number of participants to the client
208+
*/
209+
public set nbOfParticipants(nbOfParticipants: number) {
210+
this.#nbOfParticipants = nbOfParticipants
211+
this.emit("participants", nbOfParticipants)
212+
}
194213

195214
get ownId(): NodeID {
196215
if (this._ownId === undefined) {

discojs/src/client/decentralized/decentralized_client.ts

+14-9
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@ export class DecentralizedClient extends Client {
2626
#pool?: PeerPool
2727
#connections?: Map<NodeID, PeerConnection>
2828

29-
override getNbOfParticipants(): number {
30-
const nbOfParticipants = this.aggregator.nodes.size
31-
return nbOfParticipants === 0 ? 1 : nbOfParticipants
32-
}
33-
3429
// Used to handle timeouts and promise resolving after calling disconnect
3530
private get isDisconnected() : boolean {
3631
return this._server === undefined
3732
}
33+
34+
private setAggregatorNodes(nodes: Set<NodeID>) {
35+
this.aggregator.setNodes(nodes)
36+
// Emits the `participants` event
37+
this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size
38+
}
3839

3940
/**
4041
* Public method called by disco.ts when starting training. This method sends
@@ -77,7 +78,11 @@ export class DecentralizedClient extends Client {
7778
}
7879
this.server.send(msg)
7980

80-
const { id, waitForMoreParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo)
81+
const { id, waitForMoreParticipants,
82+
nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo)
83+
84+
this.nbOfParticipants = nbOfParticipants
85+
8186

8287
// This should come right after receiving the message to make sure
8388
// we don't miss a subsequent message from the server
@@ -107,7 +112,7 @@ export class DecentralizedClient extends Client {
107112

108113
if (this.#connections !== undefined) {
109114
const peers = this.#connections.keySeq().toSet()
110-
this.aggregator.setNodes(this.aggregator.nodes.subtract(peers))
115+
this.setAggregatorNodes(this.aggregator.nodes.subtract(peers))
111116
}
112117
// Disconnect from server
113118
await this.server?.disconnect()
@@ -180,7 +185,7 @@ export class DecentralizedClient extends Client {
180185
throw new Error('received peer list contains our own id')
181186
}
182187
// Store the list of peers for the current round including ourselves
183-
this.aggregator.setNodes(peers.add(this.ownId))
188+
this.setAggregatorNodes(peers.add(this.ownId))
184189
this.aggregator.setRound(receivedMessage.aggregationRound) // the server gives us the round number
185190

186191
// Initiate peer to peer connections with each peer
@@ -197,7 +202,7 @@ export class DecentralizedClient extends Client {
197202
this.#connections = connections
198203
} catch (e) {
199204
debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
200-
this.aggregator.setNodes(Set(this.ownId))
205+
this.setAggregatorNodes(Set(this.ownId))
201206
this.#connections = Map()
202207
}
203208
}

discojs/src/client/decentralized/messages.ts

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export interface NewDecentralizedNodeInfo {
1111
type: type.NewDecentralizedNodeInfo
1212
id: NodeID
1313
waitForMoreParticipants: boolean
14+
nbOfParticipants: number
1415
}
1516

1617
// WebRTC signal to forward to other node

discojs/src/client/federated/federated_client.ts

+2-11
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,6 @@ const SERVER_NODE_ID = "federated-server-node-id";
2424
* a specific task in the federated setting.
2525
*/
2626
export class FederatedClient extends Client {
27-
// Total number of other federated contributors, including this client, excluding the server
28-
// E.g., if 3 users are training a federated model, nbOfParticipants is 3
29-
#nbOfParticipants: number = 1;
30-
31-
// the number of participants excluding the server
32-
override getNbOfParticipants(): number {
33-
return this.#nbOfParticipants
34-
}
3527

3628
/**
3729
* Initializes the connection to the server, gets our node ID
@@ -92,7 +84,7 @@ export class FederatedClient extends Client {
9284
this._ownId = id;
9385
debug(`[${shortenId(id)}] joined session at round ${round} `);
9486
this.aggregator.setRound(round)
95-
this.#nbOfParticipants = nbOfParticipants
87+
this.nbOfParticipants = nbOfParticipants
9688
// Upon connecting, the server answers with a boolean
9789
// which indicates whether there are enough participants or not
9890
debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants)
@@ -161,8 +153,7 @@ export class FederatedClient extends Client {
161153
round: serverRound,
162154
nbOfParticipants
163155
} = await waitMessage( this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update
164-
165-
this.#nbOfParticipants = nbOfParticipants // Save the current participants
156+
this.nbOfParticipants = nbOfParticipants // Save the current participants
166157
const serverResult = serialization.weights.decode(payloadFromServer);
167158
this.aggregator.setRound(serverRound);
168159

discojs/src/client/local_client.ts

-4
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@ import { Client } from "./client.js";
77
*/
88
export class LocalClient extends Client {
99

10-
override getNbOfParticipants(): number {
11-
return 1;
12-
}
13-
1410
override onRoundBeginCommunication(): Promise<void> {
1511
return Promise.resolve();
1612
}

discojs/src/client/messages.ts

+2
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ export interface ClientConnected {
4747

4848
export interface EnoughParticipants {
4949
type: type.EnoughParticipants
50+
nbOfParticipants: number
5051
}
5152

5253
export interface WaitingForMoreParticipants {
5354
type: type.WaitingForMoreParticipants
55+
nbOfParticipants: number
5456
}
5557

5658
export type Message =

discojs/src/training/disco.ts

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ export type RoundStatus = 'not enough participants' | // Server notification to
4949
*/
5050
export class Disco<D extends DataType> extends EventEmitter<{
5151
status: RoundStatus;
52+
participants: number
5253
}> {
5354
public readonly trainer: Trainer<D>;
5455
readonly #client: clients.Client;
@@ -99,6 +100,7 @@ export class Disco<D extends DataType> extends EventEmitter<{
99100
this.trainer = new Trainer(task, client);
100101
// Simply propagate the training status events emitted by the client
101102
this.#client.on("status", (status) => this.emit("status", status));
103+
this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants));
102104
}
103105

104106
/** Train on dataset, yielding logs of every round. */

discojs/src/training/trainer.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ export class Trainer<D extends DataType> {
124124

125125
return {
126126
epochs: epochsLogs,
127-
participants: this.#client.getNbOfParticipants(),
127+
participants: this.#client.nbOfParticipants,
128128
};
129129
}
130130
}

server/src/controllers/decentralized_controller.ts

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export class DecentralizedController<
5050
const msg: messages.NewDecentralizedNodeInfo = {
5151
type: MessageTypes.NewDecentralizedNodeInfo,
5252
id: peerId,
53+
nbOfParticipants: this.connections.size,
5354
waitForMoreParticipants: this.connections.size < minNbOfParticipants
5455
}
5556
ws.send(msgpack.encode(msg), { binary: true })

server/src/controllers/training_controller.ts

+4-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ export abstract class TrainingController<D extends DataType> {
5959
.forEach((participantWs, participantId) => {
6060
debug("Sending enough-participant message to client [%s]", participantId.slice(0, 4))
6161
const msg: client.messages.EnoughParticipants = {
62-
type: client.messages.type.EnoughParticipants
62+
type: client.messages.type.EnoughParticipants,
63+
nbOfParticipants: this.connections.size
6364
}
6465
participantWs.send(msgpack.encode(msg))
6566
})
@@ -78,7 +79,8 @@ export abstract class TrainingController<D extends DataType> {
7879
.forEach((participantWs, participantId) => {
7980
debug("Telling remaining client [%s] to wait for participants", participantId.slice(0, 4))
8081
const msg: client.messages.WaitingForMoreParticipants = {
81-
type: client.messages.type.WaitingForMoreParticipants
82+
type: client.messages.type.WaitingForMoreParticipants,
83+
nbOfParticipants: this.connections.size
8284
}
8385
participantWs.send(msgpack.encode(msg))
8486
})

0 commit comments

Comments
 (0)