Skip to content

Commit d3dbe2d

Browse files
akalotkinAliaksei Kalotkin
and
Aliaksei Kalotkin
authored
Alternative aggregate functions to calculate histogram values. (#475)
* Alternative aggregate functions to calculate histogram values. * Reorder expected json * Alternative aggregate functions to calculate histogram values. * Alternative aggregate functions to calculate histogram values * Alternative aggregate functions to calculate histogram values --------- Co-authored-by: Aliaksei Kalotkin <aliaksei.kalotkin@nielsen.com>
1 parent f53283e commit d3dbe2d

File tree

5 files changed

+222
-17
lines changed

5 files changed

+222
-17
lines changed

src/main/scala/com/amazon/deequ/analyzers/Histogram.scala

+68-11
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,20 @@
1616

1717
package com.amazon.deequ.analyzers
1818

19+
import com.amazon.deequ.analyzers.Histogram.{AggregateFunction, Count}
1920
import com.amazon.deequ.analyzers.runners.{IllegalAnalyzerParameterException, MetricCalculationException}
2021
import com.amazon.deequ.metrics.{Distribution, DistributionValue, HistogramMetric}
2122
import org.apache.spark.sql.expressions.UserDefinedFunction
22-
import org.apache.spark.sql.functions.col
23-
import org.apache.spark.sql.types.{StringType, StructType}
23+
import org.apache.spark.sql.functions.{col, sum}
24+
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType}
2425
import org.apache.spark.sql.{DataFrame, Row}
26+
2527
import scala.util.{Failure, Try}
2628

2729
/**
2830
* Histogram is the summary of values in a column of a DataFrame. Groups the given column's values,
29-
* and calculates the number of rows with that specific value and the fraction of this value.
31+
* and calculates either number of rows or with that specific value and the fraction of this value or
32+
* sum of values in other column.
3033
*
3134
* @param column Column to do histogram analysis on
3235
* @param binningUdf Optional binning function to run before grouping to re-categorize the
@@ -37,13 +40,15 @@ import scala.util.{Failure, Try}
3740
* maxBins sets the N.
3841
* This limit does not affect what is being returned as number of bins. It
3942
* always returns the dictinct value count.
43+
* @param aggregateFunction function that implements aggregation logic.
4044
*/
4145
case class Histogram(
4246
column: String,
4347
binningUdf: Option[UserDefinedFunction] = None,
4448
maxDetailBins: Integer = Histogram.MaximumAllowedDetailBins,
4549
where: Option[String] = None,
46-
computeFrequenciesAsRatio: Boolean = true)
50+
computeFrequenciesAsRatio: Boolean = true,
51+
aggregateFunction: AggregateFunction = Count)
4752
extends Analyzer[FrequenciesAndNumRows, HistogramMetric]
4853
with FilterableAnalyzer {
4954

@@ -58,19 +63,15 @@ case class Histogram(
5863

5964
// TODO figure out a way to pass this in if its known before hand
6065
val totalCount = if (computeFrequenciesAsRatio) {
61-
data.count()
66+
aggregateFunction.total(data)
6267
} else {
6368
1
6469
}
6570

66-
val frequencies = data
71+
val df = data
6772
.transform(filterOptional(where))
6873
.transform(binOptional(binningUdf))
69-
.select(col(column).cast(StringType))
70-
.na.fill(Histogram.NullFieldReplacement)
71-
.groupBy(column)
72-
.count()
73-
.withColumnRenamed("count", Analyzers.COUNT_COL)
74+
val frequencies = query(df)
7475

7576
Some(FrequenciesAndNumRows(frequencies, totalCount))
7677
}
@@ -125,11 +126,67 @@ case class Histogram(
125126
case _ => data
126127
}
127128
}
129+
130+
private def query(data: DataFrame): DataFrame = {
131+
aggregateFunction.query(this.column, data)
132+
}
128133
}
129134

130135
object Histogram {
131136
val NullFieldReplacement = "NullValue"
132137
val MaximumAllowedDetailBins = 1000
138+
val count_function = "count"
139+
val sum_function = "sum"
140+
141+
sealed trait AggregateFunction {
142+
def query(column: String, data: DataFrame): DataFrame
143+
144+
def total(data: DataFrame): Long
145+
146+
def aggregateColumn(): Option[String]
147+
148+
def function(): String
149+
}
150+
151+
case object Count extends AggregateFunction {
152+
override def query(column: String, data: DataFrame): DataFrame = {
153+
data
154+
.select(col(column).cast(StringType))
155+
.na.fill(Histogram.NullFieldReplacement)
156+
.groupBy(column)
157+
.count()
158+
.withColumnRenamed("count", Analyzers.COUNT_COL)
159+
}
160+
161+
override def aggregateColumn(): Option[String] = None
162+
163+
override def function(): String = count_function
164+
165+
override def total(data: DataFrame): Long = {
166+
data.count()
167+
}
168+
}
169+
170+
case class Sum(aggColumn: String) extends AggregateFunction {
171+
override def query(column: String, data: DataFrame): DataFrame = {
172+
data
173+
.select(col(column).cast(StringType), col(aggColumn).cast(LongType))
174+
.na.fill(Histogram.NullFieldReplacement)
175+
.groupBy(column)
176+
.sum(aggColumn)
177+
.withColumnRenamed("count", Analyzers.COUNT_COL)
178+
}
179+
180+
override def total(data: DataFrame): Long = {
181+
data.groupBy().sum(aggColumn).first().getLong(0)
182+
}
183+
184+
override def aggregateColumn(): Option[String] = {
185+
Some(aggColumn)
186+
}
187+
188+
override def function(): String = sum_function
189+
}
133190
}
134191

