-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathmodels.py
52 lines (41 loc) · 1.7 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from attention import *
from torch.nn import CrossEntropyLoss
class Encoder(nn.Module):
def __init__(self, layer, N, d_model, vocab_size, word_embed):
super(Encoder, self).__init__()
self.word_embed = word_embed
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
self.proj = nn.Linear(d_model, vocab_size)
def forward(self, inputs, mask):
break_probs = []
x = self.word_embed(inputs)
group_prob = 0.
for layer in self.layers:
x,group_prob,break_prob = layer(x, mask,group_prob)
break_probs.append(break_prob)
x = self.norm(x)
break_probs = torch.stack(break_probs, dim=1)
return self.proj(x),break_probs
def masked_lm_loss(self, out, y):
fn = CrossEntropyLoss(ignore_index=-1)
return fn(out.view(-1, out.size()[-1]), y.view(-1))
def next_sentence_loss(self):
pass
class EncoderLayer(nn.Module):
"Encoder is made up of self-attn and feed forward (defined below)"
def __init__(self, size, self_attn, feed_forward, group_attn, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.group_attn = group_attn
self.sublayer = clones(SublayerConnection(size, dropout), 2)
self.size = size
def forward(self, x, mask, group_prob):
group_prob,break_prob = self.group_attn(x, mask, group_prob)
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, group_prob, mask))
return self.sublayer[1](x, self.feed_forward), group_prob, break_prob