@@ -32,31 +32,36 @@ class ${record.capitalizedName}SchemaBase(ABC):
32
32
Generated from: ${templateName}
33
33
"""
34
34
35
+ #set($columnVars = {})
35
36
#foreach ($field in $record.fields)
36
- ${field.upperSnakecaseName}_COLUMN: str = '${field.sparkAttributes.columnName}'
37
+ #set ($columnVars[$field.name] = "${field.upperSnakecaseName}_COLUMN")
38
+ ${columnVars[$field.name]}: str = '${field.sparkAttributes.columnName}'
37
39
#end
40
+ #set($relationVars = {})
38
41
#foreach ($relation in $record.relations)
39
- ${relation.upperSnakecaseName}_COLUMN: str = '${relation.columnName}'
42
+ #set ($relationVars[$relation.name] = "${relation.upperSnakecaseName}_COLUMN")
43
+ ${relationVars[$relation.name]}: str = '${relation.columnName}'
40
44
#end
41
45
42
46
43
47
def __init__(self):
44
48
self._schema = StructType()
45
49
50
+ ## Setting the nullable parameter to True for every column due to inconsistencies in the behavior from different data sources/toolings (Spark vs Pyspark)
51
+ ## This allows all data to be read in, and None values will be filtered out as part of the validate_dataset method if the field is required
52
+ ## Previously Pyspark would throw an exception if it encountered a None value with nullable set to False, resulting in the all previous data processed being lost
46
53
#foreach ($field in $record.fields)
47
- #set ($nullable = "#if($field.sparkAttributes.isNullable())True#{else}False#end")
48
54
#if ($field.sparkAttributes.isDecimalType())
49
- self.add(${record.capitalizedName}SchemaBase.${ field.upperSnakecaseName}_COLUMN , ${field.shortType}(${field.sparkAttributes.defaultDecimalPrecision}, ${field.sparkAttributes.decimalScale}), ${nullable} )
55
+ self.add(self.${columnVars[$ field.name]} , ${field.shortType}(${field.sparkAttributes.defaultDecimalPrecision}, ${field.sparkAttributes.decimalScale}), True )
50
56
#else
51
- self.add(${record.capitalizedName}SchemaBase.${ field.upperSnakecaseName}_COLUMN , ${field.shortType}(), ${nullable} )
57
+ self.add(self.${columnVars[$ field.name]} , ${field.shortType}(), True )
52
58
#end
53
59
#end
54
60
#foreach ($relation in $record.relations)
55
- #set ($nullable = "#if($relation.isNullable())True#{else}False#end")
56
61
#if ($relation.isOneToManyRelation())
57
- self.add(${record.capitalizedName}SchemaBase.${ relation.upperSnakecaseName}_COLUMN , ArrayType(${relation.capitalizedName}Schema().struct_type), ${nullable} )
62
+ self.add(self.${relationVars[$ relation.name]} , ArrayType(${relation.capitalizedName}Schema().struct_type), True )
58
63
#else
59
- self.add(${record.capitalizedName}SchemaBase.${ relation.upperSnakecaseName}_COLUMN , ${relation.capitalizedName}Schema().struct_type, ${nullable} )
64
+ self.add(self.${relationVars[$ relation.name]} , ${relation.capitalizedName}Schema().struct_type, True )
60
65
#end
61
66
#end
62
67
@@ -66,18 +71,18 @@ class ${record.capitalizedName}SchemaBase(ABC):
66
71
Returns the given dataset cast to this schema.
67
72
"""
68
73
#foreach ($field in $record.fields)
69
- ${field.snakeCaseName}_type = self.get_data_type(${record.capitalizedName}SchemaBase.${ field.upperSnakecaseName}_COLUMN )
74
+ ${field.snakeCaseName}_type = self.get_data_type(self.${columnVars[$ field.name]} )
70
75
#end
71
76
#foreach ($relation in $record.relations)
72
- ${relation.snakeCaseName}_type = self.get_data_type(${record.capitalizedName}SchemaBase.${ relation.upperSnakecaseName}_COLUMN )
77
+ ${relation.snakeCaseName}_type = self.get_data_type(self.${relationVars[$ relation.name]} )
73
78
#end
74
79
75
80
return dataset \
76
81
#foreach ($field in $record.fields)
77
- .withColumn(${record.capitalizedName}SchemaBase.${ field.upperSnakecaseName}_COLUMN , dataset[${record.capitalizedName}SchemaBase.${ field.upperSnakecaseName}_COLUMN ].cast(${field.snakeCaseName}_type))#if ($foreach.hasNext || $record.hasRelations()) \\#end
82
+ .withColumn(self.${columnVars[$ field.name]} , dataset[self.${columnVars[$ field.name]} ].cast(${field.snakeCaseName}_type))#if ($foreach.hasNext || $record.hasRelations()) \\#end
78
83
#end
79
84
#foreach ($relation in $record.relations)
80
- .withColumn(${record.capitalizedName}SchemaBase.${ relation.upperSnakecaseName}_COLUMN , dataset[${record.capitalizedName}SchemaBase.${ relation.upperSnakecaseName}_COLUMN ].cast(${relation.snakeCaseName}_type))#if ($foreach.hasNext) \\#end
85
+ .withColumn(self.${relationVars[$ relation.name]} , dataset[self.${relationVars[$ relation.name]} ].cast(${relation.snakeCaseName}_type))#if ($foreach.hasNext) \\#end
81
86
#end
82
87
#end
83
88
@@ -137,31 +142,32 @@ class ${record.capitalizedName}SchemaBase(ABC):
137
142
"""
138
143
data_with_validations = ingest_dataset
139
144
#foreach ($field in $record.fields)
140
- #set ( $columnName = "#if($field.column)$field.column#{else}$field.upperSnakecaseName#end" )
141
145
#if (${field.isRequired()})
142
- data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_IS_NOT_NULL", col(column_prefix + "${columnName}").isNotNull())
146
+ data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NOT_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNotNull())
147
+ #else
148
+ data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNull())
143
149
#end
144
150
#if (${field.getValidation().getMinValue()})
145
- data_with_validations = data_with_validations.withColumn("${ field.upperSnakecaseName} _GREATER_THAN_MIN", col(column_prefix + "${columnName}" ).cast('double') >= ${field.getValidation().getMinValue()})
151
+ data_with_validations = data_with_validations.withColumn(self.${columnVars[$ field.name]} + " _GREATER_THAN_MIN", col(column_prefix + self.${columnVars[$field.name]} ).cast('double') >= ${field.getValidation().getMinValue()})
146
152
#end
147
153
#if (${field.getValidation().getMaxValue()})
148
- data_with_validations = data_with_validations.withColumn("${ field.upperSnakecaseName} _LESS_THAN_MAX", col(column_prefix + "${columnName}" ).cast('double') <= ${field.getValidation().getMaxValue()})
154
+ data_with_validations = data_with_validations.withColumn(self.${columnVars[$ field.name]} + " _LESS_THAN_MAX", col(column_prefix + self.${columnVars[$field.name]} ).cast('double') <= ${field.getValidation().getMaxValue()})
149
155
#end
150
156
#if (${field.getValidation().getScale()})
151
- data_with_validations = data_with_validations.withColumn("${ field.upperSnakecaseName} _MATCHES_SCALE", col(column_prefix + "${columnName}" ).cast(StringType()).rlike(r"^[0-9]*(?:\.[0-9]{0,${field.getValidation().getScale()}})?$"))
157
+ data_with_validations = data_with_validations.withColumn(self.${columnVars[$ field.name]} + " _MATCHES_SCALE", col(column_prefix + self.${columnVars[$field.name]} ).cast(StringType()).rlike(r"^[0-9]*(?:\.[0-9]{0,${field.getValidation().getScale()}})?$"))
152
158
#end
153
159
#if (${field.getValidation().getMinLength()})
154
- data_with_validations = data_with_validations.withColumn("${ field.upperSnakecaseName} _GREATER_THAN_OR_EQUAL_TO_MIN_LENGTH", col(column_prefix + "${columnName}" ).rlike("^.{${field.getValidation().getMinLength()},}"))
160
+ data_with_validations = data_with_validations.withColumn(self.${columnVars[$ field.name]} + " _GREATER_THAN_OR_EQUAL_TO_MIN_LENGTH", col(column_prefix + self.${columnVars[$field.name]} ).rlike("^.{${field.getValidation().getMinLength()},}"))
155
161
#end
156
162
#if (${field.getValidation().getMaxLength()})
157
163
#set($max = ${field.getValidation().getMaxLength()} + 1)
158
- data_with_validations = data_with_validations.withColumn("${ field.upperSnakecaseName} _LESS_THAN_OR_EQUAL_TO_MAX_LENGTH", col(column_prefix + "${columnName}" ).rlike("^.{$max,}").eqNullSafe(False))
164
+ data_with_validations = data_with_validations.withColumn(self.${columnVars[$ field.name]} + " _LESS_THAN_OR_EQUAL_TO_MAX_LENGTH", col(column_prefix + self.${columnVars[$field.name]} ).rlike("^.{$max,}").eqNullSafe(False))
159
165
#end
160
166
#foreach ($format in $field.getValidation().getFormats())
161
167
#if ($foreach.first)
162
- data_with_validations = data_with_validations.withColumn("${ field.upperSnakecaseName} _MATCHES_FORMAT", col(column_prefix + "${columnName}" ).rlike("$format.replace("\","\\")")#if($foreach.last))#end
168
+ data_with_validations = data_with_validations.withColumn(self.${columnVars[$ field.name]} + " _MATCHES_FORMAT", col(column_prefix + self.${columnVars[$field.name]} ).rlike("$format.replace("\","\\")")#if($foreach.last))#end
163
169
#else
164
- | col(column_prefix + "${columnName}" ).rlike("$format.replace("\","\\")")#if($foreach.last))#end
170
+ | col(column_prefix + self.${columnVars[$field.name]} ).rlike("$format.replace("\","\\")")#if($foreach.last))#end
165
171
#end
166
172
#end
167
173
#end
@@ -170,28 +176,63 @@ class ${record.capitalizedName}SchemaBase(ABC):
170
176
#if (false)
171
177
#foreach($relation in $record.relations)
172
178
#if($relation.isOneToManyRelation())
173
- data_with_validations = data_with_validations.withColumn(self.${relation.upperSnakecaseName}_COLUMN + "_VALID", lit(self._validate_with_${relation.snakeCaseName}_schema(data_with_validations.select(col(self.${relation.upperSnakecaseName}_COLUMN )))))
179
+ data_with_validations = data_with_validations.withColumn(self.${relationVars[$ relation.name]} + "_VALID", lit(self._validate_with_${relation.snakeCaseName}_schema(data_with_validations.select(col(self.${relationVars[$ relation.name]} )))))
174
180
#else
175
181
${relation.snakeCaseName}_schema = ${relation.name}Schema()
176
- data_with_validations = data_with_validations.withColumn(self.${relation.upperSnakecaseName}_COLUMN + "_VALID", lit(not ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations.select(col(self.${relation.upperSnakecaseName}_COLUMN )), '${relation.columnName}.').isEmpty()))
182
+ data_with_validations = data_with_validations.withColumn(self.${relationVars[$ relation.name]} + "_VALID", lit(not ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations.select(col(self.${relationVars[$ relation.name]} )), '${relation.columnName}.').isEmpty()))
177
183
#end
178
184
#end
179
185
#end
180
186
181
- validation_columns = [x for x in data_with_validations.columns if x not in ingest_dataset.columns]
187
+ column_filter_schemas = []
188
+ validation_columns = [col for col in data_with_validations.columns if col not in ingest_dataset.columns]
182
189
183
- # Schema for filtering for valid data
184
- filter_schema = None
185
- for column_name in validation_columns:
186
- if isinstance(filter_schema, Column):
187
- filter_schema = filter_schema & col(column_name).eqNullSafe(True)
188
- else:
189
- filter_schema = col(column_name).eqNullSafe(True)
190
+ # Separate columns into groups based on their field name
191
+ columns_grouped_by_field = []
192
+
193
+ #foreach ($field in $record.fields)
194
+ columns_grouped_by_field.append([col for col in validation_columns if col.startswith(self.${columnVars[$field.name]})])
195
+ #end
196
+
197
+ # Create a schema filter for each field represented as a column group
198
+ for column_group in columns_grouped_by_field:
199
+ column_group_filter_schema = None
200
+
201
+ # This column tracks if a non-required field is None. This enables
202
+ # non-required validated fields to still pass filtering when they are None
203
+ nullable_column = None
204
+
205
+ for column_name in column_group:
206
+ if column_name.endswith("_IS_NULL"):
207
+ nullable_column = col(column_name).eqNullSafe(True)
208
+ elif column_group_filter_schema is not None:
209
+ column_group_filter_schema = column_group_filter_schema & col(column_name).eqNullSafe(True)
210
+ else:
211
+ column_group_filter_schema = col(column_name).eqNullSafe(True)
212
+
213
+ # Add the nullable column filter as a OR statement at the end of the given field schema
214
+ # If there is no other schema filters for the field, then it can be ignored
215
+ if nullable_column is not None and column_group_filter_schema is not None:
216
+ column_group_filter_schema = nullable_column | column_group_filter_schema
217
+
218
+ if column_group_filter_schema is not None:
219
+ column_filter_schemas.append(column_group_filter_schema)
190
220
191
- valid_data = data_with_validations
192
221
# Isolate the valid data and drop validation columns
193
- if isinstance(filter_schema, Column):
194
- valid_data = data_with_validations.filter(filter_schema)
222
+ valid_data = data_with_validations
223
+ if column_filter_schemas:
224
+
225
+ # Combine all the field filter schemas into one final schema for the row
226
+ final_column_filter_schemas = None
227
+
228
+ for column_group_filter_schema in column_filter_schemas:
229
+ if final_column_filter_schemas is not None:
230
+ final_column_filter_schemas = final_column_filter_schemas & column_group_filter_schema
231
+ else:
232
+ final_column_filter_schemas = column_group_filter_schema
233
+
234
+ valid_data = data_with_validations.filter(final_column_filter_schemas)
235
+
195
236
valid_data = valid_data.drop(*validation_columns)
196
237
return valid_data
197
238
0 commit comments