diff --git a/deepfake_detection.py b/deepfake_detection.py index aa8ea08..c822434 100644 --- a/deepfake_detection.py +++ b/deepfake_detection.py @@ -66,7 +66,7 @@ test_loader = torch.utils.data.DataLoader(test_set,batch_size=32,shuffle =True, num_workers=1) -model=fen.DnCNN().to(device) +model=fen.DnCNN(num_layers=31).to(device) model_params = list(model.parameters()) optimizer = torch.optim.Adam(model_params, lr=opt.lr)