Skip to content

Commit 09331b3

Browse files
committed
server/tests: add wikitext
1 parent aec0ddc commit 09331b3

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

discojs/discojs-core/src/default_tasks/wikitext.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export const wikitext: TaskProvider = {
2020
dataType: 'text',
2121
modelID: 'wikitext-103-raw-model',
2222
validationSplit: 0.2, // TODO: is this used somewhere? because train, eval and test are already split in dataset
23-
epochs: 10_000,
23+
epochs: 10,
2424
// constructing a batch is taken care automatically in the dataset to make things faster
2525
// so we fake a batch size of 1
2626
batchSize: 1,

server/tests/e2e/federated.spec.ts

+33-4
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@ import fs from 'node:fs/promises'
22
import path from 'node:path'
33
import type { Server } from 'node:http'
44
import { Range } from 'immutable'
5-
import { assert } from 'chai'
5+
import { assert, expect } from 'chai'
66

77
import type { WeightsContainer } from '@epfml/discojs-core'
88
import {
9-
Disco, TrainingSchemes, client as clients,
9+
Disco, TrainingSchemes, client as clients, data,
1010
aggregator as aggregators, informant, defaultTasks
1111
} from '@epfml/discojs-core'
12-
import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node'
12+
import { NodeImageLoader, NodeTabularLoader, NodeTextLoader } from '@epfml/discojs-node'
1313

1414
import { startServer } from '../../src'
1515

1616
const SCHEME = TrainingSchemes.FEDERATED
1717

1818
describe('end-to-end federated', function () {
19-
this.timeout(120_000)
19+
this.timeout(100_000)
2020

2121
let server: Server
2222
let url: URL
@@ -81,13 +81,42 @@ describe('end-to-end federated', function () {
8181
return aggregator.model.weights
8282
}
8383

84+
async function wikitextUser (): Promise<void> {
85+
const task = defaultTasks.wikitext.getTask()
86+
const loader = new NodeTextLoader(task)
87+
const dataSplit: data.DataSplit = {
88+
train: await data.TextData.init((await loader.load('../datasets/wikitext/wiki.train.tokens')), task),
89+
validation: await data.TextData.init(await loader.load('../datasets/wikitext/wiki.valid.tokens'), task)
90+
}
91+
92+
const aggregator = new aggregators.MeanAggregator()
93+
const client = new clients.federated.FederatedClient(url, task, aggregator)
94+
const trainingInformant = new informant.FederatedInformant(task, 10)
95+
const disco = new Disco(task, { scheme: SCHEME, client, aggregator, informant: trainingInformant })
96+
97+
await disco.fit(dataSplit)
98+
await disco.close()
99+
100+
expect(trainingInformant.losses.first()).to.be.above(trainingInformant.losses.last())
101+
}
102+
84103
it('two cifar10 users reach consensus', async () => {
104+
this.timeout(90_000)
105+
85106
const [m1, m2] = await Promise.all([cifar10user(), cifar10user()])
86107
assert.isTrue(m1.equals(m2))
87108
})
88109

89110
it('two titanic users reach consensus', async () => {
111+
this.timeout(30_000)
112+
90113
const [m1, m2] = await Promise.all([titanicUser(), titanicUser()])
91114
assert.isTrue(m1.equals(m2))
92115
})
116+
117+
it('trains wikitext', async () => {
118+
this.timeout(120_000)
119+
120+
await wikitextUser()
121+
})
93122
})

0 commit comments

Comments
 (0)