1
1
import { Range } from 'immutable'
2
+ import type { Server } from 'node:http'
2
3
3
4
import type { TrainerLog , data , Task } from '@epfml/discojs-core'
4
- import { Disco , TrainingSchemes } from '@epfml/discojs-core'
5
+ import { Disco , TrainingSchemes , aggregator as aggregators , client as clients } from '@epfml/discojs-core'
6
+ import { getClient , startServer } from '@epfml/disco-server'
5
7
6
- import { startServer , saveLog } from './utils'
8
+ import { saveLog } from './utils'
7
9
import { getTaskData } from './data'
8
10
import { args } from './args'
9
11
@@ -15,23 +17,25 @@ console.log(infoText)
15
17
16
18
console . log ( { args } )
17
19
18
- async function runUser ( task : Task , url : URL , data : data . DataSplit ) : Promise < TrainerLog > {
20
+ async function runUser ( task : Task , server : Server , data : data . DataSplit ) : Promise < TrainerLog > {
21
+ const client = await getClient ( clients . federated . FederatedClient , server , task , new aggregators . MeanAggregator ( TASK ) )
22
+
19
23
// force the federated scheme
20
24
const scheme = TrainingSchemes . FEDERATED
21
- const disco = new Disco ( task , { scheme, url } )
25
+ const disco = new Disco ( task , { scheme, client } )
22
26
23
27
await disco . fit ( data )
24
28
await disco . close ( )
25
29
return await disco . logs ( )
26
30
}
27
31
28
32
async function main ( ) : Promise < void > {
29
- const [ server , serverUrl ] = await startServer ( )
33
+ const server = await startServer ( )
30
34
31
35
const data = await getTaskData ( TASK )
32
36
33
37
const logs = await Promise . all (
34
- Range ( 0 , NUMBER_OF_USERS ) . map ( async ( _ ) => await runUser ( TASK , serverUrl , data ) ) . toArray ( )
38
+ Range ( 0 , NUMBER_OF_USERS ) . map ( async ( _ ) => await runUser ( TASK , server , data ) ) . toArray ( )
35
39
)
36
40
37
41
if ( args . save ) {
0 commit comments