135192
object OrderByAbsoluteCount extends Ordering[Row] {

src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala

+25-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import scala.collection._
3030
import scala.collection.JavaConverters._
3131
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap}
3232
import JsonSerializationConstants._
33+
import com.amazon.deequ.analyzers.Histogram.{AggregateFunction, Count => HistogramCount, Sum => HistogramSum}
3334
import org.apache.spark.sql.Column
3435
import org.apache.spark.sql.functions.expr
3536

@@ -302,6 +303,12 @@ private[deequ] object AnalyzerSerializer
302303
result.addProperty(ANALYZER_NAME_FIELD, "Histogram")
303304
result.addProperty(COLUMN_FIELD, histogram.column)
304305
result.addProperty("maxDetailBins", histogram.maxDetailBins)
306+
// Count is initial and default implementation for Histogram
307+
// We don't include fields below in json to preserve json backward compatibility.
308+
if (histogram.aggregateFunction != Histogram.Count) {
309+
result.addProperty("aggregateFunction", histogram.aggregateFunction.function())
310+
result.addProperty("aggregateColumn", histogram.aggregateFunction.aggregateColumn().get)
311+
}
305312

306313
case _ : Histogram =>
307314
throw new IllegalArgumentException("Unable to serialize Histogram with binningUdf!")
@@ -436,7 +443,10 @@ private[deequ] object AnalyzerDeserializer
436443
Histogram(
437444
json.get(COLUMN_FIELD).getAsString,
438445
None,
439-
json.get("maxDetailBins").getAsInt)
446+
json.get("maxDetailBins").getAsInt,
447+
aggregateFunction = createAggregateFunction(
448+
getOptionalStringParam(json, "aggregateFunction").getOrElse(Histogram.count_function),
449+
getOptionalStringParam(json, "aggregateColumn").getOrElse("")))
440450

441451
case "DataType" =>
442452
DataType(
@@ -489,12 +499,24 @@ private[deequ] object AnalyzerDeserializer
489499
}
490500

491501
private[this] def getOptionalWhereParam(jsonObject: JsonObject): Option[String] = {
492-
if (jsonObject.has(WHERE_FIELD)) {
493-
Option(jsonObject.get(WHERE_FIELD).getAsString)
502+
getOptionalStringParam(jsonObject, WHERE_FIELD)
503+
}
504+
505+
private[this] def getOptionalStringParam(jsonObject: JsonObject, field: String): Option[String] = {
506+
if (jsonObject.has(field)) {
507+
Option(jsonObject.get(field).getAsString)
494508
} else {
495509
None
496510
}
497511
}
512+
513+
private[this] def createAggregateFunction(function: String, aggregateColumn: String): AggregateFunction = {
514+
function match {
515+
case Histogram.count_function => HistogramCount
516+
case Histogram.sum_function => HistogramSum(aggregateColumn)
517+
case _ => throw new IllegalArgumentException("Wrong aggregate function name: " + function)
518+
}
519+
}
498520
}
499521

500522
private[deequ] object MetricSerializer extends JsonSerializer[Metric[_]] {

src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala

+20
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,26 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with
260260
}
261261
}
262262

