当前位置: 首页 > news >正文

损失函数L对全连接层W、X、b的梯度

假设:X的维度为 s × n s\times n s×n,其中s为样本数,每个样本均展平为 1 × n 1\times n 1×n的行向量;W维度为 n × o n\times o n×o,其中o为全连接层的输出维度;b维度为 1 × o 1\times o 1×o;Z维度为 s × o s\times o s×o,并且 Z = X × W + b Z=X\times W+b Z=X×W+b
Z = [ z 11 … z 1 o … z s 1 … z s o ] Z=\begin{bmatrix} z_{11} & … & z_{1o} \\ & … & \\ z_{s1} & … & z_{so} \end{bmatrix} Z= z11zs1z1ozso
X = [ x 11 … x 1 n … x s 1 … x s n ] X=\begin{bmatrix} x_{11} & … & x_{1n} \\ & … & \\ x_{s1} & … & x_{sn} \end{bmatrix} X= x11xs1x1nxsn
W = [ w 11 … w 1 o … w n 1 … w n o ] W=\begin{bmatrix} w_{11} & … & w_{1o} \\ & … & \\ w_{n1} & … & w_{no} \end{bmatrix} W= w11wn1w1owno
b = [ b 1 … b o ] b=\begin{bmatrix} b_{1} & … & b_{o} \\ \end{bmatrix} b=[b1bo]

损失函数L对W的梯度:

∂ L ∂ W = X T × ∂ L ∂ Z \frac{\partial L}{\partial W}=X^T\times\frac{\partial L}{\partial Z} WL=XT×ZL

证明: \textbf{证明:} 证明:

因为

z i j = ∑ t = 1 n x i t w t j + b j z_{ij}=\sum_{t=1}^n{x_{it}w_{tj}+b_j} zij=t=1nxitwtj+bj

