Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FedProx implementation #837

Draft
wants to merge 27 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1e05750
webapp/components/containers/testing: show image labels in all caps
JulienVig Nov 25, 2024
683ccb3
webapp/components/containers/ImageCard: improve image hover effect
JulienVig Nov 25, 2024
40de1e9
webapp/components/training/TrainingInformation: display training char…
JulienVig Nov 25, 2024
6916223
discojs/src/training/disco: handle validation split ratio equals zero
JulienVig Nov 25, 2024
4b22ef7
discojs/client/federated/federated_client: wait indefinitely for serv…
JulienVig Nov 25, 2024
7a457a4
disco/default_tasks: implement GDHF tinder dog task
JulienVig Nov 25, 2024
c573bb5
cli: setup tinder dog CLI support
JulienVig Nov 25, 2024
0733ae0
server/controllers/federated_controller: reset state when training se…
JulienVig Nov 26, 2024
bb6242d
webapp/components/dataset_input/LabeledImageDatasetInput/ByGroup: shu…
JulienVig Nov 27, 2024
da536eb
discojs/default_tasks/tinder_dog: improve textual description and lin…
JulienVig Nov 27, 2024
b0be745
discojs/models: use trainOnBatch instead of fit and fitDataset
JulienVig Nov 28, 2024
83a12d7
cli/src: support wikitext task
JulienVig Nov 28, 2024
c52ef11
discojs/models/gpt: compute logits only once, 10% faster
JulienVig Nov 28, 2024
d287795
fixup! discojs/models: use trainOnBatch instead of fit and fitDataset
JulienVig Nov 28, 2024
2391c1b
tmp: overriding weight update yields same as default
JulienVig Nov 28, 2024
a3dfaf4
tmp: sketch of fedprox implementation
JulienVig Nov 28, 2024
9ad8981
discojs/models: use trainOnBatch instead of fit and fitDataset
JulienVig Nov 28, 2024
c20ad82
discojs/models/gpt: compute logits only once, 10% faster
JulienVig Nov 28, 2024
abe7996
tmp: overriding weight update yields same as default
JulienVig Feb 25, 2025
76977e7
tmp: sketch of fedprox implementation
JulienVig Nov 28, 2024
d6952c8
Merge branch 'develop' of github.com:epfml/disco into 802-fedprox-julien
tomasoignons Mar 3, 2025
0e1f4e5
Added the FedAverage training
tomasoignons Mar 14, 2025
576a3ef
Begin to implement the choice between fedaverage and FedProx
tomasoignons Mar 14, 2025
c0d4196
Merge branch '802-fedprox-julien' of github.com:epfml/disco into 802-…
tomasoignons Mar 14, 2025
f88856f
added the fedprox by default if nothing is specified
tomasoignons Mar 14, 2025
de00d93
cli: cleanup node_modules
tharvik Mar 31, 2025
a5ecff3
cli: drop local immutable
tharvik Mar 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
"author": "",
"license": "ISC",
"dependencies": {
"server": "*",
"@epfml/discojs-node": "*",
"csv-parse": "^5.6.0",
"server": "*",
"tslib": "2"
},
"devDependencies": {
Expand Down
2 changes: 1 addition & 1 deletion cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
)

const supportedTasks = Map(
Set.of<TaskProvider<"image"> | TaskProvider<"tabular">>(
Set.of<TaskProvider<"image"> | TaskProvider<"tabular"> | TaskProvider<"text">>(
defaultTasks.cifar10,
defaultTasks.lusCovid,
defaultTasks.simpleFace,
Expand Down
2 changes: 1 addition & 1 deletion cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,4 @@ async function main(args: Required<CLIArguments>): Promise<void> {
}

// You can run this example with "npm start" from this folder
main(args).catch(console.error)
main(args).catch(console.error)
2 changes: 1 addition & 1 deletion cli/src/train_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ async function main(): Promise<void> {
}

// You can run this example with "npm run run_gpt" from this folder
main().catch(console.error)
main().catch(console.error)
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/cifar10.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ export const cifar10: TaskProvider<'image'> = {
metrics: ['accuracy']
})

return new models.TFJS('image', model)
return new models.TFJS('image', model, "fedprox")
}
}
19 changes: 10 additions & 9 deletions discojs/src/default_tasks/lus_covid.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export const lusCovid: TaskProvider<'image'> = {

// Model architecture from tensorflow.js docs:
// https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4
async getModel (): Promise<Model<'image'>> {
async getModel(): Promise<Model<'image'>> {
const seed = 42
const imageHeight = 100
const imageWidth = 100
const imageChannels = 3
Expand All @@ -55,7 +56,7 @@ export const lusCovid: TaskProvider<'image'> = {
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
kernelInitializer: tf.initializers.heNormal({ seed })
}))

// The MaxPooling layer acts as a sort of downsampling using max values
Expand All @@ -69,7 +70,7 @@ export const lusCovid: TaskProvider<'image'> = {
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
kernelInitializer: tf.initializers.heNormal({ seed })
}))
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }))

