Skip to content

Commit fa7cbbc

Browse files
committed
wasi-nn: track upstream specification
In bytecodealliance#8873, we stopped tracking the wasi-nn's upstream WIT files temporarily because it was not clear to me at the time how to implement errors as CM resources. This PR fixes that, resuming tracking in the `vendor-wit.sh` and implementing what is needed in the wasi-nn crate. This leaves several threads unresolved, though: - it looks like the `vendor-wit.sh` has a new mechanism for retrieving files into `wit/deps`--at some point wasi-nn should migrate to use that (?) - it's not clear to me that "errors as resources" is even the best approach here; I've opened [bytecodealliance#75] to consider the possibility of using "errors as records" instead. - this PR identifies that the constructor for errors is in fact unnecessary, prompting an upstream change ([bytecodealliance#76]) that should eventually be implemented here. [bytecodealliance#75]: WebAssembly/wasi-nn#75 [bytecodealliance#76]: WebAssembly/wasi-nn#76 prtest:full
1 parent ba864e9 commit fa7cbbc

File tree

3 files changed

+141
-57
lines changed

3 files changed

+141
-57
lines changed

ci/vendor-wit.sh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ rm -rf $cache_dir
6565
# Separately (for now), vendor the `wasi-nn` WIT files since their retrieval is
6666
# slightly different than above.
6767
repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn
68-
revision=e2310b
68+
revision=0.2.0-rc-2024-06-25
6969
curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx
70-
# TODO: the in-tree `wasi-nn` implementation does not yet fully support the
71-
# latest WIT specification on `main`. To create a baseline for moving forward,
72-
# the in-tree WIT incorporates some but not all of the upstream changes. This
73-
# TODO can be removed once the implementation catches up with the spec.
74-
# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
70+
curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit

crates/wasi-nn/src/wit.rs

Lines changed: 121 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
1818
use crate::backend::Id;
1919
use crate::{Backend, Registry};
20+
use anyhow::anyhow;
2021
use std::collections::HashMap;
2122
use std::hash::Hash;
2223
use std::{fmt, str::FromStr};
@@ -54,29 +55,63 @@ impl<'a> WasiNnView<'a> {
5455
}
5556
}
5657

57-
pub enum Error {
58+
/// A wasi-nn error; this appears on the Wasm side as a component model
59+
/// resource.
60+
#[derive(Debug)]
61+
pub struct Error {
62+
code: ErrorCode,
63+
data: anyhow::Error,
64+
}
65+
66+
/// Construct an [`Error`] resource and immediately return it.
67+
///
68+
/// The WIT specification currently relies on "errors as resources;" this helper
69+
/// macro hides some of that complexity. If [#75] is adopted ("errors as
70+
/// records"), this macro is no longer necessary.
71+
///
72+
/// [#75]: https://github.com/WebAssembly/wasi-nn/pull/75
73+
macro_rules! bail {
74+
($self:ident, $code:expr, $data:expr) => {
75+
let e = Error {
76+
code: $code,
77+
data: $data.into(),
78+
};
79+
tracing::error!("failure: {e:?}");
80+
let r = $self.table.push(e)?;
81+
return Ok(Err(r));
82+
};
83+
}
84+
85+
impl From<wasmtime::component::ResourceTableError> for Error {
86+
fn from(error: wasmtime::component::ResourceTableError) -> Self {
87+
Self {
88+
code: ErrorCode::Trap,
89+
data: error.into(),
90+
}
91+
}
92+
}
93+
94+
/// The list of error codes available to the `wasi-nn` API; this should match
95+
/// what is specified in WIT.
96+
#[derive(Debug)]
97+
pub enum ErrorCode {
5898
/// Caller module passed an invalid argument.
5999
InvalidArgument,
60100
/// Invalid encoding.
61101
InvalidEncoding,
62102
/// The operation timed out.
63103
Timeout,
64-
/// Runtime Error.
104+
/// Runtime error.
65105
RuntimeError,
66106
/// Unsupported operation.
67107
UnsupportedOperation,
68108
/// Graph is too large.
69109
TooLarge,
70110
/// Graph not found.
71111
NotFound,
72-
/// A runtime error occurred that we should trap on; see `StreamError`.
73-
Trap(anyhow::Error),
74-
}
75-
76-
impl From<wasmtime::component::ResourceTableError> for Error {
77-
fn from(error: wasmtime::component::ResourceTableError) -> Self {
78-
Self::Trap(error.into())
79-
}
112+
/// A runtime error that Wasmtime should trap on; this will not appear in
113+
/// the WIT specification.
114+
Trap,
80115
}
81116

82117
/// Generate the traits and types from the `wasi-nn` WIT specification.
@@ -91,6 +126,7 @@ mod gen_ {
91126
"wasi:nn/graph/graph": crate::Graph,
92127
"wasi:nn/tensor/tensor": crate::Tensor,
93128
"wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
129+
"wasi:nn/errors/error": super::Error,
94130
},
95131
trappable_error_type: {
96132
"wasi:nn/errors/error" => super::Error,
@@ -131,36 +167,45 @@ impl gen::graph::Host for WasiNnView<'_> {
131167
builders: Vec<GraphBuilder>,
132168
encoding: GraphEncoding,
133169
target: ExecutionTarget,
134-
) -> Result<Resource<crate::Graph>, Error> {
170+
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
135171
tracing::debug!("load {encoding:?} {target:?}");
136172
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
137173
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
138174
match backend.load(&slices, target.into()) {
139175
Ok(graph) => {
140176
let graph = self.table.push(graph)?;
141-
Ok(graph)
177+
Ok(Ok(graph))
142178
}
143179
Err(error) => {
144-
tracing::error!("failed to load graph: {error:?}");
145-
Err(Error::RuntimeError)
180+
bail!(self, ErrorCode::RuntimeError, error);
146181
}
147182
}
148183
} else {
149-
Err(Error::InvalidEncoding)
184+
bail!(
185+
self,
186+
ErrorCode::InvalidEncoding,
187+
anyhow!("unable to find a backend for this encoding")
188+
);
150189
}
151190
}
152191

