From 1807be84f4d9e388b19710a9282eb6501ce55f80 Mon Sep 17 00:00:00 2001 From: Justin Sing <32938975+singjc@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:22:30 -0500 Subject: [PATCH] Change/bert encoder public (#2658) * change: BertEncoder struct to public * change: make certain fields in Config struct public * change: all fields in bert config struct to be public * change: add clone to bert encoder and others * Clippy fix. --------- Co-authored-by: Laurent --- candle-transformers/src/models/bert.rs | 51 +++++++++++++++----------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index da8734160a..0ff62c4f3e 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -22,6 +22,7 @@ pub enum HiddenAct { Relu, } +#[derive(Clone)] struct HiddenActLayer { act: HiddenAct, span: tracing::Span, @@ -46,7 +47,7 @@ impl HiddenActLayer { #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[serde(rename_all = "lowercase")] -enum PositionEmbeddingType { +pub enum PositionEmbeddingType { #[default] Absolute, } @@ -54,24 +55,24 @@ enum PositionEmbeddingType { // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { - vocab_size: usize, - hidden_size: usize, - num_hidden_layers: usize, - num_attention_heads: usize, - intermediate_size: usize, + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, pub hidden_act: HiddenAct, - hidden_dropout_prob: f64, - max_position_embeddings: usize, - type_vocab_size: usize, - initializer_range: f64, - layer_norm_eps: f64, - pad_token_id: usize, + pub hidden_dropout_prob: f64, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub initializer_range: f64, + pub layer_norm_eps: f64, + pub pad_token_id: usize, #[serde(default)] - position_embedding_type: PositionEmbeddingType, + pub position_embedding_type: PositionEmbeddingType, #[serde(default)] - use_cache: bool, - classifier_dropout: Option, - model_type: Option, + pub use_cache: bool, + pub classifier_dropout: Option, + pub model_type: Option, } impl Default for Config { @@ -121,6 +122,7 @@ impl Config { } } +#[derive(Clone)] struct Dropout { #[allow(dead_code)] pr: f64, @@ -199,6 +201,7 @@ impl BertEmbeddings { } } +#[derive(Clone)] struct BertSelfAttention { query: Linear, key: Linear, @@ -266,6 +269,7 @@ impl BertSelfAttention { } } +#[derive(Clone)] struct BertSelfOutput { dense: Linear, layer_norm: LayerNorm, @@ -299,6 +303,7 @@ impl BertSelfOutput { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 +#[derive(Clone)] struct BertAttention { self_attention: BertSelfAttention, self_output: BertSelfOutput, @@ -325,6 +330,7 @@ impl BertAttention { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 +#[derive(Clone)] struct BertIntermediate { dense: Linear, intermediate_act: HiddenActLayer, @@ -352,6 +358,7 @@ impl Module for BertIntermediate { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 +#[derive(Clone)] struct BertOutput { dense: Linear, layer_norm: LayerNorm, @@ -385,7 +392,8 @@ impl BertOutput { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 -struct BertLayer { +#[derive(Clone)] +pub struct BertLayer { attention: BertAttention, intermediate: BertIntermediate, output: BertOutput, @@ -420,13 +428,14 @@ impl BertLayer { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 -struct BertEncoder { - layers: Vec, +#[derive(Clone)] +pub struct BertEncoder { + pub layers: Vec, span: tracing::Span, } impl BertEncoder { - fn load(vb: VarBuilder, config: &Config) -> Result { + pub fn load(vb: VarBuilder, config: &Config) -> Result { let layers = (0..config.num_hidden_layers) .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config)) .collect::>>()?; @@ -434,7 +443,7 @@ impl BertEncoder { Ok(BertEncoder { layers, span }) } - fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + pub fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { let _enter = self.span.enter(); let mut hidden_states = hidden_states.clone(); // Use a loop rather than a fold as it's easier to modify when adding debug/...