263+
"compute correct sum metrics " in withSparkSession { sparkSession =>
264+
val dfFull = getDateDf(sparkSession)
265+
val histogram = Histogram("product", aggregateFunction = Histogram.Sum("units")).calculate(dfFull)
266+
assert(histogram.value.isSuccess)
267+
268+
histogram.value.get match {
269+
case hv =>
270+
assert(hv.numberOfBins == 3)
271+
assert(hv.values.size == 3)
272+
assert(hv.values.keys == Set("Furniture", "Cosmetics", "Electronics"))
273+
assert(hv("Furniture").absolute == 55)
274+
assert(hv("Furniture").ratio == 55.0 / (55 + 20 + 60))
275+
assert(hv("Cosmetics").absolute == 20)
276+
assert(hv("Cosmetics").ratio == 20.0 / (55 + 20 + 60))
277+
assert(hv("Electronics").absolute == 60)
278+
assert(hv("Electronics").ratio == 60.0 / (55 + 20 + 60))
279+
280+
}
281+
}
282+
263283
"compute correct metrics on numeric values" in withSparkSession { sparkSession =>
264284
val dfFull = getDfWithNumericValues(sparkSession)
265285
val histogram = Histogram("att2").calculate(dfFull)

src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ class AnalyzerContextTest extends AnyWordSpec
8686
|{"entity":"Column","instance":"item","name":"Distinctness","value":1.0},
8787
|{"entity":"Column","instance":"att1","name":"Completeness","value":1.0},
8888
|{"entity":"Multicolumn","instance":"att1,att2","name":"Uniqueness","value":0.25},
89+
|{"entity":"Dataset","instance":"*","name":"Size (where: att2 == 'd')","value":1.0},
90+
|{"entity":"Dataset","instance":"*","name":"Size","value":4.0},
8991
|{"entity":"Column","instance":"att1","name":"Histogram.bins","value":2.0},
9092
|{"entity":"Column","instance":"att1","name":"Histogram.abs.a","value":3.0},
9193
|{"entity":"Column","instance":"att1","name":"Histogram.ratio.a","value":0.75},
9294
|{"entity":"Column","instance":"att1","name":"Histogram.abs.b","value":1.0},
93-
|{"entity":"Column","instance":"att1","name":"Histogram.ratio.b","value":0.25},
94-
|{"entity":"Dataset","instance":"*","name":"Size (where: att2 == 'd')","value":1.0},
95-
|{"entity":"Dataset","instance":"*","name":"Size","value":4.0}
95+
|{"entity":"Column","instance":"att1","name":"Histogram.ratio.b","value":0.25}
9696
|]"""
9797
.stripMargin.replaceAll("\n", "")
9898

src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala

+106
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,112 @@ class AnalysisResultSerdeTest extends FlatSpec with Matchers {
173173
assertCorrectlyConvertsAnalysisResults(Seq(result))
174174
}
175175

176+
val histogramSumJson =
177+
"""[
178+
| {
179+
| "resultKey": {
180+
| "dataSetDate": 0,
181+
| "tags": {}
182+
| },
183+
| "analyzerContext": {
184+
| "metricMap": [
185+
| {
186+
| "analyzer": {
187+
| "analyzerName": "Histogram",
188+
| "column": "columnA",
189+
| "maxDetailBins": 1000,
190+
| "aggregateFunction": "sum",
191+
| "aggregateColumn": "columnB"
192+
| },
193+
| "metric": {
194+
| "metricName": "HistogramMetric",
195+
| "column": "columnA",
196+
| "numberOfBins": 10,
197+
| "value": {
198+
| "numberOfBins": 10,
199+
| "values": {
200+
| "some": {
201+
| "absolute": 10,
202+
| "ratio": 0.5
203+
| }
204+
| }
205+
| }
206+
| }
207+
| }
208+
| ]
209+
| }
210+
| }
211+
|]""".stripMargin
212+
val histogramCountJson =
213+
"""[
214+
| {
215+
| "resultKey": {
216+
| "dataSetDate": 0,
217+
| "tags": {}
218+
| },
219+
| "analyzerContext": {
220+
| "metricMap": [
221+
| {
222+
| "analyzer": {
223+
| "analyzerName": "Histogram",
224+
| "column": "columnA",
225+
| "maxDetailBins": 1000
226+
| },
227+
| "metric": {
228+
| "metricName": "HistogramMetric",
229+
| "column": "columnA",
230+
| "numberOfBins": 10,
231+
| "value": {
232+
| "numberOfBins": 10,
233+
| "values": {
234+
| "some": {
235+
| "absolute": 10,
236+
| "ratio": 0.5
237+
| }
238+
| }
239+
| }
240+
| }
241+
| }
242+
| ]
243+
| }
244+
| }
245+
|]""".stripMargin
246+
247+
"Histogram serialization" should "be backward compatible for count" in {
248+
val expected = histogramCountJson
249+
val analyzer = Histogram("columnA")
250+
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
251+
val context = AnalyzerContext(Map(analyzer -> metric))
252+
val result = new AnalysisResult(ResultKey(0), context)
253+
assert(serialize(Seq(result)) == expected)
254+
}
255+
256+
"Histogram serialization" should "properly serialize sum" in {
257+
val expected = histogramSumJson
258+
val analyzer = Histogram("columnA", aggregateFunction = Histogram.Sum("columnB"))
259+
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
260+
val context = AnalyzerContext(Map(analyzer -> metric))
261+
val result = new AnalysisResult(ResultKey(0), context)
262+
assert(serialize(Seq(result)) == expected)
263+
}
264+
265+
"Histogram deserialization" should "be backward compatible for count" in {
266+
val analyzer = Histogram("columnA")
267+
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
268+
val context = AnalyzerContext(Map(analyzer -> metric))
269+
val expected = new AnalysisResult(ResultKey(0), context)
270+
assert(deserialize(histogramCountJson) == List(expected))
271+
}
272+
273+
"Histogram deserialization" should "properly deserialize sum" in {
274+
val analyzer = Histogram("columnA", aggregateFunction = Histogram.Sum("columnB"))
275+
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
276+
val context = AnalyzerContext(Map(analyzer -> metric))
277+
val expected = new AnalysisResult(ResultKey(0), context)
278+
assert(deserialize(histogramSumJson) == List(expected))
279+
}
280+
281+
176282
def assertCorrectlyConvertsAnalysisResults(
177283
analysisResults: Seq[AnalysisResult],
178284
shouldFail: Boolean = false)

0 commit comments

Comments
 (0)