Skip to content

Commit

Permalink
[Rust-Axum][Breaking Change] Improve the oneOf model generator (#20336
Browse files Browse the repository at this point in the history
)

* Improve the implementation of oneOf

* Fixed 2.0 schemas; possible freeze present

* Fix generate-samples.sh freezing

* Fix validation of primitive types

* Move oneOf handling to its own method

* Fix formatting and add comments

* Remove allOf based discriminator handling

* Implement a test for v3 oneOf

* Implement oneOf tests for rust axum

* Fix circle CI

* Fix pom path, ensure cargo is present

* Implement untagged test

* Add final and fix double underscore typo
  • Loading branch information
Victoria-Casasampere-BeeTheData authored Jan 8, 2025
1 parent eb96380 commit 3d65786
Show file tree
Hide file tree
Showing 54 changed files with 2,909 additions and 30 deletions.
10 changes: 10 additions & 0 deletions CI/circle_parallel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,22 @@ if [ "$NODE_INDEX" = "1" ]; then

sudo apt-get -y install cpanminus

# install rust
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"

echo "Testing perl"
(cd samples/client/petstore/perl && /bin/bash ./test.bash)

echo "Testing ruby"
(cd samples/client/petstore/ruby && mvn integration-test)
(cd samples/client/petstore/ruby-faraday && mvn integration-test)
(cd samples/client/petstore/ruby-httpx && mvn integration-test)
(cd samples/client/petstore/ruby-autoload && mvn integration-test)

echo "Testing rust"
(cd samples/server/petstore/rust-axum && mvn integration-test)

