Skip to content

Commit

Permalink
fix: generate exception correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
Millione committed Sep 23, 2024
1 parent 8402807 commit c1dca4b
Show file tree
Hide file tree
Showing 11 changed files with 2,133 additions and 1,156 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pilota-build/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pilota-build"
version = "0.11.19"
version = "0.11.20"
edition = "2021"
description = "Compile thrift and protobuf idl into rust code at compile-time."
documentation = "https://docs.rs/pilota-build"
Expand Down
17 changes: 8 additions & 9 deletions pilota-build/src/codegen/thrift/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ impl ThriftBackend {
&self,
helper: &DecodeHelper,
s: &rir::Message,
name: Symbol,
keep: bool,
is_arg: bool,
) -> String {
Expand Down Expand Up @@ -291,7 +292,7 @@ impl ThriftBackend {
}
};

let format_msg = format!("decode struct `{}` field(#{{}}) failed", s.name);
let format_msg = format!("decode struct `{}` field(#{{}}) failed", name);

let mut fields = s
.fields
Expand Down Expand Up @@ -459,7 +460,6 @@ impl CodegenBackend for ThriftBackend {
fn codegen_struct_impl(&self, def_id: DefId, stream: &mut String, s: &Message) {
let keep = self.keep_unknown_fields.contains(&def_id);
let name = self.cx.rust_name(def_id);
let name_str = &**s.name;
let mut encode_fields = self.codegen_encode_fields(&s.fields).join("");
if keep {
encode_fields.push_str(
Expand All @@ -477,10 +477,10 @@ impl CodegenBackend for ThriftBackend {
}
stream.push_str(&self.codegen_impl_message_with_helper(
def_id,
name,
name.clone(),
format! {
r#"let struct_ident =::pilota::thrift::TStructIdentifier {{
name: "{name_str}",
name: "{name}",
}};
__protocol.write_struct_begin(&struct_ident)?;
Expand All @@ -492,10 +492,10 @@ impl CodegenBackend for ThriftBackend {
},
format! {
r#"__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
name: "{name_str}",
name: "{name}",
}}) + {encode_fields_size} __protocol.field_stop_len() + __protocol.struct_end_len()"#
},
|helper| self.codegen_decode(helper, s, keep, self.is_arg(def_id)),
|helper| self.codegen_decode(helper, s, name.clone(), keep, self.is_arg(def_id)),
));
}

Expand Down Expand Up @@ -536,7 +536,6 @@ impl CodegenBackend for ThriftBackend {
None if is_entry_message => self.codegen_entry_enum(def_id, stream, e),
None => {
let name = self.rust_name(def_id);
let name_str = &**e.name;
let mut encode_variants = e
.variants
.iter()
Expand Down Expand Up @@ -603,7 +602,7 @@ impl CodegenBackend for ThriftBackend {
name.clone(),
format! {
r#"__protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier {{
name: "{name_str}",
name: "{name}",
}})?;
match self {{
{encode_variants}
Expand All @@ -614,7 +613,7 @@ impl CodegenBackend for ThriftBackend {
},
format! {
r#"__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {{
name: "{name_str}",
name: "{name}",
}}) + match self {{
{variants_size}
}} + __protocol.field_stop_len() + __protocol.struct_end_len()"#
Expand Down
2 changes: 1 addition & 1 deletion pilota-build/src/middle/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ impl Context {
}

if !self.change_case || self.names.contains(&def_id) {
return node.name().0.into();
return node.name().0.replace(' ', "").into();
}

match self.node(def_id).unwrap().kind {
Expand Down
46 changes: 13 additions & 33 deletions pilota-build/src/parser/thrift/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use std::{path::PathBuf, str::FromStr, sync::Arc};

use faststr::FastStr;
use heck::ToUpperCamelCase;
use itertools::Itertools;
use normpath::PathExt;
use pilota_thrift_parser as thrift_parser;
use pilota_thrift_parser::parser::Parser as _;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hash::FxHashMap;
use salsa::ParallelDatabase;
use thrift_parser::Annotations;

Expand Down Expand Up @@ -134,20 +133,6 @@ impl ThriftLower {

let mut related_items = Vec::default();

let mut seen = FxHashSet::default();
let mut duplicate_function_names = FxHashSet::default();
for name in service.functions.iter().map(|f| {
self.extract_tags(&f.annotations)
.get::<PilotaName>()
.map(|name| &*name.0)
.unwrap_or_else(|| &*f.name)
.to_upper_camel_case()
}) {
if !seen.insert(name.clone()) {
duplicate_function_names.insert(name);
}
}

service.functions.iter().for_each(|f| {
let exception = f
.throws
Expand All @@ -174,17 +159,11 @@ impl ThriftLower {

let method_name = tags
.get::<PilotaName>()
.map(|name| name.0.to_upper_camel_case())
.unwrap_or_else(|| {
let method_name = f.name.to_upper_camel_case();
if duplicate_function_names.contains(&method_name) {
f.name.to_string()
} else {
method_name
}
});
.map(|name| &*name.0)
.unwrap_or_else(|| &*f.name);

let name: Ident = format!("{}{}ResultRecv", service.name.as_str(), method_name).into();
let name: Ident =
format!("{} {} ResultRecv", service.name.as_str(), method_name).into();

let mut tags = self.extract_tags(&f.result_type.1);
tags.remove::<RustWrapperArc>();
Expand All @@ -206,7 +185,8 @@ impl ThriftLower {
tags.insert(crate::tags::KeepUnknownFields(false));
result.push(self.mk_item(kind, tags.into()));

let name: Ident = format!("{}{}ResultSend", service.name.as_str(), method_name).into();
let name: Ident =
format!("{} {} ResultSend", service.name.as_str(), method_name).into();
let kind = ir::ItemKind::Enum(ir::Enum {
name: name.clone(),
variants: std::iter::once(ir::EnumVariant {
Expand All @@ -227,7 +207,7 @@ impl ThriftLower {

if !exception.is_empty() {
let name: Ident =
format!("{}{}Exception", service.name.as_str(), method_name).into();
format!("{} {} Exception", service.name.as_str(), method_name).into();
let kind = ir::ItemKind::Enum(ir::Enum {
name: name.clone(),
variants: exception,
Expand All @@ -239,7 +219,7 @@ impl ThriftLower {
result.push(self.mk_item(kind, tags.into()));
}

let name: Ident = format!("{}{}ArgsSend", service.name.as_str(), method_name).into();
let name: Ident = format!("{} {} ArgsSend", service.name.as_str(), method_name).into();
let kind = ir::ItemKind::Message(ir::Message {
name: name.clone(),
fields: f.arguments.iter().map(|a| self.lower_field(a)).collect(),
Expand All @@ -249,7 +229,7 @@ impl ThriftLower {
tags.insert(crate::tags::KeepUnknownFields(false));
result.push(self.mk_item(kind, tags.into()));

let name: Ident = format!("{}{}ArgsRecv", service.name.as_str(), method_name).into();
let name: Ident = format!("{} {} ArgsRecv", service.name.as_str(), method_name).into();
let kind = ir::ItemKind::Message(ir::Message {
name: name.clone(),
fields: f
Expand Down Expand Up @@ -298,9 +278,9 @@ impl ThriftLower {
} else {
Some(Path {
segments: Arc::from([Ident::from(format!(
"{}{}Exception",
service.name.to_upper_camel_case().as_str(),
method.name.to_upper_camel_case(),
"{} {} Exception",
service.name.as_str(),
method.name.as_str(),
))]),
})
},
Expand Down
Loading

0 comments on commit c1dca4b

Please sign in to comment.