Skip to content

Commit 9d6d8b5

Browse files
committed
Change network interface from RNNCell to function.
1 parent d8a2881 commit 9d6d8b5

File tree

10 files changed

+123
-182
lines changed

10 files changed

+123
-182
lines changed

README.md

+1-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ modifying the code:
6868
| File | Content |
6969
| ---- | ------- |
7070
| `scripts/configs.py` | Experiment configurations specifying the tasks and algorithms. |
71-
| `scripts/networks.py` | Neural network models defined as [TensorFlow RNNCells][tf-rnn-cell]. |
71+
| `scripts/networks.py` | Neural network models. |
7272
| `scripts/train.py` | The executable file containing the training setup. |
7373
| `ppo/algorithm.py` | The TensorFlow graph for the PPO algorithm. |
7474

@@ -80,8 +80,6 @@ python3 -m unittest discover -p "*_test.py"
8080

8181
For further questions, please open an issue on Github.
8282

83-
[tf-rnn-cell]: https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/RNNCell
84-
8583
Implementation
8684
--------------
8785

agents/ppo/algorithm.py

+25-44
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from __future__ import print_function
2424

2525
import collections
26+
import functools
2627

2728
import tensorflow as tf
2829

@@ -31,10 +32,6 @@
3132
from agents.ppo import utility
3233

3334

34-
_NetworkOutput = collections.namedtuple(
35-
'NetworkOutput', 'policy, mean, logstd, value, state')
36-
37-
3835
class PPOAlgorithm(object):
3936
"""A vectorized implementation of the PPO algorithm by John Schulman."""
4037

