Skip to content

Commit ca2d2c1

Browse files
committed
server/tests: add wikitext
1 parent 526b966 commit ca2d2c1

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

server/tests/e2e/federated.spec.ts

+35-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import {
99
Disco, TrainingSchemes, client as clients,
1010
aggregator as aggregators, informant, defaultTasks
1111
} from '@epfml/discojs-core'
12-
import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node'
12+
import { NodeImageLoader, NodeTabularLoader, NodeTextLoader } from '@epfml/discojs-node'
1313

1414
import { startServer } from '../../src'
1515

@@ -81,6 +81,35 @@ describe('end-to-end federated', function () {
8181
return aggregator.model.weights
8282
}
8383

84+
async function wikitextUser (): Promise<WeightsContainer> {
85+
const task = defaultTasks.wikitext.getTask()
86+
console.log('>>', { task })
87+
const data = await (new NodeTextLoader(task).loadAll(['../datasets/wikitext/wiki.train.tokens']))
88+
console.log('~~', { data })
89+
90+
const aggregator = new aggregators.MeanAggregator()
91+
const client = new clients.federated.FederatedClient(url, task, aggregator)
92+
const trainingInformant = new informant.FederatedInformant(task, 10)
93+
const disco = new Disco(task, { scheme: SCHEME, client, aggregator, informant: trainingInformant })
94+
95+
await disco.fit(data)
96+
await disco.close()
97+
98+
assert(
99+
trainingInformant.trainingAccuracy() > 0.6,
100+
`expected training accuracy greater than 0.6 but got ${trainingInformant.trainingAccuracy()}`
101+
)
102+
assert(
103+
trainingInformant.validationAccuracy() > 0.6,
104+
`expected validation accuracy greater than 0.6 but got ${trainingInformant.validationAccuracy()}`
105+
)
106+
107+
if (aggregator.model === undefined) {
108+
throw new Error('model was not set')
109+
}
110+
return aggregator.model.weights
111+
}
112+
84113
it('two cifar10 users reach consensus', async () => {
85114
const [m1, m2] = await Promise.all([cifar10user(), cifar10user()])
86115
assert.isTrue(m1.equals(m2))
@@ -90,4 +119,9 @@ describe('end-to-end federated', function () {
90119
const [m1, m2] = await Promise.all([titanicUser(), titanicUser()])
91120
assert.isTrue(m1.equals(m2))
92121
})
122+
123+
it('two wikitext users reach consensus', async () => {
124+
const [m1, m2] = await Promise.all([wikitextUser(), wikitextUser()])
125+
assert.isTrue(m1.equals(m2))
126+
})
93127
})

0 commit comments

Comments
 (0)