Skip to content

Commit 2391c1b

Browse files
committed
tmp: overriding weight update yields same as default
1 parent d287795 commit 2391c1b

File tree

2 files changed

+78
-11
lines changed

2 files changed

+78
-11
lines changed

discojs/src/default_tasks/lus_covid.ts

+9-8
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ export const lusCovid: TaskProvider<'image'> = {
3939

4040
// Model architecture from tensorflow.js docs:
4141
// https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4
42-
async getModel (): Promise<Model<'image'>> {
42+
async getModel(): Promise<Model<'image'>> {
43+
const seed = 42
4344
const imageHeight = 100
4445
const imageWidth = 100
4546
const imageChannels = 3
@@ -55,7 +56,7 @@ export const lusCovid: TaskProvider<'image'> = {
5556
filters: 8,
5657
strides: 1,
5758
activation: 'relu',
58-
kernelInitializer: 'varianceScaling'
59+
kernelInitializer: tf.initializers.heNormal({ seed })
5960
}))
6061

6162
// The MaxPooling layer acts as a sort of downsampling using max values
@@ -69,7 +70,7 @@ export const lusCovid: TaskProvider<'image'> = {
6970
filters: 16,
7071
strides: 1,
7172
activation: 'relu',
72-
kernelInitializer: 'varianceScaling'
73+
kernelInitializer: tf.initializers.heNormal({ seed })
7374
}))
7475
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }))
7576

@@ -82,16 +83,16 @@ export const lusCovid: TaskProvider<'image'> = {
8283
// output class.
8384
model.add(tf.layers.dense({
8485
units: numOutputClasses,
85-
kernelInitializer: 'varianceScaling',
86-
activation: 'softmax'
86+
activation: 'softmax',
87+
kernelInitializer: tf.initializers.heNormal({ seed })
8788
}))
88-
89+
8990
model.compile({
90-
optimizer: 'sgd',
91+
optimizer: tf.train.sgd(0.001),
9192
loss: 'binaryCrossentropy',
9293
metrics: ['accuracy']
9394
})
9495

9596
return Promise.resolve(new models.TFJS('image', model))
9697
}
97-
}
98+
}

discojs/src/models/tfjs.ts

+69-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import createDebug from "debug";
12
import { List, Map, Range } from "immutable";
23
import * as tf from '@tensorflow/tfjs'
34

@@ -13,6 +14,8 @@ import { BatchLogs } from './index.js'
1314
import { Model } from './index.js'
1415
import { EpochLogs } from './logs.js'
1516

17+
const debug = createDebug("discojs:models:tfjs");
18+
1619
type Serialized<D extends DataType> = [D, tf.io.ModelArtifacts];
1720

1821
/** TensorFlow JavaScript model with standard training */
@@ -63,11 +66,71 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
6366
batch: Batched<DataFormat.ModelEncoded[D]>,
6467
): Promise<BatchLogs> {
6568
const { xs, ys } = this.#batchToTF(batch);
66-
const logs = await this.model.trainOnBatch(xs, ys);
69+
const logs = await this.trainFedProx(xs, ys);
70+
// const logs = await this.model.trainOnBatch(xs, ys);
6771
tf.dispose([xs, ys])
6872
return this.getBatchLogs(logs)
6973
}
7074

75+
async trainFedProx(
76+
xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> {
77+
// let logitsTensor: tf.Tensor<tf.Rank>;
78+
debug(this.model.loss, this.model.losses, this.model.lossFunctions)
79+
const lossFunction: () => tf.Scalar = () => {
80+
this.model.apply(xs)
81+
const logits = this.model.apply(xs)
82+
if (Array.isArray(logits))
83+
throw new Error('model outputs too many tensor')
84+
if (logits instanceof tf.SymbolicTensor)
85+
throw new Error('model outputs symbolic tensor')
86+
// logitsTensor = tf.keep(logits)
87+
// return tf.losses.softmaxCrossEntropy(ys, logits)
88+
let y: tf.Tensor;
89+
y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
90+
y = tf.log(tf.div(y, tf.sub(1, y)));
91+
return tf.losses.sigmoidCrossEntropy(ys, y);
92+
// return tf.losses.sigmoidCrossEntropy(ys, logits)
93+
}
94+
const lossTensor = this.model.optimizer.minimize(lossFunction, true)
95+
if (lossTensor === null) throw new Error("loss should not be null")
96+
// const lossTensor = tf.tidy(() => {
97+
// const { grads, value: lossTensor } = this.model.optimizer.computeGradients(() => {
98+
// const logits = this.model.apply(xs)
99+
// if (Array.isArray(logits))
100+
// throw new Error('model outputs too many tensor')
101+
// if (logits instanceof tf.SymbolicTensor)
102+
// throw new Error('model outputs symbolic tensor')
103+
// logitsTensor = tf.keep(logits)
104+
// // return tf.losses.softmaxCrossEntropy(ys, logits)
105+
// return this.model.calculateLosses(ys, logits)[0]
106+
// })
107+
// this.model.optimizer.applyGradients(grads)
108+
// return lossTensor
109+
// })
110+
111+
// // @ts-expect-error Variable 'logitsTensor' is used before being assigned
112+
// const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
113+
// const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
114+
// const accSumTensor = accTensor.sum()
115+
// const accSum = await accSumTensor.array()
116+
// if (typeof accSum !== 'number')
117+
// throw new Error('got multiple accuracy sum')
118+
// // @ts-expect-error Variable 'logitsTensor' is used before being assigned
119+
// tf.dispose([accTensor, accSumTensor, logitsTensor])
120+
121+
const loss = await lossTensor.array()
122+
tf.dispose([xs, ys, lossTensor])
123+
124+
// const memory = tf.memory().numBytes / 1024 / 1024 / 1024
125+
// debug("training metrics: %O", {
126+
// loss,
127+
// memory,
128+
// allocated: tf.memory().numTensors,
129+
// });
130+
return [loss, 0]
131+
// return [loss, accSum / accSize]
132+
}
133+
71134
async #evaluate(
72135
dataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>,
73136
): Promise<Record<"accuracy" | "loss", number>> {
@@ -160,7 +223,10 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
160223
return new this(
161224
datatype,
162225
await tf.loadLayersModel({
163-
load: () => Promise.resolve(artifacts),
226+
load: () => {
227+
console.log("deserialize called")
228+
return Promise.resolve(artifacts)
229+
},
164230
}),
165231
);
166232
}
@@ -187,7 +253,7 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
187253
return [this.datatype, await ret]
188254
}
189255

190-
[Symbol.dispose](): void{
256+
[Symbol.dispose](): void {
191257
this.model.dispose()
192258
}
193259

0 commit comments

Comments
 (0)