Skip to content

Commit

Permalink
fix: generate exception correctly (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
Millione authored Sep 24, 2024
1 parent 8402807 commit c50d2bc
Show file tree
Hide file tree
Showing 11 changed files with 2,353 additions and 1,191 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pilota-build/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pilota-build"
version = "0.11.19"
version = "0.11.20"
edition = "2021"
description = "Compile thrift and protobuf idl into rust code at compile-time."
documentation = "https://docs.rs/pilota-build"
Expand Down
17 changes: 8 additions & 9 deletions pilota-build/src/codegen/thrift/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ impl ThriftBackend {
&self,
helper: &DecodeHelper,
s: &rir::Message,
name: Symbol,
keep: bool,
is_arg: bool,
) -> String {
Expand Down Expand Up @@ -291,7 +292,7 @@ impl ThriftBackend {
}
};

let format_msg = format!("decode struct `{}` field(#{{}}) failed", s.name);
let format_msg = format!("decode struct `{}` field(#{{}}) failed", name);

let mut fields = s
.fields
Expand Down Expand Up @@ -459,7 +460,6 @@ impl CodegenBackend for ThriftBackend {
fn codegen_struct_impl(&self, def_id: DefId, stream: &mut String, s: &Message) {
let keep = self.keep_unknown_fields.contains(&def_id);
let name = self.cx.rust_name(def_id);
let name_str = &**s.name;
let mut encode_fields = self.codegen_encode_fields(&s.fields).join("");
if keep {
encode_fields.push_str(
Expand All @@ -477,10 +477,10 @@ impl CodegenBackend for ThriftBackend {
}
stream.push_str(&self.codegen_impl_message_with_helper(
def_id,
name,
name.clone(),
format! {
r#"let struct_ident =::pilota::thrift::TStructIdentifier {{
name: "{name_str}",
name: "{name}",
}};
__protocol.write_struct_begin(&struct_ident)?;
Expand All @@ -492,10 +492,10 @@ impl CodegenBackend for ThriftBackend {
},
format! {
r#"__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
name: "{name_str}",
name: "{name}",
}}) + {encode_fields_size} __protocol.field_stop_len() + __protocol.struct_end_len()"#
},
|helper| self.codegen_decode(helper, s, keep, self.is_arg(def_id)),
|helper| self.codegen_decode(helper, s, name.clone(), keep, self.is_arg(def_id)),
));
}

