bitwarden_vault/cipher/cipher_client/
bulk_update_collections.rs1use std::collections::HashSet;
2
3use bitwarden_api_api::models::CipherBulkUpdateCollectionsRequestModel;
4use bitwarden_collections::collection::CollectionId;
5use bitwarden_core::{ApiError, OrganizationId};
6use bitwarden_error::bitwarden_error;
7use bitwarden_state::repository::{RepositoryError, RepositoryOption};
8use thiserror::Error;
9#[cfg(feature = "wasm")]
10use wasm_bindgen::prelude::wasm_bindgen;
11
12use crate::{CipherId, CiphersClient};
13
14#[allow(missing_docs)]
15#[bitwarden_error(flat)]
16#[derive(Debug, Error)]
17pub enum BulkUpdateCollectionsCipherError {
18 #[error(transparent)]
19 Api(#[from] ApiError),
20 #[error(transparent)]
21 Repository(#[from] RepositoryError),
22}
23
24impl<T> From<bitwarden_api_api::apis::Error<T>> for BulkUpdateCollectionsCipherError {
25 fn from(value: bitwarden_api_api::apis::Error<T>) -> Self {
26 Self::Api(value.into())
27 }
28}
29
30#[cfg_attr(feature = "wasm", wasm_bindgen)]
31impl CiphersClient {
32 pub async fn bulk_update_collections(
37 &self,
38 organization_id: OrganizationId,
39 cipher_ids: Vec<CipherId>,
40 collection_ids: Vec<CollectionId>,
41 remove_collections: bool,
42 ) -> Result<(), BulkUpdateCollectionsCipherError> {
43 self.api_configurations
44 .api_client
45 .ciphers_api()
46 .post_bulk_collections(Some(CipherBulkUpdateCollectionsRequestModel {
47 organization_id: Some(organization_id.into()),
48 cipher_ids: Some(cipher_ids.iter().map(|id| (*id).into()).collect()),
49 collection_ids: Some(collection_ids.iter().map(|id| (*id).into()).collect()),
50 remove_collections: Some(remove_collections),
51 }))
52 .await?;
53
54 let repository = self.repository.require()?;
55 let mut updated_ciphers = Vec::new();
56 let collection_ids = collection_ids.iter().copied().collect::<HashSet<_>>();
57 for cipher_id in cipher_ids {
58 if let Some(mut cipher) = repository.get(cipher_id).await? {
59 if remove_collections {
60 cipher
61 .collection_ids
62 .retain(|id| !collection_ids.contains(id));
63 } else {
64 let existing = cipher
65 .collection_ids
66 .iter()
67 .copied()
68 .collect::<HashSet<_>>();
69 cipher.collection_ids = cipher
70 .collection_ids
71 .into_iter()
72 .chain(
73 collection_ids
74 .clone()
75 .into_iter()
76 .filter(|id| !existing.contains(id)),
77 )
78 .collect();
79 }
80 updated_ciphers.push((cipher_id, cipher));
81 }
82 }
83 repository.set_bulk(updated_ciphers).await?;
84
85 Ok(())
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use std::sync::Arc;
92
93 use bitwarden_api_api::apis::ApiClient;
94 use bitwarden_collections::collection::CollectionId;
95 use bitwarden_core::{
96 OrganizationId, client::ApiConfigurations, key_management::create_test_crypto_with_user_key,
97 };
98 use bitwarden_crypto::SymmetricCryptoKey;
99 use bitwarden_state::repository::Repository;
100 use bitwarden_test::MemoryRepository;
101
102 use crate::{Cipher, CipherId, CiphersClient};
103
104 const TEST_CIPHER_ID: &str = "5faa9684-c793-4a2d-8a12-b33900187097";
105 const TEST_ORG_ID: &str = "7faa9684-c793-4a2d-8a12-b33900187099";
106 const TEST_COLLECTION_ID_1: &str = "8faa9684-c793-4a2d-8a12-b33900187100";
107
108 fn generate_test_cipher() -> Cipher {
109 Cipher {
110 id: TEST_CIPHER_ID.parse().ok(),
111 name: "2.pMS6/icTQABtulw52pq2lg==|XXbxKxDTh+mWiN1HjH2N1w==|Q6PkuT+KX/axrgN9ubD5Ajk2YNwxQkgs3WJM0S0wtG8=".parse().unwrap(),
112 r#type: crate::CipherType::Login,
113 notes: Default::default(),
114 organization_id: Default::default(),
115 folder_id: Default::default(),
116 favorite: Default::default(),
117 reprompt: Default::default(),
118 fields: Default::default(),
119 collection_ids: Default::default(),
120 key: Default::default(),
121 login: Default::default(),
122 identity: Default::default(),
123 card: Default::default(),
124 secure_note: Default::default(),
125 ssh_key: Default::default(),
126 bank_account: Default::default(),
127 drivers_license: Default::default(),
128 passport: Default::default(),
129 organization_use_totp: Default::default(),
130 edit: Default::default(),
131 permissions: Default::default(),
132 view_password: Default::default(),
133 local_data: Default::default(),
134 attachments: Default::default(),
135 password_history: Default::default(),
136 creation_date: Default::default(),
137 deleted_date: Default::default(),
138 revision_date: Default::default(),
139 archived_date: Default::default(),
140 data: Default::default(),
141 }
142 }
143
144 fn create_test_client(api_client: ApiClient) -> (CiphersClient, Arc<MemoryRepository<Cipher>>) {
145 let repository = Arc::new(MemoryRepository::<Cipher>::default());
146 #[allow(deprecated)]
147 let client = CiphersClient {
148 key_store: create_test_crypto_with_user_key(
149 SymmetricCryptoKey::make_aes256_cbc_hmac_key(),
150 ),
151 api_configurations: Arc::new(ApiConfigurations::from_api_client(api_client)),
152 repository: Some(repository.clone() as Arc<dyn Repository<Cipher>>),
153 client: bitwarden_core::Client::new_test(None),
154 };
155 (client, repository)
156 }
157
158 fn make_api_client() -> ApiClient {
159 ApiClient::new_mocked(|mock| {
160 mock.ciphers_api
161 .expect_post_bulk_collections()
162 .returning(|_| Ok(()));
163 })
164 }
165
166 #[tokio::test]
167 async fn test_bulk_update_adds_collections() {
168 let (client, repository) = create_test_client(make_api_client());
169
170 let cipher_id: CipherId = TEST_CIPHER_ID.parse().unwrap();
171 let org_id: OrganizationId = TEST_ORG_ID.parse().unwrap();
172 let collection_id: CollectionId = TEST_COLLECTION_ID_1.parse().unwrap();
173
174 repository
175 .set(cipher_id, generate_test_cipher())
176 .await
177 .unwrap();
178
179 client
180 .bulk_update_collections(org_id, vec![cipher_id], vec![collection_id], false)
181 .await
182 .unwrap();
183
184 let c: Cipher = repository.get(cipher_id).await.unwrap().unwrap();
185 assert!(c.collection_ids.contains(&collection_id));
186 }
187
188 #[tokio::test]
189 async fn test_bulk_update_removes_collections() {
190 let (client, repository) = create_test_client(make_api_client());
191
192 let cipher_id: CipherId = TEST_CIPHER_ID.parse().unwrap();
193 let org_id: OrganizationId = TEST_ORG_ID.parse().unwrap();
194 let collection_id: CollectionId = TEST_COLLECTION_ID_1.parse().unwrap();
195
196 let mut cipher = generate_test_cipher();
197 cipher.collection_ids = vec![collection_id];
198 repository.set(cipher_id, cipher).await.unwrap();
199
200 client
201 .bulk_update_collections(org_id, vec![cipher_id], vec![collection_id], true)
202 .await
203 .unwrap();
204
205 let c: Cipher = repository.get(cipher_id).await.unwrap().unwrap();
206 assert!(!c.collection_ids.contains(&collection_id));
207 }
208
209 #[tokio::test]
210 async fn test_bulk_update_no_duplicates_when_adding() {
211 let (client, repository) = create_test_client(make_api_client());
212
213 let cipher_id: CipherId = TEST_CIPHER_ID.parse().unwrap();
214 let org_id: OrganizationId = TEST_ORG_ID.parse().unwrap();
215 let collection_id: CollectionId = TEST_COLLECTION_ID_1.parse().unwrap();
216
217 let mut cipher = generate_test_cipher();
218 cipher.collection_ids = vec![collection_id];
219 repository.set(cipher_id, cipher).await.unwrap();
220
221 client
222 .bulk_update_collections(org_id, vec![cipher_id], vec![collection_id], false)
223 .await
224 .unwrap();
225
226 let c: Cipher = repository.get(cipher_id).await.unwrap().unwrap();
227 assert_eq!(
228 c.collection_ids.len(),
229 1,
230 "no duplicates introduced when collection already present"
231 );
232 }
233
234 #[tokio::test]
235 async fn test_bulk_update_skips_missing_ciphers() {
236 let (client, _repository) = create_test_client(make_api_client());
237
238 let cipher_id: CipherId = TEST_CIPHER_ID.parse().unwrap();
239 let org_id: OrganizationId = TEST_ORG_ID.parse().unwrap();
240 let collection_id: CollectionId = TEST_COLLECTION_ID_1.parse().unwrap();
241
242 let result = client
243 .bulk_update_collections(org_id, vec![cipher_id], vec![collection_id], false)
244 .await;
245 assert!(result.is_ok());
246 }
247}