Skip to content

Commit

Permalink
Change/bert encoder public (#2658)
Browse files Browse the repository at this point in the history
* change: BertEncoder struct to public

* change: make certain fields in Config struct public

* change: all fields in bert config struct to be public

* change: add clone to bert encoder and others

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
  • Loading branch information
singjc and LaurentMazare authored Dec 4, 2024
1 parent 145aa71 commit 1807be8
Showing 1 changed file with 30 additions and 21 deletions.
51 changes: 30 additions & 21 deletions candle-transformers/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub enum HiddenAct {
Relu,
}

#[derive(Clone)]
struct HiddenActLayer {
act: HiddenAct,
span: tracing::Span,
Expand All @@ -46,32 +47,32 @@ impl HiddenActLayer {

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
pub enum PositionEmbeddingType {
#[default]
Absolute,
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
hidden_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub hidden_act: HiddenAct,
hidden_dropout_prob: f64,
max_position_embeddings: usize,
type_vocab_size: usize,
initializer_range: f64,
layer_norm_eps: f64,
pad_token_id: usize,
pub hidden_dropout_prob: f64,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub initializer_range: f64,
pub layer_norm_eps: f64,
pub pad_token_id: usize,
#[serde(default)]
position_embedding_type: PositionEmbeddingType,
pub position_embedding_type: PositionEmbeddingType,
#[serde(default)]
use_cache: bool,
classifier_dropout: Option<f64>,
model_type: Option<String>,
pub use_cache: bool,
pub classifier_dropout: Option<f64>,
pub model_type: Option<String>,
}

impl Default for Config {
Expand Down Expand Up @@ -121,6 +122,7 @@ impl Config {
}
}

#[derive(Clone)]
struct Dropout {
#[allow(dead_code)]
pr: f64,
Expand Down Expand Up @@ -199,6 +201,7 @@ impl BertEmbeddings {
}
}

#[derive(Clone)]
struct BertSelfAttention {
query: Linear,
key: Linear,
Expand Down Expand Up @@ -266,6 +269,7 @@ impl BertSelfAttention {
}
}

#[derive(Clone)]
struct BertSelfOutput {
dense: Linear,
layer_norm: LayerNorm,
Expand Down Expand Up @@ -299,6 +303,7 @@ impl BertSelfOutput {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
#[derive(Clone)]
struct BertAttention {
self_attention: BertSelfAttention,
self_output: BertSelfOutput,
Expand All @@ -325,6 +330,7 @@ impl BertAttention {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
#[derive(Clone)]
struct BertIntermediate {
dense: Linear,
intermediate_act: HiddenActLayer,
Expand Down Expand Up @@ -352,6 +358,7 @@ impl Module for BertIntermediate {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
#[derive(Clone)]
struct BertOutput {
dense: Linear,
layer_norm: LayerNorm,
Expand Down Expand Up @@ -385,7 +392,8 @@ impl BertOutput {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
struct BertLayer {
#[derive(Clone)]
pub struct BertLayer {
attention: BertAttention,
intermediate: BertIntermediate,
output: BertOutput,
Expand Down Expand Up @@ -420,21 +428,22 @@ impl BertLayer {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
struct BertEncoder {
layers: Vec<BertLayer>,
#[derive(Clone)]
pub struct BertEncoder {
pub layers: Vec<BertLayer>,
span: tracing::Span,
}

impl BertEncoder {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
}

fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
pub fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
Expand Down

0 comments on commit 1807be8

Please sign in to comment.