Skip to content

Commit 22d3801

Browse files
authored
Implement InheritableSystemObjectModelProvider to replace models with existing Azure types (Azure#49653)
1 parent f64182c commit 22d3801

30 files changed

+354
-2133
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using Azure.Generator.Management.Providers;
5+
using Azure.Generator.Mgmt.Primitives;
6+
using Microsoft.TypeSpec.Generator.ClientModel;
7+
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
8+
using Microsoft.TypeSpec.Generator.Expressions;
9+
using Microsoft.TypeSpec.Generator.Input;
10+
using Microsoft.TypeSpec.Generator.Primitives;
11+
using Microsoft.TypeSpec.Generator.Providers;
12+
using Microsoft.TypeSpec.Generator.Snippets;
13+
using Microsoft.TypeSpec.Generator.Statements;
14+
using System;
15+
using System.Collections.Generic;
16+
using System.Diagnostics.CodeAnalysis;
17+
using System.Linq;
18+
19+
namespace Azure.Generator.Management
20+
{
21+
internal class InheritableSystemObjectModelVisitor : ScmLibraryVisitor
22+
{
23+
protected override ModelProvider? PreVisitModel(InputModelType model, ModelProvider? type)
24+
{
25+
if (type is InheritableSystemObjectModelProvider systemType)
26+
{
27+
UpdateNamespace(systemType);
28+
}
29+
30+
if (type is not InheritableSystemObjectModelProvider && type?.BaseModelProvider is InheritableSystemObjectModelProvider baseSystemType)
31+
{
32+
Update(baseSystemType, type);
33+
}
34+
return type;
35+
}
36+
37+
protected override TypeProvider? VisitType(TypeProvider type)
38+
{
39+
if (type is ModelFactoryProvider modelFactory)
40+
{
41+
UpdateModelFactory(modelFactory);
42+
}
43+
44+
if (type is InheritableSystemObjectModelProvider systemType)
45+
{
46+
UpdateNamespace(systemType);
47+
}
48+
49+
if (type is ModelProvider model && model is not InheritableSystemObjectModelProvider && model.BaseModelProvider is InheritableSystemObjectModelProvider baseSystemType)
50+
{
51+
Update(baseSystemType, model);
52+
}
53+
return type;
54+
}
55+
56+
private static void UpdateModelFactory(ModelFactoryProvider modelFactory)
57+
{
58+
var methods = new List<MethodProvider>();
59+
foreach (var method in modelFactory.Methods)
60+
{
61+
var returnType = method.Signature.ReturnType;
62+
if (returnType is not null && KnownManagementTypes.IsKnownManagementType(returnType))
63+
{
64+
continue;
65+
}
66+
methods.Add(method);
67+
}
68+
modelFactory.Update(methods: methods);
69+
}
70+
71+
private static void UpdateNamespace(InheritableSystemObjectModelProvider systemType)
72+
{
73+
// This is needed because we updated the namespace with NamespaceVisitor in Azure generator earlier
74+
systemType.Type.Update(@namespace: systemType._type.Namespace);
75+
}
76+
77+
private HashSet<ModelProvider> _updated = new();
78+
private void Update(InheritableSystemObjectModelProvider baseSystemType, ModelProvider model)
79+
{
80+
// Add cache to avoid duplicated update of PreVisitModel and VisitType
81+
if (_updated.Contains(model))
82+
{
83+
return;
84+
}
85+
86+
var rawDataField = CreateRawDataField(model);
87+
UpdateSerialization(model);
88+
89+
UpdateFullConstructor(model, rawDataField);
90+
91+
var baseSystemPropertyNames = EnumerateBaseModelProperties(baseSystemType);
92+
var properties = model.Properties.Where(prop => !baseSystemPropertyNames.Contains(prop.Name));
93+
94+
model.Update(properties: properties, fields: [.. model.Fields, rawDataField]);
95+
96+
_updated.Add(model);
97+
}
98+
99+
private static readonly HashSet<string> _methodNamesToUpdate = new(){ "JsonModelCreateCore", "PersistableModelCreateCore", "PersistableModelWriteCore" };
100+
private static void UpdateSerialization(ModelProvider model)
101+
{
102+
var serializationProvider = model.SerializationProviders;
103+
foreach (var provider in model.SerializationProviders)
104+
{
105+
if (provider is MrwSerializationTypeDefinition serializationTypeDefinition)
106+
{
107+
foreach (var method in serializationTypeDefinition.Methods.Where(m => _methodNamesToUpdate.Contains(m.Signature.Name)))
108+
{
109+
var modifiers = method.Signature.Modifiers | MethodSignatureModifiers.Virtual;
110+
modifiers &= ~MethodSignatureModifiers.Override;
111+
method.Signature.Update(modifiers: modifiers);
112+
}
113+
}
114+
}
115+
}
116+
117+
private static void UpdateFullConstructor(ModelProvider model, FieldProvider rawDataField)
118+
{
119+
var signature = model.FullConstructor.Signature;
120+
var initializer = signature.Initializer;
121+
if (initializer is not null)
122+
{
123+
var updatedInitializer = new ConstructorInitializer(initializer.IsBase, initializer.Arguments.Where(arg => arg is VariableExpression variable && variable.Declaration.RequestedName != RawDataParameterName).ToArray());
124+
var updatedSignature = new ConstructorSignature(signature.Type, signature.Description, signature.Modifiers, signature.Parameters, signature.Attributes, updatedInitializer);
125+
model.FullConstructor.Update(signature: updatedSignature);
126+
}
127+
128+
var body = model.FullConstructor.BodyStatements;
129+
var statement = rawDataField.Assign(model.FullConstructor.Signature.Parameters.Single(f => f.Name.Equals(RawDataParameterName))).Terminate();
130+
MethodBodyStatement[] updatedBody = [statement, .. body!.Flatten()];
131+
model.FullConstructor.Update(bodyStatements: updatedBody);
132+
}
133+
134+
private const string RawDataParameterName = "additionalBinaryDataProperties";
135+
private static FieldProvider CreateRawDataField(ModelProvider model)
136+
{
137+
var modifiers = FieldModifiers.Private;
138+
if (!model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Sealed) && !model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Struct))
139+
{
140+
modifiers |= FieldModifiers.Protected;
141+
}
142+
modifiers |= FieldModifiers.ReadOnly;
143+
144+
var rawDataField = new FieldProvider(
145+
modifiers: modifiers,
146+
type: typeof(IDictionary<string, BinaryData>),
147+
description: FromString("Keeps track of any properties unknown to the library."),
148+
name: $"_{RawDataParameterName}",
149+
enclosingType: model);
150+
151+
return rawDataField;
152+
}
153+
154+
[return: NotNullIfNotNull(nameof(s))]
155+
private static FormattableString? FromString(string? s) =>
156+
s is null ? null : s.Length == 0 ? (FormattableString)$"" : $"{s}";
157+
158+
private static HashSet<string> EnumerateBaseModelProperties(InheritableSystemObjectModelProvider baseSystemModel)
159+
{
160+
var baseSystemPropertyNames = new HashSet<string>();
161+
ModelProvider? baseModel = baseSystemModel;
162+
while (baseModel != null)
163+
{
164+
foreach (var property in baseModel.Properties)
165+
{
166+
baseSystemPropertyNames.Add(property.Name);
167+
}
168+
baseModel = baseModel.BaseModelProvider;
169+
}
170+
return baseSystemPropertyNames;
171+
}
172+
}
173+
}

