@@ -4,7 +4,7 @@ from pyspark.sql.dataframe import DataFrame
4
4
from pyspark.sql.column import Column
5
5
from pyspark.sql.types import StructType
6
6
from pyspark.sql.types import DataType
7
- from pyspark.sql.functions import col, lit
7
+ from pyspark.sql.functions import col, lit, when
8
8
from typing import List
9
9
import types
10
10
#foreach ($import in $record.baseImports)
@@ -46,6 +46,7 @@ class ${record.capitalizedName}SchemaBase(ABC):
46
46
47
47
def __init__(self):
48
48
self._schema = StructType()
49
+ self.validation_result_column = '__VALIDATE_RESULT_${record.capitalizedName}_'
49
50
50
51
## Setting the nullable parameter to True for every column due to inconsistencies in the behavior from different data sources/toolings (Spark vs Pyspark)
51
52
## This allows all data to be read in, and None values will be filtered out as part of the validate_dataset method if the field is required
@@ -136,16 +137,35 @@ class ${record.capitalizedName}SchemaBase(ABC):
136
137
def validate_dataset(self, ingest_dataset: DataFrame) -> DataFrame:
137
138
return self.validate_dataset_with_prefix(ingest_dataset, "")
138
139
139
- def validate_dataset_with_prefix(self, ingest_dataset: DataFrame, column_prefix: str) -> DataFrame:
140
+ def validate_dataset_with_prefix(self, ingest_dataset: DataFrame, column_prefix: str, valid_data_only = True ) -> DataFrame:
140
141
"""
141
142
Validates the given dataset and returns the lists of validated records.
142
143
"""
143
144
data_with_validations = ingest_dataset
145
+ #if ($record.hasRelations())
146
+ # relation records validation
147
+ #foreach($relation in $record.relations)
148
+ #if(!$relation.isNullable())
149
+ # filter out null data for the required relation
150
+ data_with_validations = data_with_validations.withColumn(self.validation_result_column + self.${relationVars[$relation.name]} + "_IS_NOT_NULL",
151
+ col(column_prefix + self.${relationVars[$relation.name]}).isNotNull());
152
+
153
+ #end
154
+ #if($relation.isOneToManyRelation())
155
+ data_with_validations = self.with_${relation.snakeCaseName}_validation(data_with_validations, '${relation.columnName}')
156
+ #else
157
+ ${relation.snakeCaseName}_schema = ${relation.capitalizedName}Schema()
158
+ data_with_validations = ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations, '${relation.columnName}.', False)
159
+ #end
160
+ #end
161
+ #end
162
+
163
+ # record fields validation
144
164
#foreach ($field in $record.fields)
145
165
#if (${field.isRequired()})
146
166
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NOT_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNotNull())
147
167
#else
148
- data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNull())
168
+ data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNull())
149
169
#end
150
170
#if (${field.getValidation().getMinValue()})
151
171
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_GREATER_THAN_MIN", col(column_prefix + self.${columnVars[$field.name]}).cast('double') >= ${field.getValidation().getMinValue()})
@@ -172,18 +192,6 @@ class ${record.capitalizedName}SchemaBase(ABC):
172
192
#end
173
193
#end
174
194
175
- ## TODO revise validation for relations
176
- #if (false)
177
- #foreach($relation in $record.relations)
178
- #if($relation.isOneToManyRelation())
179
- data_with_validations = data_with_validations.withColumn(self.${relationVars[$relation.name]} + "_VALID", lit(self._validate_with_${relation.snakeCaseName}_schema(data_with_validations.select(col(self.${relationVars[$relation.name]})))))
180
- #else
181
- ${relation.snakeCaseName}_schema = ${relation.name}Schema()
182
- data_with_validations = data_with_validations.withColumn(self.${relationVars[$relation.name]} + "_VALID", lit(not ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations.select(col(self.${relationVars[$relation.name]})), '${relation.columnName}.').isEmpty()))
183
- #end
184
- #end
185
- #end
186
-
187
195
column_filter_schemas = []
188
196
validation_columns = [col for col in data_with_validations.columns if col not in ingest_dataset.columns]
189
197
@@ -194,6 +202,9 @@ class ${record.capitalizedName}SchemaBase(ABC):
194
202
columns_grouped_by_field.append([col for col in validation_columns if col.startswith(self.${columnVars[$field.name]})])
195
203
#end
196
204
205
+ if valid_data_only:
206
+ columns_grouped_by_field.append([col for col in validation_columns if col.startswith('__VALIDATE_')])
207
+
197
208
# Create a schema filter for each field represented as a column group
198
209
for column_group in columns_grouped_by_field:
199
210
column_group_filter_schema = None
@@ -231,19 +242,67 @@ class ${record.capitalizedName}SchemaBase(ABC):
231
242
else:
232
243
final_column_filter_schemas = column_group_filter_schema
233
244
234
- valid_data = data_with_validations.filter(final_column_filter_schemas)
245
+ if valid_data_only:
246
+ valid_data = data_with_validations.filter(final_column_filter_schemas)
247
+ else:
248
+ valid_data = data_with_validations.withColumn(self.validation_result_column, when(final_column_filter_schemas, lit(True)).otherwise(lit(False)))
249
+ else:
250
+ if not valid_data_only:
251
+ valid_data = data_with_validations.withColumn(self.validation_result_column, lit(True))
235
252
236
253
valid_data = valid_data.drop(*validation_columns)
237
254
return valid_data
238
255
239
- ## TODO revise validation for relations
240
- #if (false)
241
- #foreach($relation in $record.relations)
242
- #if($relation.isOneToManyRelation())
243
- def _validate_with_${relation.snakeCaseName}_schema(self, dataset: DataFrame) -> bool:
244
- raise NotImplementedError
256
+ #set($hasOneToManyRelation = false)
257
+ #foreach ($relation in $record.relations)
258
+ #if ($relation.isOneToManyRelation())
259
+ #set($hasOneToManyRelation = true)
260
+ def with_${relation.snakeCaseName}_validation(self, dataset: DataFrame, validation_column: str) -> DataFrame:
261
+ """
262
+ Validates the given ${relation.capitalizedName} 1:M multiplicity relation dataset against ${relation.capitalizedName}Schema
263
+ Returns A dataset with validation result __VALIDATE_${relationVars[$relation.name]} column
264
+ """
265
+ ${relation.snakeCaseName}_schema = ${relation.capitalizedName}Schema()
266
+ return self.validate_with_relation_record_schema(dataset, validation_column,
267
+ ${relation.snakeCaseName}_schema.validate_dataset_with_prefix, ${relation.snakeCaseName}_schema.validation_result_column, #if (${relation.isNullable()}) False #else True #end)
268
+
245
269
#end
246
270
#end
247
- #end
248
271
272
+ #if ($hasOneToManyRelation)
273
+ def validate_with_relation_record_schema(self, ingest_dataset: DataFrame, validation_column: str, validate_dataset_with_prefix, relation_result_column: str, is_required=False) -> DataFrame:
274
+ """
275
+ Validates the given dataset with a given column where it contains array of ${relation.name} data records
276
+ against ${relation.name} schema using the given validate_dataset_with_prefix and drop_validation_columns functions
277
+ Returns the dataset including validation results in ${relation.name}_Valid column
278
+ """
279
+ id = "id"
280
+ expanded_column = "expanded_column"
281
+ aggregated_result_column = "bool_and({})".format(relation_result_column)
282
+ result_column = "__VALIDATE_{}".format(validation_column)
283
+
284
+ # add a row id
285
+ ingest_dataset = ingest_dataset.withColumn(id, row_number().over(Window.orderBy(monotonically_increasing_id())))
286
+
287
+ # flatten relation array record data for relation record validation
288
+ validation_dataset = ingest_dataset.select(validation_column, id).withColumn(expanded_column, explode(validation_column)).drop(validation_column)
289
+
290
+ # validate the flatten dataset
291
+ validation_dataset = validate_dataset_with_prefix(validation_dataset, expanded_column + ".", False) \
292
+ .drop(expanded_column)
293
+ # group the validation result with original dataset row id
294
+ validation_dataset = validation_dataset.groupBy(id).agg(bool_and(col(relation_result_column))) \
295
+ .withColumn(result_column, col(aggregated_result_column))
296
+
297
+ # cleanup
298
+ validation_dataset = validation_dataset.drop(validation_column, aggregated_result_column)
299
+ ingest_dataset = ingest_dataset.join(validation_dataset, id, "outer").drop(id)
300
+
301
+ if is_required:
302
+ ingest_dataset = ingest_dataset.withColumn(result_column, when(col(result_column).isNotNull() & col(result_column) == True, lit(True)).otherwise(lit(False)))
303
+ else:
304
+ ingest_dataset = ingest_dataset.withColumn(result_column, when(col(result_column).isNull() | col(result_column) == True, lit(True)).otherwise(lit(False)))
305
+
306
+ return ingest_dataset
307
+ #end
249
308
0 commit comments