-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtorch.toml
74 lines (62 loc) · 1.49 KB
/
torch.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
[base]
cuda_availability = "torch.cuda.is_available()"
accelerator = 'cuda'
matmul_precision = 'medium'
[model]
strategy = "FSDPStrategy(sharding_strategy=ShardingStrategy.NO_SHARD)"
loss = "torch.nn.BCELoss()"
metrics = "[]" # "[torch.nn.MSELoss()]"
[model.dir]
RUN_DIR = "os.path.join(EXPS_DIR, _today)"
CHECKPOINTS_DIR = "os.path.join(RUN_DIR, 'checkpoints')"
SKIP_DAYS_DIRNAME = "skip_days_00"
[model.unetpp]
cls = 'Fires._models.unetpp.UnetPlusPlus'
[model.unetpp.args]
input_shape = [720, 1440, 8]
num_classes = 1
depth = 4
base_filter_dim = 64
deep_supervision = false
[model.earthformer]
cls = "Fires._models.pangu.EarthTransformer3D"
[model.earthformer.args]
in_channels=8
out_channels=1
shape=[720, 1440]
patch_size=4
win_size=6
depth=[2, 5]
heads=[5, 10]
attention_dim=32
num_wins="None"
activation= "torch.nn.ReLU()" #"torch.nn.Identity()"
[model.swintransformer]
cls = "Fires._models.swin_transformer.SwinTransformerEarth2D"
[model.swintransformer.args]
in_channels=8
out_channels=1
shape=[720, 1440]
patch_size=4
win_size=6
depth=[2, 6]
heads=[6, 12]
attention_dim=32
[trainer]
accumulation_steps = 1
devices = 2
epochs = 75
num_nodes = 4
precision = '32-true'
batch_size = 2
drop_reminder = true
plugins = "[MPIEnvironment()]"
use_distributed_sampler = "False"
[trainer.optim]
cls = "torch.optim.Adam"
args = "dict(lr=1e-4, betas=[0.9, 0.999])"
[trainer.scheduler]
cls = "None"
args = "None"
[trainer.checkpoint]
ckpt = "None"