所以 ∂ z i j ∂ w k l = { x i k , 如果  j = l 0 , 如果  j ≠ l \frac{\partial z_{ij}}{\partial w_{kl}}= \left\{ \begin{aligned} & x_{ik}, && \text{如果 } j=l \\ & 0, && \text{如果 } j \neq l \end{aligned} \right. wklzij={xik,0,如果 j=l如果 j=l
所以
∂ L ∂ w k l = ∑ i = 1 s ∑ j = 1 o ∂ L ∂ z i j ∂ z i j ∂ w k l = ∑ i = 1 s ∂ L ∂ z i l x i k \frac{\partial L}{\partial w_{kl}} = {\sum_{i=1}^s \sum_{j=1}^{o} \frac{\partial L}{\partial z_{ij}} \frac{\partial z_{ij}}{\partial w_{kl}}} =\sum_{i=1}^s\frac{\partial L}{\partial z_{il}}x_{ik} wklL=i=1sj=1ozijLwklzij=i=1szilLxik
又因为
X T × ∂ L ∂ Z = [ x 11 … x s 1 … x 1 n … x s n ] × [ ∂ L ∂ z 11 … ∂ L ∂ z 1 o … ∂ L ∂ z s 1 … ∂ L ∂ z s o ] X^T\times\frac{\partial L}{\partial Z}= \begin{bmatrix} x_{11} & … & x_{s1} \\ & … & \\ x_{1n} & … & x_{sn} \\ \end{bmatrix} \times \begin{bmatrix} \frac{\partial L}{\partial z_{11}} & … & \frac{\partial L}{\partial z_{1o}} \\ & … & \\ \frac{\partial L}{\partial z_{s1}} & … & \frac{\partial L}{\partial z_{so}} \end{bmatrix} XT×ZL= x11x1nxs1xsn × z11Lzs1Lz1oLzsoL
所以
( X T × ∂ L ∂ Z ) k l = ∑ i = 1 s ∂ L ∂ z i l x i k = ∂ L ∂ w k l (X^T\times\frac{\partial L}{\partial Z})_{kl}=\sum_{i=1}^s\frac{\partial L}{\partial z_{il}}x_{ik}=\frac{\partial L}{\partial w_{kl}} (XT×ZL)kl=i=1szilLxik=wklL
所以
∂ L ∂ W = X T × ∂ L ∂ Z \frac{\partial L}{\partial W}=X^T\times\frac{\partial L}{\partial Z} WL=XT×ZL

损失函数L对X的梯度:

∂ L ∂ X = ∂ L ∂ Z × W T \frac{\partial L}{\partial X}=\frac{\partial L}{\partial Z}\times W^T XL=ZL×WT

证明: \textbf{证明:} 证明:

∂ L ∂ x i j = ∑ l = 1 s ∑ k = 1 o ∂ L ∂ z l k ∂ z l k ∂ x i j \frac{\partial L}{\partial x_{ij}}=\sum_{l=1}^s\sum_{k=1}^o\frac{\partial L}{\partial z_{lk}}\frac{\partial z_{lk}}{\partial x_{ij}} xijL=l=1sk=1ozlkLxijzlk
因为

z l k = ∑ l = 1 n x l t w t k + b k z_{lk}=\sum_{l=1}^n{x_{lt}w_{tk}+b_k} zlk=l=1nxltwtk+bk
所以 ∂ z l k ∂ x i j = { w j k , 如果  l = i 0 , 如果  l ≠ i \frac{\partial z_{lk}}{\partial x_{ij}}= \left\{ \begin{aligned} & w_{jk}, && \text{如果 } l=i \\ & 0, && \text{如果 } l \neq i \end{aligned} \right. xijzlk={wjk,0,如果 l=i如果 l=i

所以 ∂ L ∂ x i j = ∑ k = 1 o ∂ L ∂ z i k w j k \frac{\partial L}{\partial x_{ij}}=\sum_{k=1}^o\frac{\partial L}{\partial z_{ik}}w_{jk} xijL=k=1ozikLwjk

∂ L ∂ Z × W T = [ ∂ L ∂ z 11 … ∂ L ∂ z 1 o … ∂ L ∂ z s 1 … ∂ L ∂ z s o ] × [ w 11 … w n 1 … w 1 o … w n o ] \frac{\partial L}{\partial Z}\times W^T= \begin{bmatrix} \frac{\partial L}{\partial z_{11}} & … & \frac{\partial L}{\partial z_{1o}} \\ & … & \\ \frac{\partial L}{\partial z_{s1}} & … & \frac{\partial L}{\partial z_{so}} \end{bmatrix} \times \begin{bmatrix} w_{11} & … & w_{n1} \\ & … & \\ w_{1o} & … & w_{no} \end{bmatrix} ZL×WT= z11Lzs1Lz1oLzsoL × w11w1own1wno

所以 ( ∂ L ∂ Z × W T ) i j = ∑ k = 1 o ∂ L ∂ z i k w j k = ∂ L ∂ x i j (\frac{\partial L}{\partial Z}\times W^T)_{ij}=\sum_{k=1}^o\frac{\partial L}{\partial z_{ik}}w_{jk}=\frac{\partial L}{\partial x_{ij}} (ZL×WT)ij=k=1ozikLwjk=xijL

所以 ∂ L ∂ X = ∂ L ∂ Z × W T \frac{\partial L}{\partial X}=\frac{\partial L}{\partial Z}\times W^T XL=ZL×WT

损失函数L对b的梯度:

∂ L ∂ b = s u m ( ∂ L ∂ Z , a x i s = 0 ) # 逐列求和 \frac{\partial L}{\partial b}=sum(\frac{\partial L}{\partial Z}, axis=0) \#逐列求和 bL=sum(ZL,axis=0)#逐列求和

证明: \textbf{证明:} 证明:

因为
∂ L ∂ b k = ∑ i = 1 s ∑ j = 1 o ∂ L ∂ z i j ∂ z i j ∂ b k = ∑ i = 1 s ∂ L ∂ z i k \frac{\partial L}{\partial b_{k}} = \sum_{i=1}^s \sum_{j=1}^{o} \frac{\partial L}{\partial z_{ij}} \frac{\partial z_{ij}}{\partial b_{k}} =\sum_{i=1}^s\frac{\partial L}{\partial z_{ik}} bkL=i=1sj=1ozijLbkzij=i=1szikL

所以
∂ L ∂ b = 1 s ∂ L ∂ Z = s u m ( ∂ L ∂ Z , a x i s = 0 ) \frac{\partial L}{\partial b}=\mathbf{1}_s\frac{\partial L}{\partial Z}=sum(\frac{\partial L}{\partial Z}, axis=0) bL=1sZL=sum(ZL,axis=0)
其中 1 s \mathbf{1}_s 1s为s列行向量

http://www.lqws.cn/news/134299.html

相关文章:

  • 【机器人编程基础】循环语句for-while
  • 字符串Base64编码经历了什么
  • 压测软件-Jmeter
  • 【Pandas】pandas DataFrame sample
  • 机器学习的数学基础:假设检验
  • 从上下文学习和微调看语言模型的泛化:一项对照研究
  • Linux系统iptables防火墙实验拓补
  • WES7系统深度定制全流程详解(从界面剥离到工业部署)
  • 【python】运行python程序的方式
  • 数据湖是什么?数据湖和数据仓库的区别是什么?
  • 不同视角理解三维旋转
  • macOS 上使用 Homebrew 安装redis-cli
  • CanvasGroup篇
  • dvwa9——Weak Session IDs
  • JavaSec-专题-反序列化
  • 使用osqp求解简单二次规划问题
  • LeetCode-934. 最短的桥
  • 树莓派上遇到插入耳机后显示“无输入设备”问题
  • C++课设:通讯录管理系统(vector、map协作实现)
  • MQTT协议:物联网时代的通信基石
  • 使用 LangChain 和 RAG 实现《斗破苍穹》文本问答系
  • 栈-20.有效的括号-力扣(LeetCode)
  • RAG系统中的Re-ranking引擎选择指南
  • Android SharedFlow 详解
  • Unity安卓平台开发,启动app并传参
  • 读文献先读图:GO弦图怎么看?
  • 瀚文(HelloWord)智能键盘项目深度剖析:从0到1的全流程解读
  • Vue跨层级通信
  • 华为云Flexus+DeepSeek征文|实战体验云服务器单机部署和CCE高可用的架构AI赋能
  • 「Java教案」顺序结构