Skip to content

Commit

Permalink
Don't generate prop if custom prop exists in base (#4682)
Browse files Browse the repository at this point in the history
This PR address the following bugs:
- If a model contains a spec property that was also added as a custom
model in it's base model, the derived model will not generate the
property.
- if a custom property is added to a base model and it includes the
`CodeGenSerialization` attribute, that property is included in
serialization ctor for the derived model.

fixes: #4629
  • Loading branch information
jorgerangel-msft authored Oct 11, 2024
1 parent 3660144 commit ff1725c
Show file tree
Hide file tree
Showing 11 changed files with 434 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Linq;
using System.Net;
using System.Text.Json;
using Microsoft.CodeAnalysis;
using Microsoft.Generator.CSharp.ClientModel.Snippets;
using Microsoft.Generator.CSharp.Expressions;
using Microsoft.Generator.CSharp.Input;
Expand Down Expand Up @@ -710,6 +711,24 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi
Dictionary<JsonValueKind, List<MethodBodyStatement>> additionalPropsValueKindBodyStatements = [];
var parameters = SerializationConstructor.Signature.Parameters;

// Parse the custom serialization attributes
List<AttributeData> serializationAttributes = _model.CustomCodeView?.GetAttributes()
.Where(a => a.AttributeClass?.Name == CodeGenAttributes.CodeGenSerializationAttributeName)
.ToList() ?? [];
var baseModelProvider = _model.BaseModelProvider;

while (baseModelProvider != null)
{
var customCodeView = baseModelProvider.CustomCodeView;
if (customCodeView != null)
{
serializationAttributes
.AddRange(customCodeView.GetAttributes()
.Where(a => a.AttributeClass?.Name == CodeGenAttributes.CodeGenSerializationAttributeName));
}
baseModelProvider = baseModelProvider.BaseModelProvider;
}

