@@ -176,11 +176,26 @@ public struct ModelSupport: Codable, Equatable {
176
176
private enum CodingKeys : String , CodingKey {
177
177
case `default`, supported
178
178
}
179
+
180
+ public init (
181
+ default: String ,
182
+ supported: [ String ] ,
183
+ disabled: [ String ] = [ ]
184
+ ) {
185
+ self . default = `default`
186
+ self . supported = supported
187
+ self . disabled = disabled
188
+ }
179
189
}
180
190
181
191
public struct DeviceSupport : Codable {
182
192
public let identifiers : [ String ]
183
193
public var models : ModelSupport
194
+
195
+ public init ( identifiers: [ String ] , models: ModelSupport ) {
196
+ self . identifiers = identifiers
197
+ self . models = models
198
+ }
184
199
}
185
200
186
201
public struct ModelSupportConfig : Codable {
@@ -280,6 +295,11 @@ public struct ModelSupportConfig: Codable {
280
295
public struct AudioChunk {
281
296
public var seekOffsetIndex : Int
282
297
public var audioSamples : [ Float ]
298
+
299
+ public init ( seekOffsetIndex: Int , audioSamples: [ Float ] ) {
300
+ self . seekOffsetIndex = seekOffsetIndex
301
+ self . audioSamples = audioSamples
302
+ }
283
303
}
284
304
285
305
// MARK: - Decoding
@@ -351,7 +371,12 @@ public struct DecodingCache {
351
371
public var keyCache : MLMultiArray ?
352
372
public var valueCache : MLMultiArray ?
353
373
public var alignmentWeights : MLMultiArray ?
354
- public init ( keyCache: MLMultiArray ? = nil , valueCache: MLMultiArray ? = nil , alignmentWeights: MLMultiArray ? = nil ) {
374
+
375
+ public init (
376
+ keyCache: MLMultiArray ? = nil ,
377
+ valueCache: MLMultiArray ? = nil ,
378
+ alignmentWeights: MLMultiArray ? = nil
379
+ ) {
355
380
self . keyCache = keyCache
356
381
self . valueCache = valueCache
357
382
self . alignmentWeights = alignmentWeights
@@ -432,7 +457,20 @@ public struct DecodingResult {
432
457
fallback: nil )
433
458
}
434
459
435
- public init ( language: String , languageProbs: [ String : Float ] , tokens: [ Int ] , tokenLogProbs: [ [ Int : Float ] ] , text: String , avgLogProb: Float , noSpeechProb: Float , temperature: Float , compressionRatio: Float , cache: DecodingCache ? = nil , timings: TranscriptionTimings ? = nil , fallback: DecodingFallback ? = nil ) {
460
+ public init (
461
+ language: String ,
462
+ languageProbs: [ String : Float ] ,
463
+ tokens: [ Int ] ,
464
+ tokenLogProbs: [ [ Int : Float ] ] ,
465
+ text: String ,
466
+ avgLogProb: Float ,
467
+ noSpeechProb: Float ,
468
+ temperature: Float ,
469
+ compressionRatio: Float ,
470
+ cache: DecodingCache ? = nil ,
471
+ timings: TranscriptionTimings ? = nil ,
472
+ fallback: DecodingFallback ? = nil
473
+ ) {
436
474
self . language = language
437
475
self . languageProbs = languageProbs
438
476
self . tokens = tokens
@@ -510,6 +548,20 @@ public struct TranscriptionResult: Codable {
510
548
public var timings : TranscriptionTimings
511
549
public var seekTime : Float ?
512
550
551
+ public init (
552
+ text: String ,
553
+ segments: [ TranscriptionSegment ] ,
554
+ language: String ,
555
+ timings: TranscriptionTimings ,
556
+ seekTime: Float ? = nil
557
+ ) {
558
+ self . text = text
559
+ self . segments = segments
560
+ self . language = language
561
+ self . timings = timings
562
+ self . seekTime = seekTime
563
+ }
564
+
513
565
public func logSegments( ) {
514
566
for (i, segment) in segments. enumerated ( ) {
515
567
let start = segment. start
@@ -593,18 +645,51 @@ public extension TranscriptionResult {
593
645
}
594
646
595
647
public struct TranscriptionSegment : Hashable , Codable {
596
- public var id : Int = 0
597
- public var seek : Int = 0
598
- public var start : Float = 0.0
599
- public var end : Float = 0.0
600
- public var text : String = " "
601
- public var tokens : [ Int ] = [ ]
602
- public var tokenLogProbs : [ [ Int : Float ] ] = [ [ : ] ]
603
- public var temperature : Float = 1.0
604
- public var avgLogprob : Float = 0.0
605
- public var compressionRatio : Float = 1.0
606
- public var noSpeechProb : Float = 0.0
607
- public var words : [ WordTiming ] ? = nil
648
+ public var id : Int
649
+ public var seek : Int
650
+ public var start : Float
651
+ public var end : Float
652
+ public var text : String
653
+ public var tokens : [ Int ]
654
+ public var tokenLogProbs : [ [ Int : Float ] ]
655
+ public var temperature : Float
656
+ public var avgLogprob : Float
657
+ public var compressionRatio : Float
658
+ public var noSpeechProb : Float
659
+ public var words : [ WordTiming ] ?
660
+
661
+ /// Computed property for the duration of the segment
662
+ public var duration : Float {
663
+ return end - start
664
+ }
665
+
666
+ public init (
667
+ id: Int = 0 ,
668
+ seek: Int = 0 ,
669
+ start: Float = 0.0 ,
670
+ end: Float = 0.0 ,
671
+ text: String = " " ,
672
+ tokens: [ Int ] = [ ] ,
673
+ tokenLogProbs: [ [ Int : Float ] ] = [ [ : ] ] ,
674
+ temperature: Float = 1.0 ,
675
+ avgLogprob: Float = 0.0 ,
676
+ compressionRatio: Float = 1.0 ,
677
+ noSpeechProb: Float = 0.0 ,
678
+ words: [ WordTiming ] ? = nil
679
+ ) {
680
+ self . id = id
681
+ self . seek = seek
682
+ self . start = start
683
+ self . end = end
684
+ self . text = text
685
+ self . tokens = tokens
686
+ self . tokenLogProbs = tokenLogProbs
687
+ self . temperature = temperature
688
+ self . avgLogprob = avgLogprob
689
+ self . compressionRatio = compressionRatio
690
+ self . noSpeechProb = noSpeechProb
691
+ self . words = words
692
+ }
608
693
}
609
694
610
695
public struct WordTiming : Hashable , Codable {
@@ -613,6 +698,19 @@ public struct WordTiming: Hashable, Codable {
613
698
public var start : Float
614
699
public var end : Float
615
700
public var probability : Float
701
+
702
+ /// Computed property for the duration of the word
703
+ public var duration : Float {
704
+ return end - start
705
+ }
706
+
707
+ public init ( word: String , tokens: [ Int ] , start: Float , end: Float , probability: Float ) {
708
+ self . word = word
709
+ self . tokens = tokens
710
+ self . start = start
711
+ self . end = end
712
+ self . probability = probability
713
+ }
616
714
}
617
715
618
716
public struct TranscriptionProgress {
@@ -1198,17 +1296,40 @@ public struct SpecialTokens {
1198
1296
}
1199
1297
}
1200
1298
1201
- public protocol WhisperTokenizer : Tokenizer {
1299
+ public protocol WhisperTokenizer {
1300
+ /// swift-transformers pass through
1301
+ func encode( text: String ) -> [ Int ]
1302
+ func decode( tokens: [ Int ] ) -> String
1303
+ func convertTokenToId( _ token: String ) -> Int ?
1304
+ func convertIdToToken( _ id: Int ) -> String ?
1305
+
1306
+ /// WhisperKit specific
1202
1307
var specialTokens : SpecialTokens { get }
1203
1308
var allLanguageTokens : Set < Int > { get }
1204
1309
1205
1310
func splitToWordTokens( tokenIds: [ Int ] ) -> ( words: [ String ] , wordTokens: [ [ Int ] ] )
1206
1311
}
1207
1312
1208
- struct WhisperTokenizerWrapper : WhisperTokenizer {
1313
+ open class WhisperTokenizerWrapper : WhisperTokenizer {
1209
1314
let tokenizer : any Tokenizer
1210
- let specialTokens : SpecialTokens
1211
- let allLanguageTokens : Set < Int >
1315
+ public let specialTokens : SpecialTokens
1316
+ public let allLanguageTokens : Set < Int >
1317
+
1318
+ public func encode( text: String ) -> [ Int ] {
1319
+ tokenizer. encode ( text: text)
1320
+ }
1321
+
1322
+ public func decode( tokens: [ Int ] ) -> String {
1323
+ tokenizer. decode ( tokens: tokens)
1324
+ }
1325
+
1326
+ public func convertTokenToId( _ token: String ) -> Int ? {
1327
+ tokenizer. convertTokenToId ( token)
1328
+ }
1329
+
1330
+ public func convertIdToToken( _ id: Int ) -> String ? {
1331
+ tokenizer. convertIdToToken ( id)
1332
+ }
1212
1333
1213
1334
init ( tokenizer: any Tokenizer ) {
1214
1335
let specialTokens = SpecialTokens (
@@ -1300,7 +1421,7 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
1300
1421
/// Decodes token ids into individual words and per-word subtokens
1301
1422
/// - Parameter tokenIds: Array of tokens to decode and then split
1302
1423
/// - Returns: Tuple containing and array of the split words and all tokens for each word
1303
- func splitToWordTokens( tokenIds: [ Int ] ) -> ( words: [ String ] , wordTokens: [ [ Int ] ] ) {
1424
+ public func splitToWordTokens( tokenIds: [ Int ] ) -> ( words: [ String ] , wordTokens: [ [ Int ] ] ) {
1304
1425
let decodedWords = tokenizer. decode ( tokens: tokenIds. filter { $0 < specialTokens. specialTokenBegin } )
1305
1426
1306
1427
// Detect language of input text
@@ -1316,52 +1437,6 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
1316
1437
}
1317
1438
}
1318
1439
1319
- extension WhisperTokenizerWrapper : Tokenizer {
1320
- func tokenize( text: String ) -> [ String ] {
1321
- tokenizer. tokenize ( text: text)
1322
- }
1323
-
1324
- func encode( text: String ) -> [ Int ] {
1325
- tokenizer. encode ( text: text)
1326
- }
1327
-
1328
- func decode( tokens: [ Int ] ) -> String {
1329
- tokenizer. decode ( tokens: tokens)
1330
- }
1331
-
1332
- func convertTokenToId( _ token: String ) -> Int ? {
1333
- tokenizer. convertTokenToId ( token)
1334
- }
1335
-
1336
- func convertIdToToken( _ id: Int ) -> String ? {
1337
- tokenizer. convertIdToToken ( id)
1338
- }
1339
-
1340
- var bosToken : String ? {
1341
- tokenizer. bosToken
1342
- }
1343
-
1344
- var bosTokenId : Int ? {
1345
- tokenizer. bosTokenId
1346
- }
1347
-
1348
- var eosToken : String ? {
1349
- tokenizer. eosToken
1350
- }
1351
-
1352
- var eosTokenId : Int ? {
1353
- tokenizer. eosTokenId
1354
- }
1355
-
1356
- var unknownToken : String ? {
1357
- tokenizer. unknownToken
1358
- }
1359
-
1360
- var unknownTokenId : Int ? {
1361
- tokenizer. unknownTokenId
1362
- }
1363
- }
1364
-
1365
1440
extension WhisperTokenizerWrapper {
1366
1441
/// Default values for each token, using base vocab
1367
1442
static var defaultWhitespaceToken : Int { 220 }
@@ -1512,6 +1587,9 @@ public enum Constants {
1512
1587
1513
1588
public static let defaultWindowSamples : Int = 480_000 // 30s of audio at 16khz sample rate default for Whisper models
1514
1589
1590
+ public static let defaultPrependPunctuations : String = " \" '“¡¿([{- "
1591
+ public static let defaultAppendPunctuations : String = " \" '.。,,!!??::”)]}、 "
1592
+
1515
1593
public static let fallbackModelSupportConfig : ModelSupportConfig = {
1516
1594
var config = ModelSupportConfig (
1517
1595
repoName: " whisperkit-coreml-fallback " ,
0 commit comments