Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonThormeyer committed Jan 30, 2025
1 parent 832fea1 commit 5ea09ea
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 199 deletions.
22 changes: 13 additions & 9 deletions crypto-macros/src/entity_derive/derive_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ impl KeyStoreEntityFlattened {
};

let id_from_transformed = match id_transformation {
Some(IdTransformation::Hex) => quote! { let #id = Self::id_from_hex(id)?; },
Some(IdTransformation::Hex) => {
quote! { let #id = <Self as crate::entities::EntityIdStringExt>::id_from_hex(id)?; }
}
Some(IdTransformation::Sha256) => todo!(),
None => quote! {},
};
Expand Down Expand Up @@ -156,14 +158,14 @@ impl KeyStoreEntityFlattened {

#string_id_conversion

let mut row_id = transaction
let mut rowid: Option<i64> = transaction
.query_row(&#find_one_query, [#transformed_id], |r| {
r.get::<_, i64>(0)
})
.optional()?;

use std::io::Read as _;
if let Some(rowid) = row_id.take() {
if let Some(rowid) = rowid.take() {
#(
let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #blob_column_names, rowid, true)?;
let mut #blob_columns = Vec::with_capacity(blob.len());
Expand Down Expand Up @@ -329,12 +331,14 @@ impl KeyStoreEntityFlattened {
crate::connection::KeystoreDatabaseConnection::check_buffer_size(self.#blob_columns.len())?;
)*
#(
crate::connection::KeystoreDatabaseConnection::check_buffer_size(self.#optional_blob_columns.map(Vec::len).unwrap_or_default())?;
crate::connection::KeystoreDatabaseConnection::check_buffer_size(
self.#optional_blob_columns.as_ref().map(|v| v.len()).unwrap_or_default()
)?;
)*

let sql = #upsert_query;

let row_id_result: Result<i64, rusqlite::Error> =
let rowid_result: Result<i64, rusqlite::Error> =
transaction.query_row(&sql, [
self.#id.to_sql()?
#(
Expand All @@ -343,19 +347,19 @@ impl KeyStoreEntityFlattened {
)*
#(
,
rusqlite::blob::ZeroBlob(self.#optional_blob_columns.map(Vec::len).unwrap_or_default()).to_sql()?
rusqlite::blob::ZeroBlob(self.#optional_blob_columns.as_ref().map(|v| v.len() as i32).unwrap_or_default()).to_sql()?
)*
], |r| r.get(0));

use std::io::Write as _;
match row_id_result {
Ok(row_id) => {
match rowid_result {
Ok(rowid) => {
#(
let mut blob = transaction.blob_open(
rusqlite::DatabaseName::Main,
#collection_name,
#blob_column_names,
row_id,
rowid,
false,
)?;

Expand Down
31 changes: 16 additions & 15 deletions crypto-macros/src/entity_derive/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,23 @@ impl IdColumn {

if let Some(attr) = field.attrs.iter().find(|attr| attr.path().is_ident("id")) {
let meta = &attr.meta;
let list = meta.require_list()?;
list.parse_nested_meta(|meta| {
let ident = meta.path.require_ident()?;
match ident.to_string().as_str() {
"column" => {
meta.input.parse::<Token![=]>()?;
column_name = Some(meta.input.parse::<syn::LitStr>()?.value());
Ok(())
if let Ok(list) = meta.require_list() {
list.parse_nested_meta(|meta| {
let ident = meta.path.require_ident()?;
match ident.to_string().as_str() {
"column" => {
meta.input.parse::<Token![=]>()?;
column_name = Some(meta.input.parse::<syn::LitStr>()?.value());
Ok(())
}
"hex" => {
transformation = Some(IdTransformation::Hex);
Ok(())
}
_ => Err(syn::Error::new_spanned(ident, format!("unknown argument: {ident}"))),
}
"hex" => {
transformation = Some(IdTransformation::Hex);
Ok(())
}
_ => Err(syn::Error::new_spanned(ident, format!("unknown argument: {ident}"))),
}
})?;
})?;
}

let column_type = IdColumnType::parse(&field.ty)?;

Expand Down
1 change: 1 addition & 0 deletions keystore/src/entities/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use zeroize::Zeroize;
derive(serde::Serialize, serde::Deserialize)
)]
pub struct PersistedMlsGroup {
#[id(hex, column = "id_hex")]
pub id: Vec<u8>,
pub state: Vec<u8>,
pub parent_id: Option<Vec<u8>>,
Expand Down
177 changes: 2 additions & 175 deletions keystore/src/entities/platform/generic/mls/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,182 +15,9 @@
// along with this program. If not, see http://www.gnu.org/licenses/.

use crate::{
connection::TransactionWrapper,
entities::{EntityIdStringExt, EntityTransactionExt},
entities::{Entity, EntityBase, PersistedMlsGroup, PersistedMlsGroupExt},
CryptoKeystoreResult,
};
use crate::{
connection::{DatabaseConnection, KeystoreDatabaseConnection},
entities::{Entity, EntityBase, EntityFindParams, PersistedMlsGroup, PersistedMlsGroupExt, StringEntityId},
CryptoKeystoreResult, MissingKeyErrorKind,
};

#[async_trait::async_trait]
impl Entity for PersistedMlsGroup {
fn id_raw(&self) -> &[u8] {
self.id.as_slice()
}
async fn find_all(conn: &mut Self::ConnectionType, params: EntityFindParams) -> CryptoKeystoreResult<Vec<Self>> {
let transaction = conn.transaction()?;
let query: String = format!("SELECT rowid, id_hex FROM mls_groups {}", params.to_sql());

let mut stmt = transaction.prepare_cached(&query)?;
let mut rows = stmt.query_map([], |r| {
let rowid: i64 = r.get(0)?;
let id_hex: String = r.get(1)?;
Ok((rowid, id_hex))
})?;
let entities = rows.try_fold(Vec::new(), |mut acc, row_result| {
use std::io::Read as _;
let (rowid, id_hex) = row_result?;

let id = Self::id_from_hex(&id_hex)?;

let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "state", rowid, true)?;
let mut state = vec![];
blob.read_to_end(&mut state)?;
blob.close()?;

let mut parent_id = None;
if let Ok(mut blob) =
transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "parent_id", rowid, true)
{
if !blob.is_empty() {
let mut tmp = Vec::with_capacity(blob.len());
blob.read_to_end(&mut tmp)?;
parent_id.replace(tmp);
}
blob.close()?;
}

acc.push(Self { id, parent_id, state });
crate::CryptoKeystoreResult::Ok(acc)
})?;

Ok(entities)
}

async fn find_one(
conn: &mut Self::ConnectionType,
id: &StringEntityId,
) -> crate::CryptoKeystoreResult<Option<Self>> {
use rusqlite::OptionalExtension as _;
let transaction = conn.transaction()?;
let mut rowid: Option<i64> = transaction
.query_row(
"SELECT rowid FROM mls_groups WHERE id_hex = ?",
[id.as_hex_string()],
|r| r.get::<_, i64>(0),
)
.optional()?;

if let Some(rowid) = rowid.take() {
let id = id.as_slice().to_vec();

use std::io::Read as _;
let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "state", rowid, true)?;
let mut state = Vec::with_capacity(blob.len());
blob.read_to_end(&mut state)?;
blob.close()?;

let mut parent_id = None;
if let Ok(mut blob) =
transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "parent_id", rowid, true)
{
if !blob.is_empty() {
let mut tmp = Vec::with_capacity(blob.len());
blob.read_to_end(&mut tmp)?;
parent_id.replace(tmp);
}
blob.close()?;
}

Ok(Some(Self { id, parent_id, state }))
} else {
Ok(None)
}
}

async fn find_many(
conn: &mut Self::ConnectionType,
_ids: &[StringEntityId],
) -> crate::CryptoKeystoreResult<Vec<Self>> {
// Plot twist: we always select ALL the persisted groups. Unsure if we want to make it a real API with selection
Self::find_all(conn, EntityFindParams::default()).await
}

async fn count(conn: &mut Self::ConnectionType) -> crate::CryptoKeystoreResult<usize> {
Ok(conn.query_row("SELECT COUNT(*) FROM mls_groups", [], |r| r.get(0))?)
}
}

#[async_trait::async_trait]
impl EntityBase for PersistedMlsGroup {
type ConnectionType = KeystoreDatabaseConnection;
type AutoGeneratedFields = ();
const COLLECTION_NAME: &'static str = "mls_groups";

fn to_missing_key_err_kind() -> MissingKeyErrorKind {
MissingKeyErrorKind::MlsGroup
}

fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity {
crate::transaction::dynamic_dispatch::Entity::PersistedMlsGroup(self)
}
}

#[async_trait::async_trait]
impl EntityTransactionExt for PersistedMlsGroup {
async fn save(&self, transaction: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> {
use rusqlite::ToSql as _;

let state = &self.state;
let parent_id = self.parent_id.as_ref();

Self::ConnectionType::check_buffer_size(state.len())?;
Self::ConnectionType::check_buffer_size(parent_id.map(Vec::len).unwrap_or_default())?;

let zbs = rusqlite::blob::ZeroBlob(state.len() as i32);
let zbpid = rusqlite::blob::ZeroBlob(parent_id.map(Vec::len).unwrap_or_default() as i32);

// Use UPSERT (ON CONFLICT DO UPDATE)
let sql = "
INSERT INTO mls_groups (id_hex, state, parent_id)
VALUES (?, ?, ?)
ON CONFLICT(id_hex) DO UPDATE SET state = excluded.state, parent_id = excluded.parent_id
RETURNING rowid";

let rowid: i64 =
transaction.query_row(sql, [&self.id_hex().to_sql()?, &zbs.to_sql()?, &zbpid.to_sql()?], |r| {
r.get(0)
})?;

let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "state", rowid, false)?;
use std::io::Write as _;
blob.write_all(state)?;
blob.close()?;

let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_groups", "parent_id", rowid, false)?;
if let Some(parent_id) = parent_id {
blob.write_all(parent_id)?;
}
blob.close()?;

Ok(())
}

async fn delete_fail_on_missing_id(
transaction: &TransactionWrapper<'_>,
id: StringEntityId<'_>,
) -> CryptoKeystoreResult<()> {
let updated = transaction.execute("DELETE FROM mls_groups WHERE id_hex = ?", [id.as_hex_string()])?;

if updated > 0 {
Ok(())
} else {
Err(Self::to_missing_key_err_kind().into())
}
}
}

#[async_trait::async_trait]
impl PersistedMlsGroupExt for PersistedMlsGroup {
Expand Down

0 comments on commit 5ea09ea

Please sign in to comment.