Skip to content

Commit bd03cd7

Browse files
authored
Sync All Preference Updates Across Devices (#1400)
* Sync preference in batches * a bit of cleanup * update than handling of the incoming mls sync msg * comments * wip * rely on mls for ordering * move the group_id to the info field during hkdf as suggested * lint * Constrain params to only what they need to be * lint * Store and emit hmac key updates * fix wasm * worker handle * fix user preferences * fix tests * cleanup * test fix * cleanup * gate the worker handle behind test utils * lint * lint * cleanup * try this * add test-utils flag * cleanup * revert * revert * test fix * xmtp_common exists now
1 parent 600a8e2 commit bd03cd7

File tree

12 files changed

+1222
-1019
lines changed

12 files changed

+1222
-1019
lines changed

bindings_wasm/Cargo.toml

+5-6
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,24 @@ version.workspace = true
77
crate-type = ["cdylib", "rlib"]
88

99
[dependencies]
10+
console_error_panic_hook.workspace = true
1011
hex.workspace = true
1112
js-sys.workspace = true
1213
prost.workspace = true
1314
serde-wasm-bindgen = "0.6.5"
1415
serde.workspace = true
1516
tokio.workspace = true
17+
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
18+
tracing-web = "0.1"
19+
tracing.workspace = true
1620
wasm-bindgen-futures.workspace = true
1721
wasm-bindgen.workspace = true
1822
xmtp_api_http = { path = "../xmtp_api_http" }
23+
xmtp_common.workspace = true
1924
xmtp_cryptography = { path = "../xmtp_cryptography" }
2025
xmtp_id = { path = "../xmtp_id" }
2126
xmtp_mls = { path = "../xmtp_mls", features = ["test-utils", "http-api"] }
22-
xmtp_common.workspace = true
2327
xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] }
24-
tracing-web = "0.1"
25-
tracing.workspace = true
26-
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
27-
console_error_panic_hook.workspace = true
2828

2929
[dev-dependencies]
3030
wasm-bindgen-test.workspace = true
31-
xmtp_mls = { path = "../xmtp_mls", features = ["test-utils", "http-api"] }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
DROP TABLE user_preferences;
2+
3+
CREATE TABLE user_preferences(
4+
id INTEGER PRIMARY KEY ASC,
5+
hmac_key BLOB NOT NULL
6+
);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
DROP TABLE user_preferences;
2+
3+
CREATE TABLE user_preferences(
4+
id INTEGER PRIMARY KEY ASC NOT NULL,
5+
hmac_key BLOB
6+
);

xmtp_mls/src/client.rs

+11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ use xmtp_proto::xmtp::mls::api::v1::{
3030
GroupMessage, WelcomeMessage,
3131
};
3232

33+
#[cfg(any(test, feature = "test-utils"))]
34+
use crate::groups::device_sync::WorkerHandle;
35+
3336
use crate::{
3437
api::ApiClientWrapper,
3538
groups::{
@@ -144,6 +147,9 @@ pub struct Client<ApiClient, V = RemoteSignatureVerifier<ApiClient>> {
144147
pub(crate) local_events: broadcast::Sender<LocalEvents<Self>>,
145148
/// The method of verifying smart contract wallet signatures for this Client
146149
pub(crate) scw_verifier: Arc<V>,
150+
151+
#[cfg(any(test, feature = "test-utils"))]
152+
pub(crate) sync_worker_handle: Arc<parking_lot::Mutex<Option<Arc<WorkerHandle>>>>,
147153
}
148154

149155
// most of these things are `Arc`'s
@@ -155,6 +161,9 @@ impl<ApiClient, V> Clone for Client<ApiClient, V> {
155161
history_sync_url: self.history_sync_url.clone(),
156162
local_events: self.local_events.clone(),
157163
scw_verifier: self.scw_verifier.clone(),
164+
165+
#[cfg(any(test, feature = "test-utils"))]
166+
sync_worker_handle: self.sync_worker_handle.clone(),
158167
}
159168
}
160169
}
@@ -240,6 +249,8 @@ where
240249
context,
241250
history_sync_url,
242251
local_events: tx,
252+
#[cfg(any(test, feature = "test-utils"))]
253+
sync_worker_handle: Arc::new(parking_lot::Mutex::default()),
243254
scw_verifier: scw_verifier.into(),
244255
}
245256
}

