Skip to content

Commit 036a880

Browse files
committed
The result of JSD calculation can be slightly greater than 1.0 due to floating point error. This CL fixes the bug by capping the result at 1.0.
PiperOrigin-RevId: 692261708
1 parent 58af9e7 commit 036a880

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

tensorflow_data_validation/anomalies/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ cc_library(
116116
"@com_github_tensorflow_metadata//tensorflow_metadata/proto/v0:metadata_v0_proto_cc_pb2",
117117
"@com_google_absl//absl/log:check",
118118
"@com_google_absl//absl/status",
119+
"@com_google_absl//absl/types:optional",
119120
],
120121
)
121122

tensorflow_data_validation/anomalies/metrics.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@ limitations under the License.
1717

1818
#include <algorithm>
1919
#include <cmath>
20-
#include <limits>
20+
#include <iterator>
2121
#include <map>
22-
#include <numeric>
23-
#include <string>
22+
#include <set>
23+
#include <tuple>
24+
#include <utility>
2425
#include <vector>
2526

2627
#include "absl/log/check.h"
2728
#include "absl/status/status.h"
29+
#include "absl/types/optional.h"
2830
#include "tensorflow_data_validation/anomalies/map_util.h"
31+
#include "tensorflow_data_validation/anomalies/statistics_view.h"
2932
#include "tensorflow_data_validation/anomalies/status_util.h"
3033
#include "tensorflow_metadata/proto/v0/schema.pb.h"
3134
#include "tensorflow_metadata/proto/v0/statistics.pb.h"
@@ -356,6 +359,8 @@ absl::Status JensenShannonDivergence(Histogram& histogram_1,
356359
KullbackLeiblerDivergence(histogram_2,
357360
average_distribution_histogram)) /
358361
2);
362+
// Due to precision limitations, the result will be capped at 1.0.
363+
result = std::min(result, 1.0);
359364
return absl::OkStatus();
360365
}
361366

@@ -405,7 +410,7 @@ absl::Status JensenShannonDivergence(const std::map<string, double>& map_1,
405410
kl_sum += b_ele_prob * std::log2(b_ele_prob / m);
406411
}
407412
}
408-
result = kl_sum/2;
413+
result = std::min(kl_sum / 2, 1.0);
409414

410415
return absl::OkStatus();
411416
}

0 commit comments

Comments
 (0)