当前位置: 首页 > news >正文

手撕 Decoder

Happy-LLM:从零开始的大语言模型原理与实践教程.pdf P24

Decoder Layer

一个Decoder Layer内的数据流动顺序为:

Input (x)↓
LayerNorm 1↓
Masked Self-Attention↓
x + Attention(x)↓
LayerNorm 2↓
Cross-Attention (with enc_out)↓
h = x + Attention(x, enc_out)↓
LayerNorm 3↓
MLP (Feed-Forward Network)↓
out = h + MLP(h)

该实现依然与标准transformer不一样,以下代码里,先归一化再残差连接(Pre-LayerNorm),而标准transformer则相反

代码

class DecoderLayer(nn.Module):def __init__(self, args):super().__init__()self.attention_norm_1 = LayerNorm(args.n_embd)self.mask_attention = MultiHeadAttention(args, is_causal=True)self.attention_norm_2 = LayerNorm(args.n_embd)self.attention = MultiHeadAttention(args, is_causal=False)self.ffn_norm = LayerNorm(args.n_embd)self.feed_forward = MLP(args)def forward(self, x, enc_out):# Layer Normnorm_x = self.attention_norm_1(x)# 掩码自注意力x = x + self.mask_attention.forward(norm_x, norm_x, norm_x)# 多头注意力norm_x = self.attention_norm_2(x)h = x + self.attention.forward(norm_x, enc_out, enc_out)# 经过前馈神经网络out = h + self.feed_forward.forward(self.ffn_norm(h))return out

初始化了三个子层(Masked Multi-Head Attention, Cross Attention, Feed Forward),每个子层都含有一个Layer Norm,三个层归一化的函数相同,均为LayerNorm(args.n_embd),仅名称不同

代码逻辑实质上是三个重复的x_{out}=x+SubLayer(LN(x))

搭建Decoder

 class Decoder(nn.Module):'''解码器'''def __init__(self, args):super(Decoder, self).__init__() # ⼀个 Decoder 由 N 个 Decoder Layer 组成self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.n_layer)])self.norm = LayerNorm(args.n_embd)def forward(self, x, enc_out):"Pass the input (and mask) through each layer in turn."for layer in self.layers:x = layer(x, enc_out)return self.norm(x)

[DecoderLayer(args) for _ in range(args.n_layer)] 通过循环生成 args.n_layer 个 DecoderLayer 实例;nn.ModuleList 将生成的 DecoderLayer 实例列表包装为 nn.ModuleList,动态创建并注册多个解码器层,构建符合 Transformer 架构的解码器

代码末尾的 self.norm(x) 是对所有层处理后的最终输出进行归一化

参考文章

Happy-LLM:从零开始的大语言模型原理与实践教程.pdf P24

http://www.lqws.cn/news/496369.html

相关文章:

  • 将RESP.app的备份数据转码成AnotherRedisDesktopManager的格式
  • react gsap动画库使用详解之text文本动画
  • 山洪灾害智能监测站系统解决方案
  • 通过apache共享文件
  • 渗透测试指南(CSMSF):Windows 与 Linux 系统中的日志与文件痕迹清理
  • XSD是什么,与XML关系
  • D2554探鸽协议,sensor属性,回调
  • 关于 pdd:anti_content参数分析与逆向
  • 前端面试记录
  • Flask框架index.html里引用的本地的js和css或者图片
  • C#采集电脑硬件(CPU、GPU、硬盘、内存等)温度和使用状况
  • 深入理解PHP中的生成器(Generators)
  • 新高考需求之一
  • 【GNSS定位算法】Chapter.2 导航定位算法软件学习——Ginav(二)SPP算法 [2025年6月]
  • 系统规划与管理师(第2版)第9章思维导图发布
  • Java面试核心考点复习指南
  • 智能交通中的深度学习应用:从理论到实践
  • 深入解析 Windows 文件查找命令(dir、gci)
  • 在cursor中,配置jdk和maven环境,安装拓展插件
  • AngularJS
  • 【笔记】在Cygwin上使用mintty连接wsl
  • 【软考高级系统架构论文】论企业集成架构设计及应用
  • 海拔案例分享-门店业绩管理小程序
  • 【ARM 嵌入式 编译系列 7.4 -- GCC 链接脚本中 ASSERT 函数】
  • 如何利用Charles抓包工具提升API调试与性能优化
  • QT6(46)5.2 QStringListModel 和 QListView :列表的模型与视图的界面搭建与源代码实现
  • Netty内存池分层设计架构
  • 本地文件深度交互新玩法:Obsidian Copilot的深度开发
  • 【streamlit 组件样式定位与修改】
  • 数字孪生:为UI前端设计带来沉浸式交互新体验