Skip to content

Commit 9c7072c

Browse files
committed
Experiment with rate limiting at the membrane layer
1 parent f958b8d commit 9c7072c

File tree

11 files changed

+143
-8
lines changed

11 files changed

+143
-8
lines changed

Cargo.lock

+42
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dart_example/test/main_test.dart

+13
Original file line numberDiff line numberDiff line change
@@ -506,4 +506,17 @@ void main() {
506506
id: Filter(value: [Match(field: "id", value: "1")]),
507507
withinGdpr: GDPR(value: true));
508508
});
509+
510+
test('test that functions can be rate limited', () async {
511+
final contact =
512+
Contact(id: 1, fullName: "Alice Smith", status: Status.pending);
513+
final accounts = AccountsApi();
514+
515+
assert(await accounts.rateLimitedFunction(contact: contact) ==
516+
contact.fullName);
517+
expect(() async => await accounts.rateLimitedFunction(contact: contact),
518+
throwsA(isA<MembraneRateLimited>()));
519+
expect(() async => await accounts.rateLimitedFunction(contact: contact),
520+
throwsA(isA<MembraneRateLimited>()));
521+
});
509522
}

example/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ skip-codegen = ["membrane/skip-generate"]
2222

2323
[dependencies]
2424
async-stream = "0.3"
25+
derated = {path = "../../derated"}
2526
futures = "0.3"
2627
membrane = {path = "../membrane"}
28+
once_cell = "*"
2729
serde = {version = "1.0", features = ["derive"]}
2830
serde_bytes = "0.11"
2931
tokio = {version = "1", features = ["full"]}

example/src/application/advanced.rs

+29
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use data::OptionsDemo;
22
use membrane::emitter::{emitter, Emitter, StreamEmitter};
33
use membrane::{async_dart, sync_dart};
4+
use once_cell::sync::Lazy;
45
use tokio_stream::Stream;
56

7+
use std::collections::hash_map::DefaultHasher;
68
// used for background threading examples
79
use std::{thread, time::Duration};
810

@@ -584,3 +586,30 @@ pub async fn get_org_with_borrowed_type(
584586
pub async fn unused_duplicate_borrows(_id: i64) -> Result<data::Organization, String> {
585587
todo!()
586588
}
589+
590+
struct MyLimit(RateLimit);
591+
592+
impl MyLimit {
593+
fn per_milliseconds(milliseconds: u64, max_queued: Option<u64>) -> Self {
594+
Self(RateLimit::per_milliseconds(milliseconds, max_queued))
595+
}
596+
597+
fn hash_rate_limited_function(&self, fn_name: &str, contact: &data::Contact) -> u64 {
598+
use std::hash::{Hash, Hasher};
599+
let mut s = DefaultHasher::new();
600+
(fn_name, contact.id).hash(&mut s);
601+
s.finish()
602+
}
603+
604+
async fn check(&self, key: &'static str, hash: u64) -> Result<(), derated::Dropped> {
605+
self.0.check(key, hash).await
606+
}
607+
}
608+
609+
use derated::RateLimit;
610+
static RATE_LIMIT: Lazy<MyLimit> = Lazy::new(|| MyLimit::per_milliseconds(100, None));
611+
612+
#[async_dart(namespace = "accounts", rate_limit = RATE_LIMIT)]
613+
pub async fn rate_limited_function(contact: data::Contact) -> Result<String, String> {
614+
Ok(contact.full_name)
615+
}

example/src/data.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
55

66
#[dart_enum(namespace = "accounts")]
77
#[dart_enum(namespace = "orgs")]
8-
#[derive(Debug, Clone, Deserialize, Serialize)]
8+
#[derive(Debug, Clone, Deserialize, Serialize, Hash)]
99
pub enum Status {
1010
Pending,
1111
Active,
@@ -42,7 +42,7 @@ pub struct Mixed {
4242
three: Option<VecWrapper>,
4343
}
4444

45-
#[derive(Debug, Clone, Deserialize, Serialize)]
45+
#[derive(Debug, Clone, Deserialize, Serialize, Hash)]
4646
pub struct Contact {
4747
pub id: i64,
4848
pub full_name: String,

membrane/src/generators/exceptions.rs

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class MembraneRustPanicException extends MembraneException {
3030
class MembraneUnknownResponseVariantException extends MembraneException {
3131
const MembraneUnknownResponseVariantException([String? message]) : super(message);
3232
}
33+
34+
class MembraneRateLimited extends MembraneException {
35+
const MembraneRateLimited([String? message]) : super(message);
36+
}
3337
"#
3438
.to_string()
3539
}

membrane/src/generators/functions.rs

+12-3
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,11 @@ impl Callable for Ffi {
338338
_log.{fine_logger}('Deserializing data from {fn_name}');
339339
}}
340340
final deserializer = BincodeDeserializer(data.asTypedList(length + 8).sublist(8));
341-
if (deserializer.deserializeUint8() == MembraneMsgKind.ok) {{
341+
final msgCode = deserializer.deserializeUint8();
342+
if (msgCode == MembraneMsgKind.ok) {{
342343
return {return_de};
344+
}} else if (msgCode == MembraneMsgKind.rateLimited) {{
345+
throw MembraneRateLimited();
343346
}}
344347
throw {class_name}ApiError({error_de});
345348
}} finally {{
@@ -362,8 +365,11 @@ impl Callable for Ffi {
362365
_log.{fine_logger}('Deserializing data from {fn_name}');
363366
}}
364367
final deserializer = BincodeDeserializer(input as Uint8List);
365-
if (deserializer.deserializeUint8() == MembraneMsgKind.ok) {{
368+
final msgCode = deserializer.deserializeUint8();
369+
if (msgCode == MembraneMsgKind.ok) {{
366370
return {return_de};
371+
}} else if (msgCode == MembraneMsgKind.rateLimited) {{
372+
throw MembraneRateLimited();
367373
}}
368374
throw {class_name}ApiError({error_de});
369375
}});
@@ -394,8 +400,11 @@ impl Callable for Ffi {
394400
_log.{fine_logger}('Deserializing data from {fn_name}');
395401
}}
396402
final deserializer = BincodeDeserializer(await _port.first{timeout} as Uint8List);
397-
if (deserializer.deserializeUint8() == MembraneMsgKind.ok) {{
403+
final msgCode = deserializer.deserializeUint8();
404+
if (msgCode == MembraneMsgKind.ok) {{
398405
return {return_de};
406+
}} else if (msgCode == MembraneMsgKind.rateLimited) {{
407+
throw MembraneRateLimited();
399408
}}
400409
throw {class_name}ApiError({error_de});
401410
}} finally {{

membrane/src/lib.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ impl<'a> Membrane {
656656
typedef enum MembraneMsgKind {
657657
Ok,
658658
Error,
659+
RateLimited,
659660
} MembraneMsgKind;
660661
661662
typedef enum MembraneResponseKind {
@@ -826,6 +827,7 @@ enums:
826827
'Error': 'error'
827828
'Ok': 'ok'
828829
'Panic': 'panic'
830+
'RateLimited': 'rateLimited'
829831
macros:
830832
include:
831833
- __none__
@@ -1354,10 +1356,11 @@ pub struct MembraneResponse {
13541356

13551357
#[doc(hidden)]
13561358
#[repr(u8)]
1357-
#[derive(serde::Serialize)]
1359+
#[derive(serde::Serialize, PartialEq)]
13581360
pub enum MembraneMsgKind {
13591361
Ok,
13601362
Error,
1363+
RateLimited,
13611364
}
13621365

13631366
#[doc(hidden)]

membrane/src/utils.rs

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ use crate::SourceCodeLocation;
22
use allo_isolate::Isolate;
33
use serde::ser::Serialize;
44

5+
pub fn send_rate_limited(isolate: Isolate) -> bool {
6+
if let Ok(buffer) = crate::bincode::serialize(&(crate::MembraneMsgKind::RateLimited as u8)) {
7+
isolate.post(crate::allo_isolate::ZeroCopyBuffer(buffer))
8+
} else {
9+
false
10+
}
11+
}
12+
513
pub fn send<T: Serialize, E: Serialize>(isolate: Isolate, result: Result<T, E>) -> bool {
614
match result {
715
Ok(value) => {

membrane_macro/src/lib.rs

+19-2
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ fn to_token_stream(
172172
timeout,
173173
os_thread,
174174
borrow,
175+
rate_limit,
175176
} = options;
176177

177178
let mut functions = TokenStream::new();
@@ -204,6 +205,18 @@ fn to_token_stream(
204205
let dart_transforms: Vec<String> = DartTransforms::try_from(&inputs)?.into();
205206
let dart_inner_args: Vec<String> = DartArgs::from(&inputs).into();
206207

208+
let rate_limit_condition = if let Some(limiter) = rate_limit {
209+
let hasher_function = Ident::new(&format!("hash_{}", &rust_fn_name), Span::call_site());
210+
quote! {
211+
let ::std::result::Result::Err(err) = {
212+
let hash = #limiter.#hasher_function(#rust_fn_name, #(&#rust_inner_args),*);
213+
#limiter.check(#rust_fn_name, hash).await
214+
}
215+
}
216+
} else {
217+
quote!(false)
218+
};
219+
207220
let return_statement = match output_style {
208221
OutputStyle::EmitterSerialized | OutputStyle::StreamEmitterSerialized if sync => {
209222
syn::Error::new(
@@ -281,9 +294,13 @@ fn to_token_stream(
281294
OutputStyle::Serialized => quote! {
282295
let membrane_join_handle = crate::RUNTIME.get().info_spawn(
283296
async move {
284-
let result: ::std::result::Result<#output, #error> = #fn_name(#(#rust_inner_args),*).await;
285297
let isolate = ::membrane::allo_isolate::Isolate::new(membrane_port);
286-
::membrane::utils::send::<#output, #error>(isolate, result);
298+
if #rate_limit_condition {
299+
::membrane::utils::send_rate_limited(isolate);
300+
} else {
301+
let result: ::std::result::Result<#output, #error> = #fn_name(#(#rust_inner_args),*).await;
302+
::membrane::utils::send::<#output, #error>(isolate, result);
303+
}
287304
},
288305
::membrane::runtime::Info { name: #rust_fn_name }
289306
);

membrane_macro/src/options.rs

+8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub(crate) struct Options {
88
pub timeout: Option<i32>,
99
pub os_thread: bool,
1010
pub borrow: Vec<String>,
11+
pub rate_limit: Option<syn::Path>,
1112
}
1213

1314
pub(crate) fn extract_options(
@@ -64,6 +65,13 @@ pub(crate) fn extract_options(
6465
options.disable_logging = val.value();
6566
options
6667
}
68+
Some((ident, syn::Expr::Path(syn::ExprPath { path, .. })))
69+
if ident == "rate_limit" && !sync =>
70+
{
71+
options.rate_limit = Some(path);
72+
options
73+
}
74+
// TODO handle the invalid rate_limit case
6775
Some((
6876
ident,
6977
Lit(ExprLit {

0 commit comments

Comments
 (0)