Expand All @@ -82,16 +83,16 @@ export const lusCovid: TaskProvider<'image'> = {
// output class.
model.add(tf.layers.dense({
units: numOutputClasses,
kernelInitializer: 'varianceScaling',
activation: 'softmax'
activation: 'softmax',
kernelInitializer: tf.initializers.heNormal({ seed })
}))

model.compile({
optimizer: 'sgd',
optimizer: tf.train.sgd(0.001),
loss: 'binaryCrossentropy',
metrics: ['accuracy']
})

return Promise.resolve(new models.TFJS('image', model))
return Promise.resolve(new models.TFJS('image', model, "fedprox"))
}
}
}
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/mnist.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ export const mnist: TaskProvider<'image'> = {
metrics: ['accuracy']
})

return Promise.resolve(new models.TFJS('image', model))
return Promise.resolve(new models.TFJS('image', model, "fedprox"))
}
}
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/simple_face.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ export const simpleFace: TaskProvider<'image'> = {
metrics: ['accuracy']
})

return new models.TFJS('image', model)
return new models.TFJS('image', model, "fedprox")
}
}
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/tinder_dog.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ export const tinderDog: TaskProvider<'image'> = {
metrics: ['accuracy']
})

return Promise.resolve(new models.TFJS('image', model))
return Promise.resolve(new models.TFJS('image', model, "fedprox"))
}
}
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/titanic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,6 @@ export const titanic: TaskProvider<'tabular'> = {
metrics: ['accuracy']
})

return Promise.resolve(new models.TFJS('tabular', model))
return Promise.resolve(new models.TFJS('tabular', model, "fedprox"))
}
}
28 changes: 4 additions & 24 deletions discojs/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,10 @@ export class GPT extends Model<"text"> {
async #runBatch(
batch: Batched<DataFormat.ModelEncoded["text"]>,
): Promise<BatchLogs> {
const tfBatch = this.#batchToTF(batch);

let logs: tf.Logs | undefined;
await this.model.fitDataset(tf.data.array([tfBatch]), {
epochs: 1,
verbose: 0, // don't pollute
callbacks: {
onEpochEnd: (_, cur) => {
logs = cur;
},
},
});
tf.dispose(tfBatch);
if (logs === undefined) throw new Error("batch didn't gave any logs");

const { loss, acc: accuracy } = logs;
if (loss === undefined || isNaN(loss))
throw new Error("training loss is undefined or NaN");

return {
accuracy,
loss,
memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
};
const {xs, ys} = this.#batchToTF(batch);
const logs = await this.model.trainOnBatch(xs, ys);
tf.dispose([xs, ys])
return this.getBatchLogs(logs)
}

async #evaluate(
Expand Down
136 changes: 43 additions & 93 deletions discojs/src/models/gpt/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import * as tf from '@tensorflow/tfjs'
import type { GPTConfig } from './config.js'
import { getModelSizes, DefaultGPTConfig } from './config.js'
import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js'
import evaluate from './evaluate.js'
import { GPTArchitecture } from './layers.js'

const debug = createDebug("discojs:models:gpt:model");
Expand Down Expand Up @@ -55,101 +54,52 @@ export class GPTModel extends tf.LayersModel {
: tf.train.adam(this.config.lr)
}

override async fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History> {
const callbacks = trainingArgs.callbacks as tf.CustomCallbackArgs
const evalDataset = trainingArgs.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>
await callbacks.onTrainBegin?.()
override async trainOnBatch(x: tf.Tensor, y: tf.Tensor): Promise<number | number[]> {
let weightUpdateTime = performance.now()

for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) {
let accuracyFraction: [number, number] = [0, 0];
let averageLoss = 0
let iteration = 1
const iterator = await dataset.iterator()
let next = await iterator.next()
let preprocessingTime = performance.now()
await Promise.all([x.data(), y.data()])
preprocessingTime = performance.now() - preprocessingTime

while (next.done !== true && iteration <= this.config.maxIter) {
let weightUpdateTime = performance.now()
await callbacks.onEpochBegin?.(epoch)
const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D }
let logitsTensor: tf.Tensor<tf.Rank>;
const lossTensor = tf.tidy(() => {
const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
const logits = this.apply(x)
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
logitsTensor = tf.keep(logits)
return tf.losses.softmaxCrossEntropy(y, logits)
})
const gradsClipped = clipByGlobalNormObj(grads, 1)
this.optimizer.applyGradients(gradsClipped)
return lossTensor
})

