Skip to content

keras JointDistribution minimal example not working #2000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zippeurfou opened this issue Apr 30, 2025 · 0 comments
Open

keras JointDistribution minimal example not working #2000

zippeurfou opened this issue Apr 30, 2025 · 0 comments

Comments

@zippeurfou
Copy link

zippeurfou commented Apr 30, 2025

Hi,
I am trying to train a keras model using a mix of JointDistribution and other layers. My goal is to be able to build the model, compile and use model.fit on data.
Somehow, I keep getting errors and can't get it working.
I could not find a minimal example with it running and I think I am missing something.
Here is the minimal example I tried:

import tensorflow as tf
import tensorflow_probability as tfp
import tf_keras as tfk  
import numpy as np

tfd = tfp.distributions
tfpl = tfp.layers

# Generate synthetic data
np.random.seed(42)
x = np.random.randn(2000, 1).astype(np.float32)
true_w, true_b = 2.0, -1.0
y = true_w * x + true_b + np.random.randn(2000, 1).astype(np.float32) * 0.5

# Define the JointDistributionNamed model
def make_joint(x):
    return tfd.JointDistributionNamed({
        "weight": tfd.Normal(loc=0., scale=1.),
        "bias": tfd.Normal(loc=0., scale=1.),
        "obs": lambda weight, bias: tfd.Independent(
            tfd.Normal(loc=weight * x + bias, scale=0.5),
            reinterpreted_batch_ndims=1
        )
    })

# Keras model using DistributionLambda
inputs = tfk.Input(shape=(1,))
dense = tfk.layers.Dense(2)(inputs)  # Not used for parameterization here, just as a placeholder

def posterior_fn(params):
    # For demonstration, use fixed x from outer scope
    jd = make_joint(x)
    return jd

outputs = tfpl.DistributionLambda(posterior_fn)(dense)
model = tfk.Model(inputs=inputs, outputs=outputs)

# Custom loss: negative log-likelihood of observed y under the joint
def nll(y_true, y_pred):
    # y_pred is a JointDistributionNamed, so evaluate log_prob of obs
    return -y_pred.log_prob({"obs": y_true})

model.compile(optimizer=tf.optimizers.Adam(0.01), loss=nll)
model.fit(x, y, epochs=5, batch_size=32)

And I got this error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-1-e7f1317b9e9b>](https://localhost:8080/#) in <cell line: 0>()
     33     return jd
     34 
---> 35 outputs = tfpl.DistributionLambda(posterior_fn)(dense)
     36 model = tfk.Model(inputs=inputs, outputs=outputs)
     37 

3 frames
[/usr/local/lib/python3.11/dist-packages/tensorflow_probability/python/layers/distribution_layer.py](https://localhost:8080/#) in _fn(*fargs, **fkwargs)
    183       # TODO(b/126056144): Remove silent handle once we identify how/why Keras
    184       # is losing the distribution handle for activity_regularizer.
--> 185       value._tfp_distribution = distribution  # pylint: disable=protected-access
    186       # TODO(b/120153609): Keras is incorrectly presuming everything is a
    187       # `tf.Tensor`. Closing this bug entails ensuring Keras only accesses

AttributeError: Exception encountered when calling layer "distribution_lambda" (type DistributionLambda).

'dict' object has no attribute '_tfp_distribution'

Call arguments received by layer "distribution_lambda" (type DistributionLambda):
  • inputs=tf.Tensor(shape=(None, 2), dtype=float32)
  • args=<class 'inspect._empty'>kwargs={'training': 'None'}

I am posting it here because I feel this might be a bug.

You can reproduce it via this colab link

Please note I also tried with tfd.JointDistributionSequential and no luck either.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant