|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +using Azure.Generator.Management.Providers; |
| 5 | +using Azure.Generator.Mgmt.Primitives; |
| 6 | +using Microsoft.TypeSpec.Generator.ClientModel; |
| 7 | +using Microsoft.TypeSpec.Generator.ClientModel.Providers; |
| 8 | +using Microsoft.TypeSpec.Generator.Expressions; |
| 9 | +using Microsoft.TypeSpec.Generator.Input; |
| 10 | +using Microsoft.TypeSpec.Generator.Primitives; |
| 11 | +using Microsoft.TypeSpec.Generator.Providers; |
| 12 | +using Microsoft.TypeSpec.Generator.Snippets; |
| 13 | +using Microsoft.TypeSpec.Generator.Statements; |
| 14 | +using System; |
| 15 | +using System.Collections.Generic; |
| 16 | +using System.Diagnostics.CodeAnalysis; |
| 17 | +using System.Linq; |
| 18 | + |
| 19 | +namespace Azure.Generator.Management |
| 20 | +{ |
| 21 | + internal class InheritableSystemObjectModelVisitor : ScmLibraryVisitor |
| 22 | + { |
| 23 | + protected override ModelProvider? PreVisitModel(InputModelType model, ModelProvider? type) |
| 24 | + { |
| 25 | + if (type is InheritableSystemObjectModelProvider systemType) |
| 26 | + { |
| 27 | + UpdateNamespace(systemType); |
| 28 | + } |
| 29 | + |
| 30 | + if (type is not InheritableSystemObjectModelProvider && type?.BaseModelProvider is InheritableSystemObjectModelProvider baseSystemType) |
| 31 | + { |
| 32 | + Update(baseSystemType, type); |
| 33 | + } |
| 34 | + return type; |
| 35 | + } |
| 36 | + |
| 37 | + protected override TypeProvider? VisitType(TypeProvider type) |
| 38 | + { |
| 39 | + if (type is ModelFactoryProvider modelFactory) |
| 40 | + { |
| 41 | + UpdateModelFactory(modelFactory); |
| 42 | + } |
| 43 | + |
| 44 | + if (type is InheritableSystemObjectModelProvider systemType) |
| 45 | + { |
| 46 | + UpdateNamespace(systemType); |
| 47 | + } |
| 48 | + |
| 49 | + if (type is ModelProvider model && model is not InheritableSystemObjectModelProvider && model.BaseModelProvider is InheritableSystemObjectModelProvider baseSystemType) |
| 50 | + { |
| 51 | + Update(baseSystemType, model); |
| 52 | + } |
| 53 | + return type; |
| 54 | + } |
| 55 | + |
| 56 | + private static void UpdateModelFactory(ModelFactoryProvider modelFactory) |
| 57 | + { |
| 58 | + var methods = new List<MethodProvider>(); |
| 59 | + foreach (var method in modelFactory.Methods) |
| 60 | + { |
| 61 | + var returnType = method.Signature.ReturnType; |
| 62 | + if (returnType is not null && KnownManagementTypes.IsKnownManagementType(returnType)) |
| 63 | + { |
| 64 | + continue; |
| 65 | + } |
| 66 | + methods.Add(method); |
| 67 | + } |
| 68 | + modelFactory.Update(methods: methods); |
| 69 | + } |
| 70 | + |
| 71 | + private static void UpdateNamespace(InheritableSystemObjectModelProvider systemType) |
| 72 | + { |
| 73 | + // This is needed because we updated the namespace with NamespaceVisitor in Azure generator earlier |
| 74 | + systemType.Type.Update(@namespace: systemType._type.Namespace); |
| 75 | + } |
| 76 | + |
| 77 | + private HashSet<ModelProvider> _updated = new(); |
| 78 | + private void Update(InheritableSystemObjectModelProvider baseSystemType, ModelProvider model) |
| 79 | + { |
| 80 | + // Add cache to avoid duplicated update of PreVisitModel and VisitType |
| 81 | + if (_updated.Contains(model)) |
| 82 | + { |
| 83 | + return; |
| 84 | + } |
| 85 | + |
| 86 | + var rawDataField = CreateRawDataField(model); |
| 87 | + UpdateSerialization(model); |
| 88 | + |
| 89 | + UpdateFullConstructor(model, rawDataField); |
| 90 | + |
| 91 | + var baseSystemPropertyNames = EnumerateBaseModelProperties(baseSystemType); |
| 92 | + var properties = model.Properties.Where(prop => !baseSystemPropertyNames.Contains(prop.Name)); |
| 93 | + |
| 94 | + model.Update(properties: properties, fields: [.. model.Fields, rawDataField]); |
| 95 | + |
| 96 | + _updated.Add(model); |
| 97 | + } |
| 98 | + |
| 99 | + private static readonly HashSet<string> _methodNamesToUpdate = new(){ "JsonModelCreateCore", "PersistableModelCreateCore", "PersistableModelWriteCore" }; |
| 100 | + private static void UpdateSerialization(ModelProvider model) |
| 101 | + { |
| 102 | + var serializationProvider = model.SerializationProviders; |
| 103 | + foreach (var provider in model.SerializationProviders) |
| 104 | + { |
| 105 | + if (provider is MrwSerializationTypeDefinition serializationTypeDefinition) |
| 106 | + { |
| 107 | + foreach (var method in serializationTypeDefinition.Methods.Where(m => _methodNamesToUpdate.Contains(m.Signature.Name))) |
| 108 | + { |
| 109 | + var modifiers = method.Signature.Modifiers | MethodSignatureModifiers.Virtual; |
| 110 | + modifiers &= ~MethodSignatureModifiers.Override; |
| 111 | + method.Signature.Update(modifiers: modifiers); |
| 112 | + } |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + private static void UpdateFullConstructor(ModelProvider model, FieldProvider rawDataField) |
| 118 | + { |
| 119 | + var signature = model.FullConstructor.Signature; |
| 120 | + var initializer = signature.Initializer; |
| 121 | + if (initializer is not null) |
| 122 | + { |
| 123 | + var updatedInitializer = new ConstructorInitializer(initializer.IsBase, initializer.Arguments.Where(arg => arg is VariableExpression variable && variable.Declaration.RequestedName != RawDataParameterName).ToArray()); |
| 124 | + var updatedSignature = new ConstructorSignature(signature.Type, signature.Description, signature.Modifiers, signature.Parameters, signature.Attributes, updatedInitializer); |
| 125 | + model.FullConstructor.Update(signature: updatedSignature); |
| 126 | + } |
| 127 | + |
| 128 | + var body = model.FullConstructor.BodyStatements; |
| 129 | + var statement = rawDataField.Assign(model.FullConstructor.Signature.Parameters.Single(f => f.Name.Equals(RawDataParameterName))).Terminate(); |
| 130 | + MethodBodyStatement[] updatedBody = [statement, .. body!.Flatten()]; |
| 131 | + model.FullConstructor.Update(bodyStatements: updatedBody); |
| 132 | + } |
| 133 | + |
| 134 | + private const string RawDataParameterName = "additionalBinaryDataProperties"; |
| 135 | + private static FieldProvider CreateRawDataField(ModelProvider model) |
| 136 | + { |
| 137 | + var modifiers = FieldModifiers.Private; |
| 138 | + if (!model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Sealed) && !model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Struct)) |
| 139 | + { |
| 140 | + modifiers |= FieldModifiers.Protected; |
| 141 | + } |
| 142 | + modifiers |= FieldModifiers.ReadOnly; |
| 143 | + |
| 144 | + var rawDataField = new FieldProvider( |
| 145 | + modifiers: modifiers, |
| 146 | + type: typeof(IDictionary<string, BinaryData>), |
| 147 | + description: FromString("Keeps track of any properties unknown to the library."), |
| 148 | + name: $"_{RawDataParameterName}", |
| 149 | + enclosingType: model); |
| 150 | + |
| 151 | + return rawDataField; |
| 152 | + } |
| 153 | + |
| 154 | + [return: NotNullIfNotNull(nameof(s))] |
| 155 | + private static FormattableString? FromString(string? s) => |
| 156 | + s is null ? null : s.Length == 0 ? (FormattableString)$"" : $"{s}"; |
| 157 | + |
| 158 | + private static HashSet<string> EnumerateBaseModelProperties(InheritableSystemObjectModelProvider baseSystemModel) |
| 159 | + { |
| 160 | + var baseSystemPropertyNames = new HashSet<string>(); |
| 161 | + ModelProvider? baseModel = baseSystemModel; |
| 162 | + while (baseModel != null) |
| 163 | + { |
| 164 | + foreach (var property in baseModel.Properties) |
| 165 | + { |
| 166 | + baseSystemPropertyNames.Add(property.Name); |
| 167 | + } |
| 168 | + baseModel = baseModel.BaseModelProvider; |
| 169 | + } |
| 170 | + return baseSystemPropertyNames; |
| 171 | + } |
| 172 | + } |
| 173 | +} |
0 commit comments