@@ -70,15 +67,25 @@ def __init__(self, batch_env, step, is_training, should_log, config):
7067
use_gpu = self._config.use_gpu and utility.available_gpus()
7168
with tf.device('/gpu:0' if use_gpu else '/cpu:0'):
7269
# Create network variables for later calls to reuse.
73-
self._network(
70+
action_size = self._batch_env.action.shape[1].value
71+
self._network = tf.make_template(
72+
'network', functools.partial(config.network, config, action_size))
73+
output = self._network(
7474
tf.zeros_like(self._batch_env.observ)[:, None],
75-
tf.ones(len(self._batch_env)), reuse=None)
76-
cell = self._config.network(self._batch_env.action.shape[1].value)
75+
tf.ones(len(self._batch_env)))
7776
with tf.variable_scope('ppo_temporary'):
7877
self._episodes = memory.EpisodeMemory(
7978
template, len(batch_env), config.max_length, 'episodes')
80-
self._last_state = utility.create_nested_vars(
81-
cell.zero_state(len(batch_env), tf.float32))
79+
if output.state is None:
80+
self._last_state = None
81+
else:
82+
# Ensure the batch dimension is set.
83+
tf.contrib.framework.nest.map_structure(
84+
lambda x: x.set_shape([len(batch_env)] + x.shape.as_list()[1:]),
85+
output.state)
86+
self._last_state = tf.contrib.framework.nest.map_structure(
87+
lambda x: tf.Variable(lambda: tf.zeros_like(x), False),
88+
output.state)
8289
self._last_action = tf.Variable(
8390
tf.zeros_like(self._batch_env.action), False, name='last_action')
8491
self._last_mean = tf.Variable(
@@ -102,7 +109,10 @@ def begin_episode(self, agent_indices):
102109
Summary tensor.
103110
"""
104111
with tf.name_scope('begin_episode/'):
105-
reset_state = utility.reinit_nested_vars(self._last_state, agent_indices)
112+
if self._last_state is None:
113+
reset_state = tf.no_op()
114+
else:
115+
reset_state = utility.reinit_nested_vars(self._last_state, agent_indices)
106116
reset_buffer = self._episodes.clear(agent_indices)
107117
with tf.control_dependencies([reset_state, reset_buffer]):
108118
return tf.constant('')
@@ -130,8 +140,12 @@ def perform(self, observ):
130140
tf.summary.histogram('action', action[:, 0]),
131141
tf.summary.histogram('logprob', logprob)]), str)
132142
# Remember current policy to append to memory in the experience callback.
143+
if self._last_state is None:
144+
assign_state = tf.no_op()
145+
else:
146+
assign_state = utility.assign_nested_vars(self._last_state, network.state)
133147
with tf.control_dependencies([
134-
utility.assign_nested_vars(self._last_state, network.state),
148+
assign_state,
135149
self._last_action.assign(action[:, 0]),
136150
self._last_mean.assign(network.mean[:, 0]),
137151
self._last_logstd.assign(network.logstd[:, 0])]):
@@ -523,36 +537,3 @@ def _mask(self, tensor, length):
523537
mask = tf.cast(range_[None, :] < length[:, None], tf.float32)
524538
masked = tensor * mask
525539
return tf.check_numerics(masked, 'masked')
526-
527-
def _network(self, observ, length=None, state=None, reuse=True):
528-
"""Compute the network output for a batched sequence of observations.
529-
530-
Optionally, the initial state can be specified. The weights should be
531-
reused for all calls, except for the first one. Output is a named tuple
532-
containing the policy as a TensorFlow distribution, the policy mean and log
533-
standard deviation, the approximated state value, and the new recurrent
534-
state.
535-
536-
Args:
537-
observ: Sequences of observations.
538-
length: Batch of sequence lengths.
539-
state: Batch of initial recurrent states.
540-
reuse: Python boolean whether to reuse previous variables.
541-
542-
Returns:
543-
NetworkOutput tuple.
544-
"""
545-
with tf.variable_scope('network', reuse=reuse):
546-
observ = tf.convert_to_tensor(observ)
547-
use_gpu = self._config.use_gpu and utility.available_gpus()
548-
with tf.device('/gpu:0' if use_gpu else '/cpu:0'):
549-
observ = tf.check_numerics(observ, 'observ')
550-
cell = self._config.network(self._batch_env.action.shape[1].value)
551-
(mean, logstd, value), state = tf.nn.dynamic_rnn(
552-
cell, observ, length, state, tf.float32, swap_memory=True)
553-
mean = tf.check_numerics(mean, 'mean')
554-
logstd = tf.check_numerics(logstd, 'logstd')
555-
value = tf.check_numerics(value, 'value')
556-
policy = tf.contrib.distributions.MultivariateNormalDiag(
557-
mean, tf.exp(logstd))
558-
return _NetworkOutput(policy, mean, logstd, value, state)

agents/ppo/utility.py

-14
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,6 @@
2626
from tensorflow.python.client import device_lib
2727

2828

29-
def create_nested_vars(tensors):
30-
"""Create variables matching a nested tuple of tensors.
31-
32-
Args:
33-
tensors: Nested tuple of list of tensors.
34-
35-
Returns:
36-
Nested tuple or list of variables.
37-
"""
38-
if isinstance(tensors, (tuple, list)):
39-
return type(tensors)(create_nested_vars(tensor) for tensor in tensors)
40-
return tf.Variable(tensors, False)
41-
42-
4329
def reinit_nested_vars(variables, indices=None):
4430
"""Reset all variables in a nested tuple to zeros.
4531

agents/scripts/configs.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ def default():
3232
eval_episodes = 25
3333
use_gpu = False
3434
# Network
35-
network = networks.ForwardGaussianPolicy
35+
network = networks.feed_forward_gaussian
3636
weight_summaries = dict(
37-
all=r'.*',
38-
policy=r'.*/policy/.*',
39-
value=r'.*/value/.*')
37+
all=r'.*', policy=r'.*/policy/.*', value=r'.*/value/.*')
4038
policy_layers = 200, 100
4139
value_layers = 200, 100
4240
init_mean_factor = 0.05

agents/scripts/networks.py

