Skip to main content

bitwarden_vault/
collection_client.rs

1use std::collections::HashMap;
2
3use bitwarden_collections::{
4    collection::{Collection, CollectionId, CollectionView},
5    tree::{NodeItem, Tree},
6};
7use bitwarden_core::Client;
8#[cfg(feature = "wasm")]
9use serde::{Deserialize, Serialize};
10#[cfg(feature = "wasm")]
11use tsify::Tsify;
12#[cfg(feature = "wasm")]
13use wasm_bindgen::prelude::wasm_bindgen;
14
15use crate::{DecryptError, EncryptError};
16
17#[allow(missing_docs)]
18#[cfg_attr(feature = "wasm", wasm_bindgen)]
19#[derive(Clone)]
20pub struct CollectionsClient {
21    pub(crate) client: Client,
22}
23
24#[cfg_attr(feature = "wasm", wasm_bindgen)]
25impl CollectionsClient {
26    /// Encrypts a [CollectionView] into an encrypted [Collection] using the organization key.
27    pub fn encrypt(&self, collection_view: CollectionView) -> Result<Collection, EncryptError> {
28        let key_store = self.client.internal.get_key_store();
29        let collection = key_store.encrypt(collection_view)?;
30        Ok(collection)
31    }
32
33    /// Encrypts a list of [CollectionView]s into encrypted [Collection]s using the organization
34    /// key.
35    pub fn encrypt_list(
36        &self,
37        collection_views: Vec<CollectionView>,
38    ) -> Result<Vec<Collection>, EncryptError> {
39        let key_store = self.client.internal.get_key_store();
40        let collections = key_store.encrypt_list(&collection_views)?;
41        Ok(collections)
42    }
43
44    #[allow(missing_docs)]
45    pub fn decrypt(&self, collection: Collection) -> Result<CollectionView, DecryptError> {
46        let key_store = self.client.internal.get_key_store();
47        let view = key_store.decrypt(&collection)?;
48        Ok(view)
49    }
50
51    #[allow(missing_docs)]
52    pub fn decrypt_list(
53        &self,
54        collections: Vec<Collection>,
55    ) -> Result<Vec<CollectionView>, DecryptError> {
56        let key_store = self.client.internal.get_key_store();
57        let views = key_store.decrypt_list(&collections)?;
58        Ok(views)
59    }
60
61    ///
62    /// Returns the vector of CollectionView objects in a tree structure based on its implemented
63    /// path().
64    pub fn get_collection_tree(&self, collections: Vec<CollectionView>) -> CollectionViewTree {
65        CollectionViewTree {
66            tree: Tree::from_items(collections),
67        }
68    }
69}
70
71#[cfg_attr(feature = "wasm", wasm_bindgen)]
72pub struct CollectionViewTree {
73    tree: Tree<CollectionView>,
74}
75
76#[cfg_attr(feature = "wasm", wasm_bindgen)]
77pub struct CollectionViewNodeItem {
78    node_item: NodeItem<CollectionView>,
79}
80
81#[cfg_attr(
82    feature = "wasm",
83    derive(Tsify, Serialize, Deserialize),
84    tsify(into_wasm_abi, from_wasm_abi)
85)]
86#[cfg_attr(feature = "uniffi", derive(uniffi::Record))]
87pub struct AncestorMap {
88    pub ancestors: HashMap<CollectionId, String>,
89}
90
91#[cfg_attr(feature = "wasm", wasm_bindgen)]
92impl CollectionViewNodeItem {
93    pub fn get_item(&self) -> CollectionView {
94        self.node_item.item.clone()
95    }
96
97    pub fn get_parent(&self) -> Option<CollectionView> {
98        self.node_item.parent.clone()
99    }
100
101    pub fn get_children(&self) -> Vec<CollectionView> {
102        self.node_item.children.clone()
103    }
104
105    pub fn get_ancestors(&self) -> AncestorMap {
106        AncestorMap {
107            ancestors: self
108                .node_item
109                .ancestors
110                .iter()
111                .map(|(&uuid, name)| (CollectionId::new(uuid), name.clone()))
112                .collect(),
113        }
114    }
115}
116
117#[cfg_attr(feature = "wasm", wasm_bindgen)]
118impl CollectionViewTree {
119    pub fn get_item_for_view(
120        &self,
121        collection_view: CollectionView,
122    ) -> Option<CollectionViewNodeItem> {
123        self.tree
124            .get_item_by_id(collection_view.id.unwrap_or_default().into())
125            .map(|n| CollectionViewNodeItem { node_item: n })
126    }
127
128    pub fn get_root_items(&self) -> Vec<CollectionViewNodeItem> {
129        self.tree
130            .get_root_items()
131            .into_iter()
132            .map(|n| CollectionViewNodeItem { node_item: n })
133            .collect()
134    }
135
136    pub fn get_flat_items(&self) -> Vec<CollectionViewNodeItem> {
137        self.tree
138            .get_flat_items()
139            .into_iter()
140            .map(|n| CollectionViewNodeItem { node_item: n })
141            .collect()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use bitwarden_collections::collection::CollectionType;
148    use bitwarden_core::client::test_accounts::test_bitwarden_com_account;
149
150    use super::*;
151    use crate::VaultClientExt;
152
153    fn test_collection() -> Collection {
154        Collection {
155            id: Some("66c5ca57-0868-4c7e-902f-b181009709c0".parse().unwrap()),
156            organization_id: "1bc9ac1e-f5aa-45f2-94bf-b181009709b8".parse().unwrap(),
157            name: "2.EI9Km5BfrIqBa1W+WCccfA==|laWxNnx+9H3MZww4zm7cBSLisjpi81zreaQntRhegVI=|x42+qKFf5ga6DIL0OW5pxCdLrC/gm8CXJvf3UASGteI=".parse().unwrap(),
158            external_id: None,
159            hide_passwords: false,
160            read_only: false,
161            manage: false,
162            default_user_collection_email: None,
163            r#type: CollectionType::SharedCollection,
164        }
165    }
166
167    #[tokio::test]
168    async fn test_decrypt_list() {
169        let client = Client::init_test_account(test_bitwarden_com_account()).await;
170
171        let dec = client
172            .vault()
173            .collections()
174            .decrypt_list(vec![test_collection()])
175            .unwrap();
176
177        assert_eq!(dec[0].name, "Default collection");
178    }
179
180    #[tokio::test]
181    async fn test_decrypt() {
182        let client = Client::init_test_account(test_bitwarden_com_account()).await;
183
184        let dec = client
185            .vault()
186            .collections()
187            .decrypt(test_collection())
188            .unwrap();
189
190        assert_eq!(dec.name, "Default collection");
191    }
192
193    #[tokio::test]
194    async fn test_encrypt_decrypt_roundtrip() {
195        let client = Client::init_test_account(test_bitwarden_com_account()).await;
196
197        let view = client
198            .vault()
199            .collections()
200            .decrypt(test_collection())
201            .unwrap();
202
203        assert_eq!(view.name, "Default collection");
204
205        // Re-encrypt the decrypted view, then decrypt again
206        let expected_id = view.id;
207        let expected_org_id = view.organization_id;
208        let re_encrypted = client.vault().collections().encrypt(view).unwrap();
209        let re_decrypted = client.vault().collections().decrypt(re_encrypted).unwrap();
210
211        assert_eq!(re_decrypted.name, "Default collection");
212        assert_eq!(re_decrypted.id, expected_id);
213        assert_eq!(re_decrypted.organization_id, expected_org_id);
214    }
215
216    #[tokio::test]
217    async fn test_encrypt_list_decrypt_list_roundtrip() {
218        let client = Client::init_test_account(test_bitwarden_com_account()).await;
219
220        let views = client
221            .vault()
222            .collections()
223            .decrypt_list(vec![test_collection()])
224            .unwrap();
225
226        assert_eq!(views.len(), 1);
227        assert_eq!(views[0].name, "Default collection");
228
229        let expected_id = views[0].id;
230        let expected_org_id = views[0].organization_id;
231
232        let re_encrypted = client.vault().collections().encrypt_list(views).unwrap();
233
234        assert_eq!(re_encrypted.len(), 1);
235
236        let re_decrypted = client
237            .vault()
238            .collections()
239            .decrypt_list(re_encrypted)
240            .unwrap();
241
242        assert_eq!(re_decrypted.len(), 1);
243        assert_eq!(re_decrypted[0].name, "Default collection");
244        assert_eq!(re_decrypted[0].id, expected_id);
245        assert_eq!(re_decrypted[0].organization_id, expected_org_id);
246    }
247}