eng/packages/http-client-csharp-mgmt/generator/Azure.Generator.Mgmt/src/ManagementClientGenerator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ protected override void Configure()
5151
AddMetadataReference(MetadataReference.CreateFromFile(typeof(ArmClient).Assembly.Location));
5252
AddVisitor(new RestClientVisitor());
5353
AddVisitor(new ResourceVisitor());
54+
AddVisitor(new InheritableSystemObjectModelVisitor());
5455
}
5556
}
5657
}

eng/packages/http-client-csharp-mgmt/generator/Azure.Generator.Mgmt/src/ManagementOutputLibrary.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ private static void BuildResourceCore(List<ResourceClientProvider> resources, Li
5757
protected override TypeProvider[] BuildTypeProviders()
5858
{
5959
var (resources, collections) = BuildResources();
60-
return [.. base.BuildTypeProviders(), ArmOperation, GenericArmOperation, .. resources, .. collections, .. resources.Select(r => r.Source)];
60+
return [.. base.BuildTypeProviders().Where(t => t is not InheritableSystemObjectModelProvider), ArmOperation, GenericArmOperation, .. resources, .. collections, .. resources.Select(r => r.Source)];
6161
}
6262
}
6363
}

eng/packages/http-client-csharp-mgmt/generator/Azure.Generator.Mgmt/src/ManagementTypeFactory.cs

