Skip to content

Commit

Permalink
Merge pull request #22 from blockscout/lok52/map-type
Browse files Browse the repository at this point in the history
Add support for map fields
  • Loading branch information
lok52 authored Mar 12, 2024
2 parents b04dccf + c377d7f commit 0f67be4
Show file tree
Hide file tree
Showing 26 changed files with 4,034 additions and 1,771 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ Cargo.lock

.pre-commit-config.yaml
config.toml

**/*.swagger.yaml
154 changes: 117 additions & 37 deletions actix-prost-build/src/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use std::{
use crate::helpers::extract_type_from_option;
use proc_macro2::{Ident, TokenStream};
use prost_build::Service;
use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor};
use prost_reflect::{
Cardinality, DescriptorPool, DynamicMessage, ExtensionDescriptor, FieldDescriptor, Kind,
MessageDescriptor,
};
use quote::quote;
use syn::{punctuated::Punctuated, Expr, Field, Fields, Lit, Meta, MetaNameValue, Token, Type};

Expand All @@ -20,6 +23,7 @@ pub struct ExtraFieldOptions {
}
#[derive(Debug)]
pub struct ConvertFieldOptions {
pub field: FieldDescriptor,
pub ty: Option<String>,
pub val_override: Option<String>,
pub required: bool,
Expand Down Expand Up @@ -66,11 +70,7 @@ impl TryFrom<(&DescriptorPool, &MessageDescriptor)> for ConvertOptions {
let fields = message
.fields()
.map(|f| {
let options = f.options();
let ext_val = options.get_extension(&fields_extension);
let ext_val = ext_val.as_message().unwrap();

let convert_options = ConvertFieldOptions::from(ext_val);
let convert_options = ConvertFieldOptions::from((&f, &fields_extension));

(String::from(f.name()), convert_options)
})
Expand All @@ -79,12 +79,17 @@ impl TryFrom<(&DescriptorPool, &MessageDescriptor)> for ConvertOptions {
}
}

impl From<&DynamicMessage> for ConvertFieldOptions {
fn from(value: &DynamicMessage) -> Self {
impl From<(&FieldDescriptor, &ExtensionDescriptor)> for ConvertFieldOptions {
fn from((f, ext): (&FieldDescriptor, &ExtensionDescriptor)) -> Self {
let options = f.options();
let ext_val = options.get_extension(ext);
let ext_val = ext_val.as_message().unwrap();

Self {
ty: get_string_field(value, "type"),
val_override: get_string_field(value, "override"),
required: match value.get_field_by_name("required") {
field: f.clone(),
ty: get_string_field(ext_val, "type"),
val_override: get_string_field(ext_val, "override"),
required: match ext_val.get_field_by_name("required") {
Some(v) => v.as_bool().unwrap(),
None => false,
},
Expand Down Expand Up @@ -226,9 +231,13 @@ impl ConversionsGenerator {
_ => {
let convert = &self.convert_prefix;

let from = match field_conversions.len() + extra_field_conversions.len() {
0 => quote!(_from),
_ => quote!(from),
};
quote!(
impl #convert<#from_struct_ident> for #to_struct_ident {
fn try_convert(from: #from_struct_ident) -> Result<Self, String> {
fn try_convert(#from: #from_struct_ident) -> Result<Self, String> {
Ok(Self {
#(#field_conversions,)*
#(#extra_field_conversions,)*
Expand Down Expand Up @@ -297,43 +306,114 @@ impl ConversionsGenerator {
f: &Field,
convert_field: Option<&ConvertFieldOptions>,
res: &mut Vec<TokenStream>,
) -> Option<ProcessedType> {
self.try_process_option(m_type, f, convert_field, res)
.or(self.try_process_map(m_type, f, convert_field, res))
}

fn try_process_option(
&mut self,
m_type: MessageType,
f: &Field,
convert_field: Option<&ConvertFieldOptions>,
res: &mut Vec<TokenStream>,
) -> Option<ProcessedType> {
let name = f.ident.as_ref().unwrap();

// Check if the field contains a nested message
let internal_struct = match extract_type_from_option(&f.ty) {
Some(Type::Path(ty)) => ty
.path
.segments
.first()
.and_then(|ty| self.messages.get(&ty.ident.to_string())),
match extract_type_from_option(&f.ty) {
Some(Type::Path(ty)) => {
let ty = ty.path.segments.first()?;
let rust_struct_name = self.messages.get(&ty.ident.to_string())?.ident.clone();
let new_struct_name =
self.build_internal_nested_struct(m_type, &rust_struct_name, res);
let convert = &self.convert_prefix;
let (ty, conversion) = match convert_field {
Some(ConvertFieldOptions { required: true, .. }) => {
let require_message = format!("field {} is required", name);
(
quote!(#new_struct_name),
quote!(#convert::try_convert(from.#name.ok_or(#require_message)?)?),
)
}
_ => (
quote!(::core::option::Option<#new_struct_name>),
quote!(#convert::try_convert(from.#name)?),
),
};
Some((ty, conversion))
}
_ => None,
}
}

fn try_process_map(
&mut self,
m_type: MessageType,
f: &Field,
convert_field: Option<&ConvertFieldOptions>,
res: &mut Vec<TokenStream>,
) -> Option<ProcessedType> {
let name = f.ident.as_ref().unwrap();

let field_desc = convert_field.map(|cf| &cf.field)?;
let map_type = match (field_desc.cardinality(), field_desc.kind()) {
(Cardinality::Repeated, Kind::Message(m)) => Some(m),
_ => None,
}?;
// Map keys can only be of scalar types, so we search for nested messages only in values
let map_value_type = match map_type.map_entry_value_field().kind() {
Kind::Message(m) => Some(m),
_ => None,
}?;
let map_key_type = map_type.map_entry_key_field().kind();
let map_key_rust_type = match map_key_type {
Kind::String => quote!(::prost::alloc::string::String),
Kind::Int32 => quote!(i32),
Kind::Int64 => quote!(i64),
Kind::Uint32 => quote!(u32),
Kind::Uint64 => quote!(u64),
Kind::Sint32 => quote!(i32),
Kind::Sint64 => quote!(i64),
Kind::Fixed32 => quote!(u32),
Kind::Fixed64 => quote!(u64),
Kind::Sfixed32 => quote!(i32),
Kind::Sfixed64 => quote!(i64),
Kind::Bool => quote!(bool),
_ => panic!("Map key type not supported {:?}", map_key_type),
};
// TODO: Proto name might not be the same as Rust struct name
let rust_struct_name = self.messages.get(map_value_type.name())?.ident.clone();

let new_struct_name = self.build_internal_nested_struct(m_type, &rust_struct_name, res);

// Process the nested message
let ident = &internal_struct.ident;
let convert = &self.convert_prefix;
let map_collection = if let Type::Path(p) = &f.ty {
match p.path.segments.iter().find(|s| s.ident == "HashMap") {
Some(_) => quote!(::std::collections::HashMap),
None => quote!(::std::collections::BTreeMap),
}
} else {
panic!("Type of map field is not a path")
};
let ty = quote!(#map_collection<#map_key_rust_type, #new_struct_name>);
let conversion = quote!(#convert::try_convert(from.#name)?);
Some((ty, conversion))
}

fn build_internal_nested_struct(
&mut self,
m_type: MessageType,
nested_struct_name: &Ident,
res: &mut Vec<TokenStream>,
) -> Ident {
// TODO: could incorrectly detect messages with same name in different packages
let message = self
.descriptors
.all_messages()
.find(|m| *ident == m.name())
.find(|m| *nested_struct_name == m.name())
.unwrap();
let new_struct_name = self.create_convert_struct(m_type, &message, &ident.to_string(), res);

let convert = &self.convert_prefix;
Some(match convert_field {
Some(ConvertFieldOptions { required: true, .. }) => {
let require_message = format!("field {} is required", name);
(
quote!(#new_struct_name),
quote!(#convert::try_convert(from.#name.ok_or(#require_message)?)?),
)
}
_ => (
quote!(::core::option::Option<#new_struct_name>),
quote!(#convert::try_convert(from.#name)?),
),
})
self.create_convert_struct(m_type, &message, &nested_struct_name.to_string(), res)
}

fn process_enum(m_type: MessageType, f: &Field) -> Option<ProcessedType> {
Expand Down
3 changes: 2 additions & 1 deletion actix-prost-build/src/generator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{config::HttpRule, conversions::ConversionsGenerator, method::Method, Config};
use crate::{config::HttpRule, method::Method, Config};
use proc_macro2::TokenStream;
use prost_build::{Service, ServiceGenerator};
use quote::quote;
Expand Down Expand Up @@ -135,6 +135,7 @@ impl ServiceGenerator for ActixGenerator {

#[cfg(feature = "conversions")]
{
use crate::conversions::ConversionsGenerator;
let conversions = ConversionsGenerator::new().ok().map(|mut g| {
g.messages = Rc::clone(&self.messages);
g.create_conversions(&service)
Expand Down
23 changes: 20 additions & 3 deletions convert-trait/src/impls.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::TryConvert;
use std::collections::{BTreeMap, HashMap, HashSet};

impl<T, R: TryConvert<T>> TryConvert<Option<T>> for Option<R> {
fn try_convert(input: Option<T>) -> Result<Self, String> {
Expand All @@ -15,10 +16,26 @@ impl<T, R: TryConvert<T>> TryConvert<Vec<T>> for Vec<R> {
}
}

impl<T, R: TryConvert<T> + std::hash::Hash + Eq> TryConvert<Vec<T>>
for std::collections::HashSet<R>
{
impl<T, R: TryConvert<T> + std::hash::Hash + Eq> TryConvert<Vec<T>> for HashSet<R> {
fn try_convert(input: Vec<T>) -> Result<Self, String> {
input.into_iter().map(TryConvert::try_convert).collect()
}
}

impl<K: Eq + std::hash::Hash, T, R: TryConvert<T>> TryConvert<HashMap<K, T>> for HashMap<K, R> {
fn try_convert(input: HashMap<K, T>) -> Result<Self, String> {
input
.into_iter()
.map(|(k, v)| Ok((k, TryConvert::try_convert(v)?)))
.collect()
}
}

impl<K: std::cmp::Ord, T, R: TryConvert<T>> TryConvert<BTreeMap<K, T>> for BTreeMap<K, R> {
fn try_convert(input: BTreeMap<K, T>) -> Result<Self, String> {
input
.into_iter()
.map(|(k, v)| Ok((k, TryConvert::try_convert(v)?)))
.collect()
}
}
4 changes: 3 additions & 1 deletion tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"
actix-prost = { path = "../actix-prost" }
actix-prost-macros = { path = "../actix-prost-macros" }
async-trait = "0.1"
convert-trait ={ path = "../convert-trait" }
tonic = "0.8"
prost = "0.11"
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
Expand All @@ -15,9 +16,10 @@ actix-web = "4"
http = "0.2"
serde_json = "1.0"
serde_with = { version = "2.0", features = ["base64"] }
ethers = "2.0.14"

[build-dependencies]
actix-prost-build = { path = "../actix-prost-build" }
actix-prost-build = { path = "../actix-prost-build", features = ["conversions"]}
tonic-build = "0.8"
prost-build = "0.11"

Expand Down
22 changes: 10 additions & 12 deletions tests/build.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use actix_prost_build::{ActixGenerator, GeneratorList};
use prost_build::{Config, ServiceGenerator};
use std::path::Path;
use std::{
env,
path::{Path, PathBuf},
};

// custom function to include custom generator
fn compile(
Expand All @@ -11,6 +14,10 @@ fn compile(
let mut config = Config::new();
config
.service_generator(generator)
.file_descriptor_set_path(
PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR environment variable not set"))
.join("file_descriptor_set.bin"),
)
.out_dir("src/proto")
.bytes(["."])
.compile_well_known_types()
Expand All @@ -19,17 +26,6 @@ fn compile(
.protoc_arg("grpc_api_configuration=proto/http_api.yaml,output_format=yaml")
.type_attribute(".", "#[actix_prost_macros::serde]");

// for path in protos.iter() {
// println!("cargo:rerun-if-changed={}", path.as_ref().display())
// }

// for path in includes.iter() {
// // Cargo will watch the **entire** directory recursively. If we
// // could figure out which files are imported by our protos we
// // could specify only those files instead.
// println!("cargo:rerun-if-changed={}", path.as_ref().display())
// }

config.compile_protos(protos, includes)?;
Ok(())
}
Expand All @@ -42,8 +38,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
compile(
&[
"proto/rest.proto",
"proto/simple.proto",
"proto/types.proto",
"proto/errors.proto",
"proto/conversions.proto",
],
&["proto/", "proto/googleapis", "proto/grpc-gateway"],
gens,
Expand Down
38 changes: 38 additions & 0 deletions tests/proto/conversions.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
syntax = "proto3";
package conversions;

import "convert_options.proto";

option go_package = "github.com/blockscout/actix-prost/tests";

service ConversionsRPC { rpc ConvertRPC(ConversionsRequest) returns (ConversionsResponse); }

message Nested {
string address = 3 [ (convert_options.convert) = { type : "ethers::types::Address" } ];
}

message MapValue {
string address = 1 [ (convert_options.convert) = { type : "ethers::types::Address" } ];
}

message ConversionsRequest {
option (convert_options.extra_fields) = { name: "field1", type: "String" };
option (convert_options.extra_fields) = { name: "field2", type: "i32" };
map<string, MapValue> map_field = 1;

enum NestedEnum {
NESTED_OK = 0;
NESTED_ERROR = 1;
}

string query = 2 [ (convert_options.convert) = { override : "Default::default()" } ];
repeated string addresses = 3 [ (convert_options.convert) = { type : "std::collections::HashSet<ethers::types::Address>" } ];
NestedEnum nested_enum = 4;
Nested nested = 5 [ (convert_options.convert) = { required : true } ];
}

message ConversionsResponse {
string address = 1 [ (convert_options.convert) = { type : "ethers::types::Address" } ];
Nested nested = 2;
map<string, MapValue> map_field = 3;
}
20 changes: 20 additions & 0 deletions tests/proto/convert_options.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
syntax = "proto3";

package convert_options;
option go_package = "github.com/blockscout/actix-prost/convert_options";

import "google/protobuf/descriptor.proto";

message ConvertOptions {
string type = 1;
string override = 2;
bool required = 3;
}

message ExtraFieldOptions {
string name = 1;
string type = 2;
}

extend google.protobuf.MessageOptions { repeated ExtraFieldOptions extra_fields = 50000; }
extend google.protobuf.FieldOptions { optional ConvertOptions convert = 50001; }
Loading

0 comments on commit 0f67be4

Please sign in to comment.