diff --git a/smt/surrogate_models/krg_based.py b/smt/surrogate_models/krg_based.py index 854844faa..d3e481986 100644 --- a/smt/surrogate_models/krg_based.py +++ b/smt/surrogate_models/krg_based.py @@ -529,7 +529,12 @@ def _initialize_theta(self, theta, n_levels, cat_features, cat_kernel): ncomp = 1e5 theta_cont_features = np.zeros((len(theta), 1), dtype=bool) - theta_cat_features = np.zeros((len(theta), len(n_levels)), dtype=bool) + theta_cat_features = np.ones((len(theta), len(n_levels)), dtype=bool) + if cat_kernel in [ + MixIntKernelType.EXP_HOMO_HSPHERE, + MixIntKernelType.HOMO_HSPHERE, + ]: + theta_cat_features = np.zeros((len(theta), len(n_levels)), dtype=bool) i = 0 j = 0 n_theta_cont = 0 @@ -714,17 +719,30 @@ def _matrix_data_corr( theta_cat_kernel[theta_cat_features[1]] -= ( theta_bounds[1] + theta_bounds[0] ) - theta_cat_kernel[theta_cat_features[1]] *= 1 / theta_bounds[1] + theta_cat_kernel[theta_cat_features[1]] *= 1 / ( + 1.000000000001 * theta_bounds[1] + ) for i in range(len(n_levels)): theta_cat = theta_cat_kernel[theta_cat_features[0][i]] - T = matrix_data_corr_levels_cat_matrix( - i, - n_levels, - theta_cat, - theta_bounds, - is_ehh=cat_kernel == MixIntKernelType.EXP_HOMO_HSPHERE, - ) + if cat_kernel == MixIntKernelType.COMPOUND_SYMMETRY: + T = np.zeros((n_levels[i], n_levels[i])) + for tij in range(n_levels[i]): + for tji in range(n_levels[i]): + if tij == tji: + T[tij, tji] = 1 + else: + T[tij, tji] = max( + theta_cat[0], 1e-10 - 1 / (n_levels[i] - 1) + ) + else: + T = matrix_data_corr_levels_cat_matrix( + i, + n_levels, + theta_cat, + theta_bounds, + is_ehh=cat_kernel == MixIntKernelType.EXP_HOMO_HSPHERE, + ) if cat_kernel_comps is not None: # Sampling points X and y @@ -786,6 +804,7 @@ def _matrix_data_corr( in [ MixIntKernelType.EXP_HOMO_HSPHERE, MixIntKernelType.HOMO_HSPHERE, + MixIntKernelType.COMPOUND_SYMMETRY, ], ) diff --git a/smt/utils/design_space.py b/smt/utils/design_space.py index 26ba153e1..2069bdb1d 100644 --- a/smt/utils/design_space.py +++ b/smt/utils/design_space.py @@ -1002,7 +1002,7 @@ def _get_correct_config(self, vector: np.ndarray) -> Configuration: seed = None if isinstance(self.random_state, int): seed = self.random_state - elif isinstance(random_state, np.random.RandomState): + elif isinstance(self.random_state, np.random.RandomState): seed = self.random_state.get_state()[1][0] self.seed = seed