Skip to content

Commit 45e6457

Browse files
author
Guillaume Lemaitre
committed
Merge branch 'refactor'
Conflicts: imblearn/ensemble/balance_cascade.py
2 parents 5f20c3d + 721bc2a commit 45e6457

File tree

78 files changed

+646
-2111
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+646
-2111
lines changed

imblearn/base.py

Lines changed: 77 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
"""Base class for sampling"""
1+
"""Base class for sampling"""
22

33
from __future__ import division
44
from __future__ import print_function
55

66
import warnings
7+
import logging
78

89
import numpy as np
910

@@ -19,14 +20,16 @@
1920

2021

2122
class SamplerMixin(six.with_metaclass(ABCMeta, BaseEstimator)):
23+
2224
"""Mixin class for samplers with abstact method.
2325
2426
Warning: This class should not be used directly. Use the derive classes
2527
instead.
2628
"""
2729

28-
@abstractmethod
29-
def __init__(self, ratio='auto', random_state=None, verbose=True):
30+
_estimator_type = "sampler"
31+
32+
def __init__(self, ratio='auto'):
3033
"""Initialize this object and its instance variables.
3134
3235
Parameters
@@ -37,45 +40,15 @@ def __init__(self, ratio='auto', random_state=None, verbose=True):
3740
of samples in the minority class over the the number of samples
3841
in the majority class.
3942
40-
random_state : int or None, optional (default=None)
41-
Seed for random number generation.
42-
43-
verbose : bool, optional (default=True)
44-
Boolean to either or not print information about the processing
45-
4643
Returns
4744
-------
4845
None
4946
5047
"""
51-
# The ratio correspond to the number of samples in the minority class
52-
# over the number of samples in the majority class. Thus, the ratio
53-
# cannot be greater than 1.0
54-
if isinstance(ratio, float):
55-
if ratio > 1:
56-
raise ValueError('Ration cannot be greater than one.')
57-
elif ratio <= 0:
58-
raise ValueError('Ratio cannot be negative.')
59-
else:
60-
self.ratio = ratio
61-
elif isinstance(ratio, string_types):
62-
if ratio == 'auto':
63-
self.ratio = ratio
64-
else:
65-
raise ValueError('Unknown string for the parameter ratio.')
66-
else:
67-
raise ValueError('Unknown parameter type for ratio.')
6848

69-
self.random_state = random_state
70-
self.verbose = verbose
49+
self.ratio = ratio
50+
self.logger = logging.getLogger(__name__)
7151