xmtp_mls/src/groups/device_sync.rs

+97-32
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@ use futures::{Stream, StreamExt};
2424
use preference_sync::UserPreferenceUpdate;
2525
use rand::{Rng, RngCore};
2626
use serde::{Deserialize, Serialize};
27+
use std::future::Future;
2728
use std::pin::Pin;
29+
use std::sync::atomic::{AtomicUsize, Ordering};
30+
use std::sync::Arc;
2831
use thiserror::Error;
29-
use tokio::sync::OnceCell;
32+
use tokio::sync::{Notify, OnceCell};
33+
use tokio::time::error::Elapsed;
34+
use tokio::time::timeout;
3035
use tracing::{instrument, warn};
3136
use xmtp_common::time::{now_ns, Duration};
3237
use xmtp_common::{retry_async, Retry, RetryableError};
@@ -104,8 +109,8 @@ pub enum DeviceSyncError {
104109
SyncPayloadTooOld,
105110
#[error(transparent)]
106111
Subscribe(#[from] SubscribeError),
107-
#[error("Unable to serialize: {0}")]
108-
Bincode(String),
112+
#[error(transparent)]
113+
Bincode(#[from] bincode::Error),
109114
}
110115

111116
impl RetryableError for DeviceSyncError {
@@ -114,6 +119,17 @@ impl RetryableError for DeviceSyncError {
114119
}
115120
}
116121

122+
#[cfg(any(test, feature = "test-utils"))]
123+
impl<ApiClient, V> Client<ApiClient, V> {
124+
pub fn sync_worker_handle(&self) -> Option<Arc<WorkerHandle>> {
125+
self.sync_worker_handle.lock().clone()
126+
}
127+
128+
pub(crate) fn set_sync_worker_handle(&self, handle: Arc<WorkerHandle>) {
129+
*self.sync_worker_handle.lock() = Some(handle);
130+
}
131+
}
132+
117133
impl<ApiClient, V> Client<ApiClient, V>
118134
where
119135
ApiClient: XmtpApi + Send + Sync + 'static,
@@ -128,7 +144,10 @@ where
128144
"starting sync worker"
129145
);
130146

131-
SyncWorker::new(client).spawn_worker();
147+
let worker = SyncWorker::new(client);
148+
#[cfg(any(test, feature = "test-utils"))]
149+
self.set_sync_worker_handle(worker.handle.clone());
150+
worker.spawn_worker();
132151
}
133152
}
134153

@@ -141,6 +160,57 @@ pub struct SyncWorker<ApiClient, V> {
141160
>,
142161
init: OnceCell<()>,
143162
retry: Retry,
163+
164+
// Number of events processed
165+
#[cfg(any(test, feature = "test-utils"))]
166+
handle: Arc<WorkerHandle>,
167+
}
168+
169+
#[cfg(any(test, feature = "test-utils"))]
170+
pub struct WorkerHandle {
171+
processed: AtomicUsize,
172+
notify: Notify,
173+
}
174+
175+
#[cfg(any(test, feature = "test-utils"))]
176+
impl WorkerHandle {
177+
pub async fn wait_for_new_events(&self, mut count: usize) -> Result<(), Elapsed> {
178+
timeout(Duration::from_secs(3), async {
179+
while count > 0 {
180+
self.notify.notified().await;
181+
count -= 1;
182+
}
183+
})
184+
.await?;
185+
186+
Ok(())
187+
}
188+
189+
pub async fn wait_for_processed_count(&self, expected: usize) -> Result<(), Elapsed> {
190+
timeout(Duration::from_secs(3), async {
191+
while self.processed.load(Ordering::SeqCst) < expected {
192+
self.notify.notified().await;
193+
}
194+
})
195+
.await?;
196+
197+
Ok(())
198+
}
199+
200+
pub async fn block_for_num_events<Fut>(&self, num_events: usize, op: Fut) -> Result<(), Elapsed>
201+
where
202+
Fut: Future<Output = ()>,
203+
{
204+
let processed_count = self.processed_count();
205+
op.await;
206+
self.wait_for_processed_count(processed_count + num_events)
207+
.await?;
208+
Ok(())
209+
}
210+
211+
pub fn processed_count(&self) -> usize {
212+
self.processed.load(Ordering::SeqCst)
213+
}
144214
}
145215

