diff --git a/requirements.txt b/requirements.txt index 7df715fc9..28cf6dd78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,5 +14,5 @@ scikit-learn==0.21.2 scipy==1.3.0 seaborn==0.9.0 six==1.12.0 -torch==1.1.0 -torchvision==0.3.0 +torch>=1.1.0 +torchvision>=0.3.0 diff --git a/src/DeepSAD.py b/src/DeepSAD.py index 1f002040e..d1823f350 100644 --- a/src/DeepSAD.py +++ b/src/DeepSAD.py @@ -60,7 +60,7 @@ def set_network(self, net_name): def train(self, dataset: BaseADDataset, optimizer_name: str = 'adam', lr: float = 0.001, n_epochs: int = 50, lr_milestones: tuple = (), batch_size: int = 128, weight_decay: float = 1e-6, device: str = 'cuda', - n_jobs_dataloader: int = 0): + n_jobs_dataloader: int = 0, validate : bool = False): """Trains the Deep SAD model on the training data.""" self.optimizer_name = optimizer_name @@ -68,8 +68,9 @@ def train(self, dataset: BaseADDataset, optimizer_name: str = 'adam', lr: float lr_milestones=lr_milestones, batch_size=batch_size, weight_decay=weight_decay, device=device, n_jobs_dataloader=n_jobs_dataloader) # Get the model - self.net = self.trainer.train(dataset, self.net) + self.net = self.trainer.train(dataset, self.net, validate=validate) self.results['train_time'] = self.trainer.train_time + self.train_loss = self.trainer.train_loss self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list def test(self, dataset: BaseADDataset, device: str = 'cuda', n_jobs_dataloader: int = 0): @@ -130,7 +131,7 @@ def save_model(self, export_model, save_ae=True): """Save Deep SAD model to export_model.""" net_dict = self.net.state_dict() - ae_net_dict = self.ae_net.state_dict() if save_ae else None + ae_net_dict = self.ae_net.state_dict() if (save_ae and self.ae_net is not None) else None torch.save({'c': self.c, 'net_dict': net_dict, diff --git a/src/optim/DeepSAD_trainer.py b/src/optim/DeepSAD_trainer.py index 44b1118de..486303b9d 100644 --- a/src/optim/DeepSAD_trainer.py +++ b/src/optim/DeepSAD_trainer.py @@ -28,11 +28,12 @@ def __init__(self, c, eta: float, optimizer_name: str = 'adam', lr: float = 0.00 # Results self.train_time = None + self.train_loss = None self.test_auc = None self.test_time = None self.test_scores = None - def train(self, dataset: BaseADDataset, net: BaseNet): + def train(self, dataset: BaseADDataset, net: BaseNet, validate: bool = False): logger = logging.getLogger() # Get train data loader @@ -45,25 +46,26 @@ def train(self, dataset: BaseADDataset, net: BaseNet): optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # Set learning rate scheduler - scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1) + self.scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1) # Initialize hypersphere center c (if c not loaded) if self.c is None: logger.info('Initializing center c...') self.c = self.init_center_c(train_loader, net) - logger.info('Center c initialized.') + logger.info('Center c initialized to {}.'.format(self.c)) # Training logger.info('Starting training...') start_time = time.time() net.train() + self.train_loss = [] for epoch in range(self.n_epochs): - scheduler.step() + self.scheduler.step() if epoch in self.lr_milestones: - logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0])) + logger.info(' LR scheduler: new learning rate is %g' % float(self.scheduler.get_lr()[0])) - epoch_loss = 0.0 + train_epoch_loss = 0.0 n_batches = 0 epoch_start_time = time.time() for data in train_loader: @@ -81,13 +83,40 @@ def train(self, dataset: BaseADDataset, net: BaseNet): loss.backward() optimizer.step() - epoch_loss += loss.item() + train_epoch_loss += loss.item() n_batches += 1 + train_loss = train_epoch_loss/n_batches + epoch_loss_history = (epoch + 1, train_loss) + + if validate: + n_batches = 0 + valid_epoch_loss = 0.0 + valid_loader = dataset.validation_loader(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader) + with torch.set_grad_enabled(False): + for data in valid_loader: + inputs, _, semi_targets, _ = data + inputs, semi_targets = inputs.to(self.device), semi_targets.to(self.device) + + outputs = net(inputs) + dist = torch.sum((outputs - self.c) ** 2, dim=1) + losses = torch.where(semi_targets == 0, dist, self.eta * ((dist + self.eps) ** semi_targets.float())) + loss = torch.mean(losses) + + valid_epoch_loss += loss.item() + n_batches += 1 + valid_loss = valid_epoch_loss/n_batches + epoch_loss_history = (epoch + 1, train_loss, valid_loss) + + self.train_loss.append(epoch_loss_history) # log epoch statistics epoch_train_time = time.time() - epoch_start_time - logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s ' - f'| Train Loss: {epoch_loss / n_batches:.6f} |') + + stats = f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s ' \ + f'| Train Loss: {train_loss:.6f}' + if validate: + stats = stats + f' | Valid Loss: {valid_loss:.6f}' + logger.info(stats) self.train_time = time.time() - start_time logger.info('Training Time: {:.3f}s'.format(self.train_time))