@@ -2,21 +2,21 @@ import fs from 'node:fs/promises'
2
2
import path from 'node:path'
3
3
import type { Server } from 'node:http'
4
4
import { Range } from 'immutable'
5
- import { assert } from 'chai'
5
+ import { assert , expect } from 'chai'
6
6
7
7
import type { WeightsContainer } from '@epfml/discojs-core'
8
8
import {
9
- Disco , TrainingSchemes , client as clients ,
9
+ Disco , TrainingSchemes , client as clients , data ,
10
10
aggregator as aggregators , informant , defaultTasks
11
11
} from '@epfml/discojs-core'
12
- import { NodeImageLoader , NodeTabularLoader } from '@epfml/discojs-node'
12
+ import { NodeImageLoader , NodeTabularLoader , NodeTextLoader } from '@epfml/discojs-node'
13
13
14
14
import { startServer } from '../../src'
15
15
16
16
const SCHEME = TrainingSchemes . FEDERATED
17
17
18
18
describe ( 'end-to-end federated' , function ( ) {
19
- this . timeout ( 120_000 )
19
+ this . timeout ( 100_000 )
20
20
21
21
let server : Server
22
22
let url : URL
@@ -81,13 +81,42 @@ describe('end-to-end federated', function () {
81
81
return aggregator . model . weights
82
82
}
83
83
84
+ async function wikitextUser ( ) : Promise < void > {
85
+ const task = defaultTasks . wikitext . getTask ( )
86
+ const loader = new NodeTextLoader ( task )
87
+ const dataSplit : data . DataSplit = {
88
+ train : await data . TextData . init ( ( await loader . load ( '../datasets/wikitext/wiki.train.tokens' ) ) , task ) ,
89
+ validation : await data . TextData . init ( await loader . load ( '../datasets/wikitext/wiki.valid.tokens' ) , task )
90
+ }
91
+
92
+ const aggregator = new aggregators . MeanAggregator ( )
93
+ const client = new clients . federated . FederatedClient ( url , task , aggregator )
94
+ const trainingInformant = new informant . FederatedInformant ( task , 10 )
95
+ const disco = new Disco ( task , { scheme : SCHEME , client, aggregator, informant : trainingInformant } )
96
+
97
+ await disco . fit ( dataSplit )
98
+ await disco . close ( )
99
+
100
+ expect ( trainingInformant . losses . first ( ) ) . to . be . above ( trainingInformant . losses . last ( ) )
101
+ }
102
+
84
103
it ( 'two cifar10 users reach consensus' , async ( ) => {
104
+ this . timeout ( 90_000 )
105
+
85
106
const [ m1 , m2 ] = await Promise . all ( [ cifar10user ( ) , cifar10user ( ) ] )
86
107
assert . isTrue ( m1 . equals ( m2 ) )
87
108
} )
88
109
89
110
it ( 'two titanic users reach consensus' , async ( ) => {
111
+ this . timeout ( 30_000 )
112
+
90
113
const [ m1 , m2 ] = await Promise . all ( [ titanicUser ( ) , titanicUser ( ) ] )
91
114
assert . isTrue ( m1 . equals ( m2 ) )
92
115
} )
116
+
117
+ it ( 'trains wikitext' , async ( ) => {
118
+ this . timeout ( 120_000 )
119
+
120
+ await wikitextUser ( )
121
+ } )
93
122
} )
0 commit comments