|
1 | 1 | using System;
|
2 | 2 | using System.Collections.Generic;
|
| 3 | +using System.Linq; |
| 4 | +using System.Reflection; |
| 5 | +using System.Web.Http.Description; |
3 | 6 | using Swashbuckle.Swagger;
|
4 | 7 | using VirtoCommerce.Platform.Core.Common;
|
5 | 8 |
|
6 | 9 | namespace VirtoCommerce.Platform.Web.Swagger
|
7 | 10 | {
|
8 | 11 | public class PolymorphismDocumentFilter : IDocumentFilter
|
9 | 12 | {
|
10 |
| - private readonly string _moduleName; |
11 |
| - private readonly IPolymorphismRegistrar _polymorphismRegistrar; |
12 | 13 | private readonly bool _useFullTypeNames;
|
13 | 14 |
|
14 |
| - public PolymorphismDocumentFilter(IPolymorphismRegistrar polymorphismRegistrar, string moduleName, bool useFullTypeNames) |
| 15 | + public PolymorphismDocumentFilter(bool useFullTypeNames) |
15 | 16 | {
|
16 |
| - _polymorphismRegistrar = polymorphismRegistrar ?? throw new ArgumentNullException(nameof(polymorphismRegistrar)); |
17 |
| - _moduleName = moduleName; |
18 | 17 | _useFullTypeNames = useFullTypeNames;
|
19 | 18 | }
|
20 | 19 |
|
21 | 20 | [CLSCompliant(false)]
|
22 |
| - public void Apply(SwaggerDocument swaggerDoc, SchemaRegistry schemaRegistry, System.Web.Http.Description.IApiExplorer apiExplorer) |
| 21 | + public void Apply(SwaggerDocument swaggerDoc, SchemaRegistry schemaRegistry, IApiExplorer apiExplorer) |
23 | 22 | {
|
24 |
| - foreach (var polymorphicBaseTypeInfo in _polymorphismRegistrar.GetPolymorphicBaseTypes(_moduleName)) |
25 |
| - { |
26 |
| - RegisterSubClasses(schemaRegistry, polymorphicBaseTypeInfo); |
27 |
| - } |
| 23 | + RegisterSubClasses(schemaRegistry, apiExplorer); |
28 | 24 | }
|
29 | 25 |
|
30 |
| - private void RegisterSubClasses(SchemaRegistry schemaRegistry, IPolymorphicBaseTypeInfo polymorphicBaseTypeInfo) |
| 26 | + private void RegisterSubClasses(SchemaRegistry schemaRegistry, IApiExplorer apiExplorer) |
31 | 27 | {
|
32 |
| - var abstractType = polymorphicBaseTypeInfo.Type; |
33 |
| - var discriminatorName = polymorphicBaseTypeInfo.DiscriminatorName; |
| 28 | + foreach (var type in apiExplorer.ApiDescriptions.Select(x => x.ResponseType()).Where(x => x != null).Distinct()) |
| 29 | + { |
| 30 | + var schema = GetTypeSchema(schemaRegistry, type, false); |
| 31 | + |
| 32 | + // IApiExplorer contains types from all api controllers, so some of them could be not presented in specific module schemaRegistry. |
| 33 | + if (schema != null) |
| 34 | + { |
| 35 | + // Find if type is registered in AbstractTypeFactory with descendants and TypeInfo have Discriminator filled |
| 36 | + var polymorphicBaseTypeInfoType = typeof(PolymorphicBaseTypeInfo<>).MakeGenericType(type); |
| 37 | + var polymorphicBaseTypeInfoInstance = Activator.CreateInstance(polymorphicBaseTypeInfoType); |
| 38 | + var derivedTypesPropertyGetter = polymorphicBaseTypeInfoType.GetProperty("DerivedTypes", BindingFlags.Instance | BindingFlags.Public).GetGetMethod(); |
| 39 | + var derivedTypes = (derivedTypesPropertyGetter.Invoke(polymorphicBaseTypeInfoInstance, null) as IEnumerable<Type>).ToArray(); |
| 40 | + var discriminatorPropertyGetter = polymorphicBaseTypeInfoType.GetProperty("Discriminator", BindingFlags.Instance | BindingFlags.Public).GetGetMethod(); |
| 41 | + var discriminator = discriminatorPropertyGetter.Invoke(polymorphicBaseTypeInfoInstance, null) as string; |
| 42 | + |
| 43 | + // Polymorphism registration required if we have at least one TypeInfo in AbstractTypeFactory and Discriminator is set |
| 44 | + if (derivedTypes.Length > 0 && !string.IsNullOrEmpty(discriminator)) |
| 45 | + { |
| 46 | + foreach (var derivedType in derivedTypes) |
| 47 | + { |
| 48 | + var derivedTypeSchema = GetTypeSchema(schemaRegistry, derivedType, false); |
| 49 | + |
| 50 | + // Make sure all derivedTypes are in schemaRegistry |
| 51 | + if (derivedTypeSchema == null) |
| 52 | + { |
| 53 | + derivedTypeSchema = schemaRegistry.GetOrRegister(derivedType); |
| 54 | + } |
| 55 | + |
| 56 | + AddInheritanceToDerivedTypeSchema(derivedTypeSchema, type); |
| 57 | + } |
34 | 58 |
|
35 |
| - // Need to make first property character lower to avoid properties duplication because of case, as all properties in OpenApi spec are in camelCase |
36 |
| - discriminatorName = char.ToLowerInvariant(discriminatorName[0]) + discriminatorName.Substring(1); |
| 59 | + AddDiscriminatorToBaseType(schemaRegistry, type, discriminator); |
| 60 | + } |
| 61 | + } |
| 62 | + } |
| 63 | + } |
37 | 64 |
|
38 |
| - var typeName = _useFullTypeNames ? abstractType.FullName : abstractType.FriendlyId(); |
39 |
| - var parentSchema = schemaRegistry.Definitions[typeName]; |
| 65 | + private Schema GetTypeSchema(SchemaRegistry schemaRegistry, Type type, bool throwOnEmpty) |
| 66 | + { |
| 67 | + Schema result = null; |
| 68 | + var typeName = _useFullTypeNames ? type.FullName : type.Name; |
40 | 69 |
|
41 |
| - //set up a discriminator property (it must be required) |
42 |
| - parentSchema.discriminator = discriminatorName; |
43 |
| - parentSchema.required = new List<string> { discriminatorName }; |
| 70 | + // IApiExplorer contains types from all api controllers, so some of them could be not presented in specific module schemaRegistry. |
| 71 | + if (schemaRegistry.Definitions.ContainsKey(typeName)) |
| 72 | + { |
| 73 | + result = schemaRegistry.Definitions[typeName]; |
| 74 | + } |
44 | 75 |
|
45 |
| - if (!parentSchema.properties.ContainsKey(discriminatorName)) |
| 76 | + if (throwOnEmpty && result == null) |
46 | 77 | {
|
47 |
| - parentSchema.properties.Add(discriminatorName, new Schema { type = "string" }); |
| 78 | + throw new KeyNotFoundException($"Derived type \"{type.FullName}\" does not exist in SchemaRegistry."); |
48 | 79 | }
|
49 | 80 |
|
50 |
| - //register all subclasses |
51 |
| - var derivedTypes = polymorphicBaseTypeInfo.DerivedTypes; |
| 81 | + return result; |
| 82 | + } |
| 83 | + |
| 84 | + private void AddDiscriminatorToBaseType(SchemaRegistry schemaRegistry, Type baseType, string discriminator) |
| 85 | + { |
| 86 | + // Need to make first discriminator character lower to avoid properties duplication because of case, as all properties in OpenApi spec are in camelCase |
| 87 | + discriminator = char.ToLowerInvariant(discriminator[0]) + discriminator.Substring(1); |
| 88 | + |
| 89 | + var baseTypeSchema = GetTypeSchema(schemaRegistry, baseType, true); |
52 | 90 |
|
53 |
| - foreach (var item in derivedTypes) |
| 91 | + // Set up a discriminator property (it must be required) |
| 92 | + baseTypeSchema.discriminator = discriminator; |
| 93 | + baseTypeSchema.required = new List<string> { discriminator }; |
| 94 | + |
| 95 | + if (!baseTypeSchema.properties.ContainsKey(discriminator)) |
54 | 96 | {
|
55 |
| - schemaRegistry.GetOrRegister(item); |
| 97 | + baseTypeSchema.properties.Add(discriminator, new Schema { type = "string" }); |
56 | 98 | }
|
57 | 99 | }
|
| 100 | + |
| 101 | + private void AddInheritanceToDerivedTypeSchema(Schema derivedTypeSchema, Type baseType) |
| 102 | + { |
| 103 | + var clonedSchema = new Schema |
| 104 | + { |
| 105 | + properties = derivedTypeSchema.properties, |
| 106 | + type = derivedTypeSchema.type, |
| 107 | + required = derivedTypeSchema.required |
| 108 | + }; |
| 109 | + |
| 110 | + var baseTypeName = _useFullTypeNames ? baseType.FullName : baseType.FriendlyId(); |
| 111 | + |
| 112 | + var parentSchema = new Schema { @ref = "#/definitions/" + baseTypeName }; |
| 113 | + |
| 114 | + derivedTypeSchema.allOf = new List<Schema> { parentSchema, clonedSchema }; |
| 115 | + |
| 116 | + //reset properties for they are included in allOf, should be null but code does not handle it |
| 117 | + derivedTypeSchema.properties = new Dictionary<string, Schema>(); |
| 118 | + } |
| 119 | + |
| 120 | + // This private class is used to simplify querying from AbstractTypeFactory<T>.AllTypeInfos properties (no need to implement Linq queries using reflection) |
| 121 | + private class PolymorphicBaseTypeInfo<T> |
| 122 | + { |
| 123 | + public string Discriminator { get => AbstractTypeFactory<T>.AllTypeInfos.FirstOrDefault()?.Discriminator; } |
| 124 | + public IEnumerable<Type> DerivedTypes { get => AbstractTypeFactory<T>.AllTypeInfos.Select(x => x.Type) ?? Enumerable.Empty<Type>(); } |
| 125 | + } |
58 | 126 | }
|
59 | 127 | }
|
0 commit comments