Skip to content

Commit bb51e16

Browse files
committed
server: export starter
1 parent 4ba51f7 commit bb51e16

File tree

8 files changed

+16
-43
lines changed

8 files changed

+16
-43
lines changed

cli/src/cli.ts

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import { Range } from 'immutable'
2+
import type { Server } from 'node:http'
23

34
import type { TrainerLog, data, Task } from '@epfml/discojs-core'
4-
import { Disco, TrainingSchemes } from '@epfml/discojs-core'
5+
import { Disco, TrainingSchemes, aggregator as aggregators, client as clients } from '@epfml/discojs-core'
6+
import { getClient, startServer } from '@epfml/disco-server'
57

6-
import { startServer, saveLog } from './utils'
8+
import { saveLog } from './utils'
79
import { getTaskData } from './data'
810
import { args } from './args'
911

@@ -15,23 +17,25 @@ console.log(infoText)
1517

1618
console.log({ args })
1719

18-
async function runUser (task: Task, url: URL, data: data.DataSplit): Promise<TrainerLog> {
20+
async function runUser (task: Task, server: Server, data: data.DataSplit): Promise<TrainerLog> {
21+
const client = await getClient(clients.federated.FederatedClient, server, task, new aggregators.MeanAggregator(TASK))
22+
1923
// force the federated scheme
2024
const scheme = TrainingSchemes.FEDERATED
21-
const disco = new Disco(task, { scheme, url })
25+
const disco = new Disco(task, { scheme, client })
2226

2327
await disco.fit(data)
2428
await disco.close()
2529
return await disco.logs()
2630
}
2731

2832
async function main (): Promise<void> {
29-
const [server, serverUrl] = await startServer()
33+
const server = await startServer()
3034

3135
const data = await getTaskData(TASK)
3236

3337
const logs = await Promise.all(
34-
Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, serverUrl, data)).toArray()
38+
Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, server, data)).toArray()
3539
)
3640

3741
if (args.save) {

cli/src/utils.ts

-32
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,6 @@
1-
import type http from 'node:http'
21
import fs from 'node:fs'
32

43
import type { TrainerLog } from '@epfml/discojs-core'
5-
import { Disco } from '@epfml/disco-server'
6-
7-
export async function startServer (): Promise<[http.Server, URL]> {
8-
const disco = new Disco()
9-
await disco.addDefaultTasks()
10-
11-
const server = disco.serve(8000)
12-
await new Promise((resolve, reject) => {
13-
server.once('listening', resolve)
14-
server.once('error', reject)
15-
server.on('error', console.error)
16-
})
17-
18-
let addr: string
19-
const rawAddr = server.address()
20-
if (rawAddr === null) {
21-
throw new Error('unable to get server address')
22-
} else if (typeof rawAddr === 'string') {
23-
addr = rawAddr
24-
} else if (typeof rawAddr === 'object') {
25-
if (rawAddr.family === '4') {
26-
addr = `${rawAddr.address}:${rawAddr.port}`
27-
} else {
28-
addr = `[${rawAddr.address}]:${rawAddr.port}`
29-
}
30-
} else {
31-
throw new Error('unable to get address to server')
32-
}
33-
34-
return [server, new URL('', `http://${addr}`)]
35-
}
364

375
export function saveLog (logs: TrainerLog[], fileName: string): void {
386
const filePath = `./${fileName}`

server/src/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
export { Disco } from './get_server'
2+
export * from './utils'

server/tests/utils.ts server/src/utils.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import type { Server } from 'node:http'
22

33
import type { aggregator, client, Task } from '@epfml/discojs-core'
44

5-
import { runDefaultServer } from '../src/get_server'
5+
import { runDefaultServer } from './get_server'
66

77
export async function startServer (): Promise<Server> {
88
const server = await runDefaultServer()

server/tests/client/decentralized.spec.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import type * as http from 'http'
33
import type { Task } from '@epfml/discojs-core'
44
import { aggregator as aggregators, client as clients, defaultTasks } from '@epfml/discojs-core'
55

6-
import { getClient, startServer } from '../utils'
6+
import { getClient, startServer } from '../../src'
77

88
const TASK = defaultTasks.titanic.getTask()
99

server/tests/client/federated.spec.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import type * as http from 'http'
22

33
import { aggregator as aggregators, client as clients, informant, defaultTasks } from '@epfml/discojs-core'
44

5-
import { getClient, startServer } from '../utils'
5+
import { getClient, startServer } from '../../src'
66

77
const TASK = defaultTasks.titanic.getTask()
88

server/tests/e2e/decentralized.spec.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {
77
aggregator as aggregators, informant as informants, client as clients, WeightsContainer, defaultTasks, aggregation
88
} from '@epfml/discojs-core'
99

10-
import { getClient, startServer } from '../utils'
10+
import { getClient, startServer } from '../../src'
1111

1212
// Mocked aggregators with easy-to-fetch aggregation results
1313
class MockMeanAggregator extends aggregators.MeanAggregator {

server/tests/e2e/federated.spec.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import {
1111
} from '@epfml/discojs-core'
1212
import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node'
1313

14-
import { getClient, startServer } from '../utils'
14+
import { getClient, startServer } from '../../src'
1515

1616
const SCHEME = TrainingSchemes.FEDERATED
1717

0 commit comments

Comments
 (0)