-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathExperiment.py
executable file
·69 lines (55 loc) · 2.38 KB
/
Experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python
from os import path
import numpy as np
import multiprocessing as mp
from lunarlander.framework import Framework
from lunarlander.simulator import LunarLanderSimulator
from lunarlander.policygradientagent import PolicyGradientAgent
import matplotlib.pyplot as plt
def run_experiment(num_episodes, Lambda, alpha, twe, trunc_normal, subspaces):
np.random.seed()
simulator = LunarLanderSimulator()
agent = PolicyGradientAgent (simulator,
Lambda=Lambda, alpha_u=alpha, alpha_v=alpha,
tile_weight_exponent=twe,
trunc_normal=trunc_normal,
subspaces=subspaces)
framework = Framework(simulator, agent)
return np.array([framework.run_episode() for _ in range(num_episodes)])
def run_experiments(experiments):
ctx = mp.get_context('spawn')
with ctx.Pool() as pool:
promises = {name: [pool.apply_async(run_experiment, (ex['num_episodes'],), ex['params'])
for _ in range(ex['num_runs'])]
for (name, ex) in experiments.items()}
results = {name: np.vstack([p.get() for p in ps]) for (name, ps) in promises.items()}
return results
def make_plot(results):
for (name, returns) in results.items():
p = experiments[name]['params']
label = r'$\lambda={}, \alpha={}$'.format(p['Lambda'], p['alpha'])
plt.plot(returns.mean(axis=0).cumsum(), label=label)
plt.legend (loc='lower left')
plt.show()
experiments = {
'weighted_trunc_normal': {
'params': {'Lambda':0.75, 'alpha':0.1, 'twe':0.5, 'trunc_normal':True, 'subspaces':[1,2,6]},
'num_runs':3, 'num_episodes':20000
},
'lambda_0.5_weighted_trunc_normal': {
'params': {'Lambda':0.5, 'alpha':0.1, 'twe':0.5, 'trunc_normal':True, 'subspaces':[1,2,6]},
'num_runs':3, 'num_episodes':20000
},
'lambda_0.9_weighted_trunc_normal': {
'params': {'Lambda':0.9, 'alpha':0.1, 'twe':0.5, 'trunc_normal':True, 'subspaces':[1,2,6]},
'num_runs':3, 'num_episodes':20000
}
}
if __name__ == "__main__":
filename = 'data/experiment.npz'
if not path.exists(filename):
results = run_experiments(experiments)
np.savez_compressed(filename, **results)
else:
results = np.load(filename)
make_plot(results)