Skip to content

Commit

Permalink
feat: deduplicate generated file with case-sensitive names in split m…
Browse files Browse the repository at this point in the history
…ode (#280)

* Deduplicate generated file with case-sensitive names in split mode

* Move `thrift_with_split_case_sensitive` to `thrift_with_split`

* Add a note about macOS

* Use AHashSet instead of HashSet

* Fix lint
  • Loading branch information
missingdays authored Oct 31, 2024
1 parent 9aca81f commit b1078bc
Show file tree
Hide file tree
Showing 12 changed files with 625 additions and 6 deletions.
24 changes: 22 additions & 2 deletions pilota-build/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
sync::Arc,
};

use ahash::AHashMap;
use ahash::{AHashMap, AHashSet};
use dashmap::{mapref::one::RefMut, DashMap};
use faststr::FastStr;
use itertools::Itertools;
Expand Down Expand Up @@ -536,6 +536,8 @@ where
let mod_file_name = format!("{}/mod.rs", base_mod_name);
let mut mod_stream = String::new();

let mut existing_file_names: AHashSet<String> = AHashSet::new();

for def_id in def_ids.iter() {
let mut item_stream = String::new();
let node = this.db.node(def_id.def_id).unwrap();
Expand All @@ -556,7 +558,10 @@ where

let mod_dir = base_dir.join(base_mod_name.clone());

let file_name = format!("{}_{}.rs", name_prefix, node.name());
let simple_name = format!("{}_{}", name_prefix, node.name());
let unique_name = Self::generate_unique_name(&existing_file_names, &simple_name);
existing_file_names.insert(unique_name.to_ascii_lowercase().clone());
let file_name = format!("{}.rs", unique_name);
this.write_item(&mut item_stream, *def_id, dup);

let full_path = mod_dir.join(file_name.clone());
Expand All @@ -582,6 +587,21 @@ where
stream.push_str(format!("include!(\"{}\");\n", mod_file_name).as_str());
}