146216
impl<ApiClient, V> SyncWorker<ApiClient, V>
@@ -168,33 +238,22 @@ where
168238
self.on_request(message_id, &provider).await?
169239
}
170240
},
171-
LocalEvents::OutgoingPreferenceUpdates(consent_records) => {
172-
let provider = self.client.mls_provider()?;
173-
for record in consent_records {
174-
let UserPreferenceUpdate::ConsentUpdate(consent_record) = record else {
175-
continue;
176-
};
177-
178-
self.client
179-
.send_consent_update(&provider, consent_record)
180-
.await?;
181-
}
241+
LocalEvents::OutgoingPreferenceUpdates(preference_updates) => {
242+
tracing::error!("Outgoing preference update {preference_updates:?}");
243+
UserPreferenceUpdate::sync_across_devices(preference_updates, &self.client)
244+
.await?;
182245
}
183-
LocalEvents::IncomingPreferenceUpdate(updates) => {
184-
let provider = self.client.mls_provider()?;
185-
let consent_records = updates
186-
.into_iter()
187-
.filter_map(|pu| match pu {
188-
UserPreferenceUpdate::ConsentUpdate(cr) => Some(cr),
189-
_ => None,
190-
})
191-
.collect::<Vec<_>>();
192-
provider
193-
.conn_ref()
194-
.insert_or_replace_consent_records(&consent_records)?;
246+
LocalEvents::IncomingPreferenceUpdate(_) => {
247+
tracing::error!("Incoming preference update");
195248
}
196249
_ => {}
197250
}
251+
252+
#[cfg(any(test, feature = "test-utils"))]
253+
{
254+
self.handle.processed.fetch_add(1, Ordering::SeqCst);
255+
self.handle.notify.notify_waiters();
256+
}
198257
}
199258
Ok(())
200259
}
@@ -319,6 +378,12 @@ where
319378
stream,
320379
init: OnceCell::new(),
321380
retry,
381+
382+
#[cfg(any(test, feature = "test-utils"))]
383+
handle: Arc::new(WorkerHandle {
384+
processed: AtomicUsize::new(0),
385+
notify: Notify::new(),
386+
}),
322387
}
323388
}
324389

@@ -404,10 +469,10 @@ where
404469

405470
let _message_id = sync_group.prepare_message(&content_bytes, provider, {
406471
let request = request.clone();
407-
move |_time_ns| PlaintextEnvelope {
472+
move |now| PlaintextEnvelope {
408473
content: Some(Content::V2(V2 {
409474
message_type: Some(MessageType::DeviceSyncRequest(request)),
410-
idempotency_key: new_request_id(),
475+
idempotency_key: now.to_string(),
411476
})),
412477
}
413478
})?;
@@ -471,14 +536,14 @@ where
471536
(content_bytes, contents)
472537
};
473538

474-
sync_group.prepare_message(&content_bytes, provider, |_time_ns| PlaintextEnvelope {
539+
sync_group.prepare_message(&content_bytes, provider, |now| PlaintextEnvelope {
475540
content: Some(Content::V2(V2 {
476-
idempotency_key: new_request_id(),
477541
message_type: Some(MessageType::DeviceSyncReply(contents)),
542+
idempotency_key: now.to_string(),
478543
})),
479544
})?;
480545

481-
sync_group.sync_until_last_intent_resolved(provider).await?;
546+
sync_group.publish_intents(provider).await?;
482547

483548
Ok(())
484549
}

0 commit comments

Comments
 (0)