Skip to content

Commit

Permalink
refactor: use Entity derive macro for PersistedMlsGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonThormeyer committed Jan 30, 2025
1 parent df0a711 commit 8a01ccc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 176 deletions.
4 changes: 3 additions & 1 deletion keystore/src/entities/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ use openmls_traits::types::SignatureScheme;
use zeroize::Zeroize;

/// Entity representing a persisted `MlsGroup`
#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
#[zeroize(drop)]
#[entity(collection_name = "mls_groups")]
#[cfg_attr(
any(target_family = "wasm", feature = "serde"),
derive(serde::Serialize, serde::Deserialize)
)]
pub struct PersistedMlsGroup {
#[id(hex, column = "id_hex")]
pub id: Vec<u8>,
pub state: Vec<u8>,
pub parent_id: Option<Vec<u8>>,
Expand Down
177 changes: 2 additions & 175 deletions keystore/src/entities/platform/generic/mls/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,182 +15,9 @@
// along with this program. If not, see http://www.gnu.org/licenses/.

use crate::{
connection::TransactionWrapper,
entities::{EntityIdStringExt, EntityTransactionExt},
entities::{Entity, EntityBase, PersistedMlsGroup, PersistedMlsGroupExt},
CryptoKeystoreResult,
};
use crate::{
connection::{DatabaseConnection, KeystoreDatabaseConnection},
entities::{Entity, EntityBase, EntityFindParams, PersistedMlsGroup, PersistedMlsGroupExt, StringEntityId},
CryptoKeystoreResult, MissingKeyErrorKind,
};

#[async_trait::async_trait]
impl Entity for PersistedMlsGroup {
fn id_raw(&self) -> &[u8] {
self.id.as_slice()
}
async fn find_all(conn: &mut Self::ConnectionType, params: EntityFindParams) -> CryptoKeystoreResult<Vec<Self>> {
let transaction = conn.transaction()?;
let query: String = format!("SELECT rowid, id_hex FROM mls_groups {}", params.to_sql());

let mut stmt = transaction.prepare_cached(&query)?;
let mut rows = stmt.query_map([], |r| {
let rowid: i64 = r.get(0)?;
let id_hex: String = r.get(1)?;
Ok((rowid, id_hex))
})?;
let entities = rows.try_fold(Vec::new(), |mut acc, row_result| {
use std::io::Read as _;
let (rowid, id_hex) = row_result?;

let id = Self::id_from_hex(&id_hex)?;

let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "state", rowid, true)?;
let mut state = vec![];
blob.read_to_end(&mut state)?;
blob.close()?;

let mut parent_id = None;
if let Ok(mut blob) =
transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "parent_id", rowid, true)
{
if !blob.is_empty() {
let mut tmp = Vec::with_capacity(blob.len());
blob.read_to_end(&mut tmp)?;
parent_id.replace(tmp);
}
blob.close()?;
}

acc.push(Self { id, parent_id, state });
crate::CryptoKeystoreResult::Ok(acc)
})?;

Ok(entities)
}

async fn find_one(
conn: &mut Self::ConnectionType,
id: &StringEntityId,
) -> crate::CryptoKeystoreResult<Option<Self>> {
use rusqlite::OptionalExtension as _;
let transaction = conn.transaction()?;
let mut rowid: Option<i64> = transaction
.query_row(
"SELECT rowid FROM mls_groups WHERE id_hex = ?",
[id.as_hex_string()],
|r| r.get::<_, i64>(0),
)
.optional()?;

if let Some(rowid) = rowid.take() {
let id = id.as_slice().to_vec();

use std::io::Read as _;
let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "state", rowid, true)?;
let mut state = Vec::with_capacity(blob.len());
blob.read_to_end(&mut state)?;
blob.close()?;

let mut parent_id = None;
if let Ok(mut blob) =
transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "parent_id", rowid, true)
{
if !blob.is_empty() {
let mut tmp = Vec::with_capacity(blob.len());
blob.read_to_end(&mut tmp)?;
parent_id.replace(tmp);
}
blob.close()?;
}

Ok(Some(Self { id, parent_id, state }))
} else {
Ok(None)
}
}

async fn find_many(
conn: &mut Self::ConnectionType,
_ids: &[StringEntityId],
) -> crate::CryptoKeystoreResult<Vec<Self>> {
// Plot twist: we always select ALL the persisted groups. Unsure if we want to make it a real API with selection
Self::find_all(conn, EntityFindParams::default()).await
}

async fn count(conn: &mut Self::ConnectionType) -> crate::CryptoKeystoreResult<usize> {
Ok(conn.query_row("SELECT COUNT(*) FROM mls_groups", [], |r| r.get(0))?)
}
}

#[async_trait::async_trait]
impl EntityBase for PersistedMlsGroup {
type ConnectionType = KeystoreDatabaseConnection;
type AutoGeneratedFields = ();
const COLLECTION_NAME: &'static str = "mls_groups";

fn to_missing_key_err_kind() -> MissingKeyErrorKind {
MissingKeyErrorKind::MlsGroup
}

fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity {
crate::transaction::dynamic_dispatch::Entity::PersistedMlsGroup(self)
}
}

#[async_trait::async_trait]
impl EntityTransactionExt for PersistedMlsGroup {
async fn save(&self, transaction: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> {
use rusqlite::ToSql as _;

let state = &self.state;
let parent_id = self.parent_id.as_ref();

Self::ConnectionType::check_buffer_size(state.len())?;
Self::ConnectionType::check_buffer_size(parent_id.map(Vec::len).unwrap_or_default())?;

let zbs = rusqlite::blob::ZeroBlob(state.len() as i32);
let zbpid = rusqlite::blob::ZeroBlob(parent_id.map(Vec::len).unwrap_or_default() as i32);

// Use UPSERT (ON CONFLICT DO UPDATE)
let sql = "
INSERT INTO mls_groups (id_hex, state, parent_id)
VALUES (?, ?, ?)
ON CONFLICT(id_hex) DO UPDATE SET state = excluded.state, parent_id = excluded.parent_id
RETURNING rowid";

let rowid: i64 =
transaction.query_row(sql, [&self.id_hex().to_sql()?, &zbs.to_sql()?, &zbpid.to_sql()?], |r| {
r.get(0)
})?;

let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "state", rowid, false)?;
use std::io::Write as _;
blob.write_all(state)?;
blob.close()?;

let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "parent_id", rowid, false)?;
if let Some(parent_id) = parent_id {
blob.write_all(parent_id)?;
}
blob.close()?;

Ok(())
}

async fn delete_fail_on_missing_id(
transaction: &TransactionWrapper<'_>,
id: StringEntityId<'_>,
) -> CryptoKeystoreResult<()> {
let updated = transaction.execute("DELETE FROM mls_groups WHERE id_hex = ?", [id.as_hex_string()])?;

if updated > 0 {
Ok(())
} else {
Err(Self::to_missing_key_err_kind().into())
}
}
}

#[async_trait::async_trait]
impl PersistedMlsGroupExt for PersistedMlsGroup {
Expand Down

0 comments on commit 8a01ccc

Please sign in to comment.