Skip to content

Commit 4ba51f7

Browse files
committed
discojs-core/train/models: initial
1 parent 6951421 commit 4ba51f7

35 files changed

+340
-214
lines changed

discojs/discojs-core/src/aggregator/base.ts

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import { Map, Set } from 'immutable'
2-
import type tf from '@tensorflow/tfjs'
32

4-
import type { client, Task, AsyncInformant } from '..'
3+
import type { client, Model, Task, AsyncInformant } from '..'
54

65
import { EventEmitter } from '../utils/event_emitter'
76

@@ -60,9 +59,9 @@ export abstract class Base<T> {
6059
*/
6160
public readonly task: Task,
6261
/**
63-
* The TF.js model whose weights are updated on aggregation.
62+
* The Model whose weights are updated on aggregation.
6463
*/
65-
protected _model?: tf.LayersModel,
64+
protected _model?: Model,
6665
/**
6766
* The round cut-off for contributions.
6867
*/
@@ -141,7 +140,7 @@ export abstract class Base<T> {
141140
* Sets the aggregator's TF.js model.
142141
* @param model The new TF.js model
143142
*/
144-
setModel (model: tf.LayersModel): void {
143+
setModel (model: Model): void {
145144
this._model = model
146145
}
147146

@@ -267,7 +266,7 @@ export abstract class Base<T> {
267266
/**
268267
* The aggregator's current model.
269268
*/
270-
get model (): tf.LayersModel | undefined {
269+
get model (): Model | undefined {
271270
return this._model
272271
}
273272

discojs/discojs-core/src/aggregator/mean.spec.ts

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import { assert, expect } from 'chai'
22
import type { Map } from 'immutable'
3-
import type tf from '@tensorflow/tfjs'
43

5-
import type { client, Task } from '..'
4+
import type { client, Model, Task } from '..'
65
import { aggregator, defaultTasks } from '..'
76
import { AggregationStep } from './base'
87

@@ -16,7 +15,7 @@ const bufferCapacity = weights.length
1615
export class MockMeanAggregator extends aggregator.AggregatorBase<number> {
1716
constructor (
1817
task: Task,
19-
model: tf.LayersModel,
18+
model: Model,
2019
private readonly threshold: number,
2120
roundCutoff = 0
2221
) {

discojs/discojs-core/src/aggregator/mean.ts

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import type { Map } from 'immutable'
2-
import type tf from '@tensorflow/tfjs'
32

43
import { AggregationStep, Base as Aggregator } from './base'
5-
import type { Task, WeightsContainer, client } from '..'
4+
import type { Model, Task, WeightsContainer, client } from '..'
65
import { aggregation } from '..'
76

87
/**
@@ -18,7 +17,7 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
1817

1918
constructor (
2019
task: Task,
21-
model?: tf.LayersModel,
20+
model?: Model,
2221
roundCutoff = 0,
2322
threshold = 1
2423
) {
@@ -69,7 +68,9 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
6968
aggregate (): void {
7069
this.log(AggregationStep.AGGREGATE)
7170
const result = aggregation.avg(this.contributions.get(0)?.values() as Iterable<WeightsContainer>)
72-
this.model?.setWeights(result.weights)
71+
if (this.model !== undefined) {
72+
this.model.weights = result
73+
}
7374
this.emit(result)
7475
}
7576

discojs/discojs-core/src/aggregator/secure.ts

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { Map, List, Range } from 'immutable'
33
import tf from '@tensorflow/tfjs'
44

55
import { AggregationStep, Base as Aggregator } from './base'
6-
import type { Task, WeightsContainer, client } from '..'
6+
import type { Model, Task, WeightsContainer, client } from '..'
77
import { aggregation } from '..'
88

99
/**
@@ -20,7 +20,7 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
2020

2121
constructor (
2222
task: Task,
23-
model?: tf.LayersModel
23+
model?: Model
2424
) {
2525
super(task, model, 0, 2)
2626

@@ -36,7 +36,9 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
3636
} else if (this.communicationRound === 1) {
3737
// Average the received partial sums
3838
const result = aggregation.avg(this.contributions.get(1)?.values() as Iterable<WeightsContainer>)
39-
this.model?.setWeights(result.weights)
39+
if (this.model !== undefined) {
40+
this.model.weights = result
41+
}
4042
this.emit(result)
4143
} else {
4244
throw new Error('communication round is out of bounds')

discojs/discojs-core/src/client/base.ts

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import axios from 'axios'
22
import type { Set } from 'immutable'
3-
import type tf from '@tensorflow/tfjs'
43

5-
import type { Task, TrainingInformant, WeightsContainer } from '..'
4+
import type { Model, Task, TrainingInformant, WeightsContainer } from '..'
65
import { serialization } from '..'
76
import type { NodeID } from './types'
87
import type { EventConnection } from './event_connection'
@@ -55,7 +54,7 @@ export abstract class Base {
5554
* Fetches the latest model available on the network's server, for the adequate task.
5655
* @returns The latest model
5756
*/
58-
async getLatestModel (): Promise<tf.LayersModel> {
57+
async getLatestModel (): Promise<Model> {
5958
const url = new URL('', this.url.href)
6059
if (!url.pathname.endsWith('/')) {
6160
url.pathname += '/'

discojs/discojs-core/src/default_tasks/cifar10.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tf from '@tensorflow/tfjs'
22

3-
import type { Task, TaskProvider } from '..'
4-
import { data } from '..'
3+
import type { Model, Task, TaskProvider } from '..'
4+
import { data, models } from '..'
55

66
export const cifar10: TaskProvider = {
77
getTask (): Task {
@@ -40,7 +40,7 @@ export const cifar10: TaskProvider = {
4040
}
4141
},
4242

43-
async getModel (): Promise<tf.LayersModel> {
43+
async getModel (): Promise<Model> {
4444
const mobilenet = await tf.loadLayersModel(
4545
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
4646
)
@@ -61,6 +61,6 @@ export const cifar10: TaskProvider = {
6161
metrics: ['accuracy']
6262
})
6363

64-
return model
64+
return new models.TFJS(model)
6565
}
6666
}

discojs/discojs-core/src/default_tasks/geotags.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { Range } from 'immutable'
22
import tf from '@tensorflow/tfjs'
33

4-
import type { Task, TaskProvider } from '..'
5-
import { data } from '..'
4+
import type { Model, Task, TaskProvider } from '..'
5+
import { data, models } from '..'
66
import { LabelTypeEnum } from '../task/label_type'
77

88
export const geotags: TaskProvider = {
@@ -44,7 +44,7 @@ export const geotags: TaskProvider = {
4444
}
4545
},
4646

47-
async getModel (): Promise<tf.LayersModel> {
47+
async getModel (): Promise<Model> {
4848
const pretrainedModel = await tf.loadLayersModel(
4949
'https://storage.googleapis.com/deai-313515.appspot.com/models/geotags/model.json'
5050
)
@@ -68,6 +68,6 @@ export const geotags: TaskProvider = {
6868
metrics: ['accuracy']
6969
})
7070

71-
return model
71+
return new models.TFJS(model)
7272
}
7373
}

discojs/discojs-core/src/default_tasks/lus_covid.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tf from '@tensorflow/tfjs'
22

3-
import type { Task, TaskProvider } from '..'
4-
import { data } from '..'
3+
import type { Model, Task, TaskProvider } from '..'
4+
import { data, models } from '..'
55

66
export const lusCovid: TaskProvider = {
77
getTask (): Task {
@@ -40,7 +40,7 @@ export const lusCovid: TaskProvider = {
4040
}
4141
},
4242

43-
async getModel (): Promise<tf.LayersModel> {
43+
async getModel (): Promise<Model> {
4444
const imageHeight = 100
4545
const imageWidth = 100
4646
const imageChannels = 3
@@ -93,6 +93,6 @@ export const lusCovid: TaskProvider = {
9393
metrics: ['accuracy']
9494
})
9595

96-
return model
96+
return new models.TFJS(model)
9797
}
9898
}

discojs/discojs-core/src/default_tasks/mnist.ts

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tf from '@tensorflow/tfjs'
22

3-
import type { Task, TaskProvider } from '..'
3+
import type { Model, Task, TaskProvider } from '..'
4+
import { models } from '..'
45

56
export const mnist: TaskProvider = {
67
getTask (): Task {
@@ -39,7 +40,7 @@ export const mnist: TaskProvider = {
3940
}
4041
},
4142

42-
async getModel (): Promise<tf.LayersModel> {
43+
async getModel (): Promise<Model> {
4344
const model = tf.sequential()
4445

4546
model.add(
@@ -68,6 +69,6 @@ export const mnist: TaskProvider = {
6869
metrics: ['accuracy']
6970
})
7071

71-
return model
72+
return new models.TFJS(model)
7273
}
7374
}

discojs/discojs-core/src/default_tasks/simple_face.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tf from '@tensorflow/tfjs'
22

3-
import type { Task, TaskProvider } from '..'
4-
import { data } from '..'
3+
import type { Model, Task, TaskProvider } from '..'
4+
import { data, models } from '..'
55

66
export const simpleFace: TaskProvider = {
77
getTask (): Task {
@@ -37,7 +37,7 @@ export const simpleFace: TaskProvider = {
3737
}
3838
},
3939

40-
async getModel (): Promise<tf.LayersModel> {
40+
async getModel (): Promise<Model> {
4141
const model = await tf.loadLayersModel(
4242
'https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json'
4343
)
@@ -48,6 +48,6 @@ export const simpleFace: TaskProvider = {
4848
metrics: ['accuracy']
4949
})
5050

51-
return model
51+
return new models.TFJS(model)
5252
}
5353
}

discojs/discojs-core/src/default_tasks/skin_mnist.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tf from '@tensorflow/tfjs'
22

3-
import type { Task, TaskProvider } from '..'
4-
import { data } from '..'
3+
import type { Model, Task, TaskProvider } from '..'
4+
import { data, models } from '..'
55

66
export const skinMnist: TaskProvider = {
77
getTask (): Task {
@@ -47,7 +47,7 @@ export const skinMnist: TaskProvider = {
4747
}
4848
},
4949

50-
async getModel (): Promise<tf.LayersModel> {
50+
async getModel (): Promise<Model> {
5151
const numClasses = 7
5252
const size = 28
5353

@@ -98,6 +98,6 @@ export const skinMnist: TaskProvider = {
9898
metrics: ['accuracy']
9999
})
100100

101-
return model
101+
return new models.TFJS(model)
102102
}
103103
}

discojs/discojs-core/src/default_tasks/titanic.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tf from '@tensorflow/tfjs'
22

3-
import type { Task, TaskProvider } from '..'
4-
import { data } from '..'
3+
import type { Model, Task, TaskProvider } from '..'
4+
import { data, models } from '..'
55

66
export const titanic: TaskProvider = {
77
getTask (): Task {
@@ -71,7 +71,7 @@ export const titanic: TaskProvider = {
7171
}
7272
},
7373

74-
async getModel (): Promise<tf.LayersModel> {
74+
async getModel (): Promise<Model> {
7575
const model = tf.sequential()
7676

7777
model.add(
@@ -92,6 +92,6 @@ export const titanic: TaskProvider = {
9292
metrics: ['accuracy']
9393
})
9494

95-
return model
95+
return new models.TFJS(model)
9696
}
9797
}

discojs/discojs-core/src/index.ts

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ export { Memory, ModelType, type ModelInfo, type Path, type ModelSource, Empty a
1414
export { Disco, TrainingSchemes } from './training'
1515
export { Validator } from './validation'
1616

17+
export { Model } from './models'
18+
export * as models from './models'
19+
1720
export * from './task'
1821
export * as defaultTasks from './default_tasks'
1922

discojs/discojs-core/src/memory/base.ts

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
// only used browser-side
22
// TODO: replace IO type
3-
import type tf from '@tensorflow/tfjs'
43

5-
import type { TaskID } from '..'
4+
import type { Model, TaskID } from '..'
65
import type { ModelType } from './model_type'
76

87
/**
@@ -49,7 +48,7 @@ export abstract class Memory {
4948
* @param source The model source
5049
* @returns The model
5150
*/
52-
abstract getModel (source: ModelSource): Promise<tf.LayersModel>
51+
abstract getModel (source: ModelSource): Promise<Model>
5352

5453
/**
5554
* Removes the model identified by the given model source from memory.
@@ -77,7 +76,7 @@ export abstract class Memory {
7776
* @param source The model source
7877
* @param model The new model
7978
*/
80-
abstract updateWorkingModel (source: ModelSource, model: tf.LayersModel): Promise<void>
79+
abstract updateWorkingModel (source: ModelSource, model: Model): Promise<void>
8180

8281
/**
8382
* Creates a saved model copy from the working model identified by the given model source.
@@ -94,7 +93,7 @@ export abstract class Memory {
9493
* @param model The new model
9594
* @returns The saved model's path
9695
*/
97-
abstract saveModel (source: ModelSource, model: tf.LayersModel): Promise<Path | undefined>
96+
abstract saveModel (source: ModelSource, model: Model): Promise<Path | undefined>
9897

9998
/**
10099
* Moves the model identified by the model source to a file system. This is platform-dependent.

discojs/discojs-core/src/memory/empty.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type tf from '@tensorflow/tfjs'
1+
import type { Model } from '..'
22

33
import type { ModelInfo, Path } from './base'
44
import { Memory } from './base'
@@ -15,7 +15,7 @@ export class Empty extends Memory {
1515
return false
1616
}
1717

18-
async getModel (): Promise<tf.LayersModel> {
18+
async getModel (): Promise<Model> {
1919
throw new Error('empty')
2020
}
2121

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export { Model } from './model'
2+
export { TFJS } from './tfjs'

0 commit comments

Comments
 (0)