+55
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,21 @@
33

44
using Azure.Generator.Management.InputTransformation;
55
using Azure.Generator.Management.Providers.Abstraction;
6+
using Azure.Generator.Mgmt.Primitives;
67
using Microsoft.TypeSpec.Generator;
78
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
9+
using Microsoft.TypeSpec.Generator.Expressions;
810
using Microsoft.TypeSpec.Generator.Input;
11+
using Microsoft.TypeSpec.Generator.Primitives;
12+
using Microsoft.TypeSpec.Generator.Providers;
13+
using Microsoft.TypeSpec.Generator.Snippets;
14+
using Microsoft.TypeSpec.Generator.Statements;
15+
using System.ClientModel.Primitives;
16+
using System;
917
using System.Collections.Generic;
18+
using System.Text.Json;
19+
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;
20+
using Azure.Generator.Management.Providers;
1021

1122
namespace Azure.Generator.Management
1223
{
@@ -31,5 +42,49 @@ public class ManagementTypeFactory : AzureTypeFactory
3142
var transformedClient = InputClientTransformer.TransformInputClient(inputClient);
3243
return transformedClient is null ? null : base.CreateClientCore(transformedClient);
3344
}
45+
46+
// TODO: right now, we are missing the connection between CsharpType and TypeProvider, that's why we need both CreateCSharpTypeCore and CreateModelCore
47+
// Once we have the mapping between CsharpType and TypeProvider, we should only keep CreateModelCore
48+
/// <inheritdoc/>
49+
protected override CSharpType? CreateCSharpTypeCore(InputType inputType)
50+
{
51+
if (inputType is InputModelType model && KnownManagementTypes.TryGetManagementType(model.CrossLanguageDefinitionId, out var replacedType))
52+
{
53+
return replacedType;
54+
}
55+
return base.CreateCSharpTypeCore(inputType);
56+
}
57+
58+
/// <inheritdoc/>
59+
protected override ModelProvider? CreateModelCore(InputModelType model)
60+
{
61+
if (KnownManagementTypes.TryGetManagementType(model.CrossLanguageDefinitionId, out var replacedType))
62+
{
63+
return new InheritableSystemObjectModelProvider(replacedType.FrameworkType, model);
64+
}
65+
return base.CreateModelCore(model);
66+
}
67+
68+
/// <inheritdoc/>
69+
public override MethodBodyStatement SerializeJsonValue(Type valueType, ValueExpression value, ScopedApi<Utf8JsonWriter> utf8JsonWriter, ScopedApi<ModelReaderWriterOptions> mrwOptionsParameter, SerializationFormat serializationFormat)
70+
{
71+
if (KnownManagementTypes.IsKnownManagementType(valueType))
72+
{
73+
return Static(typeof(JsonSerializer)).Invoke(nameof(JsonSerializer.Serialize), [value]).Terminate();
74+
}
75+
return base.SerializeJsonValue(valueType, value, utf8JsonWriter, mrwOptionsParameter, serializationFormat);
76+
}
77+
78+
/// <inheritdoc/>
79+
#pragma warning disable AZC0014 // Avoid using banned types in public API
80+
public override ValueExpression DeserializeJsonValue(Type valueType, ScopedApi<JsonElement> element, SerializationFormat format)
81+
#pragma warning restore AZC0014 // Avoid using banned types in public API
82+
{
83+
if (KnownManagementTypes.IsKnownManagementType(valueType))
84+
{
85+
return Static(typeof(JsonSerializer)).Invoke(nameof(JsonSerializer.Deserialize), [element], [valueType], false);
86+
}
87+
return base.DeserializeJsonValue(valueType, element, format);
88+
}
3489
}
3590
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using Azure.ResourceManager.Models;
5+
using Azure.ResourceManager.Resources.Models;
6+
using Microsoft.TypeSpec.Generator.Primitives;
7+
using System;
8+
using System.Collections.Generic;
9+
using System.Diagnostics.CodeAnalysis;
10+
using System.Linq;
11+
12+
namespace Azure.Generator.Mgmt.Primitives
13+
{
14+
internal class KnownManagementTypes
15+
{
16+
private static readonly IReadOnlyDictionary<string, CSharpType> _idToTypeMap = new Dictionary<string, CSharpType>()
17+
{
18+
["Azure.ResourceManager.CommonTypes.ExtendedLocation"] = typeof(ExtendedLocation),
19+
["Azure.ResourceManager.CommonTypes.ExtendedLocationType"] = typeof(ExtendedLocationType),
20+
["Azure.ResourceManager.CommonTypes.ManagedServiceIdentity"] = typeof(ManagedServiceIdentity),
21+
["Azure.ResourceManager.CommonTypes.ManagedServiceIdentityType"] = typeof(ManagedServiceIdentityType),
22+
["Azure.ResourceManager.CommonTypes.OperationStatusResult"] = typeof(OperationStatusResult),
23+
["Azure.ResourceManager.CommonTypes.ProxyResource"] = typeof(ResourceData),
24+
["Azure.ResourceManager.CommonTypes.Resource"] = typeof(ResourceData),
25+
["Azure.ResourceManager.CommonTypes.SystemData"] = typeof(SystemData),
26+
["Azure.ResourceManager.CommonTypes.TrackedResource"] = typeof(TrackedResourceData),
27+
["Azure.ResourceManager.CommonTypes.UserAssignedIdentity"] = typeof(UserAssignedIdentity),
28+
};
29+
30+
private static readonly HashSet<CSharpType> _knownTypes = _idToTypeMap.Values.ToHashSet(new CSharpFullNameComparer());
31+
32+
public static bool IsKnownManagementType(CSharpType type) => _knownTypes.Contains(type);
33+
34+
public static bool TryGetManagementType(string id, [MaybeNullWhen(false)] out CSharpType type) => _idToTypeMap.TryGetValue(id, out type);
35+
36+
// The comparison could be CSharpType from Azure.ResourceManager, which is a framework type
37+
// and CSharpType from InheritableSystemObjectModelProvider, which is not a framework type, they should still be equal if namespace and name match
38+
// Then, the default Equals of CSharpType doesn't apply here
39+
private class CSharpFullNameComparer : IEqualityComparer<CSharpType>
40+
{
41+
public bool Equals(CSharpType? x, CSharpType? y)
42+
{
43+
if (x is null)
44+
{
45+
if (y is null)
46+
{
47+
return true;
48+
}
49+
return false;
50+
}
51+
else
52+
{
53+
return x.AreNamesEqual(y);
54+
}
55+
}
56+
57+
public int GetHashCode([DisallowNull] CSharpType obj)
58+
{
59+
return HashCode.Combine(obj.Namespace, obj.Name);
60+
}
61+
}
62+
}
63+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using Microsoft.TypeSpec.Generator.Input;
5+
using Microsoft.TypeSpec.Generator.Providers;
6+
using System;
7+
8+
namespace Azure.Generator.Management.Providers
9+
{
10+
internal class InheritableSystemObjectModelProvider : ModelProvider
11+
{
12+
internal readonly Type _type;
13+
14+
public InheritableSystemObjectModelProvider(Type type, InputModelType inputModel) : base(inputModel)
15+
{
16+
_type = type;
17+
}
18+
19+
protected override string BuildName() => _type.Name;
20+
21+
protected override string BuildRelativeFilePath()
22+
=> throw new InvalidOperationException("This type should not be writing in generation");
23+
24+
protected override string BuildNamespace() => _type.Namespace!;
25+
}
26+
}

0 commit comments

Comments
 (0)