9
9
Disco , TrainingSchemes , client as clients ,
10
10
aggregator as aggregators , informant , defaultTasks
11
11
} from '@epfml/discojs-core'
12
- import { NodeImageLoader , NodeTabularLoader } from '@epfml/discojs-node'
12
+ import { NodeImageLoader , NodeTabularLoader , NodeTextLoader } from '@epfml/discojs-node'
13
13
14
14
import { startServer } from '../../src'
15
15
@@ -81,6 +81,35 @@ describe('end-to-end federated', function () {
81
81
return aggregator . model . weights
82
82
}
83
83
84
+ async function wikitextUser ( ) : Promise < WeightsContainer > {
85
+ const task = defaultTasks . wikitext . getTask ( )
86
+ console . log ( '>>' , { task } )
87
+ const data = await ( new NodeTextLoader ( task ) . loadAll ( [ '../datasets/wikitext/wiki.train.tokens' ] ) )
88
+ console . log ( '~~' , { data } )
89
+
90
+ const aggregator = new aggregators . MeanAggregator ( )
91
+ const client = new clients . federated . FederatedClient ( url , task , aggregator )
92
+ const trainingInformant = new informant . FederatedInformant ( task , 10 )
93
+ const disco = new Disco ( task , { scheme : SCHEME , client, aggregator, informant : trainingInformant } )
94
+
95
+ await disco . fit ( data )
96
+ await disco . close ( )
97
+
98
+ assert (
99
+ trainingInformant . trainingAccuracy ( ) > 0.6 ,
100
+ `expected training accuracy greater than 0.6 but got ${ trainingInformant . trainingAccuracy ( ) } `
101
+ )
102
+ assert (
103
+ trainingInformant . validationAccuracy ( ) > 0.6 ,
104
+ `expected validation accuracy greater than 0.6 but got ${ trainingInformant . validationAccuracy ( ) } `
105
+ )
106
+
107
+ if ( aggregator . model === undefined ) {
108
+ throw new Error ( 'model was not set' )
109
+ }
110
+ return aggregator . model . weights
111
+ }
112
+
84
113
it ( 'two cifar10 users reach consensus' , async ( ) => {
85
114
const [ m1 , m2 ] = await Promise . all ( [ cifar10user ( ) , cifar10user ( ) ] )
86
115
assert . isTrue ( m1 . equals ( m2 ) )
@@ -90,4 +119,9 @@ describe('end-to-end federated', function () {
90
119
const [ m1 , m2 ] = await Promise . all ( [ titanicUser ( ) , titanicUser ( ) ] )
91
120
assert . isTrue ( m1 . equals ( m2 ) )
92
121
} )
122
+
123
+ it ( 'two wikitext users reach consensus' , async ( ) => {
124
+ const [ m1 , m2 ] = await Promise . all ( [ wikitextUser ( ) , wikitextUser ( ) ] )
125
+ assert . isTrue ( m1 . equals ( m2 ) )
126
+ } )
93
127
} )
0 commit comments