当前位置: 首页 > news >正文

Transformer结构与代码实现详解

参考:
Transformer模型详解(图解最完整版) - 知乎https://zhuanlan.zhihu.com/p/338817680GitHub - liaoyanqing666/transformer_pytorch: 完整的原版transformer程序,complete origin transformer programhttps://github.com/liaoyanqing666/transformer_pytorcharxiv.org/pdf/1706.03762https://arxiv.org/pdf/1706.03762

一. Transformer的整体架构

Transformer 由 Encoder (编码器)和 Decoder (解码器)两个部分组成,Encoder 和 Decoder 都包含 6 个 block(块)。Transformer 的工作流程大体如下:

第一步:获取输入句子的每一个单词的表示向量 X,X由单词本身的 Embedding(Embedding就是从原始数据提取出来的特征(Feature)) 和单词位置的 Embedding 相加得到。

第二步:将得到的单词表示向量矩阵 (如上图所示,每一行是一个单词的表示 x)传入 Encoder 中,经过 6 个 Encoder block (编码器块)后可以得到句子所有单词的编码信息矩阵 C。如下图,单词向量矩阵用 X_{n\times d}表示, n 是句子中单词个数,d 是表示向量的维度(论文中 d=512)。每一个 Encoder block (编码器块)输出的矩阵维度与输入完全一致。

第三步:将 Encoder (编码器)输出的编码信息矩阵 C传递到 Decoder(解码器)中,Decoder(解码器) 依次会根据当前翻译过的单词 1~ i 翻译下一个单词 i+1,如下图所示。在使用的过程中,翻译到单词 i+1 的时候需要通过 Mask (掩盖) 操作遮盖住 i+1 之后的单词。

上图 Decoder 接收了 Encoder 的编码矩阵 C,然后首先输入一个翻译开始符 "<Begin>",预测第一个单词 "I";然后输入翻译开始符 "<Begin>" 和单词 "I",预测单词 "have",以此类推。

二. Transformer 的输入

Transformer 中单词的输入表示 单词本身的 Embedding 和单词位置 Embedding (Positional Encoding)相加得到。

2.1 单词 Embedding(词嵌入层)

单词本身的 Embedding 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到。

self.embedding = nn.Embedding(vocabulary, dim)

功能解释:

  1. 作用:将离散的整数索引(单词ID)转换为连续的向量表示

  2. 输入:形状为 [sequence_length] 的整数张量

  3. 输出:形状为 [sequence_length, dim] 的浮点数张量(X_{n\times d},n是序列长度,d是特征维度)

参数详解:

参数含义示例值说明
vocabulary词汇表大小10000表示模型能处理的不同单词/符号总数
dim嵌入维度512每个单词被表示成的向量长度

工作原理:

  1. 创建一个可学习的嵌入矩阵[vocabulary, dim],例如当 vocabulary=10000dim=512 时,是一个 10000×512 的矩阵;

  2. 每个整数索引对应矩阵中的一行:

# 假设单词"apple"的ID=42
apple_vector = embedding_matrix[42]  # 形状 [512]

在Transformer中的具体作用:

# 输入:src = torch.randint(0, 10000, (2, 10))
# 形状:[batch_size=2, seq_len=10]src_embedded = self.embedding(src)# 输出形状变为:[2, 10, 512]
# 每个整数单词ID被替换为512维的向量

可视化表现:

原始输入 (单词ID):
[ [ 25,  198, 3000, ... ],   # 句子1[ 1,   42,  999,  ... ] ]  # 句子2经过嵌入层后 (向量表示):
[ [ [0.2, -0.5, ..., 1.3],   # ID=25的向量[0.8, 0.1, ..., -0.9],   # ID=198的向量... ],[ [0.9, -0.2, ..., 0.4],   # ID=1的向量[0.3, 0.7, ..., -1.2],   # ID=42的向量... ] ]

为什么需要词嵌入:

  • 语义表示:相似的单词会有相似的向量表示

  • 降维:将离散的ID映射到连续空间(one-hot编码需要10000维 → 嵌入只需512维)

  • 可学习:在训练过程中,这些向量会不断调整以更好地表示语义关系

2.2  位置 Embedding(位置编码)

Transformer 的位置编码(Positional Encoding,PE)是模型的关键创新之一,它解决了传统序列模型(如 RNN)固有的顺序处理问题。Transformer 的自注意力机制本身不具备感知序列位置的能力,位置编码通过向输入嵌入添加位置信息,使模型能够理解序列中元素的顺序关系。位置编码计算之后的输出维度和词嵌入层相同,均为(X_{n\times d})。

位置编码的核心作用:

  1. 注入位置信息:让模型区分不同位置的相同单词(如 "bank" 在句首 vs 句尾)

  2. 保持距离关系:编码相对位置和绝对位置信息

  3. 支持并行计算:避免像 RNN 那样依赖顺序处理

为什么需要位置编码?

  1. 自注意力的位置不变性
    Attention(Q,K,V)=softmax\left ( \frac{QK^{T}}{\sqrt{d_k}} \right )V,计算过程不包含位置信息

  2. 序列顺序的重要性

  • 自然语言:"猫追狗" ≠ "狗追猫"
  • 时序数据:股价序列的顺序决定趋势替代方案对比
方法优点缺点
正弦/余弦泛化性好,理论保证固定模式不灵活
可学习适应任务特定模式长度受限,需训练
相对位置直接建模相对距离实现复杂

位置编码的实际效果

  1. 早期层作用:帮助模型建立位置感知

  2. 后期层作用:位置信息被融合到语义表示中

  3. 可视化示例

Input:    [The,   cat,   sat,   on,   mat]
Embed:    [E_The, E_cat, E_sat, E_on, E_mat]
Position: [P0,    P1,    P2,    P3,   P4]Final: [E_The+P0, E_cat+P1, ... E_mat+P4]
(1)正余弦位置编码(论文采用)

正余弦位置编码的计算公式:

其中:

  •  `pos` 是token在序列中的位置(从0开始)
  •  `d_model` 是模型的嵌入维度(即每个token的向量维度)
  •  `i` 是维度的索引(从0到d_model/2-1)

特点:

  • 波长几何级数:覆盖不同频率
  • 相对位置可学习:位置偏移的线性变换 PE_{pos+k} 可表示为 PE_{pos} 的线性函数
  • 泛化性强:可处理比训练时更长的序列
  • 对称性:sin/cos 组合允许模型学习相对位置

代码实现:

class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)return x + self.pe[:len, :]

例如,指定emb_dim=512和max_len=100,句子长度为10,则位置embedding的数值计算如下(三角函数取弧度制):

\begin{bmatrix} sin\left ( \frac{0}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{0}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{0}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{0}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{0}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{0}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{1}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{1}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{1}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{1}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{1}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{1}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{2}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{2}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{2}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{2}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{2}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{2}{10000^{\frac{510}{512}}} \right )\\ ... & ... & ... & ... & ... & ... & ...\\ sin\left ( \frac{7}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{7}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{7}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{7}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{7}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{7}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{8}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{8}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{8}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{8}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{8}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{8}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{9}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{9}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{9}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{9}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{9}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{9}{10000^{\frac{510}{512}}} \right )\\ \end{bmatrix}_{10\times 512}=\begin{bmatrix} 0 & 1 & 0 & ... & 1 & 0 & 1\\ 0.8415 & 0.5403 & 0.8219 & ... & 1.0000 & 1.0366\times 10^{-4} & 1.0000\\ 0.9093 & -0.4161 & 0.9364 & ... & 1.0000 & 2.0733\times 10^{-4} & 1.0000\\ ... & ... & ... & ... & ... & ... & ...\\ 0.6570& 0.7539 & 0.4524 & ... & 1.0000 & 7.2564\times 10^{-4} & 1.0000\\ 0.9894 & -0.1455 & 0.9907 & ... & 1.0000 & 8.2931\times 10^{-4} & 1.0000\\ 0.4121 & -0.9111 & 0.6764 & ... & 1.0000 & 9.3297\times 10^{-4} & 1.0000 \end{bmatrix}_{10\times 512}

(2)可学习位置编码
class LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]

特性

  • 直接学习位置嵌入:作为模型参数训练
  • 灵活性高:可适应特定任务的位置模式
  • 长度受限:只能处理预定义的最大长度
  • 计算效率高:直接查表无需计算

三. Self-Attention(自注意力机制)和Multi-Head Attention(多头自注意力)

Transformer 的内部结构图,左侧为 Encoder block(编码器),右侧为 Decoder block(解码器)。可以看到:

(1)Encoder block 包含一个 Multi-Head Attention;

(2)Decoder block 包含两个 Multi-Head Attention (其中有一个用到 Masked)。Multi-Head Attention 上方还包括一个 Add & Norm 层,Add 表示残差连接(Residual Connection),用于防止网络退化,Norm 表示Layer Normalization,用于对每一层的激活值进行归一化。

Multi-Head Attention 是 Transformer 的重点,它由 Self-Attention 演变而来,我们先从 Self-Attention 讲起。

3.1  Self-Attention(自注意力机制)

Self-Attention(自注意力)是 Transformer 架构的核心创新,它彻底改变了序列建模的方式。与传统的循环神经网络(RNN)和卷积神经网络(CNN)不同,self-attention 能够直接捕捉序列中任意两个元素之间的关系,无论它们之间的距离有多远:

Self-Attention 的输入用矩阵X_{n\times d}(n是序列长度,d是特征维度)进行表示,计算如下:

(1)通过可学习的权重矩阵生成Q(查询),K(键值),V(值):

\left\{\begin{matrix} Q = XW^Q \\ K = XW^K \\ V = XW^V \end{matrix}\right.

其中W^Q,W^K,W^V是可学习参数。

(2)计算 Self-Attention 的输出:Attention(Q,K,V)=softmax\left ( \frac{QK^{T}}{\sqrt{d_k}} \right )V

步骤分解:

  1. 相似度计算QK^T计算所有查询-键对之间的点积相似度,QK^T得到的矩阵行列数都为 n,n为句子单词数,这个矩阵可以表示单词之间的 attention 强度。

  2. 缩放:除以\sqrt{d_k}防止点积过大导致梯度消失

  3. 归一化:softmax 将相似度转换为概率分布

  4. 加权求和:用注意力权重对值向量加权求和,得到最终的输出

输入序列: [x1, x2, x3, x4]步骤1: 为每个输入生成Q,K,V向量
x1 → q1, k1, v1
x2 → q2, k2, v2
x3 → q3, k3, v3
x4 → q4, k4, v4步骤2: 计算注意力权重 (以x1为例)
权重1 = softmax(q1·k1 / √d_k)
权重2 = softmax(q1·k2 / √d_k)
权重3 = softmax(q1·k3 / √d_k)
权重4 = softmax(q1·k4 / √d_k)步骤3: 加权求和
输出1 = 权重1*v1 + 权重2*v2 + 权重3*v3 + 权重4*v4

3.2  Multi-Head Attention(多头注意力)

Transformer 使用多头机制增强模型表达能力:

MultiHead(Q,K,V)=Concat(head_1,head_2...head_h)W^O

其中每个注意力头:

head_i=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

  • h:注意力头的数量

  • W_i^Q, W_i^K, W_i^V:每个头的独立参数

  • W^O:输出投影矩阵

代码实现:

(1)多头分割处理:使用view将特征维度分割为多个头,确保每个头的维度:dim_head = dim_qk // num_heads

q = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)
k = ... # 类似处理
v = ... # 类似处理

(2)高效的矩阵运算:使用矩阵乘法并行计算所有位置的注意力分数

attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)

