Task1: BackPropagation
BP: 符号定义
| 符号 | 维度 | 含义 |
|---|
| X | 10×784 | 输入 |
| W1 | 784×16 | 第一层权重 |
| W2 | 16×1 | 第二层权重 |
| Z1 | 10×16 | 第一层线性输出 |
| A1 | 10×16 | 第一层激活输出 |
| Z2 | 10×1 | 第二层线性输出 |
| y^ | 10×1 | 预测值 |
| y | 10×1 | 真实标签 |
| L | 标量 | 损失函数 |
| α | 标量 | 学习率 |
BP: 前向传播
Z1=XW1,Z1∈R10×16
A1=σ(Z1)=1+e−Z11,A1∈R10×16
Z2=A1W2,Z2∈R10×1
y^=σ(Z2)=1+e−Z21,y^∈R10×1
L=−[yTlog(y^)+(1−y)Tlog(1−y^)],L∈R
注:此处 sigmoid 函数为 element-wise 的。
BP: 反向传播
∂Z2∂L=y^−y,∂Z2∂L∈R10×1
∂W2∂L=A1T∂Z2∂L,∂W2∂L∈R16×1
∂A1∂L=∂Z2∂LW2T,∂A1∂L∈R10×16
∂Z1∂L=∂A1∂L⊙A1⊙(1−A1),∂Z1∂L∈R10×16
∂W1∂L=XT∂Z1∂L,∂W1∂L∈R784×16
注:⊙ 表示逐元素相乘。
BP: 参数更新
W1←W1−α∂W1∂L
W2←W2−α∂W2∂L
Task2: BatchNorm in MLP
BN: 符号定义
| 符号 | 维度 | 含义 |
|---|
| X | N×D | BatchNorm 层输入 |
| γ | D | 缩放参数 |
| β | D | 平移参数 |
| μB | D | batch 均值 |
| σB2 | D | batch 方差 |
| X^ | N×D | 归一化后的值 |
| Y | N×D | BatchNorm 输出 |
| ∂Y∂L | N×D | 上游梯度 |
| ϵ | 标量 | 小常数 |
BN: 前向传播
μB=N1n=1∑NXn,:,μB∈RD
σB2=N1n=1∑N(Xn,:−μB)2,σB2∈RD
X^=σB2+ϵX−1μB,X^∈RN×D
Y=γ⊙X^+1β,Y∈RN×D
注:1∈RN×1 为全1列向量,运算为广播机制。
BN: 反向传播
令 s=σB2+ϵ,则:
∂γ∂L=n=1∑N∂Yn,:∂L⊙X^n,:,∂γ∂L∈RD
∂β∂L=n=1∑N∂Yn,:∂L,∂β∂L∈RD
∂X^∂L=∂Y∂L⊙γ,∂X^∂L∈RN×D
∂X∂L=Ns1⊙(N∂X^∂L−1s1−X^s2),∂X∂L∈RN×D
其中:
s1=n=1∑N∂X^∂Ln,:∈RD
s2=n=1∑N∂X^∂Ln,:⊙X^n,:∈RD
注:⊙ 表示逐元素相乘,除法为逐元素除法,1∈RN×1 为全1列向量。
BN: 参数更新
γ←γ−α∂γ∂L
β←β−α∂β∂L
反向传播推导
考虑单个特征维度(D=1),有 N 个样本:x1,x2,...,xN。
已知 ∂x^i∂L(i=1..N),求 ∂xi∂L。
第一步:前向传播表达式
μ=N1k=1∑Nxk(1)
σ2=N1k=1∑N(xk−μ)2(2)
s=σ2+ϵ(3)
x^i=sxi−μ(4)
第二步:链式法则
∂xi∂L=j=1∑N∂x^j∂L⋅∂xi∂x^j(5)
第三步:计算 ∂xi∂x^j
由 (4) 对 xi 求偏导:
∂xi∂x^j=s2∂xi∂(xj−μ)⋅s−(xj−μ)⋅∂xi∂s(6)
第四步:计算 ∂xi∂(xj−μ)
由 (1):∂xi∂μ=N1,所以:
∂xi∂(xj−μ)=δij−N1(7)
其中 δij=1 当 i=j,否则 0。
第五步:计算 ∂xi∂s
由 (3):∂xi∂s=2s1⋅∂xi∂σ2
由 (2) 计算 ∂xi∂σ2:
∂xi∂σ2=N2k=1∑N(xk−μ)(δik−N1)=N2(xi−μ)(8)
因此:
∂xi∂s=Nsxi−μ(9)
第六步:代入 (6)
将 (7)(9) 代入 (6):
∂xi∂x^j=s2(δij−N1)s−(xj−μ)⋅Nsxi−μ
整理得:
∂xi∂x^j=s1(δij−N1)−Ns31(xj−μ)(xi−μ)(10)
第七步:代入链式法则 (5)
∂xi∂L=j=1∑N∂x^j∂L[s1(δij−N1)−Ns31(xj−μ)(xi−μ)]
拆开三项:
∂xi∂L=s1j=1∑N∂x^j∂Lδij−Ns1j=1∑N∂x^j∂L−Ns31(xi−μ)j=1∑N∂x^j∂L(xj−μ)
第一项中 ∑j∂x^j∂Lδij=∂x^i∂L,所以:
∂xi∂L=s1∂x^i∂L−Ns1j=1∑N∂x^j∂L−Ns31(xi−μ)j=1∑N∂x^j∂L(xj−μ)(11)
第八步:用 x^ 替换 (x−μ)
由 (4):xi−μ=sx^i,代入 (11) 第三项:
Ns31(xi−μ)j∑∂x^j∂L(xj−μ)=Ns31(sx^i)j∑∂x^j∂L(sx^j)=Ns1x^ij∑∂x^j∂Lx^j
代回 (11):
∂xi∂L=s1∂x^i∂L−Ns1j∑∂x^j∂L−Ns1x^ij∑∂x^j∂Lx^j
提取公因子 Ns1:
∂xi∂L=Ns1(N∂x^i∂L−j∑∂x^j∂L−x^ij∑∂x^j∂Lx^j)(12)