let preprocessingTime = performance.now()
await Promise.all([xs.data(), ys.data()])
preprocessingTime = performance.now() - preprocessingTime

// TODO include as a tensor inside the model
const accTensor = tf.tidy(() => {
const logits = this.apply(xs)
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
return tf.metrics.categoricalAccuracy(ys, logits)
})
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
const accSumTensor = accTensor.sum()
const accSum = await accSumTensor.array()
tf.dispose(accSumTensor)
if (typeof accSum !== 'number')
throw new Error('got multiple accuracy sum')
accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize];
tf.dispose([accTensor])
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
const accTensor = tf.metrics.categoricalAccuracy(y, logitsTensor)
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
const accSumTensor = accTensor.sum()
const accSum = await accSumTensor.array()
if (typeof accSum !== 'number')
throw new Error('got multiple accuracy sum')
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
tf.dispose([accTensor, accSumTensor, logitsTensor])

const loss = await lossTensor.array()
weightUpdateTime = performance.now() - weightUpdateTime

const lossTensor = tf.tidy(() => {
const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
const logits = this.apply(xs)
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
return tf.losses.softmaxCrossEntropy(ys, logits)
})
const gradsClipped = clipByGlobalNormObj(grads, 1)
this.optimizer.applyGradients(gradsClipped)
return lossTensor
})

const loss = await lossTensor.array()
averageLoss += loss
weightUpdateTime = performance.now() - weightUpdateTime

tf.dispose([xs, ys, lossTensor])

if (
evalDataset !== undefined &&
this.config.evaluateEvery !== undefined &&
iteration % this.config.evaluateEvery == 0
){
const iterationLogs = await evaluate(this, evalDataset, this.config.maxEvalBatches)
debug('evaluation metrics: %O', iterationLogs);
}
const memory = tf.memory().numBytes / 1024 / 1024 / 1024
debug("training metrics: %O", {
epoch,
iteration,
loss,
memory,
allocated: tf.memory().numTensors,
preprocessingTime,
weightUpdateTime,
});
iteration++
next = await iterator.next()
}
// Memory leak: If we reached the last iteration rather than the end of the dataset, cleanup the tensors
if (next.done !== true && iteration > this.config.maxIter) {
const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D }
tf.dispose([xs, ys])
}
let logs: tf.Logs = {
'loss': averageLoss / (iteration - 1), // -1 because iteration got incremented at the end of the loop
'acc': accuracyFraction[0] / accuracyFraction[1],
}
if (evalDataset !== undefined) {
logs = { ...logs, ...await evaluate(this, evalDataset, this.config.maxEvalBatches) }
}
await callbacks.onEpochEnd?.(epoch, logs)
}
await callbacks.onTrainEnd?.()
return new tf.History()
tf.dispose([x, y, lossTensor])

const memory = tf.memory().numBytes / 1024 / 1024 / 1024
debug("training metrics: %O", {
loss,
memory,
allocated: tf.memory().numTensors,
preprocessingTime,
weightUpdateTime,
});
return [loss, accSum / accSize]
}
}
26 changes: 26 additions & 0 deletions discojs/src/models/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import type {
WeightsContainer,
} from "../index.js";

import * as tf from "@tensorflow/tfjs";

import type { BatchLogs, EpochLogs } from "./logs.js";

/**
Expand All @@ -15,12 +17,16 @@ import type { BatchLogs, EpochLogs } from "./logs.js";
**/
// TODO make it typesafe: same shape of data/input/weights
export abstract class Model<D extends DataType> implements Disposable {
protected prevRoundWeights: WeightsContainer | undefined;
// TODO don't allow external access but upgrade train to return weights on every epoch
/** Return training state */
abstract get weights(): WeightsContainer;
/** Set training state */
abstract set weights(ws: WeightsContainer);

set previousRoundWeights(ws: WeightsContainer | undefined) {
this.prevRoundWeights = ws
}
/**
* Improve predictor
*
Expand All @@ -39,6 +45,26 @@ export abstract class Model<D extends DataType> implements Disposable {
batch: Batched<DataFormat.ModelEncoded[D][0]>,
): Promise<Batched<DataFormat.ModelEncoded[D][1]>>;

protected getBatchLogs(
logs: number | number[],
): BatchLogs {
if (!Array.isArray(logs) || logs.length != 2)
throw new Error("training output has unexpected shape")

const [loss, accuracy] = logs

if (
typeof loss !== "number" || isNaN(loss) ||
typeof accuracy !== "number" || isNaN(accuracy)
)
throw new Error("training loss or accuracy is undefined or NaN");

return {
accuracy,
loss,
memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
};
}
/**
* This method is automatically called to cleanup the memory occupied by the model
* when leaving the definition scope if the instance has been defined with the `using` keyword.
Expand Down
Loading
Loading