Skip to content

Trainable variables created by tfp.experimental.vi.util.build_trainable_linear_operator_block are lost after the resulting bijector is wrapped in a tfb.Chain #1997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
bwalker1 opened this issue Apr 8, 2025 · 0 comments

Comments

@bwalker1
Copy link

bwalker1 commented Apr 8, 2025

Creating a linear operator via the function tfp.experimental.vi.util.build_trainable_linear_operator_block and then plugging it into tfb.ScaleMatvecLinearOperatorBlock produces a bijector with trainable variables. If this bijector is then put inside a tfb.Chain(), trainable variables are no longer found by the reflection. The same is not true if the linear operator is created manually. My usage of tfp.experimental.vi.util.build_trainable_linear_operator_block is based on the tutorial Variational_Inference_and_Joint_Distributions

import tensorflow as tf
import tensorflow_probability as tfp
tfb = tfp.bijectors
print(tf.__version__)
# 2.18.0
print(tfp.__version__)
# 0.25.0


# Broken example (should print True)
operators = ((tf.linalg.LinearOperatorDiag,),)
block_tril_linop = tfp.experimental.vi.util.build_trainable_linear_operator_block(
    operators, (1,)
)
scale_bijector = tfb.ScaleMatvecLinearOperatorBlock(block_tril_linop)
assert len(scale_bijector.trainable_variables) > 0
c = tfb.Chain([scale_bijector])

print(len(c.trainable_variables) > 0)
# False



# Working example
LO = tf.linalg.LinearOperatorBlockDiag([tf.linalg.LinearOperatorDiag(tf.Variable([1.0]))])
scale_bijector = tfb.ScaleMatvecLinearOperatorBlock(LO)
assert len(scale_bijector.trainable_variables) > 0
c = tfb.Chain([scale_bijector])

print(len(c.trainable_variables) > 0)
# True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant