From 0bd36765b1c81cd396d68dd5e5f8d0200bb9b607 Mon Sep 17 00:00:00 2001 From: SimonThormeyer Date: Tue, 21 Jan 2025 17:03:04 +0100 Subject: [PATCH] feat: support entity derive for tables with hex ids --- .../src/entity_derive/derive_impl.rs | 152 ++++++++++++++---- crypto-macros/src/entity_derive/mod.rs | 42 ++++- crypto-macros/src/entity_derive/parse.rs | 79 +++++++-- 3 files changed, 228 insertions(+), 45 deletions(-) diff --git a/crypto-macros/src/entity_derive/derive_impl.rs b/crypto-macros/src/entity_derive/derive_impl.rs index 8992b541e7..1e0d85e199 100644 --- a/crypto-macros/src/entity_derive/derive_impl.rs +++ b/crypto-macros/src/entity_derive/derive_impl.rs @@ -1,4 +1,4 @@ -use crate::entity_derive::{ColumnType, KeyStoreEntityFlattened}; +use crate::entity_derive::{IdColumnType, IdTransformation, KeyStoreEntityFlattened}; use quote::quote; impl quote::ToTokens for KeyStoreEntityFlattened { @@ -50,30 +50,54 @@ impl KeyStoreEntityFlattened { id, id_type, id_name, + id_transformation, blob_columns, blob_column_names, all_columns, + optional_blob_columns, + optional_blob_column_names, .. } = self; - let string_id_conversion = matches!(id_type, ColumnType::String).then(|| { + let string_id_conversion = (*id_type == IdColumnType::String).then(|| { quote! { let #id: String = id.try_into()?; } }); let id_to_byte_slice = match id_type { - ColumnType::String => quote! {self.#id.as_bytes() }, - ColumnType::Bytes => quote! { &self.#id[..] }, + IdColumnType::String => quote! {self.#id.as_bytes() }, + IdColumnType::Bytes => quote! { &self.#id.as_slice() }, }; - let id_field_construct_self = match id_type { - ColumnType::String => quote! { #id, }, - ColumnType::Bytes => quote! { #id: id.to_bytes(), }, + let id_field_find_one = match id_type { + IdColumnType::String => quote! { #id, }, + IdColumnType::Bytes => quote! { #id: id.to_bytes(), }, }; let id_slice = match id_type { - ColumnType::String => quote! { #id.as_str() }, - ColumnType::Bytes => quote! { #id.as_slice() }, + IdColumnType::String => quote! { #id.as_str() }, + IdColumnType::Bytes => quote! { #id.as_slice() }, }; + + let id_input_transformed = match id_transformation { + Some(IdTransformation::Hex) => quote! { id.as_hex_string() }, + Some(IdTransformation::Sha256) => todo!(), + None => id_slice, + }; + + let destructure_row = match id_transformation { + Some(IdTransformation::Hex) => quote! { let (rowid, #id): (_, String) = row?; }, + Some(IdTransformation::Sha256) => todo!(), + None => quote! { let (rowid, #id) = row?; }, + }; + + let id_from_transformed = match id_transformation { + Some(IdTransformation::Hex) => { + quote! { let #id = ::id_from_hex(&#id)?; } + } + Some(IdTransformation::Sha256) => todo!(), + None => quote! {}, + }; + let find_all_query = format!("SELECT rowid, {id_name} FROM {collection_name} "); let find_one_query = format!("SELECT rowid FROM {collection_name} WHERE {id_name} = ?"); @@ -100,15 +124,30 @@ impl KeyStoreEntityFlattened { let mut rows = stmt.query_map([], |r| Ok((r.get(0)?, r.get(1)?)))?; use std::io::Read as _; rows.map(|row| { - let (rowid, #id) = row?; + #destructure_row + #id_from_transformed #( - let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #blob_column_names, rowid, false)?; - let mut #blob_columns = vec![]; + let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #blob_column_names, rowid, true)?; + let mut #blob_columns = Vec::with_capacity(blob.len()); blob.read_to_end(&mut #blob_columns)?; blob.close()?; )* + #( + let mut #optional_blob_columns = None; + if let Ok(mut blob) = + transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #optional_blob_column_names, rowid, true) + { + if !blob.is_empty() { + let mut blob_data = Vec::with_capacity(blob.len()); + blob.read_to_end(&mut blob_data)?; + #optional_blob_columns.replace(blob_data); + } + blob.close()?; + } + )* + Ok(Self { #id #( , #all_columns @@ -127,26 +166,43 @@ impl KeyStoreEntityFlattened { #string_id_conversion - let mut row_id = transaction - .query_row(&#find_one_query, [#id_slice], |r| { + let mut rowid: Option = transaction + .query_row(&#find_one_query, [#id_input_transformed], |r| { r.get::<_, i64>(0) }) .optional()?; - if let Some(rowid) = row_id.take() { + use std::io::Read as _; + if let Some(rowid) = rowid.take() { #( let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #blob_column_names, rowid, true)?; - use std::io::Read as _; let mut #blob_columns = Vec::with_capacity(blob.len()); blob.read_to_end(&mut #blob_columns)?; blob.close()?; )* + #( + let mut #optional_blob_columns = None; + if let Ok(mut blob) = + transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #optional_blob_column_names, rowid, true) + { + if !blob.is_empty() { + let mut blob_data = Vec::with_capacity(blob.len()); + blob.read_to_end(&mut blob_data)?; + #optional_blob_columns.replace(blob_data); + } + blob.close()?; + } + )* + Ok(Some(Self { - #id_field_construct_self + #id_field_find_one #( #blob_columns, )* + #( + #optional_blob_columns, + )* })) } else { Ok(None) @@ -172,8 +228,8 @@ impl KeyStoreEntityFlattened { } = self; let id_to_byte_slice = match id_type { - ColumnType::String => quote! {self.#id.as_bytes() }, - ColumnType::Bytes => quote! { self.#id.as_slice() }, + IdColumnType::String => quote! {self.#id.as_bytes() }, + IdColumnType::Bytes => quote! { self.#id.as_slice() }, }; quote! { @@ -226,6 +282,9 @@ impl KeyStoreEntityFlattened { all_column_names, blob_columns, blob_column_names, + optional_blob_columns, + optional_blob_column_names, + id_transformation, no_upsert, id_type, .. @@ -246,16 +305,34 @@ impl KeyStoreEntityFlattened { .collect::>() .join(", "); + let import_id_string_ext = match id_transformation { + Some(IdTransformation::Hex) => quote! { use crate::entities::EntityIdStringExt as _; }, + Some(IdTransformation::Sha256) => todo!(), + None => quote! {}, + }; + let upsert_query = format!( "INSERT INTO {collection_name} ({id_name}, {column_list}) VALUES (?{}){upsert_postfix} RETURNING rowid", ", ?".repeat(self.all_columns.len()), ); + let self_id_transformed = match id_transformation { + Some(IdTransformation::Hex) => quote! { self.id_hex() }, + Some(IdTransformation::Sha256) => todo!(), + None => quote! { self.#id }, + }; + let delete_query = format!("DELETE FROM {collection_name} WHERE {id_name} = ?"); let id_slice_delete = match id_type { - ColumnType::String => quote! { id.try_as_str()? }, - ColumnType::Bytes => quote! { id.as_slice() }, + IdColumnType::String => quote! { id.try_as_str()? }, + IdColumnType::Bytes => quote! { id.as_slice() }, + }; + + let id_input_transformed_delete = match id_transformation { + Some(IdTransformation::Hex) => quote! { id.as_hex_string() }, + Some(IdTransformation::Sha256) => todo!(), + None => id_slice_delete, }; quote! { @@ -274,27 +351,38 @@ impl KeyStoreEntityFlattened { #( crate::connection::KeystoreDatabaseConnection::check_buffer_size(self.#blob_columns.len())?; )* + #( + crate::connection::KeystoreDatabaseConnection::check_buffer_size( + self.#optional_blob_columns.as_ref().map(|v| v.len()).unwrap_or_default() + )?; + )* + + #import_id_string_ext let sql = #upsert_query; - let row_id_result: Result = + let rowid_result: Result = transaction.query_row(&sql, [ - self.#id.to_sql()? + #self_id_transformed.to_sql()? #( , rusqlite::blob::ZeroBlob(self.#blob_columns.len() as i32).to_sql()? )* + #( + , + rusqlite::blob::ZeroBlob(self.#optional_blob_columns.as_ref().map(|v| v.len() as i32).unwrap_or_default()).to_sql()? + )* ], |r| r.get(0)); use std::io::Write as _; - match row_id_result { - Ok(row_id) => { + match rowid_result { + Ok(rowid) => { #( let mut blob = transaction.blob_open( rusqlite::DatabaseName::Main, #collection_name, #blob_column_names, - row_id, + rowid, false, )?; @@ -302,6 +390,14 @@ impl KeyStoreEntityFlattened { blob.close()?; )* + #( + let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #optional_blob_column_names, rowid, false)?; + if let Some(#optional_blob_columns) = self.#optional_blob_columns.as_ref() { + blob.write_all(#optional_blob_columns)?; + } + blob.close()?; + )* + Ok(()) } Err(rusqlite::Error::SqliteFailure(e, _)) if e.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE => { @@ -316,9 +412,9 @@ impl KeyStoreEntityFlattened { id: crate::entities::StringEntityId<'_>, ) -> crate::CryptoKeystoreResult<()> { use crate::entities::EntityBase as _; - let updated = transaction.execute(&#delete_query, [#id_slice_delete])?; + let deleted = transaction.execute(&#delete_query, [#id_input_transformed_delete])?; - if updated > 0 { + if deleted > 0 { Ok(()) } else { Err(Self::to_missing_key_err_kind().into()) diff --git a/crypto-macros/src/entity_derive/mod.rs b/crypto-macros/src/entity_derive/mod.rs index 8abf4db1e3..ce00912a52 100644 --- a/crypto-macros/src/entity_derive/mod.rs +++ b/crypto-macros/src/entity_derive/mod.rs @@ -35,11 +35,20 @@ impl KeyStoreEntity { .map(|column| column.name.clone()) .collect::>(); - let blob_column_names = blob_columns.iter().map(ToString::to_string).collect(); + let optional_blob_columns = self + .columns + .0 + .iter() + .filter(|column| column.column_type == ColumnType::OptionalBytes) + .map(|column| column.name.clone()) + .collect::>(); + let all_column_names = all_columns.iter().map(ToString::to_string).collect(); + let blob_column_names = blob_columns.iter().map(ToString::to_string).collect(); + let optional_blob_column_names = optional_blob_columns.iter().map(ToString::to_string).collect(); let id = self.id.name; - let id_name = id.to_string(); + let id_name = self.id.column_name.unwrap_or_else(|| id.to_string()); let id_type = self.id.column_type; KeyStoreEntityFlattened { @@ -49,10 +58,13 @@ impl KeyStoreEntity { id, id_type, id_name, + id_transformation: self.id.transformation, all_columns, all_column_names, blob_columns, blob_column_names, + optional_blob_columns, + optional_blob_column_names, } } } @@ -64,19 +76,36 @@ pub(super) struct KeyStoreEntityFlattened { collection_name: String, id: Ident, id_name: String, - id_type: ColumnType, + id_type: IdColumnType, + id_transformation: Option, all_columns: Vec, all_column_names: Vec, blob_columns: Vec, blob_column_names: Vec, + optional_blob_columns: Vec, + optional_blob_column_names: Vec, no_upsert: bool, } -// Now identical to column, but -// subject to change once more diverse entities are supported. +#[derive(PartialEq, Eq)] +enum IdColumnType { + String, + Bytes, +} + struct IdColumn { name: Ident, - column_type: ColumnType, + column_type: IdColumnType, + /// Only present if it differs from the name + column_name: Option, + /// If the ID cannot be stored as-is because of indexing limitations + transformation: Option, +} + +enum IdTransformation { + Hex, + #[expect(dead_code)] + Sha256, } struct Columns(Vec); @@ -90,4 +119,5 @@ struct Column { enum ColumnType { String, Bytes, + OptionalBytes, } diff --git a/crypto-macros/src/entity_derive/parse.rs b/crypto-macros/src/entity_derive/parse.rs index bafe1cbe93..4bb16ab6be 100644 --- a/crypto-macros/src/entity_derive/parse.rs +++ b/crypto-macros/src/entity_derive/parse.rs @@ -1,4 +1,4 @@ -use crate::entity_derive::{Column, ColumnType, Columns, IdColumn, KeyStoreEntity}; +use crate::entity_derive::{Column, ColumnType, Columns, IdColumn, IdColumnType, IdTransformation, KeyStoreEntity}; use heck::ToSnakeCase; use proc_macro2::{Ident, Span}; use quote::ToTokens; @@ -74,28 +74,79 @@ impl IdColumn { fn parse(named_fields: &FieldsNamed) -> syn::Result { let mut id = None; let mut implicit_id = None; - for field in named_fields.named.iter() { + + for field in &named_fields.named { let name = field .ident .as_ref() .expect("named fields always have identifiers") .clone(); - let column_type = ColumnType::parse(&field.ty)?; - if field.attrs.iter().any(|attr| attr.path().is_ident("id")) { - if id.is_some() { + if let Some(attr) = field.attrs.iter().find(|a| a.path().is_ident("id")) { + let mut column_name = None; + let mut transformation = None; + let column_type = IdColumnType::parse(&field.ty)?; + + if let Ok(list) = attr.meta.require_list() { + list.parse_nested_meta(|meta| { + match meta.path.require_ident()?.to_string().as_str() { + "column" => { + meta.input.parse::()?; + column_name = Some(meta.input.parse::()?.value()); + } + "hex" => transformation = Some(IdTransformation::Hex), + _ => return Err(syn::Error::new_spanned(meta.path, "unknown argument")), + } + Ok(()) + })?; + } + + if id + .replace(IdColumn { + name, + column_type, + column_name, + transformation, + }) + .is_some() + { return Err(syn::Error::new_spanned( field, - "Ambiguous `#[id] attributes. Provide exactly one.", + "Ambiguous `#[id]` attributes. Provide exactly one.", )); } - id = Some(IdColumn { name, column_type }); } else if name == "id" { - implicit_id = Some(IdColumn { name, column_type }); + let column_type = IdColumnType::parse(&field.ty)?; + implicit_id = Some(IdColumn { + name, + column_type, + column_name: None, + transformation: None, + }); } } - id = id.or(implicit_id); - id.ok_or(syn::Error::new_spanned(named_fields, "No `#[id]` attribute provided.")) + + id.or(implicit_id).ok_or_else(|| { + syn::Error::new_spanned( + named_fields, + "No field named `id` or annotated `#[id]` attribute provided.", + ) + }) + } +} + +impl IdColumnType { + fn parse(ty: &Type) -> Result { + let mut type_string = ty.to_token_stream().to_string(); + type_string.retain(|c| !c.is_whitespace()); + match type_string.as_str() { + "String" | "std::string::String" => Ok(Self::String), + "Vec" | "std::vec::Vec" => Ok(Self::Bytes), + type_string => Err(syn::Error::new_spanned( + ty, + format!("Expected `String` or `Vec`, not `{type_string}`."), + )), + } } } @@ -137,9 +188,15 @@ impl ColumnType { match type_string.as_str() { "String" | "std::string::String" => Ok(Self::String), "Vec" | "std::vec::Vec" => Ok(Self::Bytes), + "Option>" + | "Option>" + | "core::option::Option>" + | "core::option::Option>" + | "std::option::Option>" + | "std::option::Option>" => Ok(Self::OptionalBytes), type_string => Err(syn::Error::new_spanned( ty, - format!("Expected `String` or `Vec`, not `{type_string}`."), + format!("Expected `String`, `Vec`, or `Option>` not `{type_string}`."), )), } }