+90-85
Original file line numberDiff line numberDiff line change
@@ -12,109 +12,114 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Networks for the PPO algorithm defined as recurrent cells."""
15+
"""Network definitions for the PPO algorithm."""
1616

1717
from __future__ import absolute_import
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import collections
22+
import functools
23+
import operator
24+
2125
import tensorflow as tf
2226

2327

24-
_MEAN_WEIGHTS_INITIALIZER = tf.contrib.layers.variance_scaling_initializer(
25-
factor=0.1)
26-
_LOGSTD_INITIALIZER = tf.random_normal_initializer(-1, 1e-10)
28+
NetworkOutput = collections.namedtuple(
29+
'NetworkOutput', 'policy, mean, logstd, value, state')
2730

2831

29-
class ForwardGaussianPolicy(tf.contrib.rnn.RNNCell):
32+
def feed_forward_gaussian(
33+
config, action_size, observations, length, state=None):
3034
"""Independent feed forward networks for policy and value.
3135
3236
The policy network outputs the mean action and the log standard deviation
3337
is learned as independent parameter vector.
34-
"""
3538
36-
def __init__(
37-
self, policy_layers, value_layers, action_size,
38-
mean_weights_initializer=_MEAN_WEIGHTS_INITIALIZER,
39-
logstd_initializer=_LOGSTD_INITIALIZER):
40-
self._policy_layers = policy_layers
41-
self._value_layers = value_layers
42-
self._action_size = action_size
43-
self._mean_weights_initializer = mean_weights_initializer
44-
self._logstd_initializer = logstd_initializer
45-
46-
@property
47-
def state_size(self):
48-
unused_state_size = 1
49-
return unused_state_size
50-
51-
@property
52-
def output_size(self):
53-
return (self._action_size, self._action_size, tf.TensorShape([]))
54-
55-
def __call__(self, observation, state):
56-
with tf.variable_scope('policy'):
57-
x = tf.contrib.layers.flatten(observation)
58-
for size in self._policy_layers:
59-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
60-
mean = tf.contrib.layers.fully_connected(
61-
x, self._action_size, tf.tanh,
62-
weights_initializer=self._mean_weights_initializer)
63-
logstd = tf.get_variable(
64-
'logstd', mean.shape[1:], tf.float32, self._logstd_initializer)
65-
logstd = tf.tile(
66-
logstd[None, ...], [tf.shape(mean)[0]] + [1] * logstd.shape.ndims)
67-
with tf.variable_scope('value'):
68-
x = tf.contrib.layers.flatten(observation)
69-
for size in self._value_layers:
70-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
71-
value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0]
72-
return (mean, logstd, value), state
73-
74-
75-
class RecurrentGaussianPolicy(tf.contrib.rnn.RNNCell):
39+
Args:
40+
config: Configuration object.
41+
action_size: Length of the action vector.
42+
observations: Sequences of observations.
43+
length: Batch of sequence lengths.
44+
state: Batch of initial recurrent states.
45+
46+
Returns:
47+
NetworkOutput tuple.
48+
"""
49+
mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer(
50+
factor=config.init_mean_factor)
51+
logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10)
52+
flat_observations = tf.reshape(observations, [
53+
tf.shape(observations)[0], tf.shape(observations)[1],
54+
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
55+
with tf.variable_scope('policy'):
56+
x = flat_observations
57+
for size in config.policy_layers:
58+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
59+
mean = tf.contrib.layers.fully_connected(
60+
x, action_size, tf.tanh,
61+
weights_initializer=mean_weights_initializer)
62+
logstd = tf.tile(tf.get_variable(
63+
'logstd', mean.shape[2:], tf.float32, logstd_initializer)[None, None],
64+
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
65+
with tf.variable_scope('value'):
66+
x = flat_observations
67+
for size in config.value_layers:
68+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
69+
value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0]
70+
mean = tf.check_numerics(mean, 'mean')
71+
logstd = tf.check_numerics(logstd, 'logstd')
72+
value = tf.check_numerics(value, 'value')
73+
policy = tf.contrib.distributions.MultivariateNormalDiag(
74+
mean, tf.exp(logstd))
75+
return NetworkOutput(policy, mean, logstd, value, state)
76+
77+
78+
def recurrent_gaussian(
79+
config, action_size, observations, length, state=None):
7680
"""Independent recurrent policy and feed forward value networks.
7781
7882
The policy network outputs the mean action and the log standard deviation
7983
is learned as independent parameter vector. The last policy layer is recurrent
8084
and uses a GRU cell.
81-
"""
8285
83-
def __init__(
84-
self, policy_layers, value_layers, action_size,
85-
mean_weights_initializer=_MEAN_WEIGHTS_INITIALIZER,
86-
logstd_initializer=_LOGSTD_INITIALIZER):
87-
self._policy_layers = policy_layers
88-
self._value_layers = value_layers
89-
self._action_size = action_size
90-
self._mean_weights_initializer = mean_weights_initializer
91-
self._logstd_initializer = logstd_initializer
92-
self._cell = tf.contrib.rnn.GRUBlockCell(100)
93-
94-
@property
95-
def state_size(self):
96-
return self._cell.state_size
97-
98-
@property
99-
def output_size(self):
100-
return (self._action_size, self._action_size, tf.TensorShape([]))
101-
102-
def __call__(self, observation, state):
103-
with tf.variable_scope('policy'):
104-
x = tf.contrib.layers.flatten(observation)
105-
for size in self._policy_layers[:-1]:
106-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
107-
x, state = self._cell(x, state)
108-
mean = tf.contrib.layers.fully_connected(
109-
x, self._action_size, tf.tanh,
110-
weights_initializer=self._mean_weights_initializer)
111-
logstd = tf.get_variable(
112-
'logstd', mean.shape[1:], tf.float32, self._logstd_initializer)
113-
logstd = tf.tile(
114-
logstd[None, ...], [tf.shape(mean)[0]] + [1] * logstd.shape.ndims)
115-
with tf.variable_scope('value'):
116-
x = tf.contrib.layers.flatten(observation)
117-
for size in self._value_layers:
118-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
119-
value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0]
120-
return (mean, logstd, value), state
86+
Args:
87+
config: Configuration object.
88+
action_size: Length of the action vector.
89+
observations: Sequences of observations.
90+
length: Batch of sequence lengths.
91+
state: Batch of initial recurrent states.
92+
93+
Returns:
94+
NetworkOutput tuple.
95+
"""
96+
mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer(
97+
factor=config.init_mean_factor)
98+
logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10)
99+
cell = tf.contrib.rnn.GRUBlockCell(config.policy_layers[-1])
100+
flat_observations = tf.reshape(observations, [
101+
tf.shape(observations)[0], tf.shape(observations)[1],
102+
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
103+
with tf.variable_scope('policy'):
104+
x = flat_observations
105+
for size in config.policy_layers[:-1]:
106+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
107+
x, state = tf.nn.dynamic_rnn(cell, x, length, state, tf.float32)
108+
mean = tf.contrib.layers.fully_connected(
109+
x, action_size, tf.tanh,
110+
weights_initializer=mean_weights_initializer)
111+
logstd = tf.tile(tf.get_variable(
112+
'logstd', mean.shape[2:], tf.float32, logstd_initializer)[None, None],
113+
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
114+
with tf.variable_scope('value'):
115+
x = flat_observations
116+
for size in config.value_layers:
117+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
118+
value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0]
119+
mean = tf.check_numerics(mean, 'mean')
120+
logstd = tf.check_numerics(logstd, 'logstd')
121+
value = tf.check_numerics(value, 'value')
122+
policy = tf.contrib.distributions.MultivariateNormalDiag(
123+
mean, tf.exp(logstd))
124+
# assert state.shape.as_list()[0] is not None
125+
return NetworkOutput(policy, mean, logstd, value, state)

agents/scripts/train.py

-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ def train(config, env_processes):
101101
"""
102102
tf.reset_default_graph()
103103
with config.unlocked:
104-
config.network = functools.partial(
105-
utility.define_network, config.network, config)
106104
config.policy_optimizer = getattr(tf.train, config.policy_optimizer)
107105
config.value_optimizer = getattr(tf.train, config.value_optimizer)
108106
if config.update_every % config.num_agents:

0 commit comments

Comments
 (0)