|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -"""Networks for the PPO algorithm defined as recurrent cells.""" |
| 15 | +"""Network definitions for the PPO algorithm.""" |
16 | 16 |
|
17 | 17 | from __future__ import absolute_import
|
18 | 18 | from __future__ import division
|
19 | 19 | from __future__ import print_function
|
20 | 20 |
|
| 21 | +import collections |
| 22 | +import functools |
| 23 | +import operator |
| 24 | + |
21 | 25 | import tensorflow as tf
|
22 | 26 |
|
23 | 27 |
|
24 |
| -_MEAN_WEIGHTS_INITIALIZER = tf.contrib.layers.variance_scaling_initializer( |
25 |
| - factor=0.1) |
26 |
| -_LOGSTD_INITIALIZER = tf.random_normal_initializer(-1, 1e-10) |
| 28 | +NetworkOutput = collections.namedtuple( |
| 29 | + 'NetworkOutput', 'policy, mean, logstd, value, state') |
27 | 30 |
|
28 | 31 |
|
29 |
| -class ForwardGaussianPolicy(tf.contrib.rnn.RNNCell): |
| 32 | +def feed_forward_gaussian( |
| 33 | + config, action_size, observations, length, state=None): |
30 | 34 | """Independent feed forward networks for policy and value.
|
31 | 35 |
|
32 | 36 | The policy network outputs the mean action and the log standard deviation
|
33 | 37 | is learned as independent parameter vector.
|
34 |
| - """ |
35 | 38 |
|
36 |
| - def __init__( |
37 |
| - self, policy_layers, value_layers, action_size, |
38 |
| - mean_weights_initializer=_MEAN_WEIGHTS_INITIALIZER, |
39 |
| - logstd_initializer=_LOGSTD_INITIALIZER): |
40 |
| - self._policy_layers = policy_layers |
41 |
| - self._value_layers = value_layers |
42 |
| - self._action_size = action_size |
43 |
| - self._mean_weights_initializer = mean_weights_initializer |
44 |
| - self._logstd_initializer = logstd_initializer |
45 |
| - |
46 |
| - @property |
47 |
| - def state_size(self): |
48 |
| - unused_state_size = 1 |
49 |
| - return unused_state_size |
50 |
| - |
51 |
| - @property |
52 |
| - def output_size(self): |
53 |
| - return (self._action_size, self._action_size, tf.TensorShape([])) |
54 |
| - |
55 |
| - def __call__(self, observation, state): |
56 |
| - with tf.variable_scope('policy'): |
57 |
| - x = tf.contrib.layers.flatten(observation) |
58 |
| - for size in self._policy_layers: |
59 |
| - x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) |
60 |
| - mean = tf.contrib.layers.fully_connected( |
61 |
| - x, self._action_size, tf.tanh, |
62 |
| - weights_initializer=self._mean_weights_initializer) |
63 |
| - logstd = tf.get_variable( |
64 |
| - 'logstd', mean.shape[1:], tf.float32, self._logstd_initializer) |
65 |
| - logstd = tf.tile( |
66 |
| - logstd[None, ...], [tf.shape(mean)[0]] + [1] * logstd.shape.ndims) |
67 |
| - with tf.variable_scope('value'): |
68 |
| - x = tf.contrib.layers.flatten(observation) |
69 |
| - for size in self._value_layers: |
70 |
| - x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) |
71 |
| - value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0] |
72 |
| - return (mean, logstd, value), state |
73 |
| - |
74 |
| - |
75 |
| -class RecurrentGaussianPolicy(tf.contrib.rnn.RNNCell): |
| 39 | + Args: |
| 40 | + config: Configuration object. |
| 41 | + action_size: Length of the action vector. |
| 42 | + observations: Sequences of observations. |
| 43 | + length: Batch of sequence lengths. |
| 44 | + state: Batch of initial recurrent states. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + NetworkOutput tuple. |
| 48 | + """ |
| 49 | + mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer( |
| 50 | + factor=config.init_mean_factor) |
| 51 | + logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10) |
| 52 | + flat_observations = tf.reshape(observations, [ |
| 53 | + tf.shape(observations)[0], tf.shape(observations)[1], |
| 54 | + functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)]) |
| 55 | + with tf.variable_scope('policy'): |
| 56 | + x = flat_observations |
| 57 | + for size in config.policy_layers: |
| 58 | + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) |
| 59 | + mean = tf.contrib.layers.fully_connected( |
| 60 | + x, action_size, tf.tanh, |
| 61 | + weights_initializer=mean_weights_initializer) |
| 62 | + logstd = tf.tile(tf.get_variable( |
| 63 | + 'logstd', mean.shape[2:], tf.float32, logstd_initializer)[None, None], |
| 64 | + [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2)) |
| 65 | + with tf.variable_scope('value'): |
| 66 | + x = flat_observations |
| 67 | + for size in config.value_layers: |
| 68 | + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) |
| 69 | + value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0] |
| 70 | + mean = tf.check_numerics(mean, 'mean') |
| 71 | + logstd = tf.check_numerics(logstd, 'logstd') |
| 72 | + value = tf.check_numerics(value, 'value') |
| 73 | + policy = tf.contrib.distributions.MultivariateNormalDiag( |
| 74 | + mean, tf.exp(logstd)) |
| 75 | + return NetworkOutput(policy, mean, logstd, value, state) |
| 76 | + |
| 77 | + |
| 78 | +def recurrent_gaussian( |
| 79 | + config, action_size, observations, length, state=None): |
76 | 80 | """Independent recurrent policy and feed forward value networks.
|
77 | 81 |
|
78 | 82 | The policy network outputs the mean action and the log standard deviation
|
79 | 83 | is learned as independent parameter vector. The last policy layer is recurrent
|
80 | 84 | and uses a GRU cell.
|
81 |
| - """ |
82 | 85 |
|
83 |
| - def __init__( |
84 |
| - self, policy_layers, value_layers, action_size, |
85 |
| - mean_weights_initializer=_MEAN_WEIGHTS_INITIALIZER, |
86 |
| - logstd_initializer=_LOGSTD_INITIALIZER): |
87 |
| - self._policy_layers = policy_layers |
88 |
| - self._value_layers = value_layers |
89 |
| - self._action_size = action_size |
90 |
| - self._mean_weights_initializer = mean_weights_initializer |
91 |
| - self._logstd_initializer = logstd_initializer |
92 |
| - self._cell = tf.contrib.rnn.GRUBlockCell(100) |
93 |
| - |
94 |
| - @property |
95 |
| - def state_size(self): |
96 |
| - return self._cell.state_size |
97 |
| - |
98 |
| - @property |
99 |
| - def output_size(self): |
100 |
| - return (self._action_size, self._action_size, tf.TensorShape([])) |
101 |
| - |
102 |
| - def __call__(self, observation, state): |
103 |
| - with tf.variable_scope('policy'): |
104 |
| - x = tf.contrib.layers.flatten(observation) |
105 |
| - for size in self._policy_layers[:-1]: |
106 |
| - x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) |
107 |
| - x, state = self._cell(x, state) |
108 |
| - mean = tf.contrib.layers.fully_connected( |
109 |
| - x, self._action_size, tf.tanh, |
110 |
| - weights_initializer=self._mean_weights_initializer) |
111 |
| - logstd = tf.get_variable( |
112 |
| - 'logstd', mean.shape[1:], tf.float32, self._logstd_initializer) |
113 |
| - logstd = tf.tile( |
114 |
| - logstd[None, ...], [tf.shape(mean)[0]] + [1] * logstd.shape.ndims) |
115 |
| - with tf.variable_scope('value'): |
116 |
| - x = tf.contrib.layers.flatten(observation) |
117 |
| - for size in self._value_layers: |
118 |
| - x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) |
119 |
| - value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0] |
120 |
| - return (mean, logstd, value), state |
| 86 | + Args: |
| 87 | + config: Configuration object. |
| 88 | + action_size: Length of the action vector. |
| 89 | + observations: Sequences of observations. |
| 90 | + length: Batch of sequence lengths. |
| 91 | + state: Batch of initial recurrent states. |
| 92 | +
|
| 93 | + Returns: |
| 94 | + NetworkOutput tuple. |
| 95 | + """ |
| 96 | + mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer( |
| 97 | + factor=config.init_mean_factor) |
| 98 | + logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10) |
| 99 | + cell = tf.contrib.rnn.GRUBlockCell(config.policy_layers[-1]) |
| 100 | + flat_observations = tf.reshape(observations, [ |
| 101 | + tf.shape(observations)[0], tf.shape(observations)[1], |
| 102 | + functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)]) |
| 103 | + with tf.variable_scope('policy'): |
| 104 | + x = flat_observations |
| 105 | + for size in config.policy_layers[:-1]: |
| 106 | + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) |
| 107 | + x, state = tf.nn.dynamic_rnn(cell, x, length, state, tf.float32) |
| 108 | + mean = tf.contrib.layers.fully_connected( |
| 109 | + x, action_size, tf.tanh, |
| 110 | + weights_initializer=mean_weights_initializer) |
| 111 | + logstd = tf.tile(tf.get_variable( |
| 112 | + 'logstd', mean.shape[2:], tf.float32, logstd_initializer)[None, None], |
| 113 | + [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2)) |
| 114 | + with tf.variable_scope('value'): |
| 115 | + x = flat_observations |
| 116 | + for size in config.value_layers: |
| 117 | + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) |
| 118 | + value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0] |
| 119 | + mean = tf.check_numerics(mean, 'mean') |
| 120 | + logstd = tf.check_numerics(logstd, 'logstd') |
| 121 | + value = tf.check_numerics(value, 'value') |
| 122 | + policy = tf.contrib.distributions.MultivariateNormalDiag( |
| 123 | + mean, tf.exp(logstd)) |
| 124 | + # assert state.shape.as_list()[0] is not None |
| 125 | + return NetworkOutput(policy, mean, logstd, value, state) |
0 commit comments