Skip to content

Commit

Permalink
fix pva and add a seed for the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-Saves committed Jan 7, 2025
1 parent 5ddace1 commit ad96b92
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 18 deletions.
83 changes: 68 additions & 15 deletions smt/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,84 @@ def compute_relative_error(sm, xe=None, ye=None, kx=None):


def compute_pva(sm, xe, ye):
ye = ye.reshape((xe.shape[0], 1))
N = len(ye)
ye2 = sm.predict_values(xe)
variance = sm.predict_variances(xe)
"""
Compute the Predictive Variance Adequacy (PVA) for a surrogate model.
Parameters:
- sm: The surrogate model object, expected to have `predict_values` and `predict_variances` methods.
- xe: Input data for evaluation (N x d array).
- ye: True output values (N x 1 array or equivalent).
Returns:
- pva: Predictive Variance Adequacy score (float).
"""
ye = ye.reshape((xe.shape[0], 1)) # Ensure `ye` is column vector
Nb = len(ye) # Number of data points

# Predicted values and variances
ye_pred = sm.predict_values(xe) # Predicted values (N x 1 array)
variance = sm.predict_variances(xe) # Predicted variances (N x 1 array)

# Calculate squared error normalized by variance
error = ((ye_pred - ye) ** 2) / variance

# Compute PVA with logarithm
pva = np.abs(np.log(np.sum(error) / Nb))

error = (ye2 - ye) ** 2 / variance
pva = np.sum(error) / N
return pva


def compute_rmse(sm, xe, ye):
ye = ye.reshape((xe.shape[0], 1))
N = len(ye)
ye2 = sm.predict_values(xe)
rmse = np.sqrt(np.sum((ye2 - ye) ** 2) / N)
"""
Compute the Root Mean Square Error (RMSE) for a surrogate model.
Parameters:
- sm: The surrogate model object, expected to have a `predict_values` method.
- xe: Input data for evaluation (N x d array).
- ye: True output values (N x 1 array or equivalent).
Returns:
- rmse: Root Mean Square Error (float).
"""
ye = ye.reshape((xe.shape[0], 1)) # Ensure `ye` is a column vector

# Predicted values
ye_pred = sm.predict_values(xe) # Predicted values (N x 1 array)

# Compute RMSE
mse = np.mean((ye_pred - ye) ** 2) # Mean Squared Error
rmse = np.sqrt(mse) # Root Mean Squared Error

return rmse


def compute_q2(sm, xe, ye):
ye = ye.reshape((xe.shape[0], 1))
N = len(ye)
square_rmse = compute_rmse(sm, xe, ye) ** 2
"""
Compute the Q^2 validation criterion for a surrogate model.
Parameters:
- sm: The surrogate model object, expected to have a `predict_values` method.
- xe: Input data for evaluation (N x d array).
- ye: True output values (N x 1 array or equivalent).
Returns:
- Q2: Predictive coefficient of determination (float).
"""
ye = ye.reshape((xe.shape[0], 1)) # Ensure `ye` is a column vector

# Predicted values
ye_pred = sm.predict_values(xe) # Predicted values (N x 1 array)

# Mean of true output values
ye_mean = np.mean(ye)
variance = np.sum((ye - ye_mean) ** 2) / N
Q2 = 1 - (square_rmse / variance)

# Residual Sum of Squares (RSS) and Total Sum of Squares (TSS)
rss = np.sum((ye - ye_pred) ** 2) # Residual sum of squares
tss = np.sum((ye - ye_mean) ** 2) # Total sum of squares

# Compute Q^2
Q2 = 1 - (rss / tss)

return Q2


Expand Down
6 changes: 3 additions & 3 deletions smt/utils/test/test_misc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def prepare_tests_errors(self):

def test_pva_error(self):
xe, ye = self.prepare_tests_errors()
sm = KRG(print_global=False)
sm = KRG(print_global=False, random_state=42)
sm.set_training_values(xe, ye)
sm.train()

Expand All @@ -56,7 +56,7 @@ def test_pva_error(self):

def test_rmse_error(self):
xe, ye = self.prepare_tests_errors()
sm = KRG(print_global=False)
sm = KRG(print_global=False, random_state=42)
sm.set_training_values(xe, ye)
sm.train()

Expand All @@ -65,7 +65,7 @@ def test_rmse_error(self):

def test_q2_error(self):
xe, ye = self.prepare_tests_errors()
sm = KRG(print_global=False)
sm = KRG(print_global=False, random_state=42)
sm.set_training_values(xe, ye)
sm.train()

Expand Down

0 comments on commit ad96b92

Please sign in to comment.