@@ -29,6 +29,7 @@ pub(crate) struct OperationMethod {
29
29
params : Vec < OperationParameter > ,
30
30
responses : Vec < OperationResponse > ,
31
31
dropshot_paginated : Option < DropshotPagination > ,
32
+ dropshot_websocket : bool ,
32
33
}
33
34
34
35
enum HttpMethod {
@@ -189,6 +190,7 @@ impl OperationResponseStatus {
189
190
matches ! (
190
191
self ,
191
192
OperationResponseStatus :: Default
193
+ | OperationResponseStatus :: Code ( 101 )
192
194
| OperationResponseStatus :: Code ( 200 ..=299 )
193
195
| OperationResponseStatus :: Range ( 2 )
194
196
)
@@ -225,6 +227,7 @@ enum OperationResponseType {
225
227
Type ( TypeId ) ,
226
228
None ,
227
229
Raw ,
230
+ Upgrade ,
228
231
}
229
232
230
233
impl Generator {
@@ -338,6 +341,12 @@ impl Generator {
338
341
} )
339
342
. collect :: < Result < Vec < _ > > > ( ) ?;
340
343
344
+ let dropshot_websocket =
345
+ operation. extensions . get ( "x-dropshot-websocket" ) . is_some ( ) ;
346
+ if dropshot_websocket {
347
+ self . uses_websockets = true ;
348
+ }
349
+
341
350
if let Some ( body_param) = self . get_body_param ( operation, components) ? {
342
351
params. push ( body_param) ;
343
352
}
@@ -378,9 +387,10 @@ impl Generator {
378
387
let ( status_code, response) = v?;
379
388
380
389
// We categorize responses as "typed" based on the
381
- // "application/json" content type, "raw" if there's any other
382
- // response content type (we don't investigate further), or
383
- // "none" if there is no content.
390
+ // "application/json" content type, "upgrade" if it's a
391
+ // websocket channel without a meaningful content-type,
392
+ // "raw" if there's any other response content type (we don't
393
+ // investigate further), or "none" if there is no content.
384
394
// TODO if there are multiple response content types we could
385
395
// treat those like different response types and create an
386
396
// enum; the generated client method would check for the
@@ -407,6 +417,8 @@ impl Generator {
407
417
} ;
408
418
409
419
OperationResponseType :: Type ( typ)
420
+ } else if dropshot_websocket {
421
+ OperationResponseType :: Upgrade
410
422
} else if response. content . first ( ) . is_some ( ) {
411
423
OperationResponseType :: Raw
412
424
} else {
@@ -449,9 +461,25 @@ impl Generator {
449
461
} ) ;
450
462
}
451
463
464
+ // Must accept HTTP 101 Switching Protocols
465
+ if dropshot_websocket {
466
+ responses. push ( OperationResponse {
467
+ status_code : OperationResponseStatus :: Code ( 101 ) ,
468
+ typ : OperationResponseType :: Upgrade ,
469
+ description : None ,
470
+ } )
471
+ }
472
+
452
473
let dropshot_paginated =
453
474
self . dropshot_pagination_data ( operation, & params, & responses) ;
454
475
476
+ if dropshot_websocket && dropshot_paginated. is_some ( ) {
477
+ return Err ( Error :: InvalidExtension ( format ! (
478
+ "conflicting extensions in {:?}" ,
479
+ operation_id
480
+ ) ) ) ;
481
+ }
482
+
455
483
Ok ( OperationMethod {
456
484
operation_id : sanitize ( operation_id, Case :: Snake ) ,
457
485
tags : operation. tags . clone ( ) ,
@@ -465,6 +493,7 @@ impl Generator {
465
493
params,
466
494
responses,
467
495
dropshot_paginated,
496
+ dropshot_websocket,
468
497
} )
469
498
}
470
499
@@ -705,6 +734,20 @@ impl Generator {
705
734
( query_build, query_use)
706
735
} ;
707
736
737
+ let websock_hdrs = if method. dropshot_websocket {
738
+ quote ! {
739
+ . header( reqwest:: header:: CONNECTION , "Upgrade" )
740
+ . header( reqwest:: header:: UPGRADE , "websocket" )
741
+ . header( reqwest:: header:: SEC_WEBSOCKET_VERSION , "13" )
742
+ . header(
743
+ reqwest:: header:: SEC_WEBSOCKET_KEY ,
744
+ base64:: encode( rand:: random:: <[ u8 ; 16 ] >( ) ) ,
745
+ )
746
+ }
747
+ } else {
748
+ quote ! { }
749
+ } ;
750
+
708
751
// Generate the path rename map; then use it to generate code for
709
752
// assigning the path parameters to the `url` variable.
710
753
let url_renames = method
@@ -791,6 +834,11 @@ impl Generator {
791
834
Ok ( ResponseValue :: stream( response) )
792
835
}
793
836
}
837
+ OperationResponseType :: Upgrade => {
838
+ quote ! {
839
+ ResponseValue :: upgrade( response) . await
840
+ }
841
+ }
794
842
} ;
795
843
796
844
quote ! { #pat => { #decode } }
@@ -842,6 +890,13 @@ impl Generator {
842
890
) )
843
891
}
844
892
}
893
+ OperationResponseType :: Upgrade => {
894
+ if response. status_code == OperationResponseStatus :: Default {
895
+ return quote ! { } // catch-all handled below
896
+ } else {
897
+ todo ! ( "non-default error response handling for upgrade requests is not yet implemented" ) ;
898
+ }
899
+ }
845
900
} ;
846
901
847
902
quote ! { #pat => { #decode } }
@@ -879,6 +934,7 @@ impl Generator {
879
934
. #method_func ( url)
880
935
#( #body_func) *
881
936
#query_use
937
+ #websock_hdrs
882
938
. build( ) ?;
883
939
#pre_hook
884
940
let result = #client. client
@@ -988,6 +1044,9 @@ impl Generator {
988
1044
OperationResponseType :: Raw => {
989
1045
quote ! { ByteStream }
990
1046
}
1047
+ OperationResponseType :: Upgrade => {
1048
+ quote ! { reqwest:: Upgraded }
1049
+ }
991
1050
} )
992
1051
// TODO should this be a bytestream?
993
1052
. unwrap_or_else ( || quote ! { ( ) } ) ;
0 commit comments