@@ -28,183 +28,206 @@ open class GreedyTokenSampler: TokenSampling {
28
28
self . decodingOptions = decodingOptions
29
29
}
30
30
31
- public func update( tokens: [ Int ] , logits: MLMultiArray , logProbs: [ Float ] ) -> SamplingResult {
32
- var nextTokens = tokens
33
- var nextLogprobs = logProbs
34
- var completed = false
35
- if #available( macOS 15 . 0 , iOS 18 . 0 , watchOS 11 . 0 , visionOS 2 . 0 , * ) {
36
- // Use MLTensor operations if available for sampling
37
- // Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift
38
- var logitsTensor = MLTensor ( MLShapedArray < FloatType > ( logits) ) . cast ( to: Float . self)
39
- var nextTokenTensor : MLTensor
40
- var nextLogprobTensor : MLTensor
41
-
42
- if temperature != 0.0 {
43
- // Scale logits by temperature if > 0
44
- logitsTensor = logitsTensor / temperature
45
- }
31
+ #if swift(>=5.10)
32
+ @available ( macOS 15 , iOS 18 , watchOS 11 , visionOS 2 , * )
33
+ private func sampleWithMLTensor( logits: MLMultiArray ) -> ( token: Int , logprob: Float ) {
34
+ // Use MLTensor operations if available for sampling
35
+ // Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift
36
+ var logitsTensor = MLTensor ( MLShapedArray < FloatType > ( logits) ) . cast ( to: Float . self)
37
+ var nextTokenTensor : MLTensor
38
+ var nextLogprobTensor : MLTensor
39
+
40
+ if temperature != 0.0 {
41
+ // Scale logits by temperature if > 0
42
+ logitsTensor = logitsTensor / temperature
43
+ }
46
44
47
- // Always softmax once
48
- let softmaxScores = logitsTensor. softmax ( alongAxis: - 1 )
45
+ // Always softmax once
46
+ let softmaxScores = logitsTensor. softmax ( alongAxis: - 1 )
47
+
48
+ if temperature != 0.0 {
49
+ // top-k multinomial sampling
50
+ let ( topKProbs, topKIndices) = softmaxScores. topK ( decodingOptions. topK)
51
+
52
+ let rnd = topKProbs. sum ( ) * Float. random ( in: 0 ..< 1 )
53
+ var accumTopKProbs = topKProbs. cumulativeSum ( alongAxis: - 1 )
54
+ accumTopKProbs += ( accumTopKProbs .< rnd) * 100.0
55
+ let topKIndex = accumTopKProbs. argsort ( ) [ ... , 0 ]
56
+
57
+ nextTokenTensor = topKIndices. gathering (
58
+ atIndices: topKIndex,
59
+ alongAxis: topKIndices. rank - 1
60
+ )
61
+ nextLogprobTensor = topKProbs. gathering (
62
+ atIndices: topKIndex,
63
+ alongAxis: topKIndices. rank - 1
64
+ ) . log ( )
65
+ } else {
66
+ nextTokenTensor = logitsTensor. argmax ( alongAxis: - 1 )
67
+ nextLogprobTensor = softmaxScores. gathering ( atIndices: nextTokenTensor, alongAxis: - 1 ) . log ( )
68
+ }
49
69
50
- if temperature != 0.0 {
51
- // top-k multinomial sampling
52
- let ( topKProbs, topKIndices) = softmaxScores. topK ( decodingOptions. topK)
70
+ return (
71
+ token: nextTokenTensor. asIntArray ( ) [ 0 ] ,
72
+ logprob: nextLogprobTensor. asFloatArray ( ) [ 0 ]
73
+ )
74
+ }
75
+ #endif
53
76
54
- let rnd = topKProbs. sum ( ) * Float. random ( in: 0 ..< 1 )
55
- var accumTopKProbs = topKProbs. cumulativeSum ( alongAxis: - 1 )
56
- accumTopKProbs += ( accumTopKProbs .< rnd) * 100.0
57
- let topKIndex = accumTopKProbs. argsort ( ) [ ... , 0 ]
77
+ private func sampleWithBNNS( logits: MLMultiArray ) -> ( token: Int , logprob: Float ) {
78
+ // TODO: BNNS operations here are deprecated, replace with vDSP or MLX
79
+ var softmaxOutput : BNNSNDArrayDescriptor ?
80
+ var argmaxOutput : BNNSNDArrayDescriptor ?
81
+ var softmaxInput : BNNSNDArrayDescriptor ?
82
+ var softmaxInputNeedsDeallocate = false
58
83
59
- nextTokenTensor = topKIndices. gathering (
60
- atIndices: topKIndex,
61
- alongAxis: topKIndices. rank - 1
62
- )
63
- nextLogprobTensor = topKProbs. gathering (
64
- atIndices: topKIndex,
65
- alongAxis: topKIndices. rank - 1
66
- ) . log ( )
67
- } else {
68
- nextTokenTensor = logitsTensor. argmax ( alongAxis: - 1 )
69
- nextLogprobTensor = softmaxScores. gathering ( atIndices: nextTokenTensor, alongAxis: - 1 ) . log ( )
70
- }
84
+ var nextToken : Int ?
71
85
72
- let nextToken = nextTokenTensor. asIntArray ( ) [ 0 ]
73
- let nextLogprob = nextLogprobTensor. asFloatArray ( ) [ 0 ]
86
+ do {
87
+ let logitsRawPointer = UnsafeMutableRawBufferPointer (
88
+ start: logits. dataPointer,
89
+ count: logits. count * MemoryLayout< FloatType> . stride
90
+ )
74
91
75
- nextTokens = tokens + [ nextToken]
76
- nextLogprobs = logProbs + [ nextLogprob]
77
- completed = nextToken == eotToken
92
+ let logitsDescriptor = BNNSNDArrayDescriptor (
93
+ data: logitsRawPointer,
94
+ scalarType: FloatType . self,
95
+ shape: . vector( logits. count, stride: 1 )
96
+ ) !
78
97
79
- } else {
80
- // TODO: BNNS operations here are deprecated, replace with vDSP or MLX
81
- var softmaxOutput : BNNSNDArrayDescriptor ?
82
- var argmaxOutput : BNNSNDArrayDescriptor ?
83
- var softmaxInput : BNNSNDArrayDescriptor ?
84
- var softmaxInputNeedsDeallocate = false
85
-
86
- var nextToken : Int ?
87
-
88
- do {
89
- let logitsRawPointer = UnsafeMutableRawBufferPointer (
90
- start: logits. dataPointer,
91
- count: logits. count * MemoryLayout< FloatType> . stride
92
- )
98
+ softmaxInput = logitsDescriptor
93
99
94
- let logitsDescriptor = BNNSNDArrayDescriptor (
95
- data: logitsRawPointer,
100
+ // Scale logits by temperature if > 0
101
+ if temperature != 0.0 {
102
+ let scaledLogits = BNNSNDArrayDescriptor . allocateUninitialized (
96
103
scalarType: FloatType . self,
97
104
shape: . vector( logits. count, stride: 1 )
98
- ) !
99
-
100
- softmaxInput = logitsDescriptor
101
-
102
- // Scale logits by temperature if > 0
103
- if temperature != 0.0 {
104
- let scaledLogits = BNNSNDArrayDescriptor . allocateUninitialized (
105
- scalarType: FloatType . self,
106
- shape: . vector( logits. count, stride: 1 )
107
- )
108
-
109
- try ! BNNS . applyActivation (
110
- activation: BNNS . ActivationFunction. linear ( alpha: Float ( 1 / temperature) ) ,
111
- input: logitsDescriptor,
112
- output: scaledLogits,
113
- batchSize: 1
114
- )
115
-
116
- softmaxInput = scaledLogits
117
- softmaxInputNeedsDeallocate = true
118
- }
105
+ )
106
+
107
+ try ! BNNS . applyActivation (
108
+ activation: BNNS . ActivationFunction. linear ( alpha: Float ( 1 / temperature) ) ,
109
+ input: logitsDescriptor,
110
+ output: scaledLogits,
111
+ batchSize: 1
112
+ )
119
113
120
- // Always softmax once
121
- softmaxOutput = BNNSNDArrayDescriptor . allocateUninitialized (
114
+ softmaxInput = scaledLogits
115
+ softmaxInputNeedsDeallocate = true
116
+ }
117
+
118
+ // Always softmax once
119
+ softmaxOutput = BNNSNDArrayDescriptor . allocateUninitialized (
120
+ scalarType: Float . self,
121
+ shape: . vector( logits. count, stride: 1 )
122
+ )
123
+
124
+ try BNNS . applyActivation (
125
+ activation: BNNS . ActivationFunction. softmax,
126
+ input: softmaxInput!,
127
+ output: softmaxOutput!,
128
+ batchSize: 1
129
+ )
130
+
131
+ if temperature != 0.0 {
132
+ // top-k multinomial sampling
133
+ let k = decodingOptions. topK
134
+ let bestValues = BNNSNDArrayDescriptor . allocateUninitialized (
122
135
scalarType: Float . self,
123
- shape: . vector( logits. count, stride: 1 )
136
+ shape: . vector( k, stride: 1 )
137
+ )
138
+ let bestIndices = BNNSNDArrayDescriptor . allocateUninitialized (
139
+ scalarType: Int32 . self,
140
+ shape: . vector( k, stride: 1 )
124
141
)
125
142
126
- try BNNS . applyActivation (
127
- activation: BNNS . ActivationFunction. softmax,
128
- input: softmaxInput!,
129
- output: softmaxOutput!,
143
+ try ! BNNS . applyTopK (
144
+ k: k,
145
+ input: softmaxOutput!,
146
+ bestValues: bestValues,
147
+ bestIndices: bestIndices,
148
+ axis: 0 ,
130
149
batchSize: 1
131
150
)
132
151
133
- if temperature != 0.0 {
134
- // top-k multinomial sampling
135
- let k = decodingOptions. topK
136
-
137
- let bestValues = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float . self, shape: . vector( k, stride: 1 ) )
138
- let bestIndices = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Int32 . self, shape: . vector( k, stride: 1 ) )
139
-
140
- try ! BNNS . applyTopK (
141
- k: k,
142
- input: softmaxOutput!,
143
- bestValues: bestValues,
144
- bestIndices: bestIndices,
145
- axis: 0 ,
146
- batchSize: 1
147
- )
148
-
149
- let bestValuesResult = bestValues. makeArray ( of: Float . self) !
150
- let bestIndicesResult = bestIndices. makeArray ( of: Int32 . self) !
151
-
152
- bestValues. deallocate ( )
153
- bestIndices. deallocate ( )
154
-
155
- // multinomial sample from top-k
156
- let sumOfbestIndicesResult = bestValuesResult. reduce ( 0 , + )
157
- let rnd = Float . random ( in: 0 ..< sumOfbestIndicesResult)
158
- var accumulator = Float ( 0.0 )
159
- var chosenIndex = 0
160
- for i in 0 ..< bestValuesResult. count {
161
- accumulator += bestValuesResult [ i]
162
- if rnd < accumulator {
163
- chosenIndex = i
164
- break
165
- }
152
+ let bestValuesResult = bestValues. makeArray ( of: Float . self) !
153
+ let bestIndicesResult = bestIndices. makeArray ( of: Int32 . self) !
154
+
155
+ bestValues. deallocate ( )
156
+ bestIndices. deallocate ( )
157
+
158
+ // multinomial sample from top-k
159
+ let sumOfbestIndicesResult = bestValuesResult. reduce ( 0 , + )
160
+ let rnd = Float . random ( in: 0 ..< sumOfbestIndicesResult)
161
+ var accumulator = Float ( 0.0 )
162
+ var chosenIndex = 0
163
+ for i in 0 ..< bestValuesResult. count {
164
+ accumulator += bestValuesResult [ i]
165
+ if rnd < accumulator {
166
+ chosenIndex = i
167
+ break
166
168
}
169
+ }
167
170
168
- nextToken = Int ( bestIndicesResult [ chosenIndex] )
169
- } else {
170
- // Argmax sampling
171
- argmaxOutput = BNNSNDArrayDescriptor . allocateUninitialized (
172
- scalarType: Float . self,
173
- shape: . vector( 1 , stride: 1 )
174
- )
171
+ nextToken = Int ( bestIndicesResult [ chosenIndex] )
172
+ } else {
173
+ argmaxOutput = BNNSNDArrayDescriptor . allocateUninitialized (
174
+ scalarType: Float . self,
175
+ shape: . vector( 1 , stride: 1 )
176
+ )
175
177
176
- try ! BNNS . applyReduction (
177
- BNNS . ReductionFunction. argMax,
178
- input: logitsDescriptor,
179
- output: argmaxOutput!,
180
- weights: nil
181
- )
178
+ try ! BNNS . applyReduction (
179
+ BNNS . ReductionFunction. argMax,
180
+ input: logitsDescriptor,
181
+ output: argmaxOutput!,
182
+ weights: nil
183
+ )
182
184
183
- let argmaxResult = argmaxOutput!. makeArray ( of: Float . self) !
185
+ let argmaxResult = argmaxOutput!. makeArray ( of: Float . self) !
184
186
185
- nextToken = Int ( argmaxResult [ 0 ] )
186
- }
187
- } catch {
188
- Logging . error ( " Sampling error: \( error) " )
187
+ nextToken = Int ( argmaxResult [ 0 ] )
189
188
}
189
+ } catch {
190
+ Logging . error ( " Sampling error: \( error) " )
191
+ }
190
192
191
- // Log of softmax probability of chosen token
192
- let softmaxResult = softmaxOutput!. makeArray ( of: Float . self) !
193
- let nextLogprob = log ( Float ( softmaxResult [ nextToken!] ) )
193
+ // Log of softmax probability of chosen token
194
+ let softmaxResult = softmaxOutput!. makeArray ( of: Float . self) !
195
+ let nextLogprob = log ( Float ( softmaxResult [ nextToken!] ) )
196
+ // Deallocations
197
+ softmaxOutput? . deallocate ( )
198
+ argmaxOutput? . deallocate ( )
199
+ if softmaxInputNeedsDeallocate {
200
+ softmaxInput? . deallocate ( )
201
+ }
194
202
195
- nextTokens = tokens + [ nextToken!]
196
- nextLogprobs = logProbs + [ nextLogprob]
197
- completed = nextToken == eotToken
203
+ return ( token: nextToken!, logprob: nextLogprob)
204
+ }
198
205
199
- // Deallocations
200
- softmaxOutput? . deallocate ( )
201
- argmaxOutput? . deallocate ( )
202
- if softmaxInputNeedsDeallocate {
203
- softmaxInput? . deallocate ( )
204
- }
205
- }
206
+ public func update( tokens: [ Int ] , logits: MLMultiArray , logProbs: [ Float ] ) -> SamplingResult {
207
+ var nextTokens = tokens
208
+ var nextLogprobs = logProbs
209
+ var completed = false
206
210
207
- return SamplingResult ( tokens: nextTokens, logProbs: nextLogprobs, completed: completed)
211
+ var result : ( token: Int , logprob: Float )
212
+ #if swift(>=5.10)
213
+ if #available( macOS 15 . 0 , iOS 18 . 0 , watchOS 11 . 0 , visionOS 2 . 0 , * ) {
214
+ result = sampleWithMLTensor ( logits: logits)
215
+ } else {
216
+ result = sampleWithBNNS ( logits: logits)
217
+ }
218
+ #else
219
+ result = sampleWithBNNS ( logits: logits)
220
+ #endif
221
+
222
+ nextTokens = tokens + [ result. token]
223
+ nextLogprobs = logProbs + [ result. logprob]
224
+ completed = result. token == eotToken
225
+
226
+ return SamplingResult (
227
+ tokens: nextTokens,
228
+ logProbs: nextLogprobs,
229
+ completed: completed
230
+ )
208
231
}
209
232
210
233
public func finalize( tokens: [ Int ] , logProbs: [ Float ] ) -> SamplingResult {
0 commit comments