Skip to content

Commit 8dbb659

Browse files
committed
support more sample weights
1 parent b4f8659 commit 8dbb659

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

R/helper.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ wmean = function(x, w) { # a better stats::weighted.mean
2323
sum(x * (w / sum(w)))
2424
}
2525

26+
wsum = function(x, w) { # sum(w * x) that asserts w and accepts NULL
27+
if (is.null(w)) {
28+
return(sum(x))
29+
}
30+
assert_numeric(w, lower = 0, finite = TRUE, any.missing = FALSE, len = length(x))
31+
sum(x * w)
32+
}
33+
2634
# confusion matrix
2735
cm = function(truth, response, positive = NULL) {
2836
if (!is.null(positive)) {

R/regr_sae.R

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,27 @@
22
#'
33
#' @details
44
#' The Sum of Absolute Errors is defined as \deqn{
5-
#' \sum_{i=1}^n \left| t_i - r_i \right|.
5+
#' \sum_{i=1}^n w_i \left| t_i - r_i \right|.
66
#' }{
7-
#' sum(abs((t - r))).
7+
#' sum(w * abs((t - r))).
88
#' }
9+
#' where \eqn{w_i} are unnormalized weights for each observation \eqn{x_i}, defaulting to 1.
910
#'
1011
#' @templateVar mid sae
1112
#' @template regr_template
1213
#'
14+
#' @param sample_weights (`numeric()`)\cr
15+
#' Vector of non-negative and finite sample weights.
16+
#' Must have the same length as `truth`.
17+
#' Weights for this function are not normalized.
18+
#' Defaults to sample weights 1.
19+
#'
1320
#' @inheritParams regr_params
1421
#' @template regr_example
1522
#' @export
16-
sae = function(truth, response, ...) {
23+
sae = function(truth, response, sample_weights = NULL, ...) {
1724
assert_regr(truth, response = response)
18-
sum(.ae(truth, response))
25+
sum(.ae(truth, response), sample_weights)
1926
}
2027

2128
#' @include measures.R

R/regr_sse.R

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,27 @@
22
#'
33
#' @details
44
#' The Sum of Squared Errors is defined as \deqn{
5-
#' \sum_{i=1}^n \left( t_i - r_i \right)^2.
5+
#' \sum_{i=1}^n w_i \left( t_i - r_i \right)^2.
66
#' }{
7-
#' sum((t - r)^2).
7+
#' sum(w * (t - r)^2).
88
#' }
9+
#' where \eqn{w_i} are unnormalized weights for each observation \eqn{x_i}, defaulting to 1.
910
#'
1011
#' @templateVar mid sse
1112
#' @template regr_template
1213
#'
14+
#' @param sample_weights (`numeric()`)\cr
15+
#' Vector of non-negative and finite sample weights.
16+
#' Must have the same length as `truth`.
17+
#' Weights for this function are not normalized.
18+
#' Defaults to sample weights 1.
19+
#'
1320
#' @inheritParams regr_params
1421
#' @template regr_example
1522
#' @export
16-
sse = function(truth, response, ...) {
23+
sse = function(truth, response, sample_weights = NULL, ...) {
1724
assert_regr(truth, response = response)
18-
sum(.se(truth, response))
25+
wsum(.se(truth, response), sample_weights)
1926
}
2027

2128
#' @include measures.R

0 commit comments

Comments
 (0)