/**
On Windows and macOS, files names are case-insensitive
To avoid problems when generating files for services with similar names, e.g.
testService and TestService, such names are de-duplicated by adding a number to their nam5e
*/
fn generate_unique_name(existing_names: &AHashSet<String>, simple_name: &String) -> String {

Check warning on line 595 in pilota-build/src/codegen/mod.rs

View workflow job for this annotation

GitHub Actions / clippy

writing `&String` instead of `&str` involves a new object where a slice will do

warning: writing `&String` instead of `&str` involves a new object where a slice will do --> pilota-build/src/codegen/mod.rs:595:77 | 595 | fn generate_unique_name(existing_names: &AHashSet<String>, simple_name: &String) -> String { | ^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#ptr_arg = note: `#[warn(clippy::ptr_arg)]` on by default help: change this to | 595 ~ fn generate_unique_name(existing_names: &AHashSet<String>, simple_name: &str) -> String { 596 | let mut counter = 1; 597 ~ let mut name = simple_name.to_owned(); 598 | while existing_names.contains(name.to_ascii_lowercase().as_str()) { 599 | counter += 1; 600 ~ name = format!("{}_{}", simple_name.to_owned(), counter) |
let mut counter = 1;
let mut name = simple_name.clone();
while existing_names.contains(name.to_ascii_lowercase().as_str()) {
counter += 1;
name = format!("{}_{}", simple_name.clone(), counter)
}
name
}

pub fn write_file(self, ns_name: Symbol, file_name: impl AsRef<Path>) {
let base_dir = file_name.as_ref().parent().unwrap();
let mut stream = String::default();
Expand Down
2 changes: 1 addition & 1 deletion pilota-build/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ fn test_with_builder_workspace<F: FnOnce(&Path, &Path)>(
f: F,
) {
if std::env::var("UPDATE_TEST_DATA").as_deref() == Ok("1") {
_ = fs::remove_dir(&target);
fs::remove_dir_all(&target).unwrap();
fs::create_dir_all(&target).unwrap();
let cargo_toml_path = target.as_ref().join("Cargo.toml");
File::create(cargo_toml_path).unwrap();
Expand Down
4 changes: 4 additions & 0 deletions pilota-build/test_data/thrift_with_split/wrapper_arc.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ struct TEST {
service TestService {
TEST(pilota.rust_wrapper_arc="true") test(1: TEST req(pilota.rust_wrapper_arc="true"));
}

service testService {
TEST(pilota.rust_wrapper_arc="true") test(1: TEST req(pilota.rust_wrapper_arc="true"));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#[derive(Debug, ::pilota::derivative::Derivative)]
#[derivative(Default)]
#[derive(Clone, PartialEq)]
pub enum testServiceTestResultRecv {
#[derivative(Default)]
Ok(Test),
}

impl ::pilota::thrift::Message for testServiceTestResultRecv {
fn encode<T: ::pilota::thrift::TOutputProtocol>(
&self,
__protocol: &mut T,
) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> {
#[allow(unused_imports)]
use ::pilota::thrift::TOutputProtocolExt;
__protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier {
name: "testServiceTestResultRecv",
})?;
match self {
testServiceTestResultRecv::Ok(ref value) => {
__protocol.write_struct_field(0, value, ::pilota::thrift::TType::Struct)?;
}
}
__protocol.write_field_stop()?;
__protocol.write_struct_end()?;
::std::result::Result::Ok(())
}

fn decode<T: ::pilota::thrift::TInputProtocol>(
__protocol: &mut T,
) -> ::std::result::Result<Self, ::pilota::thrift::ThriftException> {
#[allow(unused_imports)]
use ::pilota::{thrift::TLengthProtocolExt, Buf};
let mut ret = None;
__protocol.read_struct_begin()?;
loop {
let field_ident = __protocol.read_field_begin()?;
if field_ident.field_type == ::pilota::thrift::TType::Stop {
__protocol.field_stop_len();
break;
} else {
__protocol.field_begin_len(field_ident.field_type, field_ident.id);
}
match field_ident.id {
Some(0) => {
if ret.is_none() {
let field_ident = ::pilota::thrift::Message::decode(__protocol)?;
__protocol.struct_len(&field_ident);
ret = Some(testServiceTestResultRecv::Ok(field_ident));
} else {
return ::std::result::Result::Err(
::pilota::thrift::new_protocol_exception(
::pilota::thrift::ProtocolExceptionKind::InvalidData,
"received multiple fields for union from remote Message",
),
);
}
}
_ => {
__protocol.skip(field_ident.field_type)?;
}
}
}
__protocol.read_field_end()?;
__protocol.read_struct_end()?;
if let Some(ret) = ret {
::std::result::Result::Ok(ret)
} else {
::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
::pilota::thrift::ProtocolExceptionKind::InvalidData,
"received empty union from remote Message",
))
}
}

fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>(
__protocol: &'a mut T,
) -> ::std::pin::Pin<
::std::boxed::Box<
dyn ::std::future::Future<
Output = ::std::result::Result<Self, ::pilota::thrift::ThriftException>,
> + Send
+ 'a,
>,
> {
::std::boxed::Box::pin(async move {
let mut ret = None;
__protocol.read_struct_begin().await?;
loop {
let field_ident = __protocol.read_field_begin().await?;
if field_ident.field_type == ::pilota::thrift::TType::Stop {
break;
} else {
}
match field_ident.id {
Some(0) => {
if ret.is_none() {
let field_ident =
<Test as ::pilota::thrift::Message>::decode_async(__protocol)
.await?;

ret = Some(testServiceTestResultRecv::Ok(field_ident));
} else {
return ::std::result::Result::Err(
::pilota::thrift::new_protocol_exception(
::pilota::thrift::ProtocolExceptionKind::InvalidData,
"received multiple fields for union from remote Message",
),
);
}
}
_ => {
__protocol.skip(field_ident.field_type).await?;
}
}
}
__protocol.read_field_end().await?;
__protocol.read_struct_end().await?;
if let Some(ret) = ret {
::std::result::Result::Ok(ret)
} else {
::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
::pilota::thrift::ProtocolExceptionKind::InvalidData,
"received empty union from remote Message",
))
}
})
}

fn size<T: ::pilota::thrift::TLengthProtocol>(&self, __protocol: &mut T) -> usize {
#[allow(unused_imports)]
use ::pilota::thrift::TLengthProtocolExt;
__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {
name: "testServiceTestResultRecv",
}) + match self {
testServiceTestResultRecv::Ok(ref value) => __protocol.struct_field_len(Some(0), value),
} + __protocol.field_stop_len()
+ __protocol.struct_end_len()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#[derive(Debug, ::pilota::derivative::Derivative)]
#[derivative(Default)]
#[derive(Clone, PartialEq)]
pub enum testServiceTestResultSend {
#[derivative(Default)]
Ok(::std::sync::Arc<Test>),
}

impl ::pilota::thrift::Message for testServiceTestResultSend {
fn encode<T: ::pilota::thrift::TOutputProtocol>(
&self,
__protocol: &mut T,
) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> {
#[allow(unused_imports)]
use ::pilota::thrift::TOutputProtocolExt;
__protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier {
name: "testServiceTestResultSend",
})?;
match self {
testServiceTestResultSend::Ok(ref value) => {
__protocol.write_struct_field(0, value, ::pilota::thrift::TType::Struct)?;
}
}
__protocol.write_field_stop()?;
__protocol.write_struct_end()?;
::std::result::Result::Ok(())
}

fn decode<T: ::pilota::thrift::TInputProtocol>(
__protocol: &mut T,
) -> ::std::result::Result<Self, ::pilota::thrift::ThriftException> {
#[allow(unused_imports)]
use ::pilota::{thrift::TLengthProtocolExt, Buf};
let mut ret = None;
__protocol.read_struct_begin()?;
loop {
let field_ident = __protocol.read_field_begin()?;
if field_ident.field_type == ::pilota::thrift::TType::Stop {
__protocol.field_stop_len();
break;
} else {
__protocol.field_begin_len(field_ident.field_type, field_ident.id);
}
match field_ident.id {
Some(0) => {
if ret.is_none() {
let field_ident =
::std::sync::Arc::new(::pilota::thrift::Message::decode(__protocol)?);
__protocol.struct_len(&field_ident);
ret = Some(testServiceTestResultSend::Ok(field_ident));
} else {
return ::std::result::Result::Err(
::pilota::thrift::new_protocol_exception(
::pilota::thrift::ProtocolExceptionKind::InvalidData,
"received multiple fields for union from remote Message",
),
);
}
}
_ => {
__protocol.skip(field_ident.field_type)?;
}
}
}
__protocol.read_field_end()?;
__protocol.read_struct_end()?;
if let Some(ret) = ret {
::std::result::Result::Ok(ret)
} else {
::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
::pilota::thrift::ProtocolExceptionKind::InvalidData,
"received empty union from remote Message",
))
}
}

fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>(
__protocol: &'a mut T,
) -> ::std::pin::Pin<
::std::boxed::Box<
dyn ::std::future::Future<
Output = ::std::result::Result<Self, ::pilota::thrift::ThriftException>,
> + Send
+ 'a,
>,
> {
::std::boxed::Box::pin(async move {
let mut ret = None;
__protocol.read_struct_begin().await?;
loop {
let field_ident = __protocol.read_field_begin().await?;
if field_ident.field_type == ::pilota::thrift::TType::Stop {
break;
} else {
}
match field_ident.id {
Some(0) => {
if ret.is_none() {
let field_ident = ::std::sync::Arc::new(
<Test as ::pilota::thrift::Message>::decode_async(__protocol)
.await?,
);

ret = Some(testServiceTestResultSend::Ok(field_ident));
} else {
return ::std::result::Result::Err(
::pilota::thrift::new_protocol_exception(
::pilota::thrift::ProtocolExceptionKind::InvalidData,
"received multiple fields for union from remote Message",
),
);
}
}
_ => {
__protocol.skip(field_ident.field_type).await?;
}
}
}
__protocol.read_field_end().await?;
__protocol.read_struct_end().await?;
if let Some(ret) = ret {
::std::result::Result::Ok(ret)
} else {
::std::result::Result::Err(::pilota::thrift::new_protocol_exception(
::pilota::thrift::ProtocolExceptionKind::InvalidData,
"received empty union from remote Message",
))
}
})
}

fn size<T: ::pilota::thrift::TLengthProtocol>(&self, __protocol: &mut T) -> usize {
#[allow(unused_imports)]
use ::pilota::thrift::TLengthProtocolExt;
__protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier {
name: "testServiceTestResultSend",
}) + match self {
testServiceTestResultSend::Ok(ref value) => __protocol.struct_field_len(Some(0), value),
} + __protocol.field_stop_len()
+ __protocol.struct_end_len()
}
}
Loading

0 comments on commit b1078bc

Please sign in to comment.