elif [ "$NODE_INDEX" = "2" ]; then
echo "Running node $NODE_INDEX to test Go"
# install haskell
Expand Down
11 changes: 11 additions & 0 deletions bin/configs/manual/rust-axum-oneof-v3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
generatorName: rust-axum
outputDir: samples/server/petstore/rust-axum/output/rust-axum-oneof
inputSpec: modules/openapi-generator/src/test/resources/3_0/rust-axum/rust-axum-oneof.yaml
templateDir: modules/openapi-generator/src/main/resources/rust-axum
generateAliasAsModel: true
additionalProperties:
hideGenerationTimestamp: "true"
packageName: rust-axum-oneof
globalProperties:
skipFormModel: "false"
enablePostProcessFile: true
5 changes: 5 additions & 0 deletions bin/utils/test_file_list.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@
sha256: 67a9e63e13ebddac21cb236aa015edce30f5d3bd8d6adcf50044cad00d48c45e
- filename: "samples/openapi3/client/petstore/java/jersey2-java8/src/test/java/org/openapitools/client/model/ZebraTest.java"
sha256: 15eeb6d8a9a79d0f1930b861540d9c5780d6c49ea4fdb68269ac3e7ec481e142
# rust axum test files
- filename: "samples/server/petstore/rust-axum/output/rust-axum-oneof/src/tests.rs"
sha256: 3d4198174018cc7fd9d4bcffd950609a5bd306cf03b2fa780516f4e22a566e8c
- filename: "samples/server/petstore/rust-axum/output/openapi-v3/src/tests.rs"
sha256: 356ac684b1fce91b153c63caefc1fe7472ea600ac436a19631e16bc00e986c50
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,8 @@ public String toString() {
sb.append(", items='").append(items).append('\'');
sb.append(", additionalProperties='").append(additionalProperties).append('\'');
sb.append(", isModel='").append(isModel).append('\'');
sb.append(", isNull='").append(isNull);
sb.append(", hasValidation='").append(hasValidation);
sb.append(", isNull='").append(isNull).append('\'');
sb.append(", hasValidation='").append(hasValidation).append('\'');
sb.append(", getAdditionalPropertiesIsAnyType=").append(getAdditionalPropertiesIsAnyType());
sb.append(", getHasDiscriminatorWithNonEmptyMapping=").append(hasDiscriminatorWithNonEmptyMapping);
sb.append(", getIsAnyType=").append(getIsAnyType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ public RustAxumServerCodegen() {
supportingFiles.add(new SupportingFile("header.mustache", "src", "header.rs"));
supportingFiles.add(new SupportingFile("server-mod.mustache", "src/server", "mod.rs"));
supportingFiles.add(new SupportingFile("apis-mod.mustache", apiPackage().replace('.', File.separatorChar), "mod.rs"));
supportingFiles.add(new SupportingFile("README.mustache", "", "README.md")
.doNotOverwrite());
// The file gets overwritten regardless
supportingFiles.add(new SupportingFile("tests.mustache", "src", "tests.rs").doNotOverwrite());
supportingFiles.add(new SupportingFile("README.mustache", "", "README.md").doNotOverwrite());
}

@Override
Expand Down Expand Up @@ -594,8 +595,105 @@ public CodegenOperation fromOperation(String path, String httpMethod, Operation
return op;
}

private void postProcessOneOfModels(List<ModelMap> allModels) {
final HashMap<String, List<String>> oneOfMapDiscriminator = new HashMap<>();

for (ModelMap mo : allModels) {
final CodegenModel cm = mo.getModel();

final CodegenComposedSchemas cs = cm.getComposedSchemas();
if (cs != null) {
final List<CodegenProperty> csOneOf = cs.getOneOf();

if (csOneOf != null) {
for (CodegenProperty model : csOneOf) {
// Generate a valid name for the enum variant.
// Mainly needed for primitive types.
String[] modelParts = model.dataType.replace("<", "Of").replace(">", "").split("::");
model.datatypeWithEnum = camelize(modelParts[modelParts.length - 1]);

// Primitive type is not properly set, this overrides it to guarantee adequate model generation.
if (!model.getDataType().matches(String.format(Locale.ROOT, ".*::%s", model.getDatatypeWithEnum()))) {
model.isPrimitiveType = true;
}
}

cs.setOneOf(csOneOf);
cm.setComposedSchemas(cs);
}
}

if (cm.discriminator != null) {
for (String model : cm.oneOf) {
List<String> discriminators = oneOfMapDiscriminator.getOrDefault(model, new ArrayList<>());
discriminators.add(cm.discriminator.getPropertyName());
oneOfMapDiscriminator.put(model, discriminators);
}
}
}

for (ModelMap mo : allModels) {
final CodegenModel cm = mo.getModel();

for (CodegenProperty var : cm.vars) {
var.isDiscriminator = false;
}

final List<String> discriminatorsForModel = oneOfMapDiscriminator.get(cm.getSchemaName());

if (discriminatorsForModel != null) {
for (String discriminator : discriminatorsForModel) {
boolean hasDiscriminatorDefined = false;

for (CodegenProperty var : cm.vars) {
if (var.baseName.equals(discriminator)) {
var.isDiscriminator = true;
hasDiscriminatorDefined = true;
break;
}
}

// If the discriminator field is not a defined attribute in the variant structure, create it.
if (!hasDiscriminatorDefined) {
CodegenProperty property = new CodegenProperty();

// Static attributes
// Only strings are supported by serde for tag field types, so it's the only one we'll deal with
property.openApiType = "string";
property.complexType = "string";
property.dataType = "String";
property.datatypeWithEnum = "String";
property.baseType = "string";
property.required = true;
property.isPrimitiveType = true;
property.isString = true;
property.isDiscriminator = true;

// Attributes based on the discriminator value
property.baseName = discriminator;
property.name = discriminator;
property.nameInCamelCase = camelize(discriminator);
property.nameInPascalCase = property.nameInCamelCase.substring(0, 1).toUpperCase(Locale.ROOT) + property.nameInCamelCase.substring(1);
property.nameInSnakeCase = underscore(discriminator).toUpperCase(Locale.ROOT);
property.getter = String.format(Locale.ROOT, "get%s", property.nameInPascalCase);
property.setter = String.format(Locale.ROOT, "set%s", property.nameInPascalCase);
property.defaultValueWithParam = String.format(Locale.ROOT, " = data.%s;", property.name);

// Attributes based on the model name
property.defaultValue = String.format(Locale.ROOT, "r#\"%s\"#.to_string()", cm.getSchemaName());
property.jsonSchema = String.format(Locale.ROOT, "{ \"default\":\"%s\"; \"type\":\"string\" }", cm.getSchemaName());

cm.vars.add(property);
}
}
}
}
}

@Override
public OperationsMap postProcessOperationsWithModels(final OperationsMap operationsMap, List<ModelMap> allModels) {
postProcessOneOfModels(allModels);

final OperationMap operations = operationsMap.getOperations();
operations.put("classnamePascalCase", camelize(operations.getClassname()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ pub mod apis;

#[cfg(feature = "server")]
pub(crate) mod header;

#[cfg(test)]
mod tests;
Original file line number Diff line number Diff line change
Expand Up @@ -573,21 +573,70 @@ impl PartialEq for {{{classname}}} {
self.0.get() == other.0.get()
}
}
{{/anyOf.size}}
{{#oneOf.size}}
/// One of:
{{#oneOf}}
/// - {{{.}}}
{{/oneOf}}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct {{{classname}}}(Box<serde_json::value::RawValue>);
{{#discriminator}}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
#[serde(tag = "{{{propertyBaseName}}}")]
{{/discriminator}}
{{^discriminator}}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(untagged)]
{{/discriminator}}
#[allow(non_camel_case_types)]
pub enum {{{classname}}} {
{{#composedSchemas}}
{{#oneOf}}
{{{datatypeWithEnum}}}(Box<{{{dataType}}}>),
{{/oneOf}}
{{/composedSchemas}}
}
impl validator::Validate for {{{classname}}}
{
fn validate(&self) -> std::result::Result<(), validator::ValidationErrors> {
std::result::Result::Ok(())
match self {
{{#composedSchemas}}
{{#oneOf}}
{{#isPrimitiveType}}
Self::{{{datatypeWithEnum}}}(_) => std::result::Result::Ok(()),
{{/isPrimitiveType}}
{{^isPrimitiveType}}
Self::{{{datatypeWithEnum}}}(x) => x.validate(),
{{/isPrimitiveType}}
{{/oneOf}}
{{/composedSchemas}}
}
}
}
{{#discriminator}}
impl serde::Serialize for {{{classname}}} {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer {
match self {
{{#composedSchemas}}
{{#oneOf}}
Self::{{{datatypeWithEnum}}}(x) => x.serialize(serializer),
{{/oneOf}}
{{/composedSchemas}}
}
}
}
{{/discriminator}}
{{#composedSchemas}}
{{#oneOf}}
impl From<{{{dataType}}}> for {{{classname}}} {
fn from(value: {{{dataType}}}) -> Self {
Self::{{{datatypeWithEnum}}}(Box::new(value))
}
}
{{/oneOf}}
{{/composedSchemas}}
/// Converts Query Parameters representation (style=form, explode=false) to a {{{classname}}} value
/// as specified in https://swagger.io/docs/specification/serialization/
Expand All @@ -600,11 +649,6 @@ impl std::str::FromStr for {{{classname}}} {
}
}
impl PartialEq for {{{classname}}} {
fn eq(&self, other: &Self) -> bool {
self.0.get() == other.0.get()
}
}
{{/oneOf.size}}
{{^anyOf.size}}
{{^oneOf.size}}
Expand All @@ -613,11 +657,15 @@ impl PartialEq for {{{classname}}} {
pub struct {{{classname}}} {
{{#vars}}
{{#description}}
/// {{{.}}}
/// {{{.}}}
{{/description}}
{{#isEnum}}
/// Note: inline enums are not fully supported by openapi-generator
/// Note: inline enums are not fully supported by openapi-generator
{{/isEnum}}
{{#isDiscriminator}}
#[serde(default = "{{{classname}}}::_name_for_{{{name}}}")]
#[serde(serialize_with = "{{{classname}}}::_serialize_{{{name}}}")]
{{/isDiscriminator}}
#[serde(rename = "{{{baseName}}}")]
{{#hasValidation}}
#[validate(
Expand Down Expand Up @@ -685,6 +733,25 @@ pub struct {{{classname}}} {
{{/vars}}
}
{{#vars}}
{{#isDiscriminator}}
impl {{{classname}}} {
fn _name_for_{{{name}}}() -> String {
String::from("{{{classname}}}")
}
fn _serialize_{{{name}}}<S>(_: &String, s: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
s.serialize_str(&Self::_name_for_{{{name}}}())
}
}
{{/isDiscriminator}}
{{/vars}}
{{#vars}}
{{#hasValidation}}
{{#pattern}}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#[test]
fn std_test() {
assert!(true);
}

#[tokio::test]
async fn tokio_test() {
assert!(true);
}
Loading

0 comments on commit 3d65786

Please sign in to comment.