17
17
18
18
use crate :: backend:: Id ;
19
19
use crate :: { Backend , Registry } ;
20
+ use anyhow:: anyhow;
20
21
use std:: collections:: HashMap ;
21
22
use std:: hash:: Hash ;
22
23
use std:: { fmt, str:: FromStr } ;
@@ -54,29 +55,63 @@ impl<'a> WasiNnView<'a> {
54
55
}
55
56
}
56
57
57
- pub enum Error {
58
+ /// A wasi-nn error; this appears on the Wasm side as a component model
59
+ /// resource.
60
+ #[ derive( Debug ) ]
61
+ pub struct Error {
62
+ code : ErrorCode ,
63
+ data : anyhow:: Error ,
64
+ }
65
+
66
+ /// Construct an [`Error`] resource and immediately return it.
67
+ ///
68
+ /// The WIT specification currently relies on "errors as resources;" this helper
69
+ /// macro hides some of that complexity. If [#75] is adopted ("errors as
70
+ /// records"), this macro is no longer necessary.
71
+ ///
72
+ /// [#75]: https://github.com/WebAssembly/wasi-nn/pull/75
73
+ macro_rules! bail {
74
+ ( $self: ident, $code: expr, $data: expr) => {
75
+ let e = Error {
76
+ code: $code,
77
+ data: $data. into( ) ,
78
+ } ;
79
+ tracing:: error!( "failure: {e:?}" ) ;
80
+ let r = $self. table. push( e) ?;
81
+ return Ok ( Err ( r) ) ;
82
+ } ;
83
+ }
84
+
85
+ impl From < wasmtime:: component:: ResourceTableError > for Error {
86
+ fn from ( error : wasmtime:: component:: ResourceTableError ) -> Self {
87
+ Self {
88
+ code : ErrorCode :: Trap ,
89
+ data : error. into ( ) ,
90
+ }
91
+ }
92
+ }
93
+
94
+ /// The list of error codes available to the `wasi-nn` API; this should match
95
+ /// what is specified in WIT.
96
+ #[ derive( Debug ) ]
97
+ pub enum ErrorCode {
58
98
/// Caller module passed an invalid argument.
59
99
InvalidArgument ,
60
100
/// Invalid encoding.
61
101
InvalidEncoding ,
62
102
/// The operation timed out.
63
103
Timeout ,
64
- /// Runtime Error .
104
+ /// Runtime error .
65
105
RuntimeError ,
66
106
/// Unsupported operation.
67
107
UnsupportedOperation ,
68
108
/// Graph is too large.
69
109
TooLarge ,
70
110
/// Graph not found.
71
111
NotFound ,
72
- /// A runtime error occurred that we should trap on; see `StreamError`.
73
- Trap ( anyhow:: Error ) ,
74
- }
75
-
76
- impl From < wasmtime:: component:: ResourceTableError > for Error {
77
- fn from ( error : wasmtime:: component:: ResourceTableError ) -> Self {
78
- Self :: Trap ( error. into ( ) )
79
- }
112
+ /// A runtime error that Wasmtime should trap on; this will not appear in
113
+ /// the WIT specification.
114
+ Trap ,
80
115
}
81
116
82
117
/// Generate the traits and types from the `wasi-nn` WIT specification.
@@ -91,6 +126,7 @@ mod gen_ {
91
126
"wasi:nn/graph/graph" : crate :: Graph ,
92
127
"wasi:nn/tensor/tensor" : crate :: Tensor ,
93
128
"wasi:nn/inference/graph-execution-context" : crate :: ExecutionContext ,
129
+ "wasi:nn/errors/error" : super :: Error ,
94
130
} ,
95
131
trappable_error_type: {
96
132
"wasi:nn/errors/error" => super :: Error ,
@@ -131,36 +167,45 @@ impl gen::graph::Host for WasiNnView<'_> {
131
167
builders : Vec < GraphBuilder > ,
132
168
encoding : GraphEncoding ,
133
169
target : ExecutionTarget ,
134
- ) -> Result < Resource < crate :: Graph > , Error > {
170
+ ) -> wasmtime :: Result < Result < Resource < Graph > , Resource < Error > > > {
135
171
tracing:: debug!( "load {encoding:?} {target:?}" ) ;
136
172
if let Some ( backend) = self . ctx . backends . get_mut ( & encoding) {
137
173
let slices = builders. iter ( ) . map ( |s| s. as_slice ( ) ) . collect :: < Vec < _ > > ( ) ;
138
174
match backend. load ( & slices, target. into ( ) ) {
139
175
Ok ( graph) => {
140
176
let graph = self . table . push ( graph) ?;
141
- Ok ( graph)
177
+ Ok ( Ok ( graph) )
142
178
}
143
179
Err ( error) => {
144
- tracing:: error!( "failed to load graph: {error:?}" ) ;
145
- Err ( Error :: RuntimeError )
180
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
146
181
}
147
182
}
148
183
} else {
149
- Err ( Error :: InvalidEncoding )
184
+ bail ! (
185
+ self ,
186
+ ErrorCode :: InvalidEncoding ,
187
+ anyhow!( "unable to find a backend for this encoding" )
188
+ ) ;
150
189
}
151
190
}
152
191
153
- fn load_by_name ( & mut self , name : String ) -> Result < Resource < Graph > , Error > {
192
+ fn load_by_name (
193
+ & mut self ,
194
+ name : String ,
195
+ ) -> wasmtime:: Result < Result < Resource < Graph > , Resource < Error > > > {
154
196
use core:: result:: Result :: * ;
155
197
tracing:: debug!( "load by name {name:?}" ) ;
156
198
let registry = & self . ctx . registry ;
157
199
if let Some ( graph) = registry. get ( & name) {
158
200
let graph = graph. clone ( ) ;
159
201
let graph = self . table . push ( graph) ?;
160
- Ok ( graph)
202
+ Ok ( Ok ( graph) )
161
203
} else {
162
- tracing:: error!( "failed to find graph with name: {name}" ) ;
163
- Err ( Error :: NotFound )
204
+ bail ! (
205
+ self ,
206
+ ErrorCode :: NotFound ,
207
+ anyhow!( "failed to find graph with name: {name}" )
208
+ ) ;
164
209
}
165
210
}
166
211
}
@@ -169,18 +214,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> {
169
214
fn init_execution_context (
170
215
& mut self ,
171
216
graph : Resource < Graph > ,
172
- ) -> Result < Resource < GraphExecutionContext > , Error > {
217
+ ) -> wasmtime :: Result < Result < Resource < GraphExecutionContext > , Resource < Error > > > {
173
218
use core:: result:: Result :: * ;
174
219
tracing:: debug!( "initialize execution context" ) ;
175
220
let graph = self . table . get ( & graph) ?;
176
221
match graph. init_execution_context ( ) {
177
222
Ok ( exec_context) => {
178
223
let exec_context = self . table . push ( exec_context) ?;
179
- Ok ( exec_context)
224
+ Ok ( Ok ( exec_context) )
180
225
}
181
226
Err ( error) => {
182
- tracing:: error!( "failed to initialize execution context: {error:?}" ) ;
183
- Err ( Error :: RuntimeError )
227
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
184
228
}
185
229
}
186
230
}
@@ -197,47 +241,46 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
197
241
exec_context : Resource < GraphExecutionContext > ,
198
242
name : String ,
199
243
tensor : Resource < Tensor > ,
200
- ) -> Result < ( ) , Error > {
244
+ ) -> wasmtime :: Result < Result < ( ) , Resource < Error > > > {
201
245
let tensor = self . table . get ( & tensor) ?;
202
246
tracing:: debug!( "set input {name:?}: {tensor:?}" ) ;
203
247
let tensor = tensor. clone ( ) ; // TODO: avoid copying the tensor
204
248
let exec_context = self . table . get_mut ( & exec_context) ?;
205
- if let Err ( e) = exec_context. set_input ( Id :: Name ( name) , & tensor) {
206
- tracing:: error!( "failed to set input: {e:?}" ) ;
207
- Err ( Error :: InvalidArgument )
249
+ if let Err ( error) = exec_context. set_input ( Id :: Name ( name) , & tensor) {
250
+ bail ! ( self , ErrorCode :: InvalidArgument , error) ;
208
251
} else {
209
- Ok ( ( ) )
252
+ Ok ( Ok ( ( ) ) )
210
253
}
211
254
}
212
255
213
- fn compute ( & mut self , exec_context : Resource < GraphExecutionContext > ) -> Result < ( ) , Error > {
256
+ fn compute (
257
+ & mut self ,
258
+ exec_context : Resource < GraphExecutionContext > ,
259
+ ) -> wasmtime:: Result < Result < ( ) , Resource < Error > > > {
214
260
let exec_context = & mut self . table . get_mut ( & exec_context) ?;
215
261
tracing:: debug!( "compute" ) ;
216
262
match exec_context. compute ( ) {
217
- Ok ( ( ) ) => Ok ( ( ) ) ,
263
+ Ok ( ( ) ) => Ok ( Ok ( ( ) ) ) ,
218
264
Err ( error) => {
219
- tracing:: error!( "failed to compute: {error:?}" ) ;
220
- Err ( Error :: RuntimeError )
265
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
221
266
}
222
267
}
223
268
}
224
269
225
- #[ doc = r" Extract the outputs after inference." ]
226
270
fn get_output (
227
271
& mut self ,
228
272
exec_context : Resource < GraphExecutionContext > ,
229
273
name : String ,
230
- ) -> Result < Resource < Tensor > , Error > {
274
+ ) -> wasmtime :: Result < Result < Resource < Tensor > , Resource < Error > > > {
231
275
let exec_context = self . table . get_mut ( & exec_context) ?;
232
276
tracing:: debug!( "get output {name:?}" ) ;
233
277
match exec_context. get_output ( Id :: Name ( name) ) {
234
278
Ok ( tensor) => {
235
279
let tensor = self . table . push ( tensor) ?;
236
- Ok ( tensor)
280
+ Ok ( Ok ( tensor) )
237
281
}
238
282
Err ( error) => {
239
- tracing:: error!( "failed to get output: {error:?}" ) ;
240
- Err ( Error :: RuntimeError )
283
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
241
284
}
242
285
}
243
286
}
@@ -285,21 +328,51 @@ impl gen::tensor::HostTensor for WasiNnView<'_> {
285
328
}
286
329
}
287
330
288
- impl gen:: tensor:: Host for WasiNnView < ' _ > { }
331
+ impl gen:: errors:: HostError for WasiNnView < ' _ > {
332
+ fn new (
333
+ & mut self ,
334
+ _code : gen:: errors:: ErrorCode ,
335
+ _data : String ,
336
+ ) -> wasmtime:: Result < Resource < Error > > {
337
+ unimplemented ! ( "this should be removed; see https://github.com/WebAssembly/wasi-nn/pull/76" )
338
+ }
339
+
340
+ fn code ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < gen:: errors:: ErrorCode > {
341
+ let error = self . table . get ( & error) ?;
342
+ match error. code {
343
+ ErrorCode :: InvalidArgument => Ok ( gen:: errors:: ErrorCode :: InvalidArgument ) ,
344
+ ErrorCode :: InvalidEncoding => Ok ( gen:: errors:: ErrorCode :: InvalidEncoding ) ,
345
+ ErrorCode :: Timeout => Ok ( gen:: errors:: ErrorCode :: Timeout ) ,
346
+ ErrorCode :: RuntimeError => Ok ( gen:: errors:: ErrorCode :: RuntimeError ) ,
347
+ ErrorCode :: UnsupportedOperation => Ok ( gen:: errors:: ErrorCode :: UnsupportedOperation ) ,
348
+ ErrorCode :: TooLarge => Ok ( gen:: errors:: ErrorCode :: TooLarge ) ,
349
+ ErrorCode :: NotFound => Ok ( gen:: errors:: ErrorCode :: NotFound ) ,
350
+ ErrorCode :: Trap => Err ( anyhow ! ( error. data. to_string( ) ) ) ,
351
+ }
352
+ }
353
+
354
+ fn data ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < String > {
355
+ let error = self . table . get ( & error) ?;
356
+ Ok ( error. data . to_string ( ) )
357
+ }
358
+
359
+ fn drop ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < ( ) > {
360
+ self . table . delete ( error) ?;
361
+ Ok ( ( ) )
362
+ }
363
+ }
364
+
289
365
impl gen:: errors:: Host for WasiNnView < ' _ > {
290
- fn convert_error ( & mut self , err : Error ) -> wasmtime:: Result < gen:: errors:: Error > {
291
- match err {
292
- Error :: InvalidArgument => Ok ( gen:: errors:: Error :: InvalidArgument ) ,
293
- Error :: InvalidEncoding => Ok ( gen:: errors:: Error :: InvalidEncoding ) ,
294
- Error :: Timeout => Ok ( gen:: errors:: Error :: Timeout ) ,
295
- Error :: RuntimeError => Ok ( gen:: errors:: Error :: RuntimeError ) ,
296
- Error :: UnsupportedOperation => Ok ( gen:: errors:: Error :: UnsupportedOperation ) ,
297
- Error :: TooLarge => Ok ( gen:: errors:: Error :: TooLarge ) ,
298
- Error :: NotFound => Ok ( gen:: errors:: Error :: NotFound ) ,
299
- Error :: Trap ( e) => Err ( e) ,
366
+ fn convert_error ( & mut self , err : Error ) -> wasmtime:: Result < Error > {
367
+ if matches ! ( err. code, ErrorCode :: Trap ) {
368
+ Err ( err. data )
369
+ } else {
370
+ Ok ( err)
300
371
}
301
372
}
302
373
}
374
+
375
+ impl gen:: tensor:: Host for WasiNnView < ' _ > { }
303
376
impl gen:: inference:: Host for WasiNnView < ' _ > { }
304
377
305
378
impl Hash for gen:: graph:: GraphEncoding {
0 commit comments