diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 6734ab1f85..cc9273ca9a 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -109,6 +109,14 @@ impl ConvTranspose1d { pub fn config(&self) -> &ConvTranspose1dConfig { &self.config } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl crate::Module for ConvTranspose1d { @@ -258,6 +266,14 @@ impl ConvTranspose2d { pub fn config(&self) -> &ConvTranspose2dConfig { &self.config } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl crate::Module for ConvTranspose2d {