第四章 Stable Diffusion
在前一章中,我们介绍了扩散模型及其迭代优化的基本思想。学完该章,我们已经能够生成图像,但训练模型非常耗时,而且我们无法控制生成的图像。在本章中,我们将学习如何从这一阶段走向基于文本条件的模型,这些模型可以根据文本描述高效地生成图像,研究的是一个名为Stable Diffusion(SD)的模型。不过在介绍SD之前,我们会先了解条件模型如何工作,并回顾一些产生当今文生图像模型的创新。
本文相关代码请见GitHub仓库。
增强控制:条件扩散模型
在解决从文本描述生成图像这个具有挑战性的任务之前,先聚焦于一个略简单一点的任务上。我们将了解如何引导我们的模型输出特定类型或类别的 图像。可以使用一种称为条件化的方法,其思想是要求模型生成的不是宽泛的图像,而是属于预定义类别的图像。
模型条件化是一个简单但有效的想法。我们将通过上一章使用的扩散模型,仅做一些小改动。首先,不再使用蝴蝶数据集,而是切换到一个有类别的数据集。我们使用Fashion MNIST,这是一个包含成千上万张衣服图像的数据集,每张图像都带有一个来自10个不同类别的标签。然后是关键,我们将通过模型运行两个输入。不仅向模型展示真实的图像,还会告诉它每张图像所属的类别。我们期望模型学会关联图像和标签,以理解毛衣、靴子等的自有特征。
请注意,我们并不想解决分类问题——不是希望模型告诉我们图像属于哪个类别。我们仍然希望它执行与上一章相同的任务:生成看起来像来自这个数据集的图像。唯一的区别是给了它关于这些图像的附加信息。我们将使用相同的损失函数和训练策略,因为任务是相同的。
准备数据
需要一个具有不同图像组的数据集。适用于计算机视觉分类任务的数据集是理想的选择。我们可以从类似ImageNet的数据集开始,该数据集包含数百万张涵盖1000个类别的图像。然而,在这个数据集上训练模型将花费极长的时间。在处理新问题时,最好先从较小的数据集入手,以确保一切按预期进展。这可以缩短反馈环,使我们能够快速迭代并确保方向正确。
我们可以选择MNIST作为这个例子,像上一章那样。为略显不同,我们将选择Fashion MNIST。Fashion MNIST由Zalando开发并开源,是MNIST的替代品,具有类似的特征:压缩大小、黑白图像和十个类别。主要区别在于类别对应于不同类型的衣物,而不是数字,且图像比简单的手写数字包含更多细节。
下面来看一些例子。
from datasets import load_dataset
from utils.utils import show_images
fashion_mnist = load_dataset("fashion_mnist")
clothes = fashion_mnist["train"]["image"][:8]
classes = fashion_mnist["train"]["label"][:8]
show_images(clothes, titles=classes, figsize=(4, 2.5))
因此,类别0
对应于T恤,类别2
是毛衣,而类别9
是靴子(Fashion MNIST的十个类别)。我们准备数据集和数据加载器的方式与上一章相似,主要的区别是类别信息也作为输入。与上一章的调整大小操作不同,这次我们会将图像输入(28 × 28
像素)填充到32 × 32
像素。这将保持原始图像的质量,有助于UNet做出更高质量的预测。
注:上一章中图像非常大(512,283
),我们得将其缩放至更小的尺寸。
import torch
from torchvision import transforms
preprocess = transforms.Compose(
[
transforms.RandomHorizontalFlip(), # Randomly flip (data augmentation)
transforms.ToTensor(), # Convert to tensor (0, 1)
transforms.Pad(2), # Add 2 pixels on all sides
transforms.Normalize([0.5], [0.5]), # Map to (-1, 1)
]
)
def transform(examples):
images = [preprocess(image) for image in examples["image"]]
return {"images": images, "labels": examples["label"]}
train_dataset = fashion_mnist["train"].with_transform(transform)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=256, shuffle=True
)
创建类别条件模型
diffusers库中的UNet
允许提供自定义条件信息。在这里,我们创建一个与上一章中使用的模型类似的模型,但在UNet
构造函数中添加了一个num_class_embeds
参数。这个参数告诉模型我们希望使用类别标签作为额外条件。我们将使用10,因为这是Fashion MNIST中的类别数。
from diffusers import UNet2DModel
model = UNet2DModel(
in_channels=1, # 1 channel for grayscale images
out_channels=1,
sample_size=32,
block_out_channels=(32, 64, 128, 256),
num_class_embeds=10, # Enable class conditioning
)
要使用这个模型进行预测,我们必须将类别标签作为额外输入传递给forward()
方法:
x = torch.randn((1, 1, 32, 32))
with torch.no_grad():
out = model(x, timestep=7, class_labels=torch.tensor([2])).sample
out.shape
torch.Size([1, 1, 32, 32])
注:我们还将另一样参数作为条件传递给模型: 时间步!没错,即使是上一章中的模型也可以视作一个条件扩散模型。根据时间步对其进行条件化,期望了解我们在扩散过程中的进展程度将有助于生成更真实的图像。
在内部,时间步和类别标签会被转换为模型在前向传播过程中使用的嵌入。在UNet的多个阶段,这些嵌入会被投射到与给定层中的通道数匹配的维度。然后,这些嵌入会被添加到该层的输出中。这意味着条件信息会被传递到UNet的每个块中,给模型充分的机会来学习如何有效地使用它。
训练模型
在灰度图像上添加噪声与上一章的蝴蝶图像效果一样好。来看看在更多噪声时间步时噪声的影响。
from diffusers import DDPMScheduler
scheduler = DDPMScheduler(
num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02
)
timesteps = torch.linspace(0, 999, 8).long()
batch = next(iter(train_dataloader))
x = batch["images"][0].expand([8, 1, 32, 32])
noise = torch.rand_like(x)
noised_x = scheduler.add_noise(x, noise, timesteps)
show_images((noised_x * 0.5 + 0.5).clip(0, 1))
我们的训练方式也几乎与上一章相同,只是现在传递类别标签作为条件。注意,这只是为模型提供的附加信息,并不影响定义损失函数的方式。开启训练,可以泡杯茶、咖啡或其他饮料。
注:我们还将使用Python包tqdm在训练过程中显示进度。作者忍不住分享他们文档中的这句话(https://tqdm.github.io):
tqdm在阿拉伯语中意为“进展”(taqadum, تقدّم),在西班牙语中是“我非常爱你”(te quiero demasiado)的缩写。
不要被下面的代码吓到:它与我们进行无条件生成时的类似(建议将这段代码与上一章的代码进行对比。你能找到所有的不同之处吗?)。
- 加载一批图像及其相应的标签。
- 根据时间步为图像添加噪声。
- 将带噪声的图像和类别标签喂给模型。
- 计算损失。
- 反向传播损失并用优化器更新模型权重。
注:epoch和学习率的数量不同、AdamW
的epsilon不一样,还使用了_tqdm加载数据、标签并将标签传递给模型。最重要的是条件为2-line diff。
from torch.nn import functional as F
from tqdm import tqdm
scheduler = DDPMScheduler(
num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02
)
num_epochs = 25
lr = 3e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, eps=1e-5)
losses = [] # Somewhere to store the loss values for later plotting
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Train the model (this takes a while!)
for epoch in (progress := tqdm(range(num_epochs))):
for step, batch in (
inner := tqdm(
enumerate(train_dataloader),
position=0,
leave=True,
total=len(train_dataloader),
)
):
# Load the input images and classes
clean_images = batch["images"].to(device)
class_labels = batch["labels"].to(device)
# Sample noise to add to the images
noise = torch.randn(clean_images.shape).to(device)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
scheduler.config.num_train_timesteps,
(clean_images.shape[0],),
device=device,
).long()
# Add noise to the clean images according to the timestep
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
# Get the model prediction for the noise - note the use of class_labels
noise_pred = model(
noisy_images,
timesteps,
class_labels=class_labels,
return_dict=False,
)[0]
# Compare the prediction with the actual noise:
loss = F.mse_loss(noise_pred, noise)
# Display loss
inner.set_postfix(loss=f"{loss.cpu().item():.3f}")
# Store the loss for later plotting
losses.append(loss.item())
# Update the model parameters with the optimizer based on this loss
loss.backward(loss)
optimizer.step()
optimizer.zero_grad()
import matplotlib.pyplot as plt
plt.plot(losses)
采样
现在有了一个在做预测时需要两个输入即图像和类别标签的模型。可以通过从随机噪声开始,然后逐步去噪,传入我们想生成的类别标签来创建样本:
def generate_from_class(class_to_generate, n_samples=8):
sample = torch.randn(n_samples, 1, 32, 32).to(device)
class_labels = [class_to_generate] * n_samples
class_labels = torch.tensor(class_labels).to(device)
for _, t in tqdm(enumerate(scheduler.timesteps)):
# Get model pred
with torch.no_grad():
noise_pred = model(sample, t, class_labels=class_labels).sample
# Update sample with step
sample = scheduler.step(noise_pred, t, sample).prev_sample
return sample.clip(-1, 1) * 0.5 + 0.5
# Generate t-shirts (class 0)
images = generate_from_class(0)
show_images(images, nrows=2)
1000it [00:13, 75.05it/s]
# Now generate some sneakers (class 7)
images = generate_from_class(7)
show_images(images, nrows=2)
1000it [00:13, 76.91it/s]
# ...or boots (class 9)
images = generate_from_class(9)
show_images(images, nrows=2)
1000it [00:13, 76.29it/s]
可以看到,生成的图像远称不上完美。如果进一步探索架构并延长训练时间,图像质量将大大提升。但令人惊奇的是,模型仅通过对训练数据发送这些信息,就学习到了不同类型衣物的形状,并意识到形状9
与形状0
不同。换句话说,模型习惯于看到数字9
与靴子一起出现。当我们要求它生成图像并提供数字9时,它会生成一双靴子。