16
16
17
17
package com .amazon .deequ .analyzers
18
18
19
+ import com .amazon .deequ .analyzers .Histogram .{AggregateFunction , Count }
19
20
import com .amazon .deequ .analyzers .runners .{IllegalAnalyzerParameterException , MetricCalculationException }
20
21
import com .amazon .deequ .metrics .{Distribution , DistributionValue , HistogramMetric }
21
22
import org .apache .spark .sql .expressions .UserDefinedFunction
22
- import org .apache .spark .sql .functions .col
23
- import org .apache .spark .sql .types .{StringType , StructType }
23
+ import org .apache .spark .sql .functions .{ col , sum }
24
+ import org .apache .spark .sql .types .{DoubleType , LongType , StringType , StructType }
24
25
import org .apache .spark .sql .{DataFrame , Row }
26
+
25
27
import scala .util .{Failure , Try }
26
28
27
29
/**
28
30
* Histogram is the summary of values in a column of a DataFrame. Groups the given column's values,
29
- * and calculates the number of rows with that specific value and the fraction of this value.
31
+ * and calculates either number of rows or with that specific value and the fraction of this value or
32
+ * sum of values in other column.
30
33
*
31
34
* @param column Column to do histogram analysis on
32
35
* @param binningUdf Optional binning function to run before grouping to re-categorize the
@@ -37,13 +40,15 @@ import scala.util.{Failure, Try}
37
40
* maxBins sets the N.
38
41
* This limit does not affect what is being returned as number of bins. It
39
42
* always returns the dictinct value count.
43
+ * @param aggregateFunction function that implements aggregation logic.
40
44
*/
41
45
case class Histogram (
42
46
column : String ,
43
47
binningUdf : Option [UserDefinedFunction ] = None ,
44
48
maxDetailBins : Integer = Histogram .MaximumAllowedDetailBins ,
45
49
where : Option [String ] = None ,
46
- computeFrequenciesAsRatio : Boolean = true )
50
+ computeFrequenciesAsRatio : Boolean = true ,
51
+ aggregateFunction : AggregateFunction = Count )
47
52
extends Analyzer [FrequenciesAndNumRows , HistogramMetric ]
48
53
with FilterableAnalyzer {
49
54
@@ -58,19 +63,15 @@ case class Histogram(
58
63
59
64
// TODO figure out a way to pass this in if its known before hand
60
65
val totalCount = if (computeFrequenciesAsRatio) {
61
- data.count( )
66
+ aggregateFunction.total(data )
62
67
} else {
63
68
1
64
69
}
65
70
66
- val frequencies = data
71
+ val df = data
67
72
.transform(filterOptional(where))
68
73
.transform(binOptional(binningUdf))
69
- .select(col(column).cast(StringType ))
70
- .na.fill(Histogram .NullFieldReplacement )
71
- .groupBy(column)
72
- .count()
73
- .withColumnRenamed(" count" , Analyzers .COUNT_COL )
74
+ val frequencies = query(df)
74
75
75
76
Some (FrequenciesAndNumRows (frequencies, totalCount))
76
77
}
@@ -125,11 +126,67 @@ case class Histogram(
125
126
case _ => data
126
127
}
127
128
}
129
+
130
+ private def query (data : DataFrame ): DataFrame = {
131
+ aggregateFunction.query(this .column, data)
132
+ }
128
133
}
129
134
130
135
object Histogram {
131
136
val NullFieldReplacement = " NullValue"
132
137
val MaximumAllowedDetailBins = 1000
138
+ val count_function = " count"
139
+ val sum_function = " sum"
140
+
141
+ sealed trait AggregateFunction {
142
+ def query (column : String , data : DataFrame ): DataFrame
143
+
144
+ def total (data : DataFrame ): Long
145
+
146
+ def aggregateColumn (): Option [String ]
147
+
148
+ def function (): String
149
+ }
150
+
151
+ case object Count extends AggregateFunction {
152
+ override def query (column : String , data : DataFrame ): DataFrame = {
153
+ data
154
+ .select(col(column).cast(StringType ))
155
+ .na.fill(Histogram .NullFieldReplacement )
156
+ .groupBy(column)
157
+ .count()
158
+ .withColumnRenamed(" count" , Analyzers .COUNT_COL )
159
+ }
160
+
161
+ override def aggregateColumn (): Option [String ] = None
162
+
163
+ override def function (): String = count_function
164
+
165
+ override def total (data : DataFrame ): Long = {
166
+ data.count()
167
+ }
168
+ }
169
+
170
+ case class Sum (aggColumn : String ) extends AggregateFunction {
171
+ override def query (column : String , data : DataFrame ): DataFrame = {
172
+ data
173
+ .select(col(column).cast(StringType ), col(aggColumn).cast(LongType ))
174
+ .na.fill(Histogram .NullFieldReplacement )
175
+ .groupBy(column)
176
+ .sum(aggColumn)
177
+ .withColumnRenamed(" count" , Analyzers .COUNT_COL )
178
+ }
179
+
180
+ override def total (data : DataFrame ): Long = {
181
+ data.groupBy().sum(aggColumn).first().getLong(0 )
182
+ }
183
+
184
+ override def aggregateColumn (): Option [String ] = {
185
+ Some (aggColumn)
186
+ }
187
+
188
+ override def function (): String = sum_function
189
+ }
133
190
}
134
191
135
192
object OrderByAbsoluteCount extends Ordering [Row ] {
0 commit comments