Skip to main content

bitwarden_shared_unlock/wasm/
drivers.rs

1use bitwarden_core::UserId;
2use bitwarden_crypto::SymmetricCryptoKey;
3use bitwarden_ipc::{Endpoint, HostId};
4use bitwarden_threading::ThreadBoundRunner;
5use wasm_bindgen::{JsValue, prelude::wasm_bindgen};
6use wasm_bindgen_futures::js_sys;
7
8use crate::{LockState, SharedUnlockDriver};
9
10#[wasm_bindgen(typescript_custom_section)]
11const TS_CUSTOM_TYPES: &'static str = r#"
12export interface SharedUnlockDriver {
13    lock_user(user_id: UserId): Promise<void>;
14    unlock_user(user_id: UserId, user_key: SymmetricKey): Promise<void>;
15    list_users(): Promise<UserId[]>;
16    get_user_key(user_id: UserId): Promise<SymmetricKey | undefined>;
17    suppress_vault_timeout(user_id: UserId, suppression_duration: number): Promise<void>;
18    get_client_name(): Promise<string>;
19    get_vault_url(user_id: UserId): Promise<string | undefined>;
20}
21"#;
22
23#[wasm_bindgen]
24extern "C" {
25    /// JavaScript implementation of shared unlock operations used by shared unlock protocol.
26    #[wasm_bindgen(js_name = SharedUnlockDriver, typescript_type = "SharedUnlockDriver")]
27    pub type RawJsSharedUnlockDriver;
28
29    #[wasm_bindgen(method, catch)]
30    async fn lock_user(this: &RawJsSharedUnlockDriver, user_id: UserId) -> Result<(), JsValue>;
31    #[wasm_bindgen(method, catch)]
32    async fn unlock_user(
33        this: &RawJsSharedUnlockDriver,
34        user_id: UserId,
35        user_key: SymmetricCryptoKey,
36    ) -> Result<(), JsValue>;
37    #[wasm_bindgen(method, catch)]
38    async fn list_users(this: &RawJsSharedUnlockDriver) -> Result<js_sys::Array, JsValue>;
39    #[wasm_bindgen(method, catch)]
40    async fn get_user_key(
41        this: &RawJsSharedUnlockDriver,
42        user_id: UserId,
43    ) -> Result<Option<SymmetricCryptoKey>, JsValue>;
44
45    /// Supress the vault timeout for the given duration (in milliseconds).
46    #[wasm_bindgen(method, catch)]
47    async fn suppress_vault_timeout(
48        this: &RawJsSharedUnlockDriver,
49        user_id: UserId,
50        suppression_duration: f64,
51    ) -> Result<(), JsValue>;
52
53    /// Get the client type of the current device
54    #[wasm_bindgen(method, catch)]
55    async fn get_client_name(this: &RawJsSharedUnlockDriver) -> Result<JsValue, JsValue>;
56
57    /// Get vault URL for the user with the given ID, if available. This is used to verify IPC
58    /// message sources.
59    #[wasm_bindgen(method, catch)]
60    async fn get_vault_url(
61        this: &RawJsSharedUnlockDriver,
62        user_id: UserId,
63    ) -> Result<JsValue, JsValue>;
64}
65
66pub(super) struct JsSharedUnlockDriver {
67    runner: ThreadBoundRunner<RawJsSharedUnlockDriver>,
68}
69
70impl JsSharedUnlockDriver {
71    pub(super) fn new(driver: RawJsSharedUnlockDriver) -> Self {
72        Self {
73            runner: ThreadBoundRunner::new(driver),
74        }
75    }
76}
77
78#[async_trait::async_trait]
79impl SharedUnlockDriver for JsSharedUnlockDriver {
80    async fn lock_user(&self, user_id: UserId) -> Result<(), ()> {
81        self.runner
82            .run_in_thread(
83                move |driver| async move { driver.lock_user(user_id).await.map_err(|_| ()) },
84            )
85            .await
86            .map_err(|_| ())?
87    }
88
89    async fn unlock_user(&self, user_id: UserId, user_key: SymmetricCryptoKey) -> Result<(), ()> {
90        self.runner
91            .run_in_thread(move |driver| async move {
92                driver.unlock_user(user_id, user_key).await.map_err(|_| ())
93            })
94            .await
95            .map_err(|_| ())?
96    }
97
98    async fn list_users(&self) -> Vec<UserId> {
99        self.runner
100            .run_in_thread(move |driver| async move {
101                match driver.list_users().await {
102                    Ok(array) => array
103                        .iter()
104                        .filter_map(|js_value| js_value.as_string())
105                        .filter_map(|s| s.parse().ok())
106                        .collect(),
107                    Err(_) => vec![],
108                }
109            })
110            .await
111            .unwrap_or_default()
112    }
113
114    async fn get_user_lock_state(&self, user_id: UserId) -> LockState {
115        self.runner
116            .run_in_thread(move |driver| async move {
117                match driver.get_user_key(user_id).await.ok().flatten() {
118                    Some(user_key) => LockState::Unlocked { user_key },
119                    None => LockState::Locked,
120                }
121            })
122            .await
123            .unwrap_or(LockState::Locked)
124    }
125
126    async fn get_vault_url(&self, user_id: UserId) -> Option<String> {
127        self.runner
128            .run_in_thread(move |driver| async move {
129                driver
130                    .get_vault_url(user_id)
131                    .await
132                    .ok()
133                    .and_then(|js_value| js_value.as_string())
134            })
135            .await
136            .ok()
137            .flatten()
138    }
139
140    async fn suppress_vault_timeout(
141        &self,
142        user_id: UserId,
143        suppression_duration: std::time::Duration,
144    ) {
145        let result = self
146            .runner
147            .run_in_thread(move |driver| async move {
148                driver
149                    .suppress_vault_timeout(user_id, suppression_duration.as_millis() as f64)
150                    .await
151            })
152            .await;
153        match result {
154            Ok(Ok(())) => {}
155            Ok(Err(error)) => {
156                tracing::error!(
157                    ?error,
158                    "Failed to suppress vault timeout for user_id: {}",
159                    user_id
160                )
161            }
162            Err(error) => {
163                tracing::error!(
164                    ?error,
165                    "Failed to suppress vault timeout for user_id: {}",
166                    user_id
167                )
168            }
169        }
170    }
171
172    async fn discover_leader(&self) -> Option<Endpoint> {
173        self.runner
174            .run_in_thread(move |driver| async move {
175                let client_name = match driver.get_client_name().await {
176                    Ok(name) => name.as_string()?,
177                    Err(_) => return None,
178                };
179                match client_name.as_str() {
180                    "web" => Some(Endpoint::BrowserBackground { id: HostId::Own }),
181                    "browser" => Some(Endpoint::DesktopRenderer),
182                    "cli" => Some(Endpoint::DesktopRenderer),
183                    _ => None,
184                }
185            })
186            .await
187            .ok()
188            .flatten()
189    }
190}