Skip to content

stateful=True RNN silently broadcasts input when batch_size mismatch occurs #21183

Open
@El3ssar

Description

@El3ssar

When using a Keras SimpleRNN layer with stateful=True, the documentation states that a fixed batch size is required. However, when a model is built with batch_size=N (e.g. 4), and later called with an input of batch_size=1, no error is raised. Instead, the RNN broadcasts the single input trajectory across all internal state slots, silently overwriting the state for all batches.

This violates the expected behavior of stateful=True, where the batch size must remain fixed, and state slot i must map to input sample i.

The issue is not documented, leads to incorrect behavior, and can silently corrupt stateful models like ESNs or any RNN with memory across batches.

Standalone code to reproduce:

import tensorflow as tf
from tensorflow.keras import layers, Input, Model

# Build stateful RNN with batch_size=4
inputs = Input(shape=(5, 3), batch_size=4)
rnn = layers.SimpleRNN(10, stateful=True, return_sequences=False, name="rnn")
x = rnn(inputs)
model = Model(inputs, x)

# Manually set initial state to distinguishable values
state = rnn.states[0]
for i in range(4):
    state[i].assign(tf.ones_like(state[i]) * i)

print("Initial state:")
print(state.numpy())

# Call model with a different batch size (1)
print("\nCalling model with input of batch size 1...")
_ = model(tf.random.normal((1, 5, 3)))

# Print new state
print("\nState after call:")
print(state.numpy())

Expected behaviour

An exception should be raised when the input batch size does not match the model's fixed batch_size, particularly when stateful=True.

Additional Info

  • Tensorflow version: 2.19.0
  • Keras version: 3.9.2
  • OS: Kde-Neon 6.3 (Ubuntu 24.04 based)
  • GPU: NVIDIA GeForce RTX 3060

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions