Skip to content

Commit c9a339a

Browse files
committed
Refine message bus topic and endpoint validations
1 parent 380709d commit c9a339a

File tree

4 files changed

+98
-55
lines changed

4 files changed

+98
-55
lines changed

crates/common/src/msgbus/core.rs

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,24 @@ use ahash::{AHashMap, AHashSet};
2626
use handler::ShareableMessageHandler;
2727
use indexmap::IndexMap;
2828
use matching::is_matching_backtracking;
29-
use nautilus_core::UUID4;
29+
use nautilus_core::{
30+
UUID4,
31+
correctness::{FAILED, check_predicate_true, check_valid_string},
32+
};
3033
use nautilus_model::identifiers::TraderId;
3134
use switchboard::MessagingSwitchboard;
3235
use ustr::Ustr;
3336

3437
use super::{handler, matching, set_message_bus, switchboard};
3538

39+
#[inline(always)]
40+
fn check_fully_qualified_string(value: &Ustr, key: &str) -> anyhow::Result<()> {
41+
check_predicate_true(
42+
!value.chars().any(|c| c == '*' || c == '?'),
43+
&format!("{key} `value` contained invalid characters, was {value}"),
44+
)
45+
}
46+
3647
/// A string pattern for a subscription. The pattern is used to match topics.
3748
///
3849
/// A pattern is made of characters:
@@ -80,20 +91,29 @@ impl From<&Topic> for Pattern {
8091
pub struct Topic(pub Ustr);
8192

8293
impl Topic {
83-
pub fn as_pattern(&self) -> Pattern {
84-
Pattern(self.0)
94+
/// Creates a new [`Topic`] instance.
95+
///
96+
/// # Errors
97+
///
98+
/// Returns an error if:
99+
/// - `value` string is empty, all whitespace, or contains non-ASCII characters.
100+
/// - `value` string contains wildcard characters '*' or '?'.
101+
pub fn new<T: AsRef<str>>(value: T) -> anyhow::Result<Self> {
102+
let topic = Ustr::from(value.as_ref());
103+
check_valid_string(value, stringify!(value))?;
104+
check_fully_qualified_string(&topic, stringify!(Topic))?;
105+
106+
Ok(Self(topic))
85107
}
86108

87-
// TODO: Every initialization should validate
88-
pub fn validate<T: AsRef<str>>(topic: T) -> bool {
89-
let topic = Ustr::from(topic.as_ref());
90-
!topic.chars().any(|c| c == '*' || c == '?')
109+
pub fn as_pattern(&self) -> Pattern {
110+
Pattern(self.0)
91111
}
92112
}
93113

94114
impl<T: AsRef<str>> From<T> for Topic {
95115
fn from(value: T) -> Self {
96-
Self(Ustr::from(value.as_ref()))
116+
Self::new(value).expect(FAILED)
97117
}
98118
}
99119

@@ -118,16 +138,25 @@ impl Display for Topic {
118138
pub struct Endpoint(pub Ustr);
119139

120140
impl Endpoint {
121-
// TODO: Every initialization should validate
122-
pub fn validate<T: AsRef<str>>(endpoint: T) -> bool {
123-
let endpoint = Ustr::from(endpoint.as_ref());
124-
!endpoint.chars().any(|c| c == '*' || c == '?')
141+
/// Creates a new [`Endpoint`] instance.
142+
///
143+
/// # Errors
144+
///
145+
/// Returns an error if:
146+
/// - `value` string is empty, all whitespace, or contains non-ASCII characters.
147+
/// - `value` string contains wildcard characters '*' or '?'.
148+
pub fn new<T: AsRef<str>>(value: T) -> anyhow::Result<Self> {
149+
let endpoint = Ustr::from(value.as_ref());
150+
check_valid_string(value, stringify!(value))?;
151+
check_fully_qualified_string(&endpoint, stringify!(Endpoint))?;
152+
153+
Ok(Self(endpoint))
125154
}
126155
}
127156

128157
impl<T: AsRef<str>> From<T> for Endpoint {
129158
fn from(value: T) -> Self {
130-
Self(Ustr::from(value.as_ref()))
159+
Self::new(value).expect(FAILED)
131160
}
132161
}
133162

@@ -305,12 +334,12 @@ impl MessageBus {
305334
.collect()
306335
}
307336

308-
/// Returns whether there are subscribers for the given `topic`.
337+
/// Returns whether there are subscribers for the `topic`.
309338
pub fn has_subscribers<T: AsRef<str>>(&self, topic: T) -> bool {
310339
self.subscriptions_count(topic) > 0
311340
}
312341

313-
/// Returns the count of subscribers for the given `topic`.
342+
/// Returns the count of subscribers for the `topic`.
314343
#[must_use]
315344
pub fn subscriptions_count<T: AsRef<str>>(&self, topic: T) -> usize {
316345
let topic = Topic::from(topic);
@@ -335,14 +364,14 @@ impl MessageBus {
335364
.collect()
336365
}
337366

338-
/// Returns whether there is a registered endpoint for the given `pattern`.
367+
/// Returns whether there is a registered endpoint for the `pattern`.
339368
#[must_use]
340369
pub fn is_registered<T: AsRef<str>>(&self, endpoint: T) -> bool {
341370
self.endpoints
342371
.contains_key(&Endpoint::from(endpoint.as_ref()))
343372
}
344373

345-
/// Returns whether the given `handler` is subscribed to the given `pattern`.
374+
/// Returns whether the `handler` is subscribed to the `pattern`.
346375
#[must_use]
347376
pub fn is_subscribed<T: AsRef<str>>(
348377
&self,
@@ -358,25 +387,25 @@ impl MessageBus {
358387
///
359388
/// # Errors
360389
///
361-
/// This function never returns an error (TBD).
390+
/// This function never returns an error (TBD once backing database added).
362391
pub const fn close(&self) -> anyhow::Result<()> {
363392
// TODO: Integrate the backing database
364393
Ok(())
365394
}
366395

367-
/// Returns the handler for the given `endpoint`.
396+
/// Returns the handler for the `endpoint`.
368397
#[must_use]
369398
pub fn get_endpoint(&self, endpoint: &Endpoint) -> Option<&ShareableMessageHandler> {
370399
self.endpoints.get(endpoint)
371400
}
372401

373-
/// Returns the handler for the given `correlation_id`.
402+
/// Returns the handler for the `correlation_id`.
374403
#[must_use]
375404
pub fn get_response_handler(&self, correlation_id: &UUID4) -> Option<&ShareableMessageHandler> {
376405
self.correlation_index.get(correlation_id)
377406
}
378407

379-
/// Finds the subscriptions with pattern matching the given `topic`.
408+
/// Finds the subscriptions with pattern matching the `topic`.
380409
pub(crate) fn find_topic_matches(&self, topic: &Topic) -> Vec<Subscription> {
381410
self.subscriptions
382411
.iter()
@@ -390,7 +419,7 @@ impl MessageBus {
390419
.collect()
391420
}
392421

393-
/// Finds the subscriptions which match the given `topic` and caches the
422+
/// Finds the subscriptions which match the `topic` and caches the
394423
/// results in the `patterns` map.
395424
#[must_use]
396425
pub fn matching_subscriptions<T: AsRef<str>>(&mut self, topic: T) -> Vec<Subscription> {
@@ -411,7 +440,7 @@ impl MessageBus {
411440
///
412441
/// # Errors
413442
///
414-
/// Returns an error if a handler is already registered for the given correlation ID.
443+
/// Returns an error if `handler` is already registered for the `correlation_id`.
415444
pub fn register_response_handler(
416445
&mut self,
417446
correlation_id: &UUID4,

crates/common/src/msgbus/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ pub fn get_message_bus() -> Rc<RefCell<MessageBus>> {
8787

8888
/// Sends the `message` to the `endpoint`.
8989
pub fn send(endpoint: &Endpoint, message: &dyn Any) {
90+
// TODO: This should return a Result (in case endpoint doesn't exist)
9091
let handler = get_message_bus().borrow().get_endpoint(endpoint).cloned();
9192
if let Some(handler) = handler {
9293
handler.0.handle(message);
@@ -178,7 +179,7 @@ pub fn deregister(endpoint: &Endpoint) {
178179
.shift_remove(endpoint);
179180
}
180181

181-
/// Subscribes the given `handler` to the `pattern` with an optional `priority`.
182+
/// Subscribes the `handler` to the `pattern` with an optional `priority`.
182183
///
183184
/// # Warnings
184185
///

crates/common/src/msgbus/tests.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ fn test_matching_subscriptions() {
208208
}
209209

210210
#[rstest]
211-
#[case("*", "*", true)]
212211
#[case("a", "*", true)]
213212
#[case("a", "a", true)]
214213
#[case("a", "b", false)]
@@ -310,9 +309,9 @@ fn test_subscription_pattern_matching() {
310309
let handler2 = get_stub_shareable_handler(Some(Ustr::from("2")));
311310
let handler3 = get_stub_shareable_handler(Some(Ustr::from("3")));
312311

313-
msgbus::subscribe(Pattern::from("data.quotes.*"), handler1, None);
314-
msgbus::subscribe(Pattern::from("data.trades.*"), handler2, None);
315-
msgbus::subscribe(Pattern::from("data.*.BINANCE.*"), handler3, None);
312+
msgbus::subscribe_str("data.quotes.*", handler1, None);
313+
msgbus::subscribe_str("data.trades.*", handler2, None);
314+
msgbus::subscribe_str("data.*.BINANCE.*", handler3, None);
316315
assert_eq!(msgbus.borrow().subscriptions().len(), 3);
317316

318317
let topic = "data.quotes.BINANCE.ETHUSDT";

crates/common/src/python/msgbus.rs

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@
1515

1616
use std::rc::Rc;
1717

18-
use pyo3::{PyObject, pymethods};
19-
use ustr::Ustr;
18+
use nautilus_core::python::to_pyvalue_err;
19+
use pyo3::{PyObject, PyResult, pymethods};
2020

2121
use super::handler::PythonMessageHandler;
2222
use crate::msgbus::{
23-
BusMessage, Endpoint, MessageBus, Pattern, deregister, get_message_bus,
24-
handler::ShareableMessageHandler, register, subscribe, unsubscribe,
23+
BusMessage, Endpoint, MessageBus, Pattern, Topic, deregister, handler::ShareableMessageHandler,
24+
publish, register, send, subscribe, unsubscribe,
2525
};
2626

2727
#[pymethods]
2828
impl BusMessage {
2929
#[getter]
3030
#[pyo3(name = "topic")]
31-
fn py_close(&mut self) -> String {
31+
fn py_topic(&mut self) -> String {
3232
self.topic.to_string()
3333
}
3434

@@ -49,37 +49,45 @@ impl BusMessage {
4949

5050
#[pymethods]
5151
impl MessageBus {
52-
/// Sends a message to an endpoint.
52+
/// Sends the `message` to the `endpoint`.
53+
///
54+
/// # Errors
55+
///
56+
/// Returns an error if `endpoint` is invalid.
5357
#[pyo3(name = "send")]
54-
pub fn py_send(&self, endpoint: &str, message: PyObject) {
55-
let endpoint = Endpoint::from(endpoint);
56-
57-
if let Some(handler) = self.get_endpoint(&endpoint) {
58-
handler.0.handle(&message);
59-
}
58+
pub fn py_send(&self, endpoint: &str, message: PyObject) -> PyResult<()> {
59+
let endpoint = Endpoint::new(endpoint).map_err(to_pyvalue_err)?;
60+
send(&endpoint, &message);
61+
Ok(())
6062
}
6163

62-
/// Publish a message to a topic.
64+
/// Publishes the `message` to the `topic`.
65+
///
66+
/// # Errors
67+
///
68+
/// Returns an error if `topic` is invalid.
6369
#[pyo3(name = "publish")]
6470
#[staticmethod]
65-
pub fn py_publish(topic: &str, message: PyObject) {
66-
let topic = Ustr::from(topic);
67-
let matching_subs = get_message_bus().borrow_mut().matching_subscriptions(topic);
68-
69-
for sub in matching_subs {
70-
sub.handler.0.handle(&message);
71-
}
71+
pub fn py_publish(topic: &str, message: PyObject) -> PyResult<()> {
72+
let topic = Topic::new(topic).map_err(to_pyvalue_err)?;
73+
publish(&topic, &message);
74+
Ok(())
7275
}
7376

7477
/// Registers the given `handler` for the `endpoint` address.
78+
///
79+
/// Updates endpoint handler if already exists.
80+
///
81+
/// # Errors
82+
///
83+
/// Returns an error if `endpoint` is invalid.
7584
#[pyo3(name = "register")]
7685
#[staticmethod]
77-
pub fn py_register(endpoint: &str, handler: PythonMessageHandler) {
78-
let endpoint = Endpoint::from(endpoint);
79-
80-
// Updates value if key already exists
86+
pub fn py_register(endpoint: &str, handler: PythonMessageHandler) -> PyResult<()> {
87+
let endpoint = Endpoint::new(endpoint).map_err(to_pyvalue_err)?;
8188
let handler = ShareableMessageHandler(Rc::new(handler));
8289
register(endpoint, handler);
90+
Ok(())
8391
}
8492

8593
/// Subscribes the given `handler` to the `topic`.
@@ -90,6 +98,8 @@ impl MessageBus {
9098
///
9199
/// Safety: Priority should be between 0 and 255
92100
///
101+
/// Updates topic handler if already exists.
102+
///
93103
/// # Warnings
94104
///
95105
/// Assigning priority handling is an advanced feature which *shouldn't
@@ -102,7 +112,6 @@ impl MessageBus {
102112
#[pyo3(signature = (topic, handler, priority=None))]
103113
#[staticmethod]
104114
pub fn py_subscribe(topic: &str, handler: PythonMessageHandler, priority: Option<u8>) {
105-
// Updates value if key already exists
106115
let pattern = Pattern::from(topic);
107116
let handler = ShareableMessageHandler(Rc::new(handler));
108117
subscribe(pattern, handler, priority);
@@ -133,10 +142,15 @@ impl MessageBus {
133142
}
134143

135144
/// Deregisters the given `handler` for the `endpoint` address.
145+
///
146+
/// # Errors
147+
///
148+
/// Returns an error if `endpoint` is invalid.
136149
#[pyo3(name = "deregister")]
137150
#[staticmethod]
138-
pub fn py_deregister(endpoint: &str) {
139-
let endpoint = Endpoint::from(endpoint);
151+
pub fn py_deregister(endpoint: &str) -> PyResult<()> {
152+
let endpoint = Endpoint::new(endpoint).map_err(to_pyvalue_err)?;
140153
deregister(&endpoint);
154+
Ok(())
141155
}
142156
}

0 commit comments

Comments
 (0)