-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathargs.ts
81 lines (69 loc) · 2.59 KB
/
args.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import { parse } from 'ts-command-line-args'
import { Map, Set } from 'immutable'
import type { DataType, TaskProvider } from "@epfml/discojs";
import { defaultTasks } from '@epfml/discojs'
interface BenchmarkArguments {
provider: TaskProvider<DataType>
numberOfUsers: number
epochs: number
roundDuration: number
batchSize: number
save: boolean
host: URL
}
type BenchmarkUnsafeArguments = Omit<BenchmarkArguments, 'provider'> & {
task: string
help?: boolean
}
const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'
const unsafeArgs = parse<BenchmarkUnsafeArguments>(
{
task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' },
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 },
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
typeLabel: "URL",
description: "Host to connect to",
defaultValue: new URL("http://localhost:8080"),
},
help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' }
},
{
helpArg: 'help',
headerContentSections: [{ header: 'DISCO CLI', content: 'npm start -- [Options]\n' + argExample }]
}
)
const supportedTasks = Map(
Set.of<TaskProvider<"image"> | TaskProvider<"tabular">>(
defaultTasks.cifar10,
defaultTasks.lusCovid,
defaultTasks.simpleFace,
defaultTasks.titanic,
defaultTasks.tinderDog,
).map((t) => [t.getTask().id, t]),
);
const provider = supportedTasks.get(unsafeArgs.task);
if (provider === undefined) {
throw Error(`${unsafeArgs.task} not implemented.`)
}
export const args: BenchmarkArguments = {
...unsafeArgs,
provider: {
getTask() {
const task = provider.getTask();
// Override training information
task.trainingInformation.batchSize = unsafeArgs.batchSize;
task.trainingInformation.roundDuration = unsafeArgs.roundDuration;
task.trainingInformation.epochs = unsafeArgs.epochs;
// For DP
// TASK.trainingInformation.clippingRadius = 10000000
// TASK.trainingInformation.noiseScale = 0
return task;
},
getModel: () => provider.getModel(),
},
};