153-
fn load_by_name(&mut self, name: String) -> Result<Resource<Graph>, Error> {
192+
fn load_by_name(
193+
&mut self,
194+
name: String,
195+
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
154196
use core::result::Result::*;
155197
tracing::debug!("load by name {name:?}");
156198
let registry = &self.ctx.registry;
157199
if let Some(graph) = registry.get(&name) {
158200
let graph = graph.clone();
159201
let graph = self.table.push(graph)?;
160-
Ok(graph)
202+
Ok(Ok(graph))
161203
} else {
162-
tracing::error!("failed to find graph with name: {name}");
163-
Err(Error::NotFound)
204+
bail!(
205+
self,
206+
ErrorCode::NotFound,
207+
anyhow!("failed to find graph with name: {name}")
208+
);
164209
}
165210
}
166211
}
@@ -169,18 +214,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> {
169214
fn init_execution_context(
170215
&mut self,
171216
graph: Resource<Graph>,
172-
) -> Result<Resource<GraphExecutionContext>, Error> {
217+
) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
173218
use core::result::Result::*;
174219
tracing::debug!("initialize execution context");
175220
let graph = self.table.get(&graph)?;
176221
match graph.init_execution_context() {
177222
Ok(exec_context) => {
178223
let exec_context = self.table.push(exec_context)?;
179-
Ok(exec_context)
224+
Ok(Ok(exec_context))
180225
}
181226
Err(error) => {
182-
tracing::error!("failed to initialize execution context: {error:?}");
183-
Err(Error::RuntimeError)
227+
bail!(self, ErrorCode::RuntimeError, error);
184228
}
185229
}
186230
}
@@ -197,47 +241,46 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
197241
exec_context: Resource<GraphExecutionContext>,
198242
name: String,
199243
tensor: Resource<Tensor>,
200-
) -> Result<(), Error> {
244+
) -> wasmtime::Result<Result<(), Resource<Error>>> {
201245
let tensor = self.table.get(&tensor)?;
202246
tracing::debug!("set input {name:?}: {tensor:?}");
203247
let tensor = tensor.clone(); // TODO: avoid copying the tensor
204248
let exec_context = self.table.get_mut(&exec_context)?;
205-
if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) {
206-
tracing::error!("failed to set input: {e:?}");
207-
Err(Error::InvalidArgument)
249+
if let Err(error) = exec_context.set_input(Id::Name(name), &tensor) {
250+
bail!(self, ErrorCode::InvalidArgument, error);
208251
} else {
209-
Ok(())
252+
Ok(Ok(()))
210253
}
211254
}
212255

