Skip to content

Commit 4e2dcc5

Browse files
authored
Generate a Client method for Dropshot websocket channels (#183)
Generated methods return `ResponseValue<reqwest::Upgrade`, which may be passed to a websocket protocol implementation such as `tokio_tungstenite::WebSocketStream::from_raw_stream(rv.into_inner(), ...)` for the purpose of implementing against the raw websocket connection, but may later be extended as a generic to allow higher-level channel message definitions. Per the changelog, consumers will need to depend on reqwest 0.11.12 or newer for HTTP Upgrade support, as well as base64 and rand if any endpoints are websocket channels: ``` [dependencies] reqwest = { version = "0.11.12" features = ["json", "stream"] } base64 = "0.13" rand = "0.8" ``` Co-authored-by: lif <>
1 parent fd1ae2b commit 4e2dcc5

16 files changed

+5510
-11
lines changed

CHANGELOG.adoc

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ https://github.com/oxidecomputer/progenitor/compare/v0.1.1\...v0.2.0[Full list o
2424
* Derive `Debug` for `Client` and builders for the various operations (#145)
2525
* Builders for `struct` types (#171)
2626
* Add a prelude that include the `Client` and any extension traits (#176)
27+
* Add support for upgrading connections, which requires a version bump to reqwest. (#183)
2728

2829
== 0.1.1 (released 2022-05-13)
2930

Cargo.lock

+40
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

+9-2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ Similarly if there is a `format` field set to `uuid`:
5454
+uuid = { version = "1.0.0", features = ["serde", "v4"] }
5555
```
5656

57+
And if there are any websocket channel endpoints:
58+
```diff
59+
[dependencies]
60+
+base64 = "0.13"
61+
+rand = "0.8"
62+
```
63+
5764
The macro has some additional fancy options to control the generated code:
5865

5966
```rust
@@ -116,7 +123,7 @@ You'll need to add add the following to `Cargo.toml`:
116123
+serde_json = "1.0"
117124
```
118125

119-
(`chrono` and `uuid` as above)
126+
(`chrono`, `uuid`, `base64`, and `rand` as above)
120127

121128
Note that `progenitor` is used by `build.rs`, but the generated code required
122129
`progenitor-client`.
@@ -290,4 +297,4 @@ let result = client
290297
```
291298

292299
Consumers do not need to specify parameters and struct properties that are not
293-
required or for which the API specifies defaults. Neat!
300+
required or for which the API specifies defaults. Neat!

example-build/Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ edition = "2021"
77
[dependencies]
88
chrono = { version = "0.4", features = ["serde"] }
99
progenitor-client = { path = "../progenitor-client" }
10-
reqwest = { version = "0.11", features = ["json", "stream"] }
10+
reqwest = { version = "0.11.12", features = ["json", "stream"] }
11+
base64 = "0.13"
12+
rand = "0.8"
1113
serde = { version = "1.0", features = ["derive"] }
1214
uuid = { version = "1.0", features = ["serde", "v4"] }
1315

example-macro/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ edition = "2021"
77
[dependencies]
88
chrono = { version = "0.4", features = ["serde"] }
99
progenitor = { path = "../progenitor" }
10-
reqwest = { version = "0.11", features = ["json", "stream"] }
10+
reqwest = { version = "0.11.12", features = ["json", "stream"] }
1111
schemars = { version = "0.8.10", features = ["uuid1"] }
1212
serde = { version = "1.0", features = ["derive"] }
1313
uuid = { version = "1.0", features = ["serde", "v4"] }

progenitor-client/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ description = "An OpenAPI client generator - client support"
1010
bytes = "1.2.1"
1111
futures-core = "0.3.24"
1212
percent-encoding = "2.2"
13-
reqwest = { version = "0.11", default-features = false, features = ["json", "stream"] }
13+
reqwest = { version = "0.11.12", default-features = false, features = ["json", "stream"] }
1414
serde = "1.0"
1515
serde_json = "1.0"
1616
serde_urlencoded = "0.7.1"

progenitor-client/src/progenitor_client.rs

+24
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,30 @@ impl<T: DeserializeOwned> ResponseValue<T> {
7676
}
7777
}
7878

79+
impl ResponseValue<reqwest::Upgraded> {
80+
#[doc(hidden)]
81+
pub async fn upgrade<E: std::fmt::Debug>(
82+
response: reqwest::Response,
83+
) -> Result<Self, Error<E>> {
84+
let status = response.status();
85+
let headers = response.headers().clone();
86+
if status == reqwest::StatusCode::SWITCHING_PROTOCOLS {
87+
let inner = response
88+
.upgrade()
89+
.await
90+
.map_err(Error::InvalidResponsePayload)?;
91+
92+
Ok(Self {
93+
inner,
94+
status,
95+
headers,
96+
})
97+
} else {
98+
Err(Error::UnexpectedResponse(response))
99+
}
100+
}
101+
}
102+
79103
impl ResponseValue<ByteStream> {
80104
#[doc(hidden)]
81105
pub fn stream(response: reqwest::Response) -> Self {

progenitor-impl/src/lib.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ pub enum Error {
2626
UnexpectedFormat(String),
2727
#[error("invalid operation path {0}")]
2828
InvalidPath(String),
29+
#[error("invalid dropshot extension use: {0}")]
30+
InvalidExtension(String),
2931
#[error("internal error {0}")]
3032
InternalError(String),
3133
}
@@ -36,6 +38,7 @@ pub struct Generator {
3638
type_space: TypeSpace,
3739
settings: GenerationSettings,
3840
uses_futures: bool,
41+
uses_websockets: bool,
3942
}
4043

4144
#[derive(Default, Clone)]
@@ -116,6 +119,7 @@ impl Default for Generator {
116119
),
117120
settings: Default::default(),
118121
uses_futures: Default::default(),
122+
uses_websockets: Default::default(),
119123
}
120124
}
121125
}
@@ -133,6 +137,7 @@ impl Generator {
133137
type_space: TypeSpace::new(&type_settings),
134138
settings: settings.clone(),
135139
uses_futures: false,
140+
uses_websockets: false,
136141
}
137142
}
138143

@@ -426,7 +431,7 @@ impl Generator {
426431
"bytes = \"1.1\"",
427432
"futures-core = \"0.3\"",
428433
"percent-encoding = \"2.1\"",
429-
"reqwest = { version = \"0.11\", features = [\"json\", \"stream\"] }",
434+
"reqwest = { version = \"0.11.12\", features = [\"json\", \"stream\"] }",
430435
"serde = { version = \"1.0\", features = [\"derive\"] }",
431436
"serde_urlencoded = \"0.7\"",
432437
];
@@ -444,6 +449,10 @@ impl Generator {
444449
if self.uses_futures {
445450
deps.push("futures = \"0.3\"")
446451
}
452+
if self.uses_websockets {
453+
deps.push("base64 = \"0.13\"");
454+
deps.push("rand = \"0.8\"");
455+
}
447456
if self.type_space.uses_serde_json() {
448457
deps.push("serde_json = \"1.0\"")
449458
}

progenitor-impl/src/method.rs

+62-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub(crate) struct OperationMethod {
2929
params: Vec<OperationParameter>,
3030
responses: Vec<OperationResponse>,
3131
dropshot_paginated: Option<DropshotPagination>,
32+
dropshot_websocket: bool,
3233
}
3334

3435
enum HttpMethod {
@@ -189,6 +190,7 @@ impl OperationResponseStatus {
189190
matches!(
190191
self,
191192
OperationResponseStatus::Default
193+
| OperationResponseStatus::Code(101)
192194
| OperationResponseStatus::Code(200..=299)
193195
| OperationResponseStatus::Range(2)
194196
)
@@ -225,6 +227,7 @@ enum OperationResponseType {
225227
Type(TypeId),
226228
None,
227229
Raw,
230+
Upgrade,
228231
}
229232

230233
impl Generator {
@@ -338,6 +341,12 @@ impl Generator {
338341
})
339342
.collect::<Result<Vec<_>>>()?;
340343

344+
let dropshot_websocket =
345+
operation.extensions.get("x-dropshot-websocket").is_some();
346+
if dropshot_websocket {
347+
self.uses_websockets = true;
348+
}
349+
341350
if let Some(body_param) = self.get_body_param(operation, components)? {
342351
params.push(body_param);
343352
}
@@ -378,9 +387,10 @@ impl Generator {
378387
let (status_code, response) = v?;
379388

380389
// We categorize responses as "typed" based on the
381-
// "application/json" content type, "raw" if there's any other
382-
// response content type (we don't investigate further), or
383-
// "none" if there is no content.
390+
// "application/json" content type, "upgrade" if it's a
391+
// websocket channel without a meaningful content-type,
392+
// "raw" if there's any other response content type (we don't
393+
// investigate further), or "none" if there is no content.
384394
// TODO if there are multiple response content types we could
385395
// treat those like different response types and create an
386396
// enum; the generated client method would check for the
@@ -407,6 +417,8 @@ impl Generator {
407417
};
408418

409419
OperationResponseType::Type(typ)
420+
} else if dropshot_websocket {
421+
OperationResponseType::Upgrade
410422
} else if response.content.first().is_some() {
411423
OperationResponseType::Raw
412424
} else {
@@ -449,9 +461,25 @@ impl Generator {
449461
});
450462
}
451463

464+
// Must accept HTTP 101 Switching Protocols
465+
if dropshot_websocket {
466+
responses.push(OperationResponse {
467+
status_code: OperationResponseStatus::Code(101),
468+
typ: OperationResponseType::Upgrade,
469+
description: None,
470+
})
471+
}
472+
452473
let dropshot_paginated =
453474
self.dropshot_pagination_data(operation, &params, &responses);
454475

476+
if dropshot_websocket && dropshot_paginated.is_some() {
477+
return Err(Error::InvalidExtension(format!(
478+
"conflicting extensions in {:?}",
479+
operation_id
480+
)));
481+
}
482+
455483
Ok(OperationMethod {
456484
operation_id: sanitize(operation_id, Case::Snake),
457485
tags: operation.tags.clone(),
@@ -465,6 +493,7 @@ impl Generator {
465493
params,
466494
responses,
467495
dropshot_paginated,
496+
dropshot_websocket,
468497
})
469498
}
470499

@@ -705,6 +734,20 @@ impl Generator {
705734
(query_build, query_use)
706735
};
707736

737+
let websock_hdrs = if method.dropshot_websocket {
738+
quote! {
739+
.header(reqwest::header::CONNECTION, "Upgrade")
740+
.header(reqwest::header::UPGRADE, "websocket")
741+
.header(reqwest::header::SEC_WEBSOCKET_VERSION, "13")
742+
.header(
743+
reqwest::header::SEC_WEBSOCKET_KEY,
744+
base64::encode(rand::random::<[u8; 16]>()),
745+
)
746+
}
747+
} else {
748+
quote! {}
749+
};
750+
708751
// Generate the path rename map; then use it to generate code for
709752
// assigning the path parameters to the `url` variable.
710753
let url_renames = method
@@ -791,6 +834,11 @@ impl Generator {
791834
Ok(ResponseValue::stream(response))
792835
}
793836
}
837+
OperationResponseType::Upgrade => {
838+
quote! {
839+
ResponseValue::upgrade(response).await
840+
}
841+
}
794842
};
795843

796844
quote! { #pat => { #decode } }
@@ -842,6 +890,13 @@ impl Generator {
842890
))
843891
}
844892
}
893+
OperationResponseType::Upgrade => {
894+
if response.status_code == OperationResponseStatus::Default {
895+
return quote! { } // catch-all handled below
896+
} else {
897+
todo!("non-default error response handling for upgrade requests is not yet implemented");
898+
}
899+
}
845900
};
846901

847902
quote! { #pat => { #decode } }
@@ -879,6 +934,7 @@ impl Generator {
879934
. #method_func (url)
880935
#(#body_func)*
881936
#query_use
937+
#websock_hdrs
882938
.build()?;
883939
#pre_hook
884940
let result = #client.client
@@ -988,6 +1044,9 @@ impl Generator {
9881044
OperationResponseType::Raw => {
9891045
quote! { ByteStream }
9901046
}
1047+
OperationResponseType::Upgrade => {
1048+
quote! { reqwest::Upgraded }
1049+
}
9911050
})
9921051
// TODO should this be a bytestream?
9931052
.unwrap_or_else(|| quote! { () });

0 commit comments

Comments
 (0)