-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtraining.ts
87 lines (72 loc) · 2.72 KB
/
training.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import { Repeat } from 'immutable'
import * as path from 'node:path'
import '@tensorflow/tfjs-node'
import type { Dataset, DataFormat, DataType, Image, Task } from '@epfml/discojs'
import { Disco, fetchTasks, defaultTasks } from '@epfml/discojs'
import { loadCSV, loadImagesInDir } from '@epfml/discojs-node'
import { Server } from 'server'
/**
* Example of discojs API, we load data, build the appropriate loggers, the disco object
* and finally start training.
*/
async function runUser<D extends DataType>(
url: URL,
task: Task<D>,
dataset: Dataset<DataFormat.Raw[D]>,
): Promise<void> {
// Create Disco object associated with the server url, the training scheme
const disco = new Disco(task, url, { scheme: 'federated' })
// Run training on the dataset
await disco.trainFully(dataset);
// Disconnect from the remote server
await disco.close()
}
type TaskAndDataset<D extends DataType> = [Task<D>, Dataset<DataFormat.Raw[D]>];
async function main (): Promise<void> {
// Arbitrary chosen Task ID
const NAME: string = 'titanic'
// Launch a server instance
const [server, url] = await new Server().serve(undefined, defaultTasks.simpleFace, defaultTasks.titanic)
// Get all pre-defined tasks
const tasks = await fetchTasks(url)
// Choose the task and load local data
// Make sure you first ran ./get_training_data
let taskAndDataset: TaskAndDataset<'image' | 'tabular'>
switch (NAME) {
case "titanic": {
const task = tasks.get("titanic") as Task<"tabular"> | undefined;
if (task === undefined) throw new Error("task not found");
taskAndDataset = [task, loadCSV("../../datasets/titanic_train.csv")];
break;
}
case "simple_face": {
const task = tasks.get("simple_face") as Task<"image"> | undefined;
if (task === undefined) throw new Error("task not found");
taskAndDataset = [task, await loadSimpleFaceData()];
break;
}
default:
throw new Error('task id not found')
}
// Add more users to the list to simulate more than 3 clients
await Promise.all([
runUser(url, ...taskAndDataset),
runUser(url, ...taskAndDataset),
runUser(url, ...taskAndDataset),
])
// Close server
await new Promise((resolve, reject) => {
server.once('close', resolve)
server.close(reject)
})
}
async function loadSimpleFaceData(): Promise<Dataset<[Image, string]>> {
const folder = "../datasets/simple_face";
const [adults, childs]: Dataset<[Image, string]>[] = [
(await loadImagesInDir(path.join(folder, "adult"))).zip(Repeat("adult")),
(await loadImagesInDir(path.join(folder, "child"))).zip(Repeat("child")),
];
return adults.chain(childs);
}
// You can run this example with "npm run train" from this folder
main().catch(console.error)