72-
# Create the member variables regarding the classes statistics
73-
self.min_c_ = None
74-
self.maj_c_ = None
75-
self.stats_c_ = {}
76-
self.X_shape_ = None
77-
78-
@abstractmethod
7952
def fit(self, X, y):
8053
"""Find the classes statistics before to perform sampling.
8154
@@ -97,8 +70,15 @@ def fit(self, X, y):
9770
# Check the consistency of X and y
9871
X, y = check_X_y(X, y)
9972

100-
if self.verbose:
101-
print("Determining classes statistics... ", end="")
73+
self.min_c_ = None
74+
self.maj_c_ = None
75+
self.stats_c_ = {}
76+
self.X_shape_ = None
77+
78+
if hasattr(self, 'ratio'):
79+
self._validate_ratio()
80+
81+
self.logger.info('Compute classes statistics ...')
10282

10383
# Get all the unique elements in the target array
10484
uniques = np.unique(y)
@@ -122,9 +102,8 @@ def fit(self, X, y):
122102
self.min_c_ = min(self.stats_c_, key=self.stats_c_.get)
123103
self.maj_c_ = max(self.stats_c_, key=self.stats_c_.get)
124104

125-
if self.verbose:
126-
print('{} classes detected: {}'.format(uniques.size,
127-
self.stats_c_))
105+
self.logger.info('%s classes detected: %s', uniques.size,
106+
self.stats_c_)
128107

129108
# Check if the ratio provided at initialisation make sense
130109
if isinstance(self.ratio, float):
@@ -136,7 +115,6 @@ def fit(self, X, y):
136115

137116
return self
138117

139-
@abstractmethod
140118
def sample(self, X, y):
141119
"""Resample the dataset.
142120
@@ -158,8 +136,11 @@ def sample(self, X, y):
158136
159137
"""
160138

139+
# Check the consistency of X and y
140+
X, y = check_X_y(X, y)
141+
161142
# Check that the data have been fitted
162-
if not self.stats_c_:
143+
if not hasattr(self, 'stats_c_'):
163144
raise RuntimeError('You need to fit the data, first!!!')
164145

165146
# Check if the size of the data is identical than at fitting
@@ -168,7 +149,10 @@ def sample(self, X, y):
168149
' seem to be the one earlier fitted. Use the'
169150
' fitted data.')
170151

171-
return self
152+
if hasattr(self, 'ratio'):
153+
self._validate_ratio()
154+
155+
return self._sample(X, y)
172156

173157
def fit_sample(self, X, y):
174158
"""Fit the statistics and resample the data directly.
@@ -192,3 +176,53 @@ def fit_sample(self, X, y):
192176
"""
193177

194178
return self.fit(X, y).sample(X, y)
179+
180+
def _validate_ratio(self):
181+
# The ratio correspond to the number of samples in the minority class
182+
# over the number of samples in the majority class. Thus, the ratio
183+
# cannot be greater than 1.0
184+
if isinstance(self.ratio, float):
185+
if self.ratio > 1:
186+
raise ValueError('Ration cannot be greater than one.')
187+
elif self.ratio <= 0:
188+
raise ValueError('Ratio cannot be negative.')
189+
190+
elif isinstance(self.ratio, string_types):
191+
if self.ratio != 'auto':
192+
raise ValueError('Unknown string for the parameter ratio.')
193+
else:
194+
raise ValueError('Unknown parameter type for ratio.')
195+
196+
@abstractmethod
197+
def _sample(self, X, y):
198+
"""Resample the dataset.
199+
200+
Parameters
201+
----------
202+
X : ndarray, shape (n_samples, n_features)
203+
Matrix containing the data which have to be sampled.
204+
205+
y : ndarray, shape (n_samples, )
206+
Corresponding label for each sample in X.
207+
208+
Returns
209+
-------
210+
X_resampled : ndarray, shape (n_samples_new, n_features)
211+
The array containing the resampled data.
212+
213+
y_resampled : ndarray, shape (n_samples_new)
214+
The corresponding label of `X_resampled`
215+
"""
216+
pass
217+
218+
def __getstate__(self):
219+
"""Prevent logger from being pickled."""
220+
object_dictionary = self.__dict__.copy()
221+
del object_dictionary['logger']
222+
return object_dictionary
223+
224+
def __setstate__(self, dict):
225+
"""Re-open the logger."""
226+
logger = logging.getLogger(__name__)
227+
self.__dict__.update(dict)
228+
self.logger = logger

imblearn/combine/smote_enn.py

Lines changed: 14 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from __future__ import print_function
33
from __future__ import division
44

5-
from sklearn.utils import check_X_y
6-
75
from ..over_sampling import SMOTE
86
from ..under_sampling import EditedNearestNeighbours
97
from ..base import SamplerMixin
@@ -22,11 +20,11 @@ class SMOTEENN(SamplerMixin):
2220
number of samples in the minority class over the the number of
2321
samples in the majority class.
2422
25-
random_state : int or None, optional (default=None)
26-
Seed for random number generation.
27-
28-
verbose : bool, optional (default=True)
29-
Whether or not to print information about the processing.
23+
random_state : int, RandomState instance or None, optional (default=None)
24+
If int, random_state is the seed used by the random number generator;
25+
If RandomState instance, random_state is the random number generator;
26+
If None, the random number generator is the RandomState instance used
27+
by np.random.
3028
3129
k : int, optional (default=5)
3230
Number of nearest neighbours to used to construct synthetic
@@ -60,15 +58,6 @@ class SMOTEENN(SamplerMixin):
6058
6159
Attributes
6260
----------
63-
ratio : str or float
64-
If 'auto', the ratio will be defined automatically to balance
65-
the dataset. Otherwise, the ratio is defined as the
66-
number of samples in the minority class over the the number of
67-
samples in the majority class.
68-
69-
random_state : int or None
70-
Seed for random number generation.
71-
7261
min_c_ : str or int
7362
The identifier of the minority class.
7463
@@ -96,81 +85,25 @@ class SMOTEENN(SamplerMixin):
9685
9786
"""
9887

