Skip to content

Commit ca49596

Browse files
Improve SegmentSeeker word alignment (#305)
1 parent df5b1a2 commit ca49596

File tree

11 files changed

+551
-120
lines changed

11 files changed

+551
-120
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,6 @@ upload-benchmark-results:
131131
@fastlane upload_results
132132

133133
clean-package-caches:
134-
@trash ~/Library/Developer/Xcode/DerivedData/WhisperKit*
134+
@trash ~/Library/Developer/Xcode/DerivedData/WhisperKit* || true
135135
@swift package purge-cache
136136
@swift package reset

Package.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ let package = Package(
2020
),
2121
],
2222
dependencies: [
23-
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"),
24-
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
23+
.package(url: "https://github.com/huggingface/swift-transformers.git", .upToNextMinor(from: "0.1.8")),
24+
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.3.0"),
2525
],
2626
targets: [
2727
.target(

Sources/WhisperKit/Core/Audio/AudioProcessor.swift

+5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ public typealias DeviceID = String
1616
public struct AudioDevice: Identifiable, Hashable {
1717
public let id: DeviceID
1818
public let name: String
19+
20+
public init(id: DeviceID, name: String) {
21+
self.id = id
22+
self.name = name
23+
}
1924
}
2025

2126
public protocol AudioProcessing {

Sources/WhisperKit/Core/FeatureExtractor.swift

+1-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
3737
guard inputDescription.type == .multiArray else { return nil }
3838
guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
3939
let shape = shapeConstraint.shape.map { $0.intValue }
40-
return shape[0] // The audio input is a 1D array
40+
return shape[0] // The audio input is a 1D array
4141
}
4242

4343
public func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? {
@@ -54,5 +54,4 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
5454
let output = MelSpectrogramOutput(features: outputFeatures)
5555
return output.melspectrogramFeatures
5656
}
57-
5857
}

Sources/WhisperKit/Core/Models.swift

+143-65
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,26 @@ public struct ModelSupport: Codable, Equatable {
176176
private enum CodingKeys: String, CodingKey {
177177
case `default`, supported
178178
}
179+
180+
public init(
181+
default: String,
182+
supported: [String],
183+
disabled: [String] = []
184+
) {
185+
self.default = `default`
186+
self.supported = supported
187+
self.disabled = disabled
188+
}
179189
}
180190

181191
public struct DeviceSupport: Codable {
182192
public let identifiers: [String]
183193
public var models: ModelSupport
194+
195+
public init(identifiers: [String], models: ModelSupport) {
196+
self.identifiers = identifiers
197+
self.models = models
198+
}
184199
}
185200

186201
public struct ModelSupportConfig: Codable {
@@ -280,6 +295,11 @@ public struct ModelSupportConfig: Codable {
280295
public struct AudioChunk {
281296
public var seekOffsetIndex: Int
282297
public var audioSamples: [Float]
298+
299+
public init(seekOffsetIndex: Int, audioSamples: [Float]) {
300+
self.seekOffsetIndex = seekOffsetIndex
301+
self.audioSamples = audioSamples
302+
}
283303
}
284304

285305
// MARK: - Decoding
@@ -351,7 +371,12 @@ public struct DecodingCache {
351371
public var keyCache: MLMultiArray?
352372
public var valueCache: MLMultiArray?
353373
public var alignmentWeights: MLMultiArray?
354-
public init(keyCache: MLMultiArray? = nil, valueCache: MLMultiArray? = nil, alignmentWeights: MLMultiArray? = nil) {
374+
375+
public init(
376+
keyCache: MLMultiArray? = nil,
377+
valueCache: MLMultiArray? = nil,
378+
alignmentWeights: MLMultiArray? = nil
379+
) {
355380
self.keyCache = keyCache
356381
self.valueCache = valueCache
357382
self.alignmentWeights = alignmentWeights
@@ -432,7 +457,20 @@ public struct DecodingResult {
432457
fallback: nil)
433458
}
434459

435-
public init(language: String, languageProbs: [String: Float], tokens: [Int], tokenLogProbs: [[Int: Float]], text: String, avgLogProb: Float, noSpeechProb: Float, temperature: Float, compressionRatio: Float, cache: DecodingCache? = nil, timings: TranscriptionTimings? = nil, fallback: DecodingFallback? = nil) {
460+
public init(
461+
language: String,
462+
languageProbs: [String: Float],
463+
tokens: [Int],
464+
tokenLogProbs: [[Int: Float]],
465+
text: String,
466+
avgLogProb: Float,
467+
noSpeechProb: Float,
468+
temperature: Float,
469+
compressionRatio: Float,
470+
cache: DecodingCache? = nil,
471+
timings: TranscriptionTimings? = nil,
472+
fallback: DecodingFallback? = nil
473+
) {
436474
self.language = language
437475
self.languageProbs = languageProbs
438476
self.tokens = tokens
@@ -510,6 +548,20 @@ public struct TranscriptionResult: Codable {
510548
public var timings: TranscriptionTimings
511549
public var seekTime: Float?
512550

551+
public init(
552+
text: String,
553+
segments: [TranscriptionSegment],
554+
language: String,
555+
timings: TranscriptionTimings,
556+
seekTime: Float? = nil
557+
) {
558+
self.text = text
559+
self.segments = segments
560+
self.language = language
561+
self.timings = timings
562+
self.seekTime = seekTime
563+
}
564+
513565
public func logSegments() {
514566
for (i, segment) in segments.enumerated() {
515567
let start = segment.start
@@ -593,18 +645,51 @@ public extension TranscriptionResult {
593645
}
594646

595647
public struct TranscriptionSegment: Hashable, Codable {
596-
public var id: Int = 0
597-
public var seek: Int = 0
598-
public var start: Float = 0.0
599-
public var end: Float = 0.0
600-
public var text: String = ""
601-
public var tokens: [Int] = []
602-
public var tokenLogProbs: [[Int: Float]] = [[:]]
603-
public var temperature: Float = 1.0
604-
public var avgLogprob: Float = 0.0
605-
public var compressionRatio: Float = 1.0
606-
public var noSpeechProb: Float = 0.0
607-
public var words: [WordTiming]? = nil
648+
public var id: Int
649+
public var seek: Int
650+
public var start: Float
651+
public var end: Float
652+
public var text: String
653+
public var tokens: [Int]
654+
public var tokenLogProbs: [[Int: Float]]
655+
public var temperature: Float
656+
public var avgLogprob: Float
657+
public var compressionRatio: Float
658+
public var noSpeechProb: Float
659+
public var words: [WordTiming]?
660+
661+
/// Computed property for the duration of the segment
662+
public var duration: Float {
663+
return end - start
664+
}
665+
666+
public init(
667+
id: Int = 0,
668+
seek: Int = 0,
669+
start: Float = 0.0,
670+
end: Float = 0.0,
671+
text: String = "",
672+
tokens: [Int] = [],
673+
tokenLogProbs: [[Int: Float]] = [[:]],
674+
temperature: Float = 1.0,
675+
avgLogprob: Float = 0.0,
676+
compressionRatio: Float = 1.0,
677+
noSpeechProb: Float = 0.0,
678+
words: [WordTiming]? = nil
679+
) {
680+
self.id = id
681+
self.seek = seek
682+
self.start = start
683+
self.end = end
684+
self.text = text
685+
self.tokens = tokens
686+
self.tokenLogProbs = tokenLogProbs
687+
self.temperature = temperature
688+
self.avgLogprob = avgLogprob
689+
self.compressionRatio = compressionRatio
690+
self.noSpeechProb = noSpeechProb
691+
self.words = words
692+
}
608693
}
609694

610695
public struct WordTiming: Hashable, Codable {
@@ -613,6 +698,19 @@ public struct WordTiming: Hashable, Codable {
613698
public var start: Float
614699
public var end: Float
615700
public var probability: Float
701+
702+
/// Computed property for the duration of the word
703+
public var duration: Float {
704+
return end - start
705+
}
706+
707+
public init(word: String, tokens: [Int], start: Float, end: Float, probability: Float) {
708+
self.word = word
709+
self.tokens = tokens
710+
self.start = start
711+
self.end = end
712+
self.probability = probability
713+
}
616714
}
617715

618716
public struct TranscriptionProgress {
@@ -1198,17 +1296,40 @@ public struct SpecialTokens {
11981296
}
11991297
}
12001298

1201-
public protocol WhisperTokenizer: Tokenizer {
1299+
public protocol WhisperTokenizer {
1300+
/// swift-transformers pass through
1301+
func encode(text: String) -> [Int]
1302+
func decode(tokens: [Int]) -> String
1303+
func convertTokenToId(_ token: String) -> Int?
1304+
func convertIdToToken(_ id: Int) -> String?
1305+
1306+
/// WhisperKit specific
12021307
var specialTokens: SpecialTokens { get }
12031308
var allLanguageTokens: Set<Int> { get }
12041309

12051310
func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]])
12061311
}
12071312

1208-
struct WhisperTokenizerWrapper: WhisperTokenizer {
1313+
open class WhisperTokenizerWrapper: WhisperTokenizer {
12091314
let tokenizer: any Tokenizer
1210-
let specialTokens: SpecialTokens
1211-
let allLanguageTokens: Set<Int>
1315+
public let specialTokens: SpecialTokens
1316+
public let allLanguageTokens: Set<Int>
1317+
1318+
public func encode(text: String) -> [Int] {
1319+
tokenizer.encode(text: text)
1320+
}
1321+
1322+
public func decode(tokens: [Int]) -> String {
1323+
tokenizer.decode(tokens: tokens)
1324+
}
1325+
1326+
public func convertTokenToId(_ token: String) -> Int? {
1327+
tokenizer.convertTokenToId(token)
1328+
}
1329+
1330+
public func convertIdToToken(_ id: Int) -> String? {
1331+
tokenizer.convertIdToToken(id)
1332+
}
12121333

12131334
init(tokenizer: any Tokenizer) {
12141335
let specialTokens = SpecialTokens(
@@ -1300,7 +1421,7 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
13001421
/// Decodes token ids into individual words and per-word subtokens
13011422
/// - Parameter tokenIds: Array of tokens to decode and then split
13021423
/// - Returns: Tuple containing and array of the split words and all tokens for each word
1303-
func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) {
1424+
public func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) {
13041425
let decodedWords = tokenizer.decode(tokens: tokenIds.filter { $0 < specialTokens.specialTokenBegin })
13051426

13061427
// Detect language of input text
@@ -1316,52 +1437,6 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
13161437
}
13171438
}
13181439

1319-
extension WhisperTokenizerWrapper: Tokenizer {
1320-
func tokenize(text: String) -> [String] {
1321-
tokenizer.tokenize(text: text)
1322-
}
1323-
1324-
func encode(text: String) -> [Int] {
1325-
tokenizer.encode(text: text)
1326-
}
1327-
1328-
func decode(tokens: [Int]) -> String {
1329-
tokenizer.decode(tokens: tokens)
1330-
}
1331-
1332-
func convertTokenToId(_ token: String) -> Int? {
1333-
tokenizer.convertTokenToId(token)
1334-
}
1335-
1336-
func convertIdToToken(_ id: Int) -> String? {
1337-
tokenizer.convertIdToToken(id)
1338-
}
1339-
1340-
var bosToken: String? {
1341-
tokenizer.bosToken
1342-
}
1343-
1344-
var bosTokenId: Int? {
1345-
tokenizer.bosTokenId
1346-
}
1347-
1348-
var eosToken: String? {
1349-
tokenizer.eosToken
1350-
}
1351-
1352-
var eosTokenId: Int? {
1353-
tokenizer.eosTokenId
1354-
}
1355-
1356-
var unknownToken: String? {
1357-
tokenizer.unknownToken
1358-
}
1359-
1360-
var unknownTokenId: Int? {
1361-
tokenizer.unknownTokenId
1362-
}
1363-
}
1364-
13651440
extension WhisperTokenizerWrapper {
13661441
/// Default values for each token, using base vocab
13671442
static var defaultWhitespaceToken: Int { 220 }
@@ -1512,6 +1587,9 @@ public enum Constants {
15121587

15131588
public static let defaultWindowSamples: Int = 480_000 // 30s of audio at 16khz sample rate default for Whisper models
15141589

1590+
public static let defaultPrependPunctuations: String = "\"'“¡¿([{-"
1591+
public static let defaultAppendPunctuations: String = "\"'.。,,!!??::”)]}、"
1592+
15151593
public static let fallbackModelSupportConfig: ModelSupportConfig = {
15161594
var config = ModelSupportConfig(
15171595
repoName: "whisperkit-coreml-fallback",

0 commit comments

Comments
 (0)