Expand Down Expand Up @@ -536,7 +536,6 @@ impl CodegenBackend for ThriftBackend {
None if is_entry_message => self.codegen_entry_enum(def_id, stream, e),
None => {
let name = self.rust_name(def_id);
let name_str = &**e.name;
let mut encode_variants = e
.variants
.iter()
Expand Down Expand Up @@ -603,7 +602,7 @@ impl CodegenBackend for ThriftBackend {
name.clone(),
format! {
r#"__protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier {{
name: "{name_str}",
name: "{name}",
}})?;
match self {{
{encode_variants}
Expand All @@ -614,7 +613,7 @@ impl CodegenBackend for ThriftBackend {
},
format! {
r#"__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
name: "{name_str}",
name: "{name}",
}}) + match self {{
{variants_size}
}} + __protocol.field_stop_len() + __protocol.struct_end_len()"#
Expand Down
2 changes: 1 addition & 1 deletion pilota-build/src/middle/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ impl Context {
}

if !self.change_case || self.names.contains(&def_id) {
return node.name().0.into();
return node.name();
}

match self.node(def_id).unwrap().kind {
Expand Down
133 changes: 89 additions & 44 deletions pilota-build/src/parser/thrift/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub struct ThriftLower {
file_ids_map: FxHashMap<Arc<PathBuf>, FileId>,
include_dirs: Vec<PathBuf>,
packages: FxHashMap<Path, Vec<Arc<PathBuf>>>,
service_name_duplicates: FxHashSet<String>,
}

impl ThriftLower {
Expand All @@ -84,6 +85,7 @@ impl ThriftLower {
file_ids_map: FxHashMap::default(),
include_dirs,
packages: Default::default(),
service_name_duplicates: Default::default(),
}
}

Expand Down Expand Up @@ -115,6 +117,33 @@ impl ThriftLower {
}

fn lower_service(&self, service: &thrift_parser::Service) -> Vec<ir::Item> {
let service_name = if self
.service_name_duplicates
.contains(&service.name.to_upper_camel_case())
{
service.name.to_string()
} else {
service.name.to_upper_camel_case()
};

let mut function_names: FxHashMap<String, Vec<String>> = FxHashMap::default();
service.functions.iter().for_each(|func| {
let name = self
.extract_tags(&func.annotations)
.get::<PilotaName>()
.map(|name| name.0.to_string())
.unwrap_or_else(|| func.name.to_string());
function_names
.entry(name.to_upper_camel_case())
.or_default()
.push(name);
});
let function_name_duplicates = function_names
.iter()
.filter(|(_, v)| v.len() > 1)
.map(|(k, _)| k)
.collect::<FxHashSet<_>>();

let kind = ir::ItemKind::Service(ir::Service {
name: self.lower_ident(&service.name),
extend: service
Expand All @@ -126,28 +155,14 @@ impl ThriftLower {
methods: service
.functions
.iter()
.map(|f| self.lower_method(service, f))
.map(|f| self.lower_method(&service_name, f, &function_name_duplicates))
.collect(),
});
let mut service_item = self.mk_item(kind, Default::default());
let mut result = vec![];

let mut related_items = Vec::default();

let mut seen = FxHashSet::default();
let mut duplicate_function_names = FxHashSet::default();
for name in service.functions.iter().map(|f| {
self.extract_tags(&f.annotations)
.get::<PilotaName>()
.map(|name| &*name.0)
.unwrap_or_else(|| &*f.name)
.to_upper_camel_case()
}) {
if !seen.insert(name.clone()) {
duplicate_function_names.insert(name);
}
}

service.functions.iter().for_each(|f| {
let exception = f
.throws
Expand All @@ -171,21 +186,17 @@ impl ThriftLower {
.collect::<Vec<_>>();

let tags = self.extract_tags(&f.annotations);

let method_name = tags
let name = tags
.get::<PilotaName>()
.map(|name| name.0.to_upper_camel_case())
.unwrap_or_else(|| {
let method_name = f.name.to_upper_camel_case();
if duplicate_function_names.contains(&method_name) {
f.name.to_string()
} else {
method_name
}
});

let name: Ident = format!("{}{}ResultRecv", service.name.as_str(), method_name).into();
.map(|name| name.0.to_string())
.unwrap_or_else(|| f.name.to_string());
let method_name = if function_name_duplicates.contains(&name.to_upper_camel_case()) {
name
} else {
name.to_upper_camel_case()
};

let name: Ident = format!("{}{}ResultRecv", service_name, method_name).into();
let mut tags = self.extract_tags(&f.result_type.1);
tags.remove::<RustWrapperArc>();
let kind = ir::ItemKind::Enum(ir::Enum {
Expand All @@ -201,12 +212,13 @@ impl ThriftLower {
.collect(),
repr: None,
});
related_items.push(name);
related_items.push(name.clone());
let mut tags = Tags::default();
tags.insert(crate::tags::KeepUnknownFields(false));
tags.insert(crate::tags::PilotaName(name.sym.0));
result.push(self.mk_item(kind, tags.into()));

let name: Ident = format!("{}{}ResultSend", service.name.as_str(), method_name).into();
let name: Ident = format!("{}{}ResultSend", service_name, method_name).into();
let kind = ir::ItemKind::Enum(ir::Enum {
name: name.clone(),
variants: std::iter::once(ir::EnumVariant {
Expand All @@ -220,36 +232,38 @@ impl ThriftLower {
.collect(),
repr: None,
});
related_items.push(name);
related_items.push(name.clone());
let mut tags = Tags::default();
tags.insert(crate::tags::KeepUnknownFields(false));
tags.insert(crate::tags::PilotaName(name.sym.0));
result.push(self.mk_item(kind, tags.into()));

if !exception.is_empty() {
let name: Ident =
format!("{}{}Exception", service.name.as_str(), method_name).into();
let name: Ident = format!("{}{}Exception", service_name, method_name).into();
let kind = ir::ItemKind::Enum(ir::Enum {
name: name.clone(),
variants: exception,
repr: None,
});
related_items.push(name);
related_items.push(name.clone());
let mut tags = Tags::default();
tags.insert(crate::tags::KeepUnknownFields(false));
tags.insert(crate::tags::PilotaName(name.sym.0));
result.push(self.mk_item(kind, tags.into()));
}

let name: Ident = format!("{}{}ArgsSend", service.name.as_str(), method_name).into();
let name: Ident = format!("{}{}ArgsSend", service_name, method_name).into();
let kind = ir::ItemKind::Message(ir::Message {
name: name.clone(),
fields: f.arguments.iter().map(|a| self.lower_field(a)).collect(),
});
related_items.push(name);
related_items.push(name.clone());
let mut tags = Tags::default();
tags.insert(crate::tags::KeepUnknownFields(false));
tags.insert(crate::tags::PilotaName(name.sym.0));
result.push(self.mk_item(kind, tags.into()));

let name: Ident = format!("{}{}ArgsRecv", service.name.as_str(), method_name).into();
let name: Ident = format!("{}{}ArgsRecv", service_name, method_name).into();
let kind = ir::ItemKind::Message(ir::Message {
name: name.clone(),
fields: f
Expand All @@ -262,9 +276,10 @@ impl ThriftLower {
})
.collect(),
});
related_items.push(name);
related_items.push(name.clone());
let mut tags: Tags = Tags::default();
tags.insert(crate::tags::KeepUnknownFields(false));
tags.insert(crate::tags::PilotaName(name.sym.0));
result.push(self.mk_item(kind, tags.into()));
});

Expand All @@ -275,9 +290,21 @@ impl ThriftLower {

fn lower_method(
&self,
service: &thrift_parser::Service,
service_name: &String,
method: &thrift_parser::Function,
function_name_duplicates: &FxHashSet<&String>,
) -> ir::Method {
let tags = self.extract_tags(&method.annotations);
let name = tags
.get::<PilotaName>()
.map(|name| name.0.to_string())
.unwrap_or_else(|| method.name.to_string());
let method_name = if function_name_duplicates.contains(&name.to_upper_camel_case()) {
name
} else {
name.to_upper_camel_case()
};

ir::Method {
name: self.lower_ident(&method.name),
args: method
Expand All @@ -292,15 +319,14 @@ impl ThriftLower {
.collect(),
ret: self.lower_ty(&method.result_type),
oneway: method.oneway,
tags: self.extract_tags(&method.annotations).into(),
tags: tags.into(),
exceptions: if method.throws.is_empty() {
None
} else {
Some(Path {
segments: Arc::from([Ident::from(format!(
"{}{}Exception",
service.name.to_upper_camel_case().as_str(),
method.name.to_upper_camel_case(),
service_name, method_name,
))]),
})
},
Expand Down Expand Up @@ -596,7 +622,23 @@ impl Lower<Arc<thrift_parser::File>> for ThriftLower {
.or_default()
.push(f.path.clone());

ir::File {
let mut service_names: FxHashMap<String, Vec<String>> = FxHashMap::default();
f.items.iter().for_each(|item| {
if let thrift_parser::Item::Service(service) = item {
service_names
.entry(service.name.to_upper_camel_case())
.or_default()
.push(service.name.to_string());
}
});
this.service_name_duplicates.extend(
service_names
.into_iter()
.filter(|(_, v)| v.len() > 1)
.map(|(k, _)| k),
);

let ret = ir::File {
package: file_package,
items: f
.items
Expand All @@ -607,7 +649,10 @@ impl Lower<Arc<thrift_parser::File>> for ThriftLower {
.collect(),
id: file_id,
uses,
}
};

this.service_name_duplicates.clear();
ret
});

file.id
Expand Down
Loading

0 comments on commit c50d2bc

Please sign in to comment.