213-
fn compute(&mut self, exec_context: Resource<GraphExecutionContext>) -> Result<(), Error> {
256+
fn compute(
257+
&mut self,
258+
exec_context: Resource<GraphExecutionContext>,
259+
) -> wasmtime::Result<Result<(), Resource<Error>>> {
214260
let exec_context = &mut self.table.get_mut(&exec_context)?;
215261
tracing::debug!("compute");
216262
match exec_context.compute() {
217-
Ok(()) => Ok(()),
263+
Ok(()) => Ok(Ok(())),
218264
Err(error) => {
219-
tracing::error!("failed to compute: {error:?}");
220-
Err(Error::RuntimeError)
265+
bail!(self, ErrorCode::RuntimeError, error);
221266
}
222267
}
223268
}
224269

225-
#[doc = r" Extract the outputs after inference."]
226270
fn get_output(
227271
&mut self,
228272
exec_context: Resource<GraphExecutionContext>,
229273
name: String,
230-
) -> Result<Resource<Tensor>, Error> {
274+
) -> wasmtime::Result<Result<Resource<Tensor>, Resource<Error>>> {
231275
let exec_context = self.table.get_mut(&exec_context)?;
232276
tracing::debug!("get output {name:?}");
233277
match exec_context.get_output(Id::Name(name)) {
234278
Ok(tensor) => {
235279
let tensor = self.table.push(tensor)?;
236-
Ok(tensor)
280+
Ok(Ok(tensor))
237281
}
238282
Err(error) => {
239-
tracing::error!("failed to get output: {error:?}");
240-
Err(Error::RuntimeError)
283+
bail!(self, ErrorCode::RuntimeError, error);
241284
}
242285
}
243286
}
@@ -285,21 +328,51 @@ impl gen::tensor::HostTensor for WasiNnView<'_> {
285328
}
286329
}
287330

288-
impl gen::tensor::Host for WasiNnView<'_> {}
331+
impl gen::errors::HostError for WasiNnView<'_> {
332+
fn new(
333+
&mut self,
334+
_code: gen::errors::ErrorCode,
335+
_data: String,
336+
) -> wasmtime::Result<Resource<Error>> {
337+
unimplemented!("this should be removed; see https://github.com/WebAssembly/wasi-nn/pull/76")
338+
}
339+
340+
fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<gen::errors::ErrorCode> {
341+
let error = self.table.get(&error)?;
342+
match error.code {
343+
ErrorCode::InvalidArgument => Ok(gen::errors::ErrorCode::InvalidArgument),
344+
ErrorCode::InvalidEncoding => Ok(gen::errors::ErrorCode::InvalidEncoding),
345+
ErrorCode::Timeout => Ok(gen::errors::ErrorCode::Timeout),
346+
ErrorCode::RuntimeError => Ok(gen::errors::ErrorCode::RuntimeError),
347+
ErrorCode::UnsupportedOperation => Ok(gen::errors::ErrorCode::UnsupportedOperation),
348+
ErrorCode::TooLarge => Ok(gen::errors::ErrorCode::TooLarge),
349+
ErrorCode::NotFound => Ok(gen::errors::ErrorCode::NotFound),
350+
ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
351+
}
352+
}
353+
354+
fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
355+
let error = self.table.get(&error)?;
356+
Ok(error.data.to_string())
357+
}
358+
359+
fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
360+
self.table.delete(error)?;
361+
Ok(())
362+
}
363+
}
364+
289365
impl gen::errors::Host for WasiNnView<'_> {
290-
fn convert_error(&mut self, err: Error) -> wasmtime::Result<gen::errors::Error> {
291-
match err {
292-
Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument),
293-
Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding),
294-
Error::Timeout => Ok(gen::errors::Error::Timeout),
295-
Error::RuntimeError => Ok(gen::errors::Error::RuntimeError),
296-
Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation),
297-
Error::TooLarge => Ok(gen::errors::Error::TooLarge),
298-
Error::NotFound => Ok(gen::errors::Error::NotFound),
299-
Error::Trap(e) => Err(e),
366+
fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
367+
if matches!(err.code, ErrorCode::Trap) {
368+
Err(err.data)
369+
} else {
370+
Ok(err)
300371
}
301372
}
302373
}
374+
375+
impl gen::tensor::Host for WasiNnView<'_> {}
303376
impl gen::inference::Host for WasiNnView<'_> {}
304377

305378
impl Hash for gen::graph::GraphEncoding {

crates/wasi-nn/wit/wasi-nn.wit

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package wasi:nn;
1+
package wasi:nn@0.2.0-rc-2024-06-25;
22

33
/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet)
44
/// capable of performing ML training. WebAssembly programs that want to use a host's ML
@@ -134,7 +134,7 @@ interface inference {
134134

135135
/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
136136
interface errors {
137-
enum error {
137+
enum error-code {
138138
// Caller module passed an invalid argument.
139139
invalid-argument,
140140
// Invalid encoding.
@@ -148,6 +148,21 @@ interface errors {
148148
// Graph is too large.
149149
too-large,
150150
// Graph not found.
151-
not-found
151+
not-found,
152+
// The operation is insecure or has insufficient privilege to be performed.
153+
// e.g., cannot access a hardware feature requested
154+
security,
155+
// The operation failed for an unspecified reason.
156+
unknown
157+
}
158+
159+
resource error {
160+
constructor(code: error-code, data: string);
161+
162+
/// Return the error code.
163+
code: func() -> error-code;
164+
165+
/// Errors can propagated with backend specific status through a string value.
166+
data: func() -> string;
152167
}
153168
}

0 commit comments

Comments
 (0)