1
+ import createDebug from "debug" ;
1
2
import { List , Map , Range } from "immutable" ;
2
3
import * as tf from '@tensorflow/tfjs'
3
4
@@ -13,6 +14,8 @@ import { BatchLogs } from './index.js'
13
14
import { Model } from './index.js'
14
15
import { EpochLogs } from './logs.js'
15
16
17
+ const debug = createDebug ( "discojs:models:tfjs" ) ;
18
+
16
19
type Serialized < D extends DataType > = [ D , tf . io . ModelArtifacts ] ;
17
20
18
21
/** TensorFlow JavaScript model with standard training */
@@ -63,11 +66,71 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
63
66
batch : Batched < DataFormat . ModelEncoded [ D ] > ,
64
67
) : Promise < BatchLogs > {
65
68
const { xs, ys } = this . #batchToTF( batch ) ;
66
- const logs = await this . model . trainOnBatch ( xs , ys ) ;
69
+ const logs = await this . trainFedProx ( xs , ys ) ;
70
+ // const logs = await this.model.trainOnBatch(xs, ys);
67
71
tf . dispose ( [ xs , ys ] )
68
72
return this . getBatchLogs ( logs )
69
73
}
70
74
75
+ async trainFedProx (
76
+ xs : tf . Tensor , ys : tf . Tensor ) : Promise < [ number , number ] > {
77
+ // let logitsTensor: tf.Tensor<tf.Rank>;
78
+ debug ( this . model . loss , this . model . losses , this . model . lossFunctions )
79
+ const lossFunction : ( ) => tf . Scalar = ( ) => {
80
+ this . model . apply ( xs )
81
+ const logits = this . model . apply ( xs )
82
+ if ( Array . isArray ( logits ) )
83
+ throw new Error ( 'model outputs too many tensor' )
84
+ if ( logits instanceof tf . SymbolicTensor )
85
+ throw new Error ( 'model outputs symbolic tensor' )
86
+ // logitsTensor = tf.keep(logits)
87
+ // return tf.losses.softmaxCrossEntropy(ys, logits)
88
+ let y : tf . Tensor ;
89
+ y = tf . clipByValue ( logits , 0.00001 , 1 - 0.00001 ) ;
90
+ y = tf . log ( tf . div ( y , tf . sub ( 1 , y ) ) ) ;
91
+ return tf . losses . sigmoidCrossEntropy ( ys , y ) ;
92
+ // return tf.losses.sigmoidCrossEntropy(ys, logits)
93
+ }
94
+ const lossTensor = this . model . optimizer . minimize ( lossFunction , true )
95
+ if ( lossTensor === null ) throw new Error ( "loss should not be null" )
96
+ // const lossTensor = tf.tidy(() => {
97
+ // const { grads, value: lossTensor } = this.model.optimizer.computeGradients(() => {
98
+ // const logits = this.model.apply(xs)
99
+ // if (Array.isArray(logits))
100
+ // throw new Error('model outputs too many tensor')
101
+ // if (logits instanceof tf.SymbolicTensor)
102
+ // throw new Error('model outputs symbolic tensor')
103
+ // logitsTensor = tf.keep(logits)
104
+ // // return tf.losses.softmaxCrossEntropy(ys, logits)
105
+ // return this.model.calculateLosses(ys, logits)[0]
106
+ // })
107
+ // this.model.optimizer.applyGradients(grads)
108
+ // return lossTensor
109
+ // })
110
+
111
+ // // @ts -expect-error Variable 'logitsTensor' is used before being assigned
112
+ // const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
113
+ // const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
114
+ // const accSumTensor = accTensor.sum()
115
+ // const accSum = await accSumTensor.array()
116
+ // if (typeof accSum !== 'number')
117
+ // throw new Error('got multiple accuracy sum')
118
+ // // @ts -expect-error Variable 'logitsTensor' is used before being assigned
119
+ // tf.dispose([accTensor, accSumTensor, logitsTensor])
120
+
121
+ const loss = await lossTensor . array ( )
122
+ tf . dispose ( [ xs , ys , lossTensor ] )
123
+
124
+ // const memory = tf.memory().numBytes / 1024 / 1024 / 1024
125
+ // debug("training metrics: %O", {
126
+ // loss,
127
+ // memory,
128
+ // allocated: tf.memory().numTensors,
129
+ // });
130
+ return [ loss , 0 ]
131
+ // return [loss, accSum / accSize]
132
+ }
133
+
71
134
async #evaluate(
72
135
dataset : Dataset < Batched < DataFormat . ModelEncoded [ D ] > > ,
73
136
) : Promise < Record < "accuracy" | "loss" , number > > {
@@ -160,7 +223,10 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
160
223
return new this (
161
224
datatype ,
162
225
await tf . loadLayersModel ( {
163
- load : ( ) => Promise . resolve ( artifacts ) ,
226
+ load : ( ) => {
227
+ console . log ( "deserialize called" )
228
+ return Promise . resolve ( artifacts )
229
+ } ,
164
230
} ) ,
165
231
) ;
166
232
}
@@ -187,7 +253,7 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
187
253
return [ this . datatype , await ret ]
188
254
}
189
255
190
- [ Symbol . dispose ] ( ) : void {
256
+ [ Symbol . dispose ] ( ) : void {
191
257
this . model . dispose ( )
192
258
}
193
259
0 commit comments