Skip to content

Commit 4acd48d

Browse files
committed
fixup! discojs-core/models: add gpt
1 parent bd54004 commit 4acd48d

File tree

1 file changed

+3
-3
lines changed
  • discojs/discojs-core/src/models/gpt

1 file changed

+3
-3
lines changed

discojs/discojs-core/src/models/gpt/index.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ export class GPT extends Model {
5858
this.model.setWeights(ws.weights)
5959
}
6060

61-
private convertCharDataset (dataset: Dataset): Dataset {
61+
private convertCharDataset (dataset: Dataset): tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> {
6262
const batchSize = 4
6363
const sampleSize = GPT.blockSize + 1
6464
const chunkSize = sampleSize * batchSize * 2
@@ -80,8 +80,8 @@ export class GPT extends Model {
8080

8181
const buffer = await chunk.buffer()
8282

83-
const xs = tf.buffer([batchSize, GPT.blockSize], 'int32')
84-
const ys = tf.buffer([batchSize, GPT.blockSize, GPT.vocabSize], 'int32')
83+
const xs = tf.buffer<tf.Rank.R2, 'int32'>([batchSize, GPT.blockSize], 'int32')
84+
const ys = tf.buffer<tf.Rank.R3, 'int32'>([batchSize, GPT.blockSize, GPT.vocabSize], 'int32')
8585

8686
for (let i = 0; i < batchSize; i++) {
8787
for (let j = 0; j < sampleSize; j++) {

0 commit comments

Comments
 (0)