Skip to content

Commit a3dfaf4

Browse files
committed
tmp: sketch of fedprox implementation
1 parent 2391c1b commit a3dfaf4

File tree

3 files changed

+49
-46
lines changed

3 files changed

+49
-46
lines changed

discojs/src/models/model.ts

+4
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@ import type { BatchLogs, EpochLogs } from "./logs.js";
1717
**/
1818
// TODO make it typesafe: same shape of data/input/weights
1919
export abstract class Model<D extends DataType> implements Disposable {
20+
protected prevRoundWeights: WeightsContainer | undefined;
2021
// TODO don't allow external access but upgrade train to return weights on every epoch
2122
/** Return training state */
2223
abstract get weights(): WeightsContainer;
2324
/** Set training state */
2425
abstract set weights(ws: WeightsContainer);
2526

27+
set previousRoundWeights(ws: WeightsContainer | undefined) {
28+
this.prevRoundWeights = ws
29+
}
2630
/**
2731
* Improve predictor
2832
*

discojs/src/models/tfjs.ts

+43-45
Original file line numberDiff line numberDiff line change
@@ -73,62 +73,60 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
7373
}
7474

7575
async trainFedProx(
76-
xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> {
77-
// let logitsTensor: tf.Tensor<tf.Rank>;
78-
debug(this.model.loss, this.model.losses, this.model.lossFunctions)
76+
xs: tf.Tensor, ys: tf.Tensor,
77+
): Promise<[number, number]> {
78+
let logitsTensor: tf.Tensor<tf.Rank>;
7979
const lossFunction: () => tf.Scalar = () => {
80+
// Proximal term
81+
let proximalTerm = tf.tensor(0)
82+
if (this.prevRoundWeights !== undefined) {
83+
// squared norm
84+
const norm = new WeightsContainer(this.model.getWeights())
85+
.sub(this.prevRoundWeights)
86+
.map(t => t.square().sum())
87+
.reduce((t, acc) => tf.add(t, acc)).asScalar()
88+
const mu = 1
89+
proximalTerm = tf.mul(mu / 2, norm)
90+
}
91+
8092
this.model.apply(xs)
8193
const logits = this.model.apply(xs)
82-
if (Array.isArray(logits))
83-
throw new Error('model outputs too many tensor')
84-
if (logits instanceof tf.SymbolicTensor)
85-
throw new Error('model outputs symbolic tensor')
86-
// logitsTensor = tf.keep(logits)
87-
// return tf.losses.softmaxCrossEntropy(ys, logits)
88-
let y: tf.Tensor;
89-
y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
90-
y = tf.log(tf.div(y, tf.sub(1, y)));
91-
return tf.losses.sigmoidCrossEntropy(ys, y);
92-
// return tf.losses.sigmoidCrossEntropy(ys, logits)
94+
if (Array.isArray(logits))
95+
throw new Error('model outputs too many tensor')
96+
if (logits instanceof tf.SymbolicTensor)
97+
throw new Error('model outputs symbolic tensor')
98+
logitsTensor = tf.keep(logits)
99+
// binaryCrossEntropy
100+
let y: tf.Tensor;
101+
y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
102+
y = tf.log(tf.div(y, tf.sub(1, y)));
103+
const loss = tf.losses.sigmoidCrossEntropy(ys, y);
104+
console.log(loss.dataSync(), proximalTerm.dataSync())
105+
return tf.add(loss, proximalTerm)
93106
}
94107
const lossTensor = this.model.optimizer.minimize(lossFunction, true)
95108
if (lossTensor === null) throw new Error("loss should not be null")
96-
// const lossTensor = tf.tidy(() => {
97-
// const { grads, value: lossTensor } = this.model.optimizer.computeGradients(() => {
98-
// const logits = this.model.apply(xs)
99-
// if (Array.isArray(logits))
100-
// throw new Error('model outputs too many tensor')
101-
// if (logits instanceof tf.SymbolicTensor)
102-
// throw new Error('model outputs symbolic tensor')
103-
// logitsTensor = tf.keep(logits)
104-
// // return tf.losses.softmaxCrossEntropy(ys, logits)
105-
// return this.model.calculateLosses(ys, logits)[0]
106-
// })
107-
// this.model.optimizer.applyGradients(grads)
108-
// return lossTensor
109-
// })
110109

111-
// // @ts-expect-error Variable 'logitsTensor' is used before being assigned
112-
// const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
113-
// const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
114-
// const accSumTensor = accTensor.sum()
115-
// const accSum = await accSumTensor.array()
116-
// if (typeof accSum !== 'number')
117-
// throw new Error('got multiple accuracy sum')
118-
// // @ts-expect-error Variable 'logitsTensor' is used before being assigned
119-
// tf.dispose([accTensor, accSumTensor, logitsTensor])
110+
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
111+
const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
112+
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
113+
const accSumTensor = accTensor.sum()
114+
const accSum = await accSumTensor.array()
115+
if (typeof accSum !== 'number')
116+
throw new Error('got multiple accuracy sum')
117+
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
118+
tf.dispose([accTensor, accSumTensor, logitsTensor])
120119

121120
const loss = await lossTensor.array()
122121
tf.dispose([xs, ys, lossTensor])
123122

124-
// const memory = tf.memory().numBytes / 1024 / 1024 / 1024
125-
// debug("training metrics: %O", {
126-
// loss,
127-
// memory,
128-
// allocated: tf.memory().numTensors,
129-
// });
130-
return [loss, 0]
131-
// return [loss, accSum / accSize]
123+
const memory = tf.memory().numBytes / 1024 / 1024 / 1024
124+
debug("training metrics: %O", {
125+
loss,
126+
memory,
127+
allocated: tf.memory().numTensors,
128+
});
129+
return [loss, accSum / accSize]
132130
}
133131

134132
async #evaluate(

discojs/src/training/trainer.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ export class Trainer<D extends DataType> {
9090
let previousRoundWeights: WeightsContainer | undefined;
9191
for (let round = 0; round < totalRound; round++) {
9292
await this.#client.onRoundBeginCommunication();
93-
93+
94+
this.model.previousRoundWeights = previousRoundWeights
9495
yield this.#runRound(dataset, validationDataset);
9596

9697
let localWeights = this.model.weights;

0 commit comments

Comments
 (0)