@@ -73,62 +73,60 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
73
73
}
74
74
75
75
async trainFedProx (
76
- xs : tf . Tensor , ys : tf . Tensor ) : Promise < [ number , number ] > {
77
- // let logitsTensor: tf.Tensor<tf.Rank>;
78
- debug ( this . model . loss , this . model . losses , this . model . lossFunctions )
76
+ xs : tf . Tensor , ys : tf . Tensor ,
77
+ ) : Promise < [ number , number ] > {
78
+ let logitsTensor : tf . Tensor < tf . Rank > ;
79
79
const lossFunction : ( ) => tf . Scalar = ( ) => {
80
+ // Proximal term
81
+ let proximalTerm = tf . tensor ( 0 )
82
+ if ( this . prevRoundWeights !== undefined ) {
83
+ // squared norm
84
+ const norm = new WeightsContainer ( this . model . getWeights ( ) )
85
+ . sub ( this . prevRoundWeights )
86
+ . map ( t => t . square ( ) . sum ( ) )
87
+ . reduce ( ( t , acc ) => tf . add ( t , acc ) ) . asScalar ( )
88
+ const mu = 1
89
+ proximalTerm = tf . mul ( mu / 2 , norm )
90
+ }
91
+
80
92
this . model . apply ( xs )
81
93
const logits = this . model . apply ( xs )
82
- if ( Array . isArray ( logits ) )
83
- throw new Error ( 'model outputs too many tensor' )
84
- if ( logits instanceof tf . SymbolicTensor )
85
- throw new Error ( 'model outputs symbolic tensor' )
86
- // logitsTensor = tf.keep(logits)
87
- // return tf.losses.softmaxCrossEntropy(ys, logits)
88
- let y : tf . Tensor ;
89
- y = tf . clipByValue ( logits , 0.00001 , 1 - 0.00001 ) ;
90
- y = tf . log ( tf . div ( y , tf . sub ( 1 , y ) ) ) ;
91
- return tf . losses . sigmoidCrossEntropy ( ys , y ) ;
92
- // return tf.losses.sigmoidCrossEntropy(ys, logits)
94
+ if ( Array . isArray ( logits ) )
95
+ throw new Error ( 'model outputs too many tensor' )
96
+ if ( logits instanceof tf . SymbolicTensor )
97
+ throw new Error ( 'model outputs symbolic tensor' )
98
+ logitsTensor = tf . keep ( logits )
99
+ // binaryCrossEntropy
100
+ let y : tf . Tensor ;
101
+ y = tf . clipByValue ( logits , 0.00001 , 1 - 0.00001 ) ;
102
+ y = tf . log ( tf . div ( y , tf . sub ( 1 , y ) ) ) ;
103
+ const loss = tf . losses . sigmoidCrossEntropy ( ys , y ) ;
104
+ console . log ( loss . dataSync ( ) , proximalTerm . dataSync ( ) )
105
+ return tf . add ( loss , proximalTerm )
93
106
}
94
107
const lossTensor = this . model . optimizer . minimize ( lossFunction , true )
95
108
if ( lossTensor === null ) throw new Error ( "loss should not be null" )
96
- // const lossTensor = tf.tidy(() => {
97
- // const { grads, value: lossTensor } = this.model.optimizer.computeGradients(() => {
98
- // const logits = this.model.apply(xs)
99
- // if (Array.isArray(logits))
100
- // throw new Error('model outputs too many tensor')
101
- // if (logits instanceof tf.SymbolicTensor)
102
- // throw new Error('model outputs symbolic tensor')
103
- // logitsTensor = tf.keep(logits)
104
- // // return tf.losses.softmaxCrossEntropy(ys, logits)
105
- // return this.model.calculateLosses(ys, logits)[0]
106
- // })
107
- // this.model.optimizer.applyGradients(grads)
108
- // return lossTensor
109
- // })
110
109
111
- // // @ts -expect-error Variable 'logitsTensor' is used before being assigned
112
- // const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
113
- // const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
114
- // const accSumTensor = accTensor.sum()
115
- // const accSum = await accSumTensor.array()
116
- // if (typeof accSum !== 'number')
117
- // throw new Error('got multiple accuracy sum')
118
- // // @ts -expect-error Variable 'logitsTensor' is used before being assigned
119
- // tf.dispose([accTensor, accSumTensor, logitsTensor])
110
+ // @ts -expect-error Variable 'logitsTensor' is used before being assigned
111
+ const accTensor = tf . metrics . categoricalAccuracy ( ys , logitsTensor )
112
+ const accSize = accTensor . shape . reduce ( ( l , r ) => l * r , 1 )
113
+ const accSumTensor = accTensor . sum ( )
114
+ const accSum = await accSumTensor . array ( )
115
+ if ( typeof accSum !== 'number' )
116
+ throw new Error ( 'got multiple accuracy sum' )
117
+ // @ts -expect-error Variable 'logitsTensor' is used before being assigned
118
+ tf . dispose ( [ accTensor , accSumTensor , logitsTensor ] )
120
119
121
120
const loss = await lossTensor . array ( )
122
121
tf . dispose ( [ xs , ys , lossTensor ] )
123
122
124
- // const memory = tf.memory().numBytes / 1024 / 1024 / 1024
125
- // debug("training metrics: %O", {
126
- // loss,
127
- // memory,
128
- // allocated: tf.memory().numTensors,
129
- // });
130
- return [ loss , 0 ]
131
- // return [loss, accSum / accSize]
123
+ const memory = tf . memory ( ) . numBytes / 1024 / 1024 / 1024
124
+ debug ( "training metrics: %O" , {
125
+ loss,
126
+ memory,
127
+ allocated : tf . memory ( ) . numTensors ,
128
+ } ) ;
129
+ return [ loss , accSum / accSize ]
132
130
}
133
131
134
132
async #evaluate(
0 commit comments