(3)多头合并:使用view合并多头:num_heads * d_v = dim_v

output = output.transpose(1, 2)
output = output.contiguous().view(-1, len_q, self.dim_v)

完整Multi-Head Attention(多头注意力)的代码实现,这里已经考虑了掩码处理的实现,关于掩码将在后面介绍。

class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return output

四.Encoder (编码器)结构

上图红色部分是 Transformer 的 Encoder block ()结构,可以看到是由 Multi-Head Attention, Add & Norm, Feed Forward, Add & Norm 组成的。刚刚已经了解了 Multi-Head Attention 的计算过程,现在了解一下 Add & Norm 和 Feed Forward 部分。

1

六.完整代码实现

import torch
import torch.nn as nnclass LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)print(self.pe[:len, :])return x + self.pe[:len, :]class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return outputclass Feedforward(nn.Module):def __init__(self, dim, hidden_dim=2048, dropout=0., activate=nn.ReLU()):super(Feedforward, self).__init__()self.dim = dimself.hidden_dim = hidden_dimself.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, dim)self.act = activatedef forward(self, x):x = self.act(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xdef attn_mask(len):""":param len: length of sequence:return: mask tensor, False for not replaced, True for replaced as -infe.g. attn_mask(3) =tensor([[[False,  True,  True],[False, False,  True],[False, False, False]]])"""mask = torch.triu(torch.ones(len, len, dtype=torch.bool), 1)return maskdef padding_mask(pad_q, pad_k):""":param pad_q: pad label of query (0 is padding, 1 is not padding), [B, len_q]:param pad_k: pad label of key (0 is padding, 1 is not padding), [B, len_k]:return: mask tensor, False for not replaced, True for replaced as -infe.g. pad_q = tensor([[1, 1, 0]], [1, 0, 1])padding_mask(pad_q, pad_q) =tensor([[[False, False,  True],[False, False,  True],[ True,  True,  True]],[[False,  True, False],[ True,  True,  True],[False,  True, False]]])"""assert pad_q.ndim == pad_k.ndim == 2, 'pad_q and pad_k must be 2-dimensional'assert pad_q.size(0) == pad_k.size(0), 'batch size mismatch'mask = pad_q.bool().unsqueeze(2) * pad_k.bool().unsqueeze(1)mask = ~mask# mask: [B, len_q, len_k]return maskclass EncoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(EncoderLayer, self).__init__()self.attn = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)def forward(self, x, mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn(res1, res1, res1, mask)res2 = self.norm2(x)x = x + self.ffn(res2)else:x = self.attn(x, x, x, mask) + xx = self.norm1(x)x = self.ffn(x) + xx = self.norm2(x)return xclass Encoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Encoder, self).__init__()self.layers = nn.ModuleList([EncoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, mask=None):for layer in self.layers:x = layer(x, mask)return xclass DecoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(DecoderLayer, self).__init__()self.attn1 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.attn2 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)self.norm3 = nn.LayerNorm(dim)def forward(self, x, enc, self_mask=None, pad_mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn1(res1, res1, res1, self_mask)res2 = self.norm2(x)x = x + self.attn2(res2, enc, enc, pad_mask)res3 = self.norm3(x)x = x + self.ffn(res3)else:x = self.attn1(x, x, x, self_mask) + xx = self.norm1(x)x = self.attn2(x, enc, enc, pad_mask) + xx = self.norm2(x)x = self.ffn(x) + xx = self.norm3(x)return xclass Decoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Decoder, self).__init__()self.layers = nn.ModuleList([DecoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, enc, self_mask=None, pad_mask=None):for layer in self.layers:x = layer(x, enc, self_mask, pad_mask)return xclass Transformer(nn.Module):def __init__(self, dim, vocabulary, num_heads=1, num_layers=1, dropout=0., learnable_pos=False, pre_norm=False):super(Transformer, self).__init__()self.dim = dimself.vocabulary = vocabularyself.num_heads = num_headsself.num_layers = num_layersself.dropout = dropoutself.learnable_pos = learnable_posself.pre_norm = pre_normself.embedding = nn.Embedding(vocabulary, dim)self.pos_enc = LearnablePositionalEncoding(dim, 100) if learnable_pos else PositionalEncoding(dim, 100)self.encoder = Encoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.decoder = Decoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.linear = nn.Linear(dim, vocabulary)def forward(self, src, tgt, src_mask=None, tgt_mask=None, pad_mask=None):# src.shape: torch.Size([2, 10])src = self.embedding(src)# src.shape: torch.Size([2, 10, 512])src = self.pos_enc(src)# src.shape: torch.Size([2, 10, 512])src = self.encoder(src, src_mask)# src.shape: torch.Size([2, 10, 512])# tgt.shape: torch.Size([2, 8])tgt = self.embedding(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.pos_enc(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.decoder(tgt, src, tgt_mask, pad_mask)# tgt.shape: torch.Size([2, 8, 512])output = self.linear(tgt)# output.shape: torch.Size([2, 8, 10000])return outputdef get_mask(self, tgt, src_pad=None):# Under normal circumstances, tgt_pad will perform mask processing when calculating loss, and it isn't necessarily in decoderif src_pad is not None:src_mask = padding_mask(src_pad, src_pad)else:src_mask = Nonetgt_mask = attn_mask(tgt.size(1))if src_pad is not None:pad_mask = padding_mask(torch.zeros_like(tgt), src_pad)else:pad_mask = None# src_mask: [B, len_src, len_src]# tgt_mask: [len_tgt, len_tgt]# pad_mask: [B, len_tgt, len_src]return src_mask, tgt_mask, pad_maskif __name__ == '__main__':model = Transformer(dim=512, vocabulary=10000, num_heads=8, num_layers=6, dropout=0.1, learnable_pos=False, pre_norm=True)src = torch.randint(0, 10000, (2, 10))  # torch.Size([2, 10])tgt = torch.randint(0, 10000, (2, 8))   # torch.Size([2, 8])src_pad = torch.randint(0, 2, (2, 10))  # torch.Size([2, 10])src_mask, tgt_mask, pad_mask = model.get_mask(tgt, src_pad)model(src, tgt, src_mask, tgt_mask, pad_mask)# output.shape: torch.Size([2, 8, 10000])

http://www.lqws.cn/news/571645.html

相关文章:

  • redisson看门狗实现原理
  • Linux基本命令篇 —— head命令
  • 【锁相环系列5】再谈数字锁相环
  • python sklearn 机器学习(1)
  • 多模态大语言模型arxiv论文略读(143)
  • 代理模式 - Flutter中的智能替身,掌控对象访问的每一道关卡!
  • ⚙️ 深度学习模型编译器实战:解锁工业级部署新范式​​—— 基于PyTorch-MLIR的全流程优化指南(开源工具链集成)​​
  • Python银行管理系统01升级(适合初学者)
  • 【百日精通JAVA | 语法篇】static关键字
  • CppCon 2017 学习:Undefined Behavior in 2017
  • idea运行到远程机器 和 idea远程JVM调试
  • x86 rop攻击理解2
  • 设计模式-外观模式、适配器模式
  • 设备健康状态实时监测:从技术原理到中讯烛龙的智能实践
  • X-Search:Spring AI实现的AI智能搜索
  • redis延时双删,为什么第一次删除
  • 检查达梦外部表
  • ROS的可视化工具rviz介绍
  • wpf的Binding之UpdateSourceTrigger
  • PaddleNLP
  • 桌面小屏幕实战课程:DesktopScreen 18 FONTPAINT
  • RAG检索增强生成在垂类AI应用效能优化中的应用
  • 【硬核数学】6. 升级你的线性代数:张量,深度学习的多维数据语言《从零构建机器学习、深度学习到LLM的数学认知》
  • 【Java EE初阶 --- 多线程(进阶)】锁策略
  • 构建创意系统:驾驭Audition与Photoshop的AI之力,洞悉原子化设计哲学
  • Cursor1.1.6安装c++插件
  • MyBatis实战指南(八)MyBatis日志
  • 【数据集处理】基于 3D-GloBFP建筑轮廓数据 栅格化建筑数据(完整Python代码)
  • Day.46
  • 水果维生素含量排名详表