Skip to content

Commit cd16844

Browse files
authored
Fix modelSupport where prefix overlaps with different hardware chips (#326)
Before this change, the code was matching first prefix which could lead to selecting incorrect configuration. For example, iPad 13,1 (A14) is prefix for iPad13,10, iPad13,11, iPad13,16, and iPad13,17 (M2). Fixing this issue by using longest prefix match. Tested with a unit test that covers all iPad / iPhone devices.
1 parent 11a1fab commit cd16844

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

Sources/WhisperKit/Core/Models.swift

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,25 @@ public struct ModelSupportConfig: Codable {
254254

255255
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
256256
public func modelSupport(for deviceIdentifier: String = WhisperKit.deviceName()) -> ModelSupport {
257+
// Find the support with the longest matching identifier prefix
258+
// i.e. `iPad13,16` should match exact `iPad13,16` instead of first prefix like `iPad13,1`
259+
var bestMatch: (support: DeviceSupport, prefixLength: Int)? = nil
257260
for support in deviceSupports {
258-
if support.identifiers.contains(where: { deviceIdentifier.hasPrefix($0) }) {
259-
return support.models
261+
for identifier in support.identifiers {
262+
if deviceIdentifier.hasPrefix(identifier) {
263+
let matchLength = identifier.count
264+
if bestMatch == nil || matchLength > bestMatch!.prefixLength {
265+
bestMatch = (support, matchLength)
266+
}
267+
}
260268
}
261269
}
262270

271+
if let match = bestMatch {
272+
Logging.debug("Matched \(deviceIdentifier) to devices: \(match.support.identifiers)")
273+
return match.support.models
274+
}
275+
263276
Logging.info("No device support found for \(deviceIdentifier), using default")
264277
return defaultSupport.models
265278
}

Tests/WhisperKitTests/UnitTests.swift

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,78 @@ final class UnitTests: XCTestCase {
218218
}
219219
}
220220

221+
func testModelSupportForiOSDevices() throws {
222+
let defaultDevicesList = [
223+
// --- iPads ---
224+
"iPad2,1", "iPad2,2", "iPad2,3", "iPad2,4", "iPad2,5", "iPad2,6", "iPad2,7",
225+
"iPad3,1", "iPad3,2", "iPad3,3", "iPad3,4", "iPad3,5", "iPad3,6",
226+
"iPad4,1", "iPad4,2", "iPad4,3", "iPad4,4", "iPad4,5", "iPad4,6", "iPad4,7", "iPad4,8", "iPad4,9",
227+
"iPad5,1", "iPad5,2", "iPad5,3", "iPad5,4", "iPad6,11", "iPad6,12", "iPad6,3", "iPad6,4", "iPad6,7", "iPad6,8",
228+
"iPad7,1", "iPad7,2", "iPad7,3", "iPad7,4", "iPad7,5", "iPad7,6", "iPad7,11", "iPad7,12",
229+
"iPad8,1", "iPad8,2", "iPad8,3", "iPad8,4", "iPad8,5", "iPad8,6", "iPad8,7", "iPad8,8", "iPad8,9", "iPad8,10", "iPad8,11", "iPad8,12",
230+
"iPad11,1", "iPad11,2", "iPad11,3", "iPad11,4", "iPad11,6", "iPad11,7",
231+
"iPad12,1", "iPad12,2", // (should be A13 config, need new config)
232+
233+
// -- iPhones ---
234+
"iPhone10,1", "iPhone10,2", "iPhone10,3", "iPhone10,4", "iPhone10,5", "iPhone10,6",
235+
"iPhone4,1",
236+
"iPhone5,1", "iPhone5,2", "iPhone5,3", "iPhone5,4",
237+
"iPhone6,1", "iPhone6,2",
238+
"iPhone7,1", "iPhone7,2",
239+
"iPhone8,1", "iPhone8,2", "iPhone8,4",
240+
"iPhone9,1", "iPhone9,2", "iPhone9,3", "iPhone9,4"
241+
]
242+
243+
let deviceMap = [
244+
"A12": ["iPhone11,2", "iPhone11,4", "iPhone11,6", "iPhone11,8"],
245+
"A13": ["iPhone12,1", "iPhone12,3", "iPhone12,5", "iPhone12,8"],
246+
"A14": ["iPad13,1", "iPad13,18", "iPad13,19", "iPad13,2", "iPhone13,1", "iPhone13,2", "iPhone13,3", "iPhone13,4"],
247+
"A15": ["iPad14,1", "iPad14,2", "iPhone14,2", "iPhone14,3", "iPhone14,4", "iPhone14,5", "iPhone14,6", "iPhone14,7", "iPhone14,8"],
248+
"A16": ["iPhone15,2", "iPhone15,3", "iPhone15,4", "iPhone15,5"],
249+
"A17": ["iPhone16,1", "iPhone16,2"],
250+
"A18": ["iPhone17,1", "iPhone17,2", "iPhone17,3", "iPhone17,4"],
251+
"M1": ["iPad13,10", "iPad13,11", "iPad13,16", "iPad13,17", "iPad13,4", "iPad13,5", "iPad13,6", "iPad13,7", "iPad13,8", "iPad13,9"],
252+
"M2": ["iPad14,3", "iPad14,4", "iPad14,5", "iPad14,6", "iPad14,10", "iPad14,11", "iPad14,8", "iPad14,9"],
253+
"A17 Pro": ["iPad16,1", "iPad16,2"],
254+
"M4": ["iPad16,3", "iPad16,4", "iPad16,5", "iPad16,6"]
255+
]
256+
257+
let configFilePath = try XCTUnwrap(
258+
Bundle.current(for: self).path(forResource: "config-v03", ofType: "json"),
259+
"Config file not found"
260+
)
261+
262+
let jsonData = try Data(contentsOf: URL(fileURLWithPath: configFilePath))
263+
let decoder = JSONDecoder()
264+
let loadedConfig = try decoder.decode(ModelSupportConfig.self, from: jsonData)
265+
266+
func supportedModels(chip: String) -> ModelSupport {
267+
var supportedModels = loadedConfig.defaultSupport.models
268+
for deviceSupport in loadedConfig.deviceSupports {
269+
if let chips = deviceSupport.chips, chips.contains(chip) {
270+
supportedModels = deviceSupport.models
271+
break
272+
}
273+
}
274+
return supportedModels
275+
}
276+
277+
for device in deviceMap {
278+
let supportedModelsForChip = supportedModels(chip: device.key)
279+
for deviceIdentifier in device.value {
280+
let modelSupport = modelSupport(for: deviceIdentifier, from: loadedConfig)
281+
XCTAssertEqual(modelSupport, supportedModelsForChip, "Device: \(deviceIdentifier) (\(device.key))")
282+
}
283+
}
284+
285+
// Test devices that should use default configuration
286+
let defaultModelSupport = loadedConfig.defaultSupport.models
287+
for deviceIdentifier in defaultDevicesList {
288+
let modelSupport = modelSupport(for: deviceIdentifier, from: loadedConfig)
289+
XCTAssertEqual(modelSupport, defaultModelSupport, "Device: \(deviceIdentifier)")
290+
}
291+
}
292+
221293
func testModelSupportConfigFetch() async throws {
222294
// Make sure remote repo config loads successfully from HF
223295
let modelRepoConfig = await WhisperKit.fetchModelSupportConfig()

0 commit comments

Comments
 (0)