99-
def __init__(self, ratio='auto', random_state=None, verbose=True,
88+
def __init__(self, ratio='auto', random_state=None,
10089
k=5, m=10, out_step=0.5, kind_smote='regular',
10190
size_ngh=3, kind_enn='all', n_jobs=-1, **kwargs):
10291

103-
"""Initialise the SMOTE ENN object.
104-
105-
Parameters
106-
----------
107-
ratio : str or float, optional (default='auto')
108-
If 'auto', the ratio will be defined automatically to balance
109-
the dataset. Otherwise, the ratio is defined as the
110-
number of samples in the minority class over the the number of
111-
samples in the majority class.
112-
113-
random_state : int or None, optional (default=None)
114-
Seed for random number generation.
115-
116-
verbose : bool, optional (default=True)
117-
Whether or not to print information about the processing.
118-
119-
k : int, optional (default=5)
120-
Number of nearest neighbours to used to construct synthetic
121-
samples.
122-
123-
m : int, optional (default=10)
124-
Number of nearest neighbours to use to determine if a minority
125-
sample is in danger.
126-
127-
out_step : float, optional (default=0.5)
128-
Step size when extrapolating.
129-
130-
kind_smote : str, optional (default='regular')
131-
The type of SMOTE algorithm to use one of the following
132-
options: 'regular', 'borderline1', 'borderline2', 'svm'.
133-
134-
size_ngh : int, optional (default=3)
135-
Size of the neighbourhood to consider to compute the average
136-
distance to the minority point samples.
137-
138-
kind_sel : str, optional (default='all')
139-
Strategy to use in order to exclude samples.
140-
141-
- If 'all', all neighbours will have to agree with the samples of
142-
interest to not be excluded.
143-
- If 'mode', the majority vote of the neighbours will be used in
144-
order to exclude a sample.
145-
146-
n_jobs : int, optional (default=-1)
147-
The number of threads to open if possible.
148-
149-
Returns
150-
-------
151-
None
152-
153-
"""
154-
super(SMOTEENN, self).__init__(ratio=ratio, random_state=random_state,
155-
verbose=verbose)
156-
92+
super(SMOTEENN, self).__init__(ratio=ratio)
93+
self.random_state = random_state
15794
self.k = k
15895
self.m = m
15996
self.out_step = out_step
16097
self.kind_smote = kind_smote
98+
self.size_ngh = size_ngh
99+
self.kind_enn = kind_enn
161100
self.n_jobs = n_jobs
162101
self.kwargs = kwargs
163-
164102
self.sm = SMOTE(ratio=self.ratio, random_state=self.random_state,
165-
verbose=self.verbose, k=self.k, m=self.m,
166-
out_step=self.out_step, kind=self.kind_smote,
167-
n_jobs=self.n_jobs, **self.kwargs)
168-
169-
self.size_ngh = size_ngh
170-
self.kind_enn = kind_enn
171-
103+
k=self.k, m=self.m, out_step=self.out_step,
104+
kind=self.kind_smote, n_jobs=self.n_jobs,
105+
**self.kwargs)
172106
self.enn = EditedNearestNeighbours(random_state=self.random_state,
173-
verbose=self.verbose,
174107
size_ngh=self.size_ngh,
175108
kind_sel=self.kind_enn,
176109
n_jobs=self.n_jobs)
@@ -192,8 +125,6 @@ def fit(self, X, y):
192125
Return self.
193126
194127
"""
195-
# Check the consistency of X and y
196-
X, y = check_X_y(X, y)
197128

198129
super(SMOTEENN, self).fit(X, y)
199130

@@ -202,7 +133,7 @@ def fit(self, X, y):
202133

203134
return self
204135

205-
def sample(self, X, y):
136+
def _sample(self, X, y):
206137
"""Resample the dataset.
207138
208139
Parameters
@@ -222,10 +153,6 @@ def sample(self, X, y):
222153
The corresponding label of `X_resampled`
223154
224155
"""
225-
# Check the consistency of X and y
226-
X, y = check_X_y(X, y)
227-
228-
super(SMOTEENN, self).sample(X, y)
229156

230157
# Transform using SMOTE
231158
X, y = self.sm.sample(X, y)

0 commit comments

Comments
 (0)