Skip to content

Commit d6d093c

Browse files
authored
Merge pull request #581 from boozallen/572-filter-out-records-with-invalid-relations
572-relation-schema-validation
2 parents f93c103 + e87c14c commit d6d093c

File tree

12 files changed

+714
-477
lines changed

12 files changed

+714
-477
lines changed

foundation/foundation-mda/src/main/java/com/boozallen/aiops/mda/metamodel/element/pyspark/PySparkSchemaRecord.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ public class PySparkSchemaRecord extends PythonRecord {
3434

3535
private static final String SCHEMA_PACKAGE = "from ...schema.%s_schema import %sSchema";
3636
private static final String PYSPARK_ARRAY_IMPORT = "from pyspark.sql.types import ArrayType";
37+
private static final String PYSPARK_COL_FUNCTIONS = "from pyspark.sql.functions import bool_and, explode, monotonically_increasing_id, row_number";
38+
private static final String PYSPARK_WINDOW_IMPORT = "from pyspark.sql.window import Window";
3739
private Set<String> imports = new TreeSet<>();
3840

3941
/**
@@ -79,12 +81,14 @@ public Set<String> getBaseImports() {
7981
imports.add(dictionaryTypeImport);
8082
}
8183
}
82-
boolean isArrayImportAdded = false;
84+
boolean isPysparkImportAdded = false;
8385
for (Relation relation : getRelations()) {
8486
PythonRecordRelation wrappedRelation = new PythonRecordRelation(relation);
85-
if(wrappedRelation.isOneToManyRelation() && !isArrayImportAdded) {
86-
isArrayImportAdded = true;
87+
if(wrappedRelation.isOneToManyRelation() && !isPysparkImportAdded) {
88+
isPysparkImportAdded = true;
8789
imports.add(PYSPARK_ARRAY_IMPORT);
90+
imports.add(PYSPARK_COL_FUNCTIONS);
91+
imports.add(PYSPARK_WINDOW_IMPORT);
8892
}
8993
imports.add(String.format(SCHEMA_PACKAGE, wrappedRelation.getSnakeCaseName(), relation.getName()));
9094
}

foundation/foundation-mda/src/main/resources/templates/data-delivery-data-records/pyspark.schema.base.py.vm

+82-23
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ from pyspark.sql.dataframe import DataFrame
44
from pyspark.sql.column import Column
55
from pyspark.sql.types import StructType
66
from pyspark.sql.types import DataType
7-
from pyspark.sql.functions import col, lit
7+
from pyspark.sql.functions import col, lit, when
88
from typing import List
99
import types
1010
#foreach ($import in $record.baseImports)
@@ -46,6 +46,7 @@ class ${record.capitalizedName}SchemaBase(ABC):
4646

4747
def __init__(self):
4848
self._schema = StructType()
49+
self.validation_result_column = '__VALIDATE_RESULT_${record.capitalizedName}_'
4950

5051
## Setting the nullable parameter to True for every column due to inconsistencies in the behavior from different data sources/toolings (Spark vs Pyspark)
5152
## This allows all data to be read in, and None values will be filtered out as part of the validate_dataset method if the field is required
@@ -136,16 +137,35 @@ class ${record.capitalizedName}SchemaBase(ABC):
136137
def validate_dataset(self, ingest_dataset: DataFrame) -> DataFrame:
137138
return self.validate_dataset_with_prefix(ingest_dataset, "")
138139

139-
def validate_dataset_with_prefix(self, ingest_dataset: DataFrame, column_prefix: str) -> DataFrame:
140+
def validate_dataset_with_prefix(self, ingest_dataset: DataFrame, column_prefix: str, valid_data_only = True) -> DataFrame:
140141
"""
141142
Validates the given dataset and returns the lists of validated records.
142143
"""
143144
data_with_validations = ingest_dataset
145+
#if ($record.hasRelations())
146+
# relation records validation
147+
#foreach($relation in $record.relations)
148+
#if(!$relation.isNullable())
149+
# filter out null data for the required relation
150+
data_with_validations = data_with_validations.withColumn(self.validation_result_column + self.${relationVars[$relation.name]} + "_IS_NOT_NULL",
151+
col(column_prefix + self.${relationVars[$relation.name]}).isNotNull());
152+
153+
#end
154+
#if($relation.isOneToManyRelation())
155+
data_with_validations = self.with_${relation.snakeCaseName}_validation(data_with_validations, '${relation.columnName}')
156+
#else
157+
${relation.snakeCaseName}_schema = ${relation.capitalizedName}Schema()
158+
data_with_validations = ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations, '${relation.columnName}.', False)
159+
#end
160+
#end
161+
#end
162+
163+
# record fields validation
144164
#foreach ($field in $record.fields)
145165
#if (${field.isRequired()})
146166
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NOT_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNotNull())
147167
#else
148-
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNull())
168+
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNull())
149169
#end
150170
#if (${field.getValidation().getMinValue()})
151171
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_GREATER_THAN_MIN", col(column_prefix + self.${columnVars[$field.name]}).cast('double') >= ${field.getValidation().getMinValue()})
@@ -172,18 +192,6 @@ class ${record.capitalizedName}SchemaBase(ABC):
172192
#end
173193
#end
174194

175-
## TODO revise validation for relations
176-
#if (false)
177-
#foreach($relation in $record.relations)
178-
#if($relation.isOneToManyRelation())
179-
data_with_validations = data_with_validations.withColumn(self.${relationVars[$relation.name]} + "_VALID", lit(self._validate_with_${relation.snakeCaseName}_schema(data_with_validations.select(col(self.${relationVars[$relation.name]})))))
180-
#else
181-
${relation.snakeCaseName}_schema = ${relation.name}Schema()
182-
data_with_validations = data_with_validations.withColumn(self.${relationVars[$relation.name]} + "_VALID", lit(not ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations.select(col(self.${relationVars[$relation.name]})), '${relation.columnName}.').isEmpty()))
183-
#end
184-
#end
185-
#end
186-
187195
column_filter_schemas = []
188196
validation_columns = [col for col in data_with_validations.columns if col not in ingest_dataset.columns]
189197

@@ -194,6 +202,9 @@ class ${record.capitalizedName}SchemaBase(ABC):
194202
columns_grouped_by_field.append([col for col in validation_columns if col.startswith(self.${columnVars[$field.name]})])
195203
#end
196204

205+
if valid_data_only:
206+
columns_grouped_by_field.append([col for col in validation_columns if col.startswith('__VALIDATE_')])
207+
197208
# Create a schema filter for each field represented as a column group
198209
for column_group in columns_grouped_by_field:
199210
column_group_filter_schema = None
@@ -231,19 +242,67 @@ class ${record.capitalizedName}SchemaBase(ABC):
231242
else:
232243
final_column_filter_schemas = column_group_filter_schema
233244

234-
valid_data = data_with_validations.filter(final_column_filter_schemas)
245+
if valid_data_only:
246+
valid_data = data_with_validations.filter(final_column_filter_schemas)
247+
else:
248+
valid_data = data_with_validations.withColumn(self.validation_result_column, when(final_column_filter_schemas, lit(True)).otherwise(lit(False)))
249+
else:
250+
if not valid_data_only:
251+
valid_data = data_with_validations.withColumn(self.validation_result_column, lit(True))
235252

236253
valid_data = valid_data.drop(*validation_columns)
237254
return valid_data
238255

239-
## TODO revise validation for relations
240-
#if (false)
241-
#foreach($relation in $record.relations)
242-
#if($relation.isOneToManyRelation())
243-
def _validate_with_${relation.snakeCaseName}_schema(self, dataset: DataFrame) -> bool:
244-
raise NotImplementedError
256+
#set($hasOneToManyRelation = false)
257+
#foreach ($relation in $record.relations)
258+
#if ($relation.isOneToManyRelation())
259+
#set($hasOneToManyRelation = true)
260+
def with_${relation.snakeCaseName}_validation(self, dataset: DataFrame, validation_column: str) -> DataFrame:
261+
"""
262+
Validates the given ${relation.capitalizedName} 1:M multiplicity relation dataset against ${relation.capitalizedName}Schema
263+
Returns A dataset with validation result __VALIDATE_${relationVars[$relation.name]} column
264+
"""
265+
${relation.snakeCaseName}_schema = ${relation.capitalizedName}Schema()
266+
return self.validate_with_relation_record_schema(dataset, validation_column,
267+
${relation.snakeCaseName}_schema.validate_dataset_with_prefix, ${relation.snakeCaseName}_schema.validation_result_column, #if (${relation.isNullable()}) False #else True #end)
268+
245269
#end
246270
#end
247-
#end
248271

272+
#if ($hasOneToManyRelation)
273+
def validate_with_relation_record_schema(self, ingest_dataset: DataFrame, validation_column: str, validate_dataset_with_prefix, relation_result_column: str, is_required=False) -> DataFrame:
274+
"""
275+
Validates the given dataset with a given column where it contains array of ${relation.name} data records
276+
against ${relation.name} schema using the given validate_dataset_with_prefix and drop_validation_columns functions
277+
Returns the dataset including validation results in ${relation.name}_Valid column
278+
"""
279+
id = "id"
280+
expanded_column = "expanded_column"
281+
aggregated_result_column = "bool_and({})".format(relation_result_column)
282+
result_column = "__VALIDATE_{}".format(validation_column)
283+
284+
# add a row id
285+
ingest_dataset = ingest_dataset.withColumn(id, row_number().over(Window.orderBy(monotonically_increasing_id())))
286+
287+
# flatten relation array record data for relation record validation
288+
validation_dataset = ingest_dataset.select(validation_column, id).withColumn(expanded_column, explode(validation_column)).drop(validation_column)
289+
290+
# validate the flatten dataset
291+
validation_dataset = validate_dataset_with_prefix(validation_dataset, expanded_column + ".", False) \
292+
.drop(expanded_column)
293+
# group the validation result with original dataset row id
294+
validation_dataset = validation_dataset.groupBy(id).agg(bool_and(col(relation_result_column))) \
295+
.withColumn(result_column, col(aggregated_result_column))
296+
297+
# cleanup
298+
validation_dataset = validation_dataset.drop(validation_column, aggregated_result_column)
299+
ingest_dataset = ingest_dataset.join(validation_dataset, id, "outer").drop(id)
300+
301+
if is_required:
302+
ingest_dataset = ingest_dataset.withColumn(result_column, when(col(result_column).isNotNull() & col(result_column) == True, lit(True)).otherwise(lit(False)))
303+
else:
304+
ingest_dataset = ingest_dataset.withColumn(result_column, when(col(result_column).isNull() | col(result_column) == True, lit(True)).otherwise(lit(False)))
305+
306+
return ingest_dataset
307+
#end
249308

0 commit comments

Comments
 (0)