@@ -24,9 +24,14 @@ use futures::{Stream, StreamExt};
24
24
use preference_sync:: UserPreferenceUpdate ;
25
25
use rand:: { Rng , RngCore } ;
26
26
use serde:: { Deserialize , Serialize } ;
27
+ use std:: future:: Future ;
27
28
use std:: pin:: Pin ;
29
+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
30
+ use std:: sync:: Arc ;
28
31
use thiserror:: Error ;
29
- use tokio:: sync:: OnceCell ;
32
+ use tokio:: sync:: { Notify , OnceCell } ;
33
+ use tokio:: time:: error:: Elapsed ;
34
+ use tokio:: time:: timeout;
30
35
use tracing:: { instrument, warn} ;
31
36
use xmtp_common:: time:: { now_ns, Duration } ;
32
37
use xmtp_common:: { retry_async, Retry , RetryableError } ;
@@ -104,8 +109,8 @@ pub enum DeviceSyncError {
104
109
SyncPayloadTooOld ,
105
110
#[ error( transparent) ]
106
111
Subscribe ( #[ from] SubscribeError ) ,
107
- #[ error( "Unable to serialize: {0}" ) ]
108
- Bincode ( String ) ,
112
+ #[ error( transparent ) ]
113
+ Bincode ( # [ from ] bincode :: Error ) ,
109
114
}
110
115
111
116
impl RetryableError for DeviceSyncError {
@@ -114,6 +119,17 @@ impl RetryableError for DeviceSyncError {
114
119
}
115
120
}
116
121
122
+ #[ cfg( any( test, feature = "test-utils" ) ) ]
123
+ impl < ApiClient , V > Client < ApiClient , V > {
124
+ pub fn sync_worker_handle ( & self ) -> Option < Arc < WorkerHandle > > {
125
+ self . sync_worker_handle . lock ( ) . clone ( )
126
+ }
127
+
128
+ pub ( crate ) fn set_sync_worker_handle ( & self , handle : Arc < WorkerHandle > ) {
129
+ * self . sync_worker_handle . lock ( ) = Some ( handle) ;
130
+ }
131
+ }
132
+
117
133
impl < ApiClient , V > Client < ApiClient , V >
118
134
where
119
135
ApiClient : XmtpApi + Send + Sync + ' static ,
@@ -128,7 +144,10 @@ where
128
144
"starting sync worker"
129
145
) ;
130
146
131
- SyncWorker :: new ( client) . spawn_worker ( ) ;
147
+ let worker = SyncWorker :: new ( client) ;
148
+ #[ cfg( any( test, feature = "test-utils" ) ) ]
149
+ self . set_sync_worker_handle ( worker. handle . clone ( ) ) ;
150
+ worker. spawn_worker ( ) ;
132
151
}
133
152
}
134
153
@@ -141,6 +160,57 @@ pub struct SyncWorker<ApiClient, V> {
141
160
> ,
142
161
init : OnceCell < ( ) > ,
143
162
retry : Retry ,
163
+
164
+ // Number of events processed
165
+ #[ cfg( any( test, feature = "test-utils" ) ) ]
166
+ handle : Arc < WorkerHandle > ,
167
+ }
168
+
169
+ #[ cfg( any( test, feature = "test-utils" ) ) ]
170
+ pub struct WorkerHandle {
171
+ processed : AtomicUsize ,
172
+ notify : Notify ,
173
+ }
174
+
175
+ #[ cfg( any( test, feature = "test-utils" ) ) ]
176
+ impl WorkerHandle {
177
+ pub async fn wait_for_new_events ( & self , mut count : usize ) -> Result < ( ) , Elapsed > {
178
+ timeout ( Duration :: from_secs ( 3 ) , async {
179
+ while count > 0 {
180
+ self . notify . notified ( ) . await ;
181
+ count -= 1 ;
182
+ }
183
+ } )
184
+ . await ?;
185
+
186
+ Ok ( ( ) )
187
+ }
188
+
189
+ pub async fn wait_for_processed_count ( & self , expected : usize ) -> Result < ( ) , Elapsed > {
190
+ timeout ( Duration :: from_secs ( 3 ) , async {
191
+ while self . processed . load ( Ordering :: SeqCst ) < expected {
192
+ self . notify . notified ( ) . await ;
193
+ }
194
+ } )
195
+ . await ?;
196
+
197
+ Ok ( ( ) )
198
+ }
199
+
200
+ pub async fn block_for_num_events < Fut > ( & self , num_events : usize , op : Fut ) -> Result < ( ) , Elapsed >
201
+ where
202
+ Fut : Future < Output = ( ) > ,
203
+ {
204
+ let processed_count = self . processed_count ( ) ;
205
+ op. await ;
206
+ self . wait_for_processed_count ( processed_count + num_events)
207
+ . await ?;
208
+ Ok ( ( ) )
209
+ }
210
+
211
+ pub fn processed_count ( & self ) -> usize {
212
+ self . processed . load ( Ordering :: SeqCst )
213
+ }
144
214
}
145
215
146
216
impl < ApiClient , V > SyncWorker < ApiClient , V >
@@ -168,33 +238,22 @@ where
168
238
self . on_request ( message_id, & provider) . await ?
169
239
}
170
240
} ,
171
- LocalEvents :: OutgoingPreferenceUpdates ( consent_records) => {
172
- let provider = self . client . mls_provider ( ) ?;
173
- for record in consent_records {
174
- let UserPreferenceUpdate :: ConsentUpdate ( consent_record) = record else {
175
- continue ;
176
- } ;
177
-
178
- self . client
179
- . send_consent_update ( & provider, consent_record)
180
- . await ?;
181
- }
241
+ LocalEvents :: OutgoingPreferenceUpdates ( preference_updates) => {
242
+ tracing:: error!( "Outgoing preference update {preference_updates:?}" ) ;
243
+ UserPreferenceUpdate :: sync_across_devices ( preference_updates, & self . client )
244
+ . await ?;
182
245
}
183
- LocalEvents :: IncomingPreferenceUpdate ( updates) => {
184
- let provider = self . client . mls_provider ( ) ?;
185
- let consent_records = updates
186
- . into_iter ( )
187
- . filter_map ( |pu| match pu {
188
- UserPreferenceUpdate :: ConsentUpdate ( cr) => Some ( cr) ,
189
- _ => None ,
190
- } )
191
- . collect :: < Vec < _ > > ( ) ;
192
- provider
193
- . conn_ref ( )
194
- . insert_or_replace_consent_records ( & consent_records) ?;
246
+ LocalEvents :: IncomingPreferenceUpdate ( _) => {
247
+ tracing:: error!( "Incoming preference update" ) ;
195
248
}
196
249
_ => { }
197
250
}
251
+
252
+ #[ cfg( any( test, feature = "test-utils" ) ) ]
253
+ {
254
+ self . handle . processed . fetch_add ( 1 , Ordering :: SeqCst ) ;
255
+ self . handle . notify . notify_waiters ( ) ;
256
+ }
198
257
}
199
258
Ok ( ( ) )
200
259
}
@@ -319,6 +378,12 @@ where
319
378
stream,
320
379
init : OnceCell :: new ( ) ,
321
380
retry,
381
+
382
+ #[ cfg( any( test, feature = "test-utils" ) ) ]
383
+ handle : Arc :: new ( WorkerHandle {
384
+ processed : AtomicUsize :: new ( 0 ) ,
385
+ notify : Notify :: new ( ) ,
386
+ } ) ,
322
387
}
323
388
}
324
389
@@ -404,10 +469,10 @@ where
404
469
405
470
let _message_id = sync_group. prepare_message ( & content_bytes, provider, {
406
471
let request = request. clone ( ) ;
407
- move |_time_ns | PlaintextEnvelope {
472
+ move |now | PlaintextEnvelope {
408
473
content : Some ( Content :: V2 ( V2 {
409
474
message_type : Some ( MessageType :: DeviceSyncRequest ( request) ) ,
410
- idempotency_key : new_request_id ( ) ,
475
+ idempotency_key : now . to_string ( ) ,
411
476
} ) ) ,
412
477
}
413
478
} ) ?;
@@ -471,14 +536,14 @@ where
471
536
( content_bytes, contents)
472
537
} ;
473
538
474
- sync_group. prepare_message ( & content_bytes, provider, |_time_ns | PlaintextEnvelope {
539
+ sync_group. prepare_message ( & content_bytes, provider, |now | PlaintextEnvelope {
475
540
content : Some ( Content :: V2 ( V2 {
476
- idempotency_key : new_request_id ( ) ,
477
541
message_type : Some ( MessageType :: DeviceSyncReply ( contents) ) ,
542
+ idempotency_key : now. to_string ( ) ,
478
543
} ) ) ,
479
544
} ) ?;
480
545
481
- sync_group. sync_until_last_intent_resolved ( provider) . await ?;
546
+ sync_group. publish_intents ( provider) . await ?;
482
547
483
548
Ok ( ( ) )
484
549
}
0 commit comments