Skip to content

Commit

Permalink
add CS
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-Saves committed Feb 4, 2024
1 parent d59e6b7 commit 4a1b405
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
37 changes: 28 additions & 9 deletions smt/surrogate_models/krg_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,12 @@ def _initialize_theta(self, theta, n_levels, cat_features, cat_kernel):
ncomp = 1e5

theta_cont_features = np.zeros((len(theta), 1), dtype=bool)
theta_cat_features = np.zeros((len(theta), len(n_levels)), dtype=bool)
theta_cat_features = np.ones((len(theta), len(n_levels)), dtype=bool)
if cat_kernel in [
MixIntKernelType.EXP_HOMO_HSPHERE,
MixIntKernelType.HOMO_HSPHERE,
]:
theta_cat_features = np.zeros((len(theta), len(n_levels)), dtype=bool)
i = 0
j = 0
n_theta_cont = 0
Expand Down Expand Up @@ -714,17 +719,30 @@ def _matrix_data_corr(
theta_cat_kernel[theta_cat_features[1]] -= (
theta_bounds[1] + theta_bounds[0]
)
theta_cat_kernel[theta_cat_features[1]] *= 1 / theta_bounds[1]
theta_cat_kernel[theta_cat_features[1]] *= 1 / (
1.000000000001 * theta_bounds[1]
)

for i in range(len(n_levels)):
theta_cat = theta_cat_kernel[theta_cat_features[0][i]]
T = matrix_data_corr_levels_cat_matrix(
i,
n_levels,
theta_cat,
theta_bounds,
is_ehh=cat_kernel == MixIntKernelType.EXP_HOMO_HSPHERE,
)
if cat_kernel == MixIntKernelType.COMPOUND_SYMMETRY:
T = np.zeros((n_levels[i], n_levels[i]))
for tij in range(n_levels[i]):
for tji in range(n_levels[i]):
if tij == tji:
T[tij, tji] = 1
else:
T[tij, tji] = max(
theta_cat[0], 1e-10 - 1 / (n_levels[i] - 1)
)
else:
T = matrix_data_corr_levels_cat_matrix(
i,
n_levels,
theta_cat,
theta_bounds,
is_ehh=cat_kernel == MixIntKernelType.EXP_HOMO_HSPHERE,
)

if cat_kernel_comps is not None:
# Sampling points X and y
Expand Down Expand Up @@ -786,6 +804,7 @@ def _matrix_data_corr(
in [
MixIntKernelType.EXP_HOMO_HSPHERE,
MixIntKernelType.HOMO_HSPHERE,
MixIntKernelType.COMPOUND_SYMMETRY,
],
)

Expand Down
2 changes: 1 addition & 1 deletion smt/utils/design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ def _get_correct_config(self, vector: np.ndarray) -> Configuration:
seed = None
if isinstance(self.random_state, int):
seed = self.random_state
elif isinstance(random_state, np.random.RandomState):
elif isinstance(self.random_state, np.random.RandomState):
seed = self.random_state.get_state()[1][0]
self.seed = seed

Expand Down

0 comments on commit 4a1b405

Please sign in to comment.