// Create each property's deserialization statement
for (int i = 0; i < parameters.Count; i++)
{
Expand All @@ -731,7 +750,7 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi
var propertySerializationName = wireInfo.SerializedName;
var checkIfJsonPropEqualsName = new IfStatement(jsonProperty.NameEquals(propertySerializationName))
{
DeserializeProperty(property, jsonProperty)
DeserializeProperty(property, jsonProperty, serializationAttributes)
};
propertyDeserializationStatements.Add(checkIfJsonPropEqualsName);
}
Expand All @@ -752,7 +771,7 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi
var rawBinaryData = _rawDataField;
if (rawBinaryData == null)
{
var baseModelProvider = _model.BaseModelProvider;
baseModelProvider = _model.BaseModelProvider;
while (baseModelProvider != null)
{
var field = baseModelProvider.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
Expand Down Expand Up @@ -1033,7 +1052,8 @@ private static SwitchStatement CreateDeserializeAdditionalPropsValueKindCheck(

private MethodBodyStatement[] DeserializeProperty(
PropertyProvider property,
ScopedApi<JsonProperty> jsonProperty)
ScopedApi<JsonProperty> jsonProperty,
IEnumerable<AttributeData> serializationAttributes)
{
var serializationFormat = property.WireInfo?.SerializationFormat ?? SerializationFormat.Default;
var propertyVarReference = property.AsVariableExpression;
Expand All @@ -1043,8 +1063,7 @@ private MethodBodyStatement[] DeserializeProperty(
propertyVarReference.Assign(value).Terminate()
};

foreach (var attribute in _model.CustomCodeView?.GetAttributes()
.Where(a => a.AttributeClass?.Name == CodeGenAttributes.CodeGenSerializationAttributeName) ?? [])
foreach (var attribute in serializationAttributes)
{
if (CodeGenAttributes.TryGetCodeGenSerializationAttributeValue(
attribute,
Expand All @@ -1059,6 +1078,7 @@ private MethodBodyStatement[] DeserializeProperty(
deserializationHook,
jsonProperty,
ByRef(propertyVarReference)).Terminate()];
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,52 @@ public async Task CanChangePropertyName()
var expected = Helpers.GetExpectedFromFile();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

// Validates that if a custom property is added to the base model, and the CodeGenSerialization attribute is used,
// then the derived model includes the custom property in the serialization ctor.
[Test]
public async Task CanSerializeCustomPropertyFromBase()
{
var baseModel = InputFactory.Model(
"baseModel",
usage: InputModelTypeUsage.Input,
properties: [InputFactory.Property("BaseProp", InputPrimitiveType.Int32, isRequired: true)]);
var plugin = await MockHelpers.LoadMockPluginAsync(
inputModels: () => [
InputFactory.Model(
"mockInputModel",
// use Input so that we generate a public ctor
usage: InputModelTypeUsage.Input,
properties:
[
InputFactory.Property("OtherProp", InputPrimitiveType.Int32, isRequired: true),
],
baseModel: baseModel),
],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var modelTypeProvider = plugin.Object.OutputLibrary.TypeProviders.FirstOrDefault(t => t is ModelProvider && t.Name == "MockInputModel");
Assert.IsNotNull(modelTypeProvider);

var baseModelTypeProvider = (modelTypeProvider as ModelProvider)?.BaseModelProvider;
Assert.IsNotNull(baseModelTypeProvider);
var customCodeView = baseModelTypeProvider!.CustomCodeView;
Assert.IsNotNull(customCodeView);
Assert.IsNull(modelTypeProvider!.CustomCodeView);

Assert.AreEqual(1, baseModelTypeProvider!.Properties.Count);
Assert.AreEqual("BaseProp", baseModelTypeProvider.Properties[0].Name);
Assert.AreEqual(new CSharpType(typeof(int)), baseModelTypeProvider.Properties[0].Type);
Assert.AreEqual(1, customCodeView!.Properties.Count);
Assert.AreEqual("Prop1", customCodeView.Properties[0].Name);

Assert.AreEqual(1, modelTypeProvider.Properties.Count);
Assert.AreEqual("OtherProp", modelTypeProvider.Properties[0].Name);

// the custom property should exist in the full ctor
var fullCtor = modelTypeProvider.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Internal));
Assert.IsNotNull(fullCtor);
Assert.IsTrue(fullCtor!.Signature.Parameters.Any(p => p.Name == "prop1"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,38 @@ public async Task CanCustomizeSerializationMethodForRenamedProperty()
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

// Validates that the custom serialization method is used in the serialization provider
// for the custom property that exists in the base model.
[Test]
public async Task CanCustomizeSerializationMethodForPropertyInBase()
{
var baseModel = InputFactory.Model(
"baseModel",
usage: InputModelTypeUsage.Input,
properties: [InputFactory.Property("Prop1", InputPrimitiveType.Int32, isRequired: true)]);
var plugin = await MockHelpers.LoadMockPluginAsync(
inputModels: () => [
InputFactory.Model(
"mockInputModel",
usage: InputModelTypeUsage.Json,
properties:
[
InputFactory.Property("OtherProp", InputPrimitiveType.Int32, isRequired: true),
],
baseModel: baseModel),
],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var modelProvider = plugin.Object.OutputLibrary.TypeProviders.FirstOrDefault(t => t is ModelProvider);
Assert.IsNotNull(modelProvider);
var serializationProvider = modelProvider!.SerializationProviders.Single(t => t is MrwSerializationTypeDefinition);
Assert.IsNotNull(serializationProvider);

var writer = new TypeProviderWriter(serializationProvider);
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

// Validates that a properties serialization name can be changed using custom code.
[Test]
public async Task CanChangePropertySerializedName()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using Sample;
using System;
using System.Collections.Generic;
using Microsoft.Generator.CSharp.Customization;

namespace Sample.Models;

[CodeGenSerialization(nameof(Prop1), DeserializationValueHook = nameof(DeserializationMethod))]
public partial class BaseModel
{
internal string Prop1 { get; set; }

private static void DeserializationMethod(JsonProperty property, ref string fieldValue)
=> fieldValue = property.Value.GetString();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Text.Json;
using Sample;

namespace Sample.Models
{
/// <summary></summary>
public partial class MockInputModel : global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>
{
internal MockInputModel()
{
}

void global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Write(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
writer.WriteStartObject();
this.JsonModelWriteCore(writer, options);
writer.WriteEndObject();
}

/// <param name="writer"> The JSON writer. </param>
/// <param name="options"> The client options for reading and writing models. </param>
protected override void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
if ((format != "J"))
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support writing '{format}' format.");
}
base.JsonModelWriteCore(writer, options);
writer.WritePropertyName("otherProp"u8);
writer.WriteNumberValue(OtherProp);
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.JsonModelCreateCore(ref reader, options));

/// <param name="reader"> The JSON reader. </param>
/// <param name="options"> The client options for reading and writing models. </param>
protected override global::Sample.Models.BaseModel JsonModelCreateCore(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
if ((format != "J"))
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support reading '{format}' format.");
}
using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.ParseValue(ref reader);
return global::Sample.Models.MockInputModel.DeserializeMockInputModel(document.RootElement, options);
}

internal static global::Sample.Models.MockInputModel DeserializeMockInputModel(global::System.Text.Json.JsonElement element, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
if ((element.ValueKind == global::System.Text.Json.JsonValueKind.Null))
{
return null;
}
int otherProp = default;
int prop1 = default;
global::System.Collections.Generic.IDictionary<string, global::System.BinaryData> additionalBinaryDataProperties = new global::Sample.ChangeTrackingDictionary<string, global::System.BinaryData>();
foreach (var prop in element.EnumerateObject())
{
if (prop.NameEquals("otherProp"u8))
{
otherProp = prop.Value.GetInt32();
continue;
}
if (prop.NameEquals("prop1"u8))
{
DeserializationMethod(prop, ref prop1);
continue;
}
if ((options.Format != "W"))
{
additionalBinaryDataProperties.Add(prop.Name, global::System.BinaryData.FromString(prop.Value.GetRawText()));
}
}
return new global::Sample.Models.MockInputModel(otherProp, prop1, additionalBinaryDataProperties);
}

global::System.BinaryData global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Write(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelWriteCore(options);

/// <param name="options"> The client options for reading and writing models. </param>
protected override global::System.BinaryData PersistableModelWriteCore(global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
switch (format)
{
case "J":
return global::System.ClientModel.Primitives.ModelReaderWriter.Write(this, options);
default:
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support writing '{options.Format}' format.");
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.PersistableModelCreateCore(data, options));

/// <param name="data"> The data to parse. </param>
/// <param name="options"> The client options for reading and writing models. </param>
protected override global::Sample.Models.BaseModel PersistableModelCreateCore(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
switch (format)
{
case "J":
using (global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(data))
{
return global::Sample.Models.MockInputModel.DeserializeMockInputModel(document.RootElement, options);
}
default:
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support reading '{options.Format}' format.");
}
}

string global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.GetFormatFromOptions(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => "J";

/// <param name="mockInputModel"> The <see cref="global::Sample.Models.MockInputModel"/> to serialize into <see cref="global::System.ClientModel.BinaryContent"/>. </param>
public static implicit operator BinaryContent(global::Sample.Models.MockInputModel mockInputModel)
{
return global::System.ClientModel.BinaryContent.Create(mockInputModel, global::Sample.ModelSerializationExtensions.WireOptions);
}

/// <param name="result"> The <see cref="global::System.ClientModel.ClientResult"/> to deserialize the <see cref="global::Sample.Models.MockInputModel"/> from. </param>
public static explicit operator MockInputModel(global::System.ClientModel.ClientResult result)
{
using global::System.ClientModel.Primitives.PipelineResponse response = result.GetRawResponse();
using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(response.Content);
return global::Sample.Models.MockInputModel.DeserializeMockInputModel(document.RootElement, global::Sample.ModelSerializationExtensions.WireOptions);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

using Microsoft.Generator.CSharp.Customization;

namespace Sample.Models
{
[CodeGenSerialization(nameof(Prop1), SerializationValueHook = nameof(SerializationMethod), DeserializationValueHook = nameof(DeserializationMethod))]
public partial class BaseModel
{
private void SerializationMethod(Utf8JsonWriter writer, ModelReaderWriterOptions options)
=> writer.WriteObjectValue(Prop1, options);

private static void DeserializationMethod(JsonProperty property, ref string fieldValue)
=> fieldValue = property.Value.GetString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ private IReadOnlyList<ModelProvider> BuildDerivedModels()

return [.. derivedModels];
}
internal override TypeProvider? BaseTypeProvider => BaseModelProvider;

public ModelProvider? BaseModelProvider
=> _baseModelProvider ??= (_baseTypeProvider?.Value is ModelProvider baseModelProvider ? baseModelProvider : null);
Expand Down
Loading

0 comments on commit ff1725c

Please sign in to comment.