Skip to content

Commit 0c3d179

Browse files
committed
Update readme, small simplifications
1 parent 593a864 commit 0c3d179

File tree

5 files changed

+120
-58
lines changed

5 files changed

+120
-58
lines changed

.gitignore

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Created by .ignore support plugin (hsz.mobi)
2+
### Python template
3+
# Byte-compiled / optimized / DLL files
4+
__pycache__/
5+
*.py[cod]
6+
*$py.class
7+
8+
# C extensions
9+
*.so
10+
11+
# Distribution / packaging
12+
.Python
13+
build/
14+
develop-eggs/
15+
dist/
16+
downloads/
17+
eggs/
18+
.eggs/
19+
lib/
20+
lib64/
21+
parts/
22+
sdist/
23+
var/
24+
wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
.hypothesis/
50+
.pytest_cache/
51+
52+
# Translations
53+
*.mo
54+
*.pot
55+
56+
# Django stuff:
57+
*.log
58+
local_settings.py
59+
db.sqlite3
60+
61+
# Flask stuff:
62+
instance/
63+
.webassets-cache
64+
65+
# Scrapy stuff:
66+
.scrapy
67+
68+
# Sphinx documentation
69+
docs/_build/
70+
71+
# PyBuilder
72+
target/
73+
74+
# Jupyter Notebook
75+
.ipynb_checkpoints
76+
77+
# pyenv
78+
.python-version
79+
80+
# celery beat schedule file
81+
celerybeat-schedule
82+
83+
# SageMath parsed files
84+
*.sage.py
85+
86+
# Environments
87+
.env
88+
.venv
89+
env/
90+
venv/
91+
ENV/
92+
env.bak/
93+
venv.bak/
94+
95+
# Spyder project settings
96+
.spyderproject
97+
.spyproject
98+
99+
# Rope project settings
100+
.ropeproject
101+
102+
# mkdocs documentation
103+
/site
104+
105+
# mypy
106+
.mypy_cache/
107+
108+
.idea/

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ This is a TensorFlow based implementation for our [paper on large-scale study of
2525
}
2626

2727
### Installation and Usage
28-
Stay tuned! To be released soon.
28+
The following command should train a pure exploration agent on Breakout with default experiment parameters.
29+
```bash
30+
python run.py
31+
```
32+
To use more than one gpu/machine, use MPI (e.g. `mpiexec -n 8 python run.py` should use 1024 parallel environments to collect experience instead of the default 128 on an 8 gpu machine).
2933

3034
### Other helpful pointers
3135
- [Paper](https://pathak22.github.io/large-scale-curiosity/resources/largeScaleCuriosity2018.pdf)

run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@
44
except:
55
print("no OpenGL.GLU")
66
import functools
7-
import gym
87
import os.path as osp
8+
from functools import partial
9+
10+
import gym
911
import tensorflow as tf
1012
from baselines import logger
1113
from baselines.bench import Monitor
1214
from baselines.common.atari_wrappers import NoopResetEnv, FrameStack
13-
from functools import partial
1415
from mpi4py import MPI
1516

1617
from auxiliary_tasks import FeatureExtractor, InverseDynamics, VAE, JustPixels
1718
from cnn_policy import CnnPolicy
1819
from cppo_agent import PpoOptimizer
1920
from dynamics import Dynamics, UNet
20-
from utils import random_agent_ob_mean_std, save_exp_details
21+
from utils import random_agent_ob_mean_std
2122
from wrappers import MontezumaInfoWrapper, make_mario_env, make_robo_pong, make_robo_hockey, \
2223
make_multi_pong, AddRandomStateToInfo, MaxAndSkipEnv, ProcessFrame84, ExtraTimeLimit
2324

@@ -32,7 +33,6 @@ def start_experiment(**args):
3233
with log, tf_sess:
3334
logdir = logger.get_dir()
3435
print("results will be saved to ", logdir)
35-
save_exp_details(filename=__file__, savedir=logdir, args=args)
3636
trainer.train()
3737

3838

utils.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import multiprocessing
2-
import numpy as np
32
import os
4-
import pickle
53
import platform
6-
import subprocess
7-
import sys
4+
from functools import partial
5+
6+
import numpy as np
87
import tensorflow as tf
98
from baselines.common.tf_util import normc_initializer
10-
from functools import partial
119
from mpi4py import MPI
1210

1311

@@ -227,24 +225,3 @@ def row(i):
227225

228226
return np.concatenate([row(i) for i in range(n_rows)], axis=0)
229227

230-
231-
def save_exp_details(filename, savedir, args):
232-
source_dirname = os.path.dirname(os.path.abspath(filename))
233-
git_hash = subprocess.run("git log --pretty=format:%H -n 1".split(' '), cwd=source_dirname,
234-
stdout=subprocess.PIPE).stdout.decode('utf-8')
235-
git_diff = subprocess.run("git diff {} --full-index".format(git_hash).split(' '), cwd=source_dirname,
236-
stdout=subprocess.PIPE).stdout.decode('utf-8')
237-
ordered_arg_names = sorted(list(args.keys()))
238-
sorted_args = [(k, args[k]) for k in ordered_arg_names]
239-
240-
rank_zero_savedir = savedir if MPI.COMM_WORLD.Get_rank() == 0 else None
241-
rank_zero_savedir = MPI.COMM_WORLD.bcast(rank_zero_savedir, root=0)
242-
243-
with open(os.path.join(savedir, "exp_details.pkl"), 'wb') as f:
244-
obj = {'git_hash': git_hash,
245-
'git_diff': git_diff,
246-
'args': sorted_args,
247-
'argv': ' '.join(sys.argv),
248-
'name': args['exp_name'],
249-
'rank_zero_savedir': rank_zero_savedir}
250-
pickle.dump(obj, f, protocol=0)

wrappers.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -119,40 +119,13 @@ class MontezumaInfoWrapper(gym.Wrapper):
119119
ram_map = {
120120
"room": dict(
121121
index=3,
122-
values=range(24),
123-
value_type="range",
124122
),
125123
"x": dict(
126124
index=42,
127-
values=range(0, 152),
128-
value_type="range",
129125
),
130126
"y": dict(
131127
index=43,
132-
values=range(148, 256),
133-
value_type="range",
134128
),
135-
# "objects": dict(
136-
# index=67,
137-
# values=range(16, 32),
138-
# value_type="categorical",
139-
# ), # 1st level: doors, skeleton, key
140-
# "skeleton_location": dict(
141-
# index=47,
142-
# values=range(20, 80), # not exactly the min/max, but good enough
143-
# value_type="range",
144-
# ),
145-
# "beam_wall": dict(
146-
# index=27,
147-
# values=[253, 209],
148-
# value_type="categorical",
149-
# meanings=["off", "on"]
150-
# ),
151-
# "beam_countdown": dict(
152-
# index=83,
153-
# values=range(37),
154-
# value_type="range",
155-
# ),
156129
}
157130

158131
def __init__(self, env):

0 commit comments

Comments
 (0)