-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathjax_mcts.yml
48 lines (42 loc) · 1.18 KB
/
jax_mcts.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
wandb:
use: true
work_dir: runs
layerwise_logging: true
log_interval: 200
validation_interval: 1250
use_cuda: true
scale_observation: false
reposition: false
gumbel_scale: 0.3
version_string: 'mcts'
net_type: custom_dense
dimension: 3
max_num_points: 5
max_length_game: 40
max_value: 20
max_grad_norm: 0.5
eval_batch_size: 512
num_evaluations: 100
num_evaluations_as_opponent: 100
eval_on_cpu: false
max_num_considered_actions: 10
#rollout process repeats `rollout_size//eval_batch_size` amount of times
discount: 0.98
host:
batch_size: 128
optim:
name: 'adam'
args: # Pass optimizer parameters here
learning_rate: 0.001 #
# Learning rate scheduling is built-in (by schedule function and `chain`), so we do not specify them.
#net_arch: [256, 256, 256, 256, 256, 256, 256, 256]
net_arch: [128, 128, 128, 128, 128, 128, 128, 128]
agent:
batch_size: 128
optim:
name: 'adam'
args: # Pass optimizer parameters here
learning_rate: 0.001
# Learning rate scheduling is built-in (by schedule function and `chain`), so we do not specify them.
#net_arch: [256, 256, 256, 256, 256, 256, 256, 256]
net_arch: [128, 128, 128, 128, 128, 128, 128, 128]