生成式人工智能实战 | 变分自编码器(Variational Auto-Encoder, VAE)
生成式人工智能实战 | 变分自编码器
- 0. 前言
- 1. 潜空间运算
- 2. 变分自编码器
- 2.1 VAE 工作原理
- 2.2 VAE 构建策略
- 2.3 KL 散度
- 2.4 重参数化技巧
- 3. 实现 VAE
- 3.1 数据加载
- 3.2 模型构建
- 3.3 模型训练
0. 前言
虽然自编码器 (AutoEncoder, AE) 在重建输入数据方面表现良好,但通常在生成训练集中不存在的新样本时表现不佳。更重要的是,自编码器在输入插值方面同样表现不佳,无法生成两个输入数据点之间的中间表示。这就引出了变分自编码器 (Variational Auto-Encoder
, VAE
),变分自编码器是一种生成模型,结合了深度学习和概率图模型的优点,通过学习数据的潜在概率分布来生成新的数据样本。本节将从零开始构建和训练一个 VAE
,使用 cifar-10
数据集训练 VAE
。
1. 潜空间运算
使用变分自编码器 (Variational Auto-Encoder
, VAE
) 可以进行向量运算和输入插值。操作不同输入的编码表示(潜向量),以在解码时实现特定的结果(例如,图像中是否具有某些特征)。潜向量控制解码图像中的不同特征,如性别、图像中是否有眼镜等。例如,可以首先获得戴眼镜的男性的潜向量 (z1
)、戴眼镜的女性的潜向量 (z2
) 和不戴眼镜的女性的潜向量 (z3
)。然后,计算一个新的潜向量 z4 = z1 – z2 + z3
。由于 z1
和 z2
解码后都会出现眼镜,z1 – z2
会在结果图像中去除眼镜特征。类似地,由于 z2
和 z3
都会解码为女性面孔,z3 – z2
会去除结果图像中的女性特征。因此,如果使用训练好的 VAE
解码 z4
将得到一张没有不戴眼镜的男性图像。
2. 变分自编码器
虽然自编码器 (AutoEncoder
, AE
) 擅长重建原始图像,但它们在生成训练集中没有出现的新图像方面表现不佳。此外,自编码器通常无法将相似的输入映射到潜空间中的相邻点。因此,AE
的潜空间既不连续,也不容易解释。例如,无法通过插值两个输入数据点来生成有意义的中间表示。基于这些原因,我们将学习自编码器的改进模型,变分自编码器 (Variational Auto-Encoder
, VAE
)。
2.1 VAE 工作原理
VAE
使用深度学习构建概率模型,将输入数据映射到一个低维度的潜空间中,并通过解码器将潜空间中的分布转换回数据空间中,以生成与原始数据相似的数据。与传统的自编码器相比,VAE
更加稳定,生成样本的质量更高。
VAE
的核心思想是利用概率模型来描述高维的输入数据,将输入数据采样于一个低维度的潜变量分布中,并通过解码器生成与原始数据相似的输出。具体来说,VAE
同样是由编码器和解码器组成:
- 编码器将数据 x x x 映射到一个潜在空间 z z z 中,该空间定义在低维正态分布中,即 z ∼ N ( 0 , I ) z∼N(0,I) z∼N(0,I),编码器由两个部分组成:一是将数据映射到均值和方差,即 z ∼ N ( μ , σ 2 ) z∼N(μ,σ^2) z∼N(μ,σ2);二是通过重参数化技巧,将均值和方差的采样过程分离出来,并引入随机变量 ϵ ∼ N ( 0 , I ) ϵ∼N(0,I) ϵ∼N(0,I),使得 z = μ + ϵ σ z=μ+ϵσ z=μ+ϵσ
- 解码器将潜在变量 z z z 映射回数据空间中,生成与原始数据 x x x 相似的数据 x ′ x′ x′,为了使生成的数据 x ′ x′ x′ 能够与原始数据 x x x 较高的相似度,
VAE
在损失函数中使用重构误差和正则化项,重构误差表示生成数据与原始数据之间的差异,正则化项用于约束潜在变量的分布,使其满足高斯正态分布,使得VAE
从潜空间中生成的样本质量更高
VAE
具有广泛的应用场景,如图像生成、语音、自然语言处理等领域,它能够通过有限的数据样本学习到输入数据背后的潜在规律,生成与原始数据类似的新数据,具有很强的潜数据的可解释性。
2.2 VAE 构建策略
在 VAE
中,基于预定义分布获得的随机向量生成逼真图像,而在传统自编码器中并未指定在网络中生成图像的数据分布。可以通过以下策略,实现 VAE:
- 编码器的输出包括两个向量:
- 输入图像平均值
- 输入图像标准差
- 根据以上两个向量,通过在均值和标准差之和中引入随机变量 ( ϵ ∼ N ( 0 , I ) ϵ∼N(0,I) ϵ∼N(0,I)) 获取随机向量 ( z = μ + ϵ σ z=μ+ϵσ z=μ+ϵσ)
- 将上一步得到的随机向量作为输入传递给解码器以重构图像
- 损失函数是均方误差和 KL 散度损失的组合:
KL
散度损失衡量由均值向量 μ \mu μ 和标准差向量 σ \sigma σ 构建的分布与 N ( 0 , I ) N(0,I) N(0,I) 分布的偏差- 均方损失用于优化重建(解码)图像
通过训练网络,指定输入数据满足由均值向量 μ \mu μ 和标准差向量 σ \sigma σ 构建的 N ( 0 , 1 ) N(0,1) N(0,1) 分布,当我们生成均值为 0
且标准差为 1
的随机噪声时,解码器将能够生成逼真的图像。
需要注意的是,如果只最小化 KL
散度,编码器将预测均值向量为 0
,标准差为 1
。因此,需要同时最小化 KL
散度损失和均方损失。在下一节中,让我们介绍 KL
散度,以便将其纳入模型的损失值计算中。
2.3 KL 散度
KL
散度(也称相对熵)可以用于衡量两个概率分布之间的差异:
K L ( P ∣ ∣ Q ) = ∑ x ∈ X P ( x ) l n ( P ( i ) Q ( i ) ) KL(P||Q) = \sum_{x∈X} P(x) ln(\frac {P(i)}{Q(i)}) KL(P∣∣Q)=x∈X∑P(x)ln(Q(i)P(i))
其中, P P P 和 Q Q Q 为两个概率分布,KL
散度的值越小,两个分布的相似性就越高,当且仅当 P P P 和 Q Q Q 两个概率分布完全相同时,KL
散度等于 0
。在 VAE
中,我们希望瓶颈特征值遵循平均值为 0
和标准差为 1
的正态分布。因此,我们可以使用 KL
散度衡量变分自编码器中编码器输出的分布与标准高斯分布 N ( 0 , 1 ) N(0,1) N(0,1) 之间的差异。
可以通过以下公式计算 KL
散度损失:
∑ i = 1 n σ i 2 + μ i 2 − l o g ∗ ( σ i ) − 1 \sum_{i=1}^n\sigma_i^2+\mu_i^2-log*(\sigma_i)-1 i=1∑nσi2+μi2−log∗(σi)−1
在上式中, σ σ σ 和 μ μ μ 表示每个输入图像的均值和标准差值:
- 确保均值向量分布在
0
附近:- 最小化上式中的均方误差 ( μ i 2 \mu_i^2 μi2) 可确保 μ \mu μ 尽可能接近
0
- 最小化上式中的均方误差 ( μ i 2 \mu_i^2 μi2) 可确保 μ \mu μ 尽可能接近
- 确保标准差向量分布在
1
附近:- 上式中其余部分(除了 μ i 2 \mu_i^2 μi2 )用于确保标准差 ( s i g m a sigma sigma) 分布在
1
附近
- 上式中其余部分(除了 μ i 2 \mu_i^2 μi2 )用于确保标准差 ( s i g m a sigma sigma) 分布在
当均值 ( μ μ μ) 为 0
且标准差为 1
时,以上损失函数值达到最小,通过引入标准差的对数,确保 σ \sigma σ 值不为负。通过最小化以上损失可以确保编码器输出遵循预定义分布。
2.4 重参数化技巧
下图左侧显示了 VAE
网络。编码器获取输入 x x x,并估计潜矢量 z z z 的多元高斯分布的均值 μ μ μ 和标准差 σ σ σ,解码器从潜矢量 z z z 采样,以将输入重构为 x x x:
但是反向传播梯度不会通过随机采样块。虽然可以为神经网络提供随机输入,但梯度不可能穿过随机层。解决此问题的方法是将“采样”过程作为输入,如图右侧所示。 采样计算为:
S a m p l e = μ + ε σ Sample=\mu + εσ Sample=μ+εσ
如果 ε ε ε 和 σ σ σ 以矢量形式表示,则 ε σ εσ εσ 是逐元素乘法,使用上式,令采样好像直接来自于潜空间。 这种技术被称为重参数化技巧 (Reparameterization trick
)。
3. 实现 VAE
在本节中,使用 PyTorch
实现 VAE
模型生成 cifar-10
图像。
3.1 数据加载
(1) 首先导入所需的库:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
(2) 定义数据预处理转换:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 将像素值归一化到[-1,1]
])
(3) 加载 CIFAR-10
训练集和测试集:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
(4) 创建数据加载器:
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
3.2 模型构建
(1) 定义 VAE
模型,由编码器和解码器构成:
class VAE(nn.Module):def __init__(self, latent_dim=128):super(VAE, self).__init__()self.latent_dim = latent_dim# 编码器self.encoder = nn.Sequential(nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), # 32x16x16nn.ReLU(),nn.BatchNorm2d(32),nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # 64x8x8nn.ReLU(),nn.BatchNorm2d(64),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128x4x4nn.ReLU(),nn.BatchNorm2d(128),nn.Flatten(), # 128*4*4=2048nn.Linear(2048, 1024),nn.ReLU())# 潜在空间的均值和对数方差self.fc_mu = nn.Linear(1024, latent_dim)self.fc_logvar = nn.Linear(1024, latent_dim)# 解码器self.decoder_input = nn.Linear(latent_dim, 1024)self.decoder = nn.Sequential(nn.Linear(1024, 2048),nn.ReLU(),nn.Unflatten(1, (128, 4, 4)), # 128x4x4nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 64x8x8nn.ReLU(),nn.BatchNorm2d(64),nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 32x16x16nn.ReLU(),nn.BatchNorm2d(32),nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # 3x32x32nn.Tanh() # 输出在[-1,1]之间,与输入归一化一致)def encode(self, x):"""编码输入图像x,返回潜在空间的均值和方差"""h = self.encoder(x)mu = self.fc_mu(h)logvar = self.fc_logvar(h)return mu, logvardef reparameterize(self, mu, logvar):"""重参数化技巧,从N(mu, var)采样"""std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):"""从潜在变量z解码重构图像"""h = self.decoder_input(z)x_recon = self.decoder(h)return x_recondef forward(self, x):mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)x_recon = self.decode(z)return x_recon, mu, logvar
(2) 定义损失函数,由重建损失和 KL 散度组成:
def vae_loss(recon_x, x, mu, logvar):"""VAE损失函数 = 重构损失 + KL散度"""# 重构损失recon_loss = F.mse_loss(recon_x, x, reduction='sum')# KL散度:-0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return recon_loss + kl_loss
3.3 模型训练
(1) 定义模型训练和测试函数:
def train(model, device, train_loader, optimizer, epoch):model.train()train_loss = 0for batch_idx, (data, _) in enumerate(train_loader):data = data.to(device)optimizer.zero_grad()# 前向传播recon_batch, mu, logvar = model(data)# 计算损失loss = vae_loss(recon_batch, data, mu, logvar)# 反向传播和优化loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')avg_loss = train_loss / len(train_loader.dataset)print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')return avg_lossdef test(model, device, test_loader):model.eval()test_loss = 0with torch.no_grad():for data, _ in test_loader:data = data.to(device)recon_batch, mu, logvar = model(data)test_loss += vae_loss(recon_batch, data, mu, logvar).item()test_loss /= len(test_loader.dataset)print(f'====> Test set loss: {test_loss:.4f}')return test_loss
(2) 定义可视化函数,用于可视化原始图像和重构图像:
def visualize_reconstruction(model, device, test_loader, num_images=8):model.eval()with torch.no_grad():# 获取一批测试图像data, _ = next(iter(test_loader))data = data[:num_images].to(device)# 重构图像recon_data, _, _ = model(data)# 将图像从[-1,1]转换回[0,1]以便显示data = data.cpu().numpy().transpose(0, 2, 3, 1)data = (data + 1) / 2 # 从[-1,1]到[0,1]recon_data = recon_data.cpu().numpy().transpose(0, 2, 3, 1)recon_data = (recon_data + 1) / 2 # 从[-1,1]到[0,1]# 绘制图像fig, axes = plt.subplots(2, num_images, figsize=(num_images * 2, 4))for i in range(num_images):axes[0, i].imshow(data[i])axes[0, i].axis('off')axes[1, i].imshow(recon_data[i])axes[1, i].axis('off')axes[0, 0].set_ylabel('Original')axes[1, 0].set_ylabel('Reconstructed')plt.show()
(3) 定义 generate_samples()
,从潜空间随机采样生成新图像:
def generate_samples(model, device, latent_dim, num_samples=16):model.eval()with torch.no_grad():# 从标准正态分布采样z = torch.randn(num_samples, latent_dim).to(device)# 生成样本samples = model.decode(z).cpu()samples = samples.numpy().transpose(0, 2, 3, 1)samples = (samples + 1) / 2 # 从[-1,1]到[0,1]# 绘制生成的样本fig, axes = plt.subplots(4, 4, figsize=(8, 8))for i, ax in enumerate(axes.flat):ax.imshow(samples[i])ax.axis('off')plt.show()
(4) 训练模型 50
个 epoch
,训练完成后,可视化模型生成效果,并绘制训练和测试损失变化曲线:
def main():# 设置设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 初始化模型latent_dim = 128model = VAE(latent_dim=latent_dim).to(device)# 定义优化器optimizer = optim.Adam(model.parameters(), lr=1e-4)# 训练参数epochs = 50train_losses = []test_losses = []# 训练循环for epoch in range(1, epochs + 1):train_loss = train(model, device, train_loader, optimizer, epoch)test_loss = test(model, device, test_loader)train_losses.append(train_loss)test_losses.append(test_loss)# 每5个epoch可视化一次if epoch % 5 == 0:visualize_reconstruction(model, device, test_loader)# 训练完成后可视化generate_samples(model, device, latent_dim)# 绘制训练和测试损失曲线plt.figure(figsize=(10, 5))plt.plot(train_losses, label='Train Loss')plt.plot(test_losses, label='Test Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.title('Training and Test Loss')plt.show()main()
重建效果:
生成结果: