Skip to content

Commit 2eff5b0

Browse files
committed
Use compiler flags for mltensor sampling
1 parent 4915574 commit 2eff5b0

File tree

1 file changed

+173
-150
lines changed

1 file changed

+173
-150
lines changed

Sources/WhisperKit/Core/Text/TokenSampler.swift

Lines changed: 173 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -28,183 +28,206 @@ open class GreedyTokenSampler: TokenSampling {
2828
self.decodingOptions = decodingOptions
2929
}
3030

31-
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult {
32-
var nextTokens = tokens
33-
var nextLogprobs = logProbs
34-
var completed = false
35-
if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
36-
// Use MLTensor operations if available for sampling
37-
// Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift
38-
var logitsTensor = MLTensor(MLShapedArray<FloatType>(logits)).cast(to: Float.self)
39-
var nextTokenTensor: MLTensor
40-
var nextLogprobTensor: MLTensor
41-
42-
if temperature != 0.0 {
43-
// Scale logits by temperature if > 0
44-
logitsTensor = logitsTensor / temperature
45-
}
31+
#if swift(>=5.10)
32+
@available(macOS 15, iOS 18, watchOS 11, visionOS 2, *)
33+
private func sampleWithMLTensor(logits: MLMultiArray) -> (token: Int, logprob: Float) {
34+
// Use MLTensor operations if available for sampling
35+
// Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift
36+
var logitsTensor = MLTensor(MLShapedArray<FloatType>(logits)).cast(to: Float.self)
37+
var nextTokenTensor: MLTensor
38+
var nextLogprobTensor: MLTensor
39+
40+
if temperature != 0.0 {
41+
// Scale logits by temperature if > 0
42+
logitsTensor = logitsTensor / temperature
43+
}
4644

47-
// Always softmax once
48-
let softmaxScores = logitsTensor.softmax(alongAxis: -1)
45+
// Always softmax once
46+
let softmaxScores = logitsTensor.softmax(alongAxis: -1)
47+
48+
if temperature != 0.0 {
49+
// top-k multinomial sampling
50+
let (topKProbs, topKIndices) = softmaxScores.topK(decodingOptions.topK)
51+
52+
let rnd = topKProbs.sum() * Float.random(in: 0..<1)
53+
var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1)
54+
accumTopKProbs += (accumTopKProbs .< rnd) * 100.0
55+
let topKIndex = accumTopKProbs.argsort()[..., 0]
56+
57+
nextTokenTensor = topKIndices.gathering(
58+
atIndices: topKIndex,
59+
alongAxis: topKIndices.rank - 1
60+
)
61+
nextLogprobTensor = topKProbs.gathering(
62+
atIndices: topKIndex,
63+
alongAxis: topKIndices.rank - 1
64+
).log()
65+
} else {
66+
nextTokenTensor = logitsTensor.argmax(alongAxis: -1)
67+
nextLogprobTensor = softmaxScores.gathering(atIndices: nextTokenTensor, alongAxis: -1).log()
68+
}
4969

50-
if temperature != 0.0 {
51-
// top-k multinomial sampling
52-
let (topKProbs, topKIndices) = softmaxScores.topK(decodingOptions.topK)
70+
return (
71+
token: nextTokenTensor.asIntArray()[0],
72+
logprob: nextLogprobTensor.asFloatArray()[0]
73+
)
74+
}
75+
#endif
5376

54-
let rnd = topKProbs.sum() * Float.random(in: 0..<1)
55-
var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1)
56-
accumTopKProbs += (accumTopKProbs .< rnd) * 100.0
57-
let topKIndex = accumTopKProbs.argsort()[..., 0]
77+
private func sampleWithBNNS(logits: MLMultiArray) -> (token: Int, logprob: Float) {
78+
// TODO: BNNS operations here are deprecated, replace with vDSP or MLX
79+
var softmaxOutput: BNNSNDArrayDescriptor?
80+
var argmaxOutput: BNNSNDArrayDescriptor?
81+
var softmaxInput: BNNSNDArrayDescriptor?
82+
var softmaxInputNeedsDeallocate = false
5883

59-
nextTokenTensor = topKIndices.gathering(
60-
atIndices: topKIndex,
61-
alongAxis: topKIndices.rank - 1
62-
)
63-
nextLogprobTensor = topKProbs.gathering(
64-
atIndices: topKIndex,
65-
alongAxis: topKIndices.rank - 1
66-
).log()
67-
} else {
68-
nextTokenTensor = logitsTensor.argmax(alongAxis: -1)
69-
nextLogprobTensor = softmaxScores.gathering(atIndices: nextTokenTensor, alongAxis: -1).log()
70-
}
84+
var nextToken: Int?
7185

