diff --git a/CHANGELOG.md b/CHANGELOG.md index db93d1a65..408ffe7d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # UNRELEASED +- Add `OlmMachine.registerDevicesUpdatedCallback` to notify when devices have + been updated. + ([#88](https://github.com/matrix-org/matrix-rust-sdk-crypto-wasm/pull/88)) + # matrix-sdk-crypto-wasm v4.0.1 - `PickledInboundGroupSession.sender_signing_key` is now optional. diff --git a/src/machine.rs b/src/machine.rs index a38c87c17..9710c4c6e 100644 --- a/src/machine.rs +++ b/src/machine.rs @@ -1,6 +1,10 @@ //! The crypto specific Olm objects. -use std::{collections::BTreeMap, ops::Deref, time::Duration}; +use std::{ + collections::{BTreeMap, HashSet}, + ops::Deref, + time::Duration, +}; use futures_util::{pin_mut, StreamExt}; use js_sys::{Array, Function, JsString, Map, Promise, Set}; @@ -1267,6 +1271,26 @@ impl OlmMachine { }); } + /// Register a callback which will be called whenever there is an update to + /// a device. + /// + /// `callback` should be a function that takes a single argument (an array + /// of user IDs as strings) and returns a Promise. + #[wasm_bindgen(js_name = "registerDevicesUpdatedCallback")] + pub async fn register_devices_updated_callback(&self, callback: Function) { + let stream = self.inner.store().identities_stream_raw(); + + // fire up a promise chain which will call `callback` on each result from the + // stream + spawn_local(async move { + // take a reference to `callback` (which we then pass into the closure), to stop + // the callback being moved into the closure (which would mean we could only + // call the closure once) + let callback_ref = &callback; + stream.for_each(move |item| send_device_updates_to_callback(callback_ref, item)).await; + }); + } + /// Register a callback which will be called whenever a secret /// (`m.secret.send`) is received. /// @@ -1443,6 +1467,34 @@ async fn send_user_identities_to_callback( } } +// helper for register_device_updated_callback: passes the user IDs into +// the javascript function +async fn send_device_updates_to_callback( + callback: &Function, + (_, device_updates): (IdentityChanges, DeviceChanges), +) { + // get the user IDs of all the devices that have changed + let updated_chain = device_updates + .new + .into_iter() + .chain(device_updates.changed.into_iter()) + .chain(device_updates.deleted.into_iter()); + // put them in a set to make them unique + let updated_users: HashSet = + HashSet::from_iter(updated_chain.map(|device| device.user_id().to_string())); + let updated_users_vec = Vec::from_iter(updated_users.iter()); + match promise_result_to_future( + callback.call1(&JsValue::NULL, &serde_wasm_bindgen::to_value(&updated_users_vec).unwrap()), + ) + .await + { + Ok(_) => (), + Err(e) => { + warn!("Error calling device-updated callback: {:?}", e); + } + } +} + // helper for register_secret_receive_callback: passes the secret name and value // into the javascript function async fn send_secret_gossip_to_callback(callback: &Function, secret: &GossippedSecret) { diff --git a/tests/machine.test.ts b/tests/machine.test.ts index e9f4390e9..16dd5c989 100644 --- a/tests/machine.test.ts +++ b/tests/machine.test.ts @@ -1388,4 +1388,37 @@ describe(OlmMachine.name, () => { expect(toDeviceRequests).toHaveLength(0); }); }); + + test("Updating devices should call devicesUpdatedCallback", async () => { + const userId = new UserId("@alice:example.org"); + const deviceId = new DeviceId("ABCDEF"); + const machine = await OlmMachine.initialize(userId, deviceId); + + const callback = jest.fn().mockImplementation(() => Promise.resolve(undefined)); + machine.registerDevicesUpdatedCallback(callback); + + const outgoingRequests = await machine.outgoingRequests(); + let deviceKeys; + // outgoingRequests will have a KeysUploadRequest before the + // KeysQueryRequest, so we grab the device upload and put it in the + // response to the /keys/query + for (const request of outgoingRequests) { + if (request instanceof KeysUploadRequest) { + deviceKeys = JSON.parse(request.body).device_keys; + } else if (request instanceof KeysQueryRequest) { + await machine.markRequestAsSent( + request.id, + request.type, + JSON.stringify({ + device_keys: { + "@alice:example.org": { + ABCDEF: deviceKeys, + }, + }, + }), + ); + } + } + expect(callback).toHaveBeenCalledWith(["@alice:example.org"]); + }); });