Skip to main content

bitwarden_vault/cipher/cipher_client/
bulk_update_collections.rs

1use std::collections::HashSet;
2
3use bitwarden_api_api::models::CipherBulkUpdateCollectionsRequestModel;
4use bitwarden_collections::collection::CollectionId;
5use bitwarden_core::{ApiError, OrganizationId};
6use bitwarden_error::bitwarden_error;
7use bitwarden_state::repository::{RepositoryError, RepositoryOption};
8use thiserror::Error;
9#[cfg(feature = "wasm")]
10use wasm_bindgen::prelude::wasm_bindgen;
11
12use crate::{CipherId, CiphersClient};
13
14#[allow(missing_docs)]
15#[bitwarden_error(flat)]
16#[derive(Debug, Error)]
17pub enum BulkUpdateCollectionsCipherError {
18    #[error(transparent)]
19    Api(#[from] ApiError),
20    #[error(transparent)]
21    Repository(#[from] RepositoryError),
22}
23
24impl<T> From<bitwarden_api_api::apis::Error<T>> for BulkUpdateCollectionsCipherError {
25    fn from(value: bitwarden_api_api::apis::Error<T>) -> Self {
26        Self::Api(value.into())
27    }
28}
29
30#[cfg_attr(feature = "wasm", wasm_bindgen)]
31impl CiphersClient {
32    /// Updates collection membership for multiple [`Cipher`](crate::Cipher) objects.
33    ///
34    /// When `remove_collections` is `true`, the given collection IDs are removed from each cipher.
35    /// When `false`, they are added without introducing duplicates.
36    pub async fn bulk_update_collections(
37        &self,
38        organization_id: OrganizationId,
39        cipher_ids: Vec<CipherId>,
40        collection_ids: Vec<CollectionId>,
41        remove_collections: bool,
42    ) -> Result<(), BulkUpdateCollectionsCipherError> {
43        self.api_configurations
44            .api_client
45            .ciphers_api()
46            .post_bulk_collections(Some(CipherBulkUpdateCollectionsRequestModel {
47                organization_id: Some(organization_id.into()),
48                cipher_ids: Some(cipher_ids.iter().map(|id| (*id).into()).collect()),
49                collection_ids: Some(collection_ids.iter().map(|id| (*id).into()).collect()),
50                remove_collections: Some(remove_collections),
51            }))
52            .await?;
53
54        let repository = self.repository.require()?;
55        let mut updated_ciphers = Vec::new();
56        let collection_ids = collection_ids.iter().copied().collect::<HashSet<_>>();
57        for cipher_id in cipher_ids {
58            if let Some(mut cipher) = repository.get(cipher_id).await? {
59                if remove_collections {
60                    cipher
61                        .collection_ids
62                        .retain(|id| !collection_ids.contains(id));
63                } else {
64                    let existing = cipher
65                        .collection_ids
66                        .iter()
67                        .copied()
68                        .collect::<HashSet<_>>();
69                    cipher.collection_ids = cipher
70                        .collection_ids
71                        .into_iter()
72                        .chain(
73                            collection_ids
74                                .clone()
75                                .into_iter()
76                                .filter(|id| !existing.contains(id)),
77                        )
78                        .collect();
79                }
80                updated_ciphers.push((cipher_id, cipher));
81            }
82        }
83        repository.set_bulk(updated_ciphers).await?;
84
85        Ok(())
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use std::sync::Arc;
92
93    use bitwarden_api_api::apis::ApiClient;
94    use bitwarden_collections::collection::CollectionId;
95    use bitwarden_core::{
96        OrganizationId, client::ApiConfigurations, key_management::create_test_crypto_with_user_key,
97    };
98    use bitwarden_crypto::SymmetricCryptoKey;
99    use bitwarden_state::repository::Repository;
100    use bitwarden_test::MemoryRepository;
101
102    use crate::{Cipher, CipherId, CiphersClient};
103
104    const TEST_CIPHER_ID: &str = "5faa9684-c793-4a2d-8a12-b33900187097";
105    const TEST_ORG_ID: &str = "7faa9684-c793-4a2d-8a12-b33900187099";
106    const TEST_COLLECTION_ID_1: &str = "8faa9684-c793-4a2d-8a12-b33900187100";
107
108    fn generate_test_cipher() -> Cipher {
109        Cipher {
110            id: TEST_CIPHER_ID.parse().ok(),
111            name: "2.pMS6/icTQABtulw52pq2lg==|XXbxKxDTh+mWiN1HjH2N1w==|Q6PkuT+KX/axrgN9ubD5Ajk2YNwxQkgs3WJM0S0wtG8=".parse().unwrap(),
112            r#type: crate::CipherType::Login,
113            notes: Default::default(),
114            organization_id: Default::default(),
115            folder_id: Default::default(),
116            favorite: Default::default(),
117            reprompt: Default::default(),
118            fields: Default::default(),
119            collection_ids: Default::default(),
120            key: Default::default(),
121            login: Default::default(),
122            identity: Default::default(),
123            card: Default::default(),
124            secure_note: Default::default(),
125            ssh_key: Default::default(),
126            bank_account: Default::default(),
127            drivers_license: Default::default(),
128            passport: Default::default(),
129            organization_use_totp: Default::default(),
130            edit: Default::default(),
131            permissions: Default::default(),
132            view_password: Default::default(),
133            local_data: Default::default(),
134            attachments: Default::default(),
135            password_history: Default::default(),
136            creation_date: Default::default(),
137            deleted_date: Default::default(),
138            revision_date: Default::default(),
139            archived_date: Default::default(),
140            data: Default::default(),
141        }
142    }
143
144    fn create_test_client(api_client: ApiClient) -> (CiphersClient, Arc<MemoryRepository<Cipher>>) {
145        let repository = Arc::new(MemoryRepository::<Cipher>::default());
146        #[allow(deprecated)]
147        let client = CiphersClient {
148            key_store: create_test_crypto_with_user_key(
149                SymmetricCryptoKey::make_aes256_cbc_hmac_key(),
150            ),
151            api_configurations: Arc::new(ApiConfigurations::from_api_client(api_client)),
152            repository: Some(repository.clone() as Arc<dyn Repository<Cipher>>),
153            client: bitwarden_core::Client::new_test(None),
154        };
155        (client, repository)
156    }
157
158    fn make_api_client() -> ApiClient {
159        ApiClient::new_mocked(|mock| {
160            mock.ciphers_api
161                .expect_post_bulk_collections()
162                .returning(|_| Ok(()));
163        })
164    }
165
166    #[tokio::test]
167    async fn test_bulk_update_adds_collections() {
168        let (client, repository) = create_test_client(make_api_client());
169
170        let cipher_id: CipherId = TEST_CIPHER_ID.parse().unwrap();
171        let org_id: OrganizationId = TEST_ORG_ID.parse().unwrap();
172        let collection_id: CollectionId = TEST_COLLECTION_ID_1.parse().unwrap();
173
174        repository
175            .set(cipher_id, generate_test_cipher())
176            .await
177            .unwrap();
178
179        client
180            .bulk_update_collections(org_id, vec![cipher_id], vec![collection_id], false)
181            .await
182            .unwrap();
183
184        let c: Cipher = repository.get(cipher_id).await.unwrap().unwrap();
185        assert!(c.collection_ids.contains(&collection_id));
186    }
187
188    #[tokio::test]
189    async fn test_bulk_update_removes_collections() {
190        let (client, repository) = create_test_client(make_api_client());
191
192        let cipher_id: CipherId = TEST_CIPHER_ID.parse().unwrap();
193        let org_id: OrganizationId = TEST_ORG_ID.parse().unwrap();
194        let collection_id: CollectionId = TEST_COLLECTION_ID_1.parse().unwrap();
195
196        let mut cipher = generate_test_cipher();
197        cipher.collection_ids = vec![collection_id];
198        repository.set(cipher_id, cipher).await.unwrap();
199
200        client
201            .bulk_update_collections(org_id, vec![cipher_id], vec![collection_id], true)
202            .await
203            .unwrap();
204
205        let c: Cipher = repository.get(cipher_id).await.unwrap().unwrap();
206        assert!(!c.collection_ids.contains(&collection_id));
207    }
208
209    #[tokio::test]
210    async fn test_bulk_update_no_duplicates_when_adding() {
211        let (client, repository) = create_test_client(make_api_client());
212
213        let cipher_id: CipherId = TEST_CIPHER_ID.parse().unwrap();
214        let org_id: OrganizationId = TEST_ORG_ID.parse().unwrap();
215        let collection_id: CollectionId = TEST_COLLECTION_ID_1.parse().unwrap();
216
217        let mut cipher = generate_test_cipher();
218        cipher.collection_ids = vec![collection_id];
219        repository.set(cipher_id, cipher).await.unwrap();
220
221        client
222            .bulk_update_collections(org_id, vec![cipher_id], vec![collection_id], false)
223            .await
224            .unwrap();
225
226        let c: Cipher = repository.get(cipher_id).await.unwrap().unwrap();
227        assert_eq!(
228            c.collection_ids.len(),
229            1,
230            "no duplicates introduced when collection already present"
231        );
232    }
233
234    #[tokio::test]
235    async fn test_bulk_update_skips_missing_ciphers() {
236        let (client, _repository) = create_test_client(make_api_client());
237
238        let cipher_id: CipherId = TEST_CIPHER_ID.parse().unwrap();
239        let org_id: OrganizationId = TEST_ORG_ID.parse().unwrap();
240        let collection_id: CollectionId = TEST_COLLECTION_ID_1.parse().unwrap();
241
242        let result = client
243            .bulk_update_collections(org_id, vec![cipher_id], vec![collection_id], false)
244            .await;
245        assert!(result.is_ok());
246    }
247}