72-
let nextToken = nextTokenTensor.asIntArray()[0]
73-
let nextLogprob = nextLogprobTensor.asFloatArray()[0]
86+
do {
87+
let logitsRawPointer = UnsafeMutableRawBufferPointer(
88+
start: logits.dataPointer,
89+
count: logits.count * MemoryLayout<FloatType>.stride
90+
)
7491

75-
nextTokens = tokens + [nextToken]
76-
nextLogprobs = logProbs + [nextLogprob]
77-
completed = nextToken == eotToken
92+
let logitsDescriptor = BNNSNDArrayDescriptor(
93+
data: logitsRawPointer,
94+
scalarType: FloatType.self,
95+
shape: .vector(logits.count, stride: 1)
96+
)!
7897

79-
} else {
80-
// TODO: BNNS operations here are deprecated, replace with vDSP or MLX
81-
var softmaxOutput: BNNSNDArrayDescriptor?
82-
var argmaxOutput: BNNSNDArrayDescriptor?
83-
var softmaxInput: BNNSNDArrayDescriptor?
84-
var softmaxInputNeedsDeallocate = false
85-
86-
var nextToken: Int?
87-
88-
do {
89-
let logitsRawPointer = UnsafeMutableRawBufferPointer(
90-
start: logits.dataPointer,
91-
count: logits.count * MemoryLayout<FloatType>.stride
92-
)
98+
softmaxInput = logitsDescriptor
9399

94-
let logitsDescriptor = BNNSNDArrayDescriptor(
95-
data: logitsRawPointer,
100+
// Scale logits by temperature if > 0
101+
if temperature != 0.0 {
102+
let scaledLogits = BNNSNDArrayDescriptor.allocateUninitialized(
96103
scalarType: FloatType.self,
97104
shape: .vector(logits.count, stride: 1)
98-
)!
99-
100-
softmaxInput = logitsDescriptor
101-
102-
// Scale logits by temperature if > 0
103-
if temperature != 0.0 {
104-
let scaledLogits = BNNSNDArrayDescriptor.allocateUninitialized(
105-
scalarType: FloatType.self,
106-
shape: .vector(logits.count, stride: 1)
107-
)
108-
109-
try! BNNS.applyActivation(
110-
activation: BNNS.ActivationFunction.linear(alpha: Float(1 / temperature)),
111-
input: logitsDescriptor,
112-
output: scaledLogits,
113-
batchSize: 1
114-
)
115-
116-
softmaxInput = scaledLogits
117-
softmaxInputNeedsDeallocate = true
118-
}
105+
)
106+
107+
try! BNNS.applyActivation(
108+
activation: BNNS.ActivationFunction.linear(alpha: Float(1 / temperature)),
109+
input: logitsDescriptor,
110+
output: scaledLogits,
111+
batchSize: 1
112+
)
119113

120-
// Always softmax once
121-
softmaxOutput = BNNSNDArrayDescriptor.allocateUninitialized(
114+
softmaxInput = scaledLogits
115+
softmaxInputNeedsDeallocate = true
116+
}
117+
118+
// Always softmax once
119+
softmaxOutput = BNNSNDArrayDescriptor.allocateUninitialized(
120+
scalarType: Float.self,
121+
shape: .vector(logits.count, stride: 1)
122+
)
123+
124+
try BNNS.applyActivation(
125+
activation: BNNS.ActivationFunction.softmax,
126+
input: softmaxInput!,
127+
output: softmaxOutput!,
128+
batchSize: 1
129+
)
130+
131+
if temperature != 0.0 {
132+
// top-k multinomial sampling
133+
let k = decodingOptions.topK
134+
let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(
122135
scalarType: Float.self,
123-
shape: .vector(logits.count, stride: 1)
136+
shape: .vector(k, stride: 1)
137+
)
138+
let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(
139+
scalarType: Int32.self,
140+
shape: .vector(k, stride: 1)
124141
)
125142

126-
try BNNS.applyActivation(
127-
activation: BNNS.ActivationFunction.softmax,
128-
input: softmaxInput!,
129-
output: softmaxOutput!,
143+
try! BNNS.applyTopK(
144+
k: k,
145+
input: softmaxOutput!,
146+
bestValues: bestValues,
147+
bestIndices: bestIndices,
148+
axis: 0,
130149
batchSize: 1
131150
)
132151

133-
if temperature != 0.0 {
134-
// top-k multinomial sampling
135-
let k = decodingOptions.topK
136-
137-
let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float.self, shape: .vector(k, stride: 1))
138-
let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int32.self, shape: .vector(k, stride: 1))
139-
140-
try! BNNS.applyTopK(
141-
k: k,
142-
input: softmaxOutput!,
143-
bestValues: bestValues,
144-
bestIndices: bestIndices,
145-
axis: 0,
146-
batchSize: 1
147-
)
148-
149-
let bestValuesResult = bestValues.makeArray(of: Float.self)!
150-
let bestIndicesResult = bestIndices.makeArray(of: Int32.self)!
151-
152-
bestValues.deallocate()
153-
bestIndices.deallocate()
154-
155-
// multinomial sample from top-k
156-
let sumOfbestIndicesResult = bestValuesResult.reduce(0, +)
157-
let rnd = Float.random(in: 0..<sumOfbestIndicesResult)
158-
var accumulator = Float(0.0)
159-
var chosenIndex = 0
160-
for i in 0..<bestValuesResult.count {
161-
accumulator += bestValuesResult[i]
162-
if rnd < accumulator {
163-
chosenIndex = i
164-
break
165-
}
152+
let bestValuesResult = bestValues.makeArray(of: Float.self)!
153+
let bestIndicesResult = bestIndices.makeArray(of: Int32.self)!
154+
155+
bestValues.deallocate()
156+
bestIndices.deallocate()
157+
158+
// multinomial sample from top-k
159+
let sumOfbestIndicesResult = bestValuesResult.reduce(0, +)
160+
let rnd = Float.random(in: 0..<sumOfbestIndicesResult)
161+
var accumulator = Float(0.0)
162+
var chosenIndex = 0
163+
for i in 0..<bestValuesResult.count {
164+
accumulator += bestValuesResult[i]
165+
if rnd < accumulator {
166+
chosenIndex = i
167+
break
166168
}
169+
}
167170

168-
nextToken = Int(bestIndicesResult[chosenIndex])
169-
} else {
170-
// Argmax sampling
171-
argmaxOutput = BNNSNDArrayDescriptor.allocateUninitialized(
172-
scalarType: Float.self,
173-
shape: .vector(1, stride: 1)
174-
)
171+
nextToken = Int(bestIndicesResult[chosenIndex])
172+
} else {
173+
argmaxOutput = BNNSNDArrayDescriptor.allocateUninitialized(
174+
scalarType: Float.self,
175+
shape: .vector(1, stride: 1)
176+
)
175177

176-
try! BNNS.applyReduction(
177-
BNNS.ReductionFunction.argMax,
178-
input: logitsDescriptor,
179-
output: argmaxOutput!,
180-
weights: nil
181-
)
178+
try! BNNS.applyReduction(
179+
BNNS.ReductionFunction.argMax,
180+
input: logitsDescriptor,
181+
output: argmaxOutput!,
182+
weights: nil
183+
)
182184

183-
let argmaxResult = argmaxOutput!.makeArray(of: Float.self)!
185+
let argmaxResult = argmaxOutput!.makeArray(of: Float.self)!
184186

185-
nextToken = Int(argmaxResult[0])
186-
}
187-
} catch {
188-
Logging.error("Sampling error: \(error)")
187+
nextToken = Int(argmaxResult[0])
189188
}
189+
} catch {
190+
Logging.error("Sampling error: \(error)")
191+
}
190192

191-
// Log of softmax probability of chosen token
192-
let softmaxResult = softmaxOutput!.makeArray(of: Float.self)!
193-
let nextLogprob = log(Float(softmaxResult[nextToken!]))
193+
// Log of softmax probability of chosen token
194+
let softmaxResult = softmaxOutput!.makeArray(of: Float.self)!
195+
let nextLogprob = log(Float(softmaxResult[nextToken!]))
196+
// Deallocations
197+
softmaxOutput?.deallocate()
198+
argmaxOutput?.deallocate()
199+
if softmaxInputNeedsDeallocate {
200+
softmaxInput?.deallocate()
201+
}
194202

195-
nextTokens = tokens + [nextToken!]
196-
nextLogprobs = logProbs + [nextLogprob]
197-
completed = nextToken == eotToken
203+
return (token: nextToken!, logprob: nextLogprob)
204+
}
198205

199-
// Deallocations
200-
softmaxOutput?.deallocate()
201-
argmaxOutput?.deallocate()
202-
if softmaxInputNeedsDeallocate {
203-
softmaxInput?.deallocate()
204-
}
205-
}
206+
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult {
207+
var nextTokens = tokens
208+
var nextLogprobs = logProbs
209+
var completed = false
206210

207-
return SamplingResult(tokens: nextTokens, logProbs: nextLogprobs, completed: completed)
211+
var result: (token: Int, logprob: Float)
212+
#if swift(>=5.10)
213+
if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
214+
result = sampleWithMLTensor(logits: logits)
215+
} else {
216+
result = sampleWithBNNS(logits: logits)
217+
}
218+
#else
219+
result = sampleWithBNNS(logits: logits)
220+
#endif
221+
222+
nextTokens = tokens + [result.token]
223+
nextLogprobs = logProbs + [result.logprob]
224+
completed = result.token == eotToken
225+
226+
return SamplingResult(
227+
tokens: nextTokens,
228+
logProbs: nextLogprobs,
229+
completed: completed
230+
)
208231
}
209232

210233
public func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult {

0 commit comments

Comments
 (0)