基于 LoRA+Stable Diffusion 的 100 种动物图像生成
一个基于 Stable Diffusion 和 LoRA 技术的动物图像生成系统。项目包含完整的训练流程和 PyQt5 图形界面,支持文本生成高质量动物图像。内容涵盖模型架构解析(VAE、CLIP、U-Net)、LoRA 微调原理、训练代码实现(参数配置、数据处理、早停机制)及 UI 交互设计。通过 CLIP 分数评估生成质量,实现了跨平台运行和参数自定义调整。

一个基于 Stable Diffusion 和 LoRA 技术的动物图像生成系统。项目包含完整的训练流程和 PyQt5 图形界面,支持文本生成高质量动物图像。内容涵盖模型架构解析(VAE、CLIP、U-Net)、LoRA 微调原理、训练代码实现(参数配置、数据处理、早停机制)及 UI 交互设计。通过 CLIP 分数评估生成质量,实现了跨平台运行和参数自定义调整。

代码详见:https://github.com/xiaozhou-alt/Animals_Generation
这是一个基于 Stable Diffusion 和 LoRA 技术的动物图像生成系统,能够通过文本描述生成高质量的动物图像,包含完整的训练流程和用户友好的图形界面,支持自定义参数调整和实时图像生成。
主要特性
生成的部分动物图像:
Animals_Creation/
├── README.md
├── demo.gif # 演示动画
├── demo.mp4 # 演示视频
├── demo.py # 主演示脚本
├── icons/ # 图标资源目录
├── train.py
├── log/ # 日志目录
├── model/
│ └── LCM-runwayml-stable-diffusion-v1-5/ # Stable Diffusion 模型
│ ├── feature_extractor/ # 特征提取器
│ ├── model_index.json # 模型索引文件
│ ├── safety_checker/ # 安全检查器
│ ├── scheduler/ # 调度器
│ ├── text_encoder/ # 文本编码器
│ ├── tokenizer/ # 分词器
│ ├── unet/ # UNet 模型
│ └── vae/ # 变分自编码器
├── output/
│ ├── evaluation_results.xlsx # 评估结果 Excel 文件
│ ├── lora_models/ # LoRA 模型权重
│ │ └── clip-31.475.safetensors
│ ├── training_history.xlsx # 训练历史记录
│ └── pic/
└── requirements.txt
本项目使用的动物数据集包含 100 个不同类别的动物图片,因为使用网页图片提取下载,清洗由个人完全进行,数据集数据量较大,所以部分动物文件夹存在 1%-1.5% 的噪声图片,数据集组织结构如下:
在模型训练过程中,通过数据增强技术扩充了训练样本,包括旋转、平移、缩放、亮度调整等操作,以提高模型的泛化能力。
动物的类别信息请查看 class.txt:
antelope badger bat …
数据集下载:100 种动物识别数据集 (ScienceDB)
引用 如果您使用了本项目的数据集,请使用如下方式进行引用:
Haojing ZHOU.100 种动物识别数据集 [DS/OL]. V1. Science Data Bank,2025[2025-08-30]. https://cstr.cn/31253.11.sciencedb.29221. CSTR:31253.11.sciencedb.29221.
或
@misc{动物识别,author ={Haojing ZHOU}, title ={100 种动物识别数据集}, year ={2025}, doi ={10.57760/sciencedb.29221}, url ={https://doi.org/10.57760/sciencedb.29221}, note ={CSTR:31253.11.sciencedb.29221}, publisher ={ScienceDB}}
Stable Diffusion 采用潜在扩散模型(Latent Diffusion Model)架构,通过将高维图像压缩到低维潜在空间进行扩散过程,显著提升了计算效率。该模型主要由四个核心组件构成:变分自编码器(VAE) + CLIP 文本编码器 + U-Net 模型 + 噪声调度器(DDPMScheduler)
VAE 在 Stable Diffusion 中承担着图像与潜在空间的双向转换任务。其编码器将输入图像 x 压缩为潜在表示 z,解码器则将潜在表示重建为图像 x^。在代码实现中,我们使用预训练的 AutoencoderKL 模型:
vae = AutoencoderKL.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="vae"
)
VAE 的核心工作原理是通过变分推断学习数据的潜在分布。对于输入图像 x,编码器输出潜在分布的均值 μ 和方差 σ²,通过重参数化技巧采样得到潜在表示: z = μ + ε · σ, ε ~ N(0, I)
在项目中,我们将编码得到的潜在表示进行缩放:
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * 0.18215 # 缩放因子
这里的缩放因子 0.18215 是 Stable Diffusion 模型预训练时确定的常数,用于将 VAE 输出的潜在空间分布标准化到更适合扩散过程的范围。
文本引导 是 Stable Diffusion 的核心特性,这一功能由 CLIP(Contrastive Language-Image Pretraining)文本编码器实现。它将文本描述转换为固定维度的向量表示,建立文本与图像之间的语义关联:
text_encoder = CLIPTextModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="text_encoder"
)
tokenizer = CLIPTokenizer.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="tokenizer"
)
CLIP 文本编码器通过对比学习训练,其输出的文本嵌入 t 与图像嵌入在同一语义空间中。对于输入文本 w(如 "a photo of a cat"),经过分词和编码后得到文本特征: t = text_encoder(tokenizer(w))
在项目中,我们使用多样化的提示词模板增强文本嵌入的鲁棒性:
self.prompt_templates = [
"a photo of a {}",
"a high quality image of a {}",
# 更多模板...
]
U-Net 是 Stable Diffusion 的 核心扩散模块,负责在潜在空间中 预测噪声。它以带噪声的潜在表示 z_t、时间步 t 和文本嵌入 c 作为输入,输出噪声预测 ε_θ(z_t, t, c):
unet = UNet2DConditionModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="unet"
)
# 噪声预测 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
U-Net 采用编码器 - 解码器结构,通过跳跃连接保留细节信息,同时引入时间步嵌入和文本条件嵌入,实现条件生成。损失函数采用预测噪声与真实噪声的均方误差: L = E[z_0, ε, t] [ ||ε - ε_θ(z_t, t, c)||² ]
噪声调度器控制着扩散过程中的 噪声添加和去除 策略。在训练阶段,它按照特定 schedule 向干净样本添加噪声;在推理阶段,则逐步从纯噪声中生成图像:
noise_scheduler = DDPMScheduler.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="scheduler"
)
扩散过程遵循 马尔可夫链,前向过程中噪声逐步增加: z_t = √α_t z_{t-1} + √(1-α_t) ε, ε ~ N(0, I)
其中 α_t 是调度器预定义的噪声系数。在项目中,我们通过调度器添加噪声:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
在大规模预训练模型的微调任务中,全参数微调需要巨大的计算资源。LoRA(Low-Rank Adaptation)技术通过 冻结预训练模型权重,仅训练低秩矩阵参数,实现高效微调:
def prepare_unet_for_lora(unet, rank=2, alpha=16):
lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
target_modules=["to_q","to_k","to_v","to_out.0"],
lora_dropout=0.0,
bias="none",
)
unet = get_peft_model(unet, lora_config)
return unet
LoRA 的核心思想是将权重更新表示为低秩矩阵分解的形式。对于预训练权重 W ∈ R^{d×k},LoRA 通过学习两个低秩矩阵 W_A ∈ R^{d×r} 和 W_B ∈ R^{r×k}(r << min(d,k))来近似权重更新: W' = W + W_B W_A
在项目中,我们将 LoRA 应用于 U-Net 的注意力模块,具体是查询(to_q)、键(to_k)、值(to_v)投影层和输出投影层(to_out.0):
Attention(Q + ΔQ, K + ΔK, V + ΔV)
其中 ΔQ = W_B^Q W_A^Q,ΔK 和 ΔV 类似。这种设计使模型能够在保持预训练知识的同时,高效学习特定任务的知识。
在项目配置中,我们选择了较小的 秩(rank=2)和 alpha 值(lora_alpha=16):
# LoRA 参数
rank = 2
lora_alpha = 16
这种配置大大减少了可训练参数数量。通过 print_trainable_parameters() 可以发现,仅约 0.1% 的参数参与训练,显著降低了内存需求和计算成本。同时,LoRA 权重文件体积小(通常只有几 MB),便于存储和分享。
本文项目实现部分分割较细,篇幅较长,读者不想深究代码原理可见 GitHub 项目中 README.md 进行项目实现;若是愿意深究代码逻辑,下文对于 训练代码 和 UI 界面使用代码 进行了详细说明。
Config 类集中管理了所有关键参数,体现了资源受限情况下的优化策略:
256x256)、减少 LoRA 秩(rank=2)等措施,显著降低显存占用,使训练在普通 GPU 上成为可能。max_grad_norm=0.5),有效防止训练过程中的 梯度爆炸 问题。max_samples_per_class 限制每类样本数量,解决动物数据集类别不平衡问题,避免模型对样本多的类别过拟合。# 参数配置 - 关键优化点
class Config:
# 数据参数 - 减少数据量
data_root = "/kaggle/input/animals/Animal/Animal" # 动物数据集根路径
output_dir = "/kaggle/working/output" # 所有输出文件的目录
lora_model_dir = os.path.join(output_dir, "lora_models") # 保存 LoRA 模型的目录
history_file = os.path.join(output_dir, "training_history.xlsx") # 训练历史记录文件
sample_output_dir = os.path.join(output_dir, "validation_samples") # 验证样本输出目录
evaluation_file = os.path.join(output_dir, "evaluation_results.xlsx") # 评估结果文件
comparison_dir = os.path.join(output_dir, "comparison_samples") # 对比样本目录
# 模型参数 - 降低分辨率
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" # 使用 SD 1.5 作为基础模型
resolution = 256 # 降低分辨率以减少计算量 (原为 512)
center_crop = True # 中心裁剪
random_flip = True # 随机水平翻转 (数据增强)
# LoRA 参数 - 简化 LoRA
rank = 2 # 降低 LoRA 的秩 (原为 4)
lora_alpha = 16 # 降低 LoRA 的 alpha 值 (原为 32)
# 训练参数 - 关键优化
train_batch_size = 1 # 批大小
gradient_accumulation_steps = 4 # 梯度累积步数
num_train_epochs = 10 # 训练轮数
learning_rate = 1e-5 # 学习率
lr_scheduler_type = "cosine_with_warmup" # 学习率调度器类型
lr_warmup_steps =
max_grad_norm =
use_ema =
gradient_checkpointing =
mixed_precision =
early_stopping_patience =
early_stopping_delta =
validation_split =
num_validation_samples =
num_inference_steps =
num_final_inference_steps =
guidance_scale =
max_samples_per_class =
num_evaluation_samples =
clip_model_name =
AnimalDataset 类实现了动物图像数据集的 加载和预处理 功能,核心特点包括:
"根目录 / 动物类别 / 图像文件" 的层级结构,通过扫描子文件夹自动识别类别名称。max_samples_per_class 的类别进行随机采样,确保各类别样本量相对均衡。特写、自然栖息地)和属性(可爱、野生)的描述,丰富了模型的条件学习信号。# 1. 数据处理与准备 - 添加样本限制
class AnimalDataset(Dataset):
def __init__(self, data_root, tokenizer, size=384, center_crop=True, random_flip=True, max_samples_per_class=100):
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size # 使用新的分辨率
self.center_crop = center_crop
self.random_flip = random_flip
self.max_samples_per_class = max_samples_per_class
# 获取所有图像路径和对应的类别(动物名称)
self.image_paths = []
self.class_names = []
# 假设子文件夹以动物英文名称命名
subfolders = [f.name for f in os.scandir(data_root) if f.is_dir()]
for class_name in subfolders:
class_dir = os.path.join(data_root, class_name)
image_files = glob.glob(os.path.join(class_dir, "*.jpg")) + \
glob.glob(os.path.join(class_dir, "*.png")) + \
glob.glob(os.path.join(class_dir, "*.jpeg"))
# 限制每类样本数量
if len(image_files) > max_samples_per_class:
image_files = random.sample(image_files, max_samples_per_class)
for img_path in image_files:
self.image_paths.append(img_path)
self.class_names.append(class_name)
# 为每个类别创建提示词模板
self.prompt_templates = [
,
,
,
,
,
,
]
LANCZOS 重采样方法进行缩放,保证图像质量CLIP tokenizer 对文本进行编码,生成模型可理解的输入 _idspadding 和 truncation),便于批量处理def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
class_name = self.class_names[idx]
# 加载和预处理图像
image = Image.open(image_path).convert("RGB")
# 调整大小和中心裁剪
if self.center_crop:
# 保持宽高比的调整大小和中心裁剪
image = self._center_crop(image)
else:
image = image.resize((self.size, self.size), Image.Resampling.LANCZOS)
# 随机水平翻转 (数据增强)
if self.random_flip and random.random() < 0.5:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
# 将图像转换为模型输入的张量 (-1 to 1)
image_tensor = (torch.tensor(np.array(image).astype(np.float32)/127.5)-1.0).permute(2,0,1)
# 为图像生成随机的提示词
prompt_template = random.choice(self.prompt_templates)
prompt = prompt_template.format(class_name)
# 对提示词进行标记化
tokenized_input = self.tokenizer(
prompt,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = tokenized_input.input_ids.squeeze(0)
return {
: image_tensor,
: input_ids,
: prompt,
: class_name
}
():
width, height = image.size
new_size = (width, height)
left = (width - new_size)/
top = (height - new_size)/
right = (width + new_size)/
bottom = (height + new_size)/
image = image.crop((left, top, right, bottom))
image = image.resize((.size, .size), Image.Resampling.LANCZOS)
image
传统的基于损失的早停可能无法准确反映生成模型的质量,本项目采用 基于 CLIP 分数 的早停策略,具有以下特点:
图像 - 文本匹配度)不再显著提升时停止训练,更符合生成任务的质量目标。patience: 允许分数不提升的轮数,设置为 5 给予模型足够的优化空间delta: 最小改善阈值,0.02 确保只有显著提升才被认可# 早停机制 (PyTorch 实现) - 使用 CLIP 分数作为指标
class EarlyStopping:
def __init__(self, patience=3, delta=0.05, verbose=False):
self.patience = patience
self.delta = delta
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, clip_score):
if self.best_score is None:
self.best_score = clip_score
elif clip_score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = clip_score
self.counter = 0
函数实现了对 Stable Diffusion 核心组件UNet 的 LoRA 适配,是参数高效微调的关键:
r(rank):低秩矩阵的秩,设置为 2 大幅减少可训练参数lora_alpha:缩放因子,与秩配合控制更新幅度target_modules:指定需要注入 LoRA 的模块,选择注意力层的查询、键、值投影和输出层get_peft_model 函数将 LoRA 适配器注入UNet,仅训练少量适配器参数而非整个模型。print_trainable_parameters() 会输出可训练参数比例,通常仅为原始模型的 0.1% 左右。# 为 UNet 准备 LoRA 的函数 - 使用 peft 库
def prepare_unet_for_lora(unet, rank=2, alpha=16):
# 配置 LoRA
lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
target_modules=["to_q","to_k","to_v","to_out.0"],
lora_dropout=0.0,
bias="none",
)
# 应用 LoRA 到 UNet
unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters()
return unet
CLIP(Contrastive Language-Image Pretraining)分数用于 量化评估生成图像与文本描述的匹配程度,是生成质量的重要指标:
logits_per_image 表示图像与文本的匹配分数,值越高表示匹配度越好torch.autocast 和 torch.no_grad 优化计算效率# 计算验证 CLIP 分数的函数
def compute_validation_clip_score(config, unet, text_encoder, vae, tokenizer, device):
# 获取所有动物类别
animal_classes = [f.name for f in os.scandir(config.data_root) if f.is_dir()]
# 随机选择验证用的动物
selected_animals = random.sample(animal_classes, min(config.num_validation_samples, len(animal_classes)))
print(f"Selected animals for validation CLIP score: {selected_animals}")
# 创建生成管道
pipe = StableDiffusionPipeline.from_pretrained(
config.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
safety_checker=None,
torch_dtype=torch.float16 if config.mixed_precision == "fp16" else torch.float32
).to(device)
# 加载 CLIP 模型和处理器
clip_model = CLIPModel.from_pretrained(config.clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(config.clip_model_name)
clip_scores = []
for animal in selected_animals:
prompt = f"a high quality photo of a {animal}"
# 生成图像
with torch.autocast(device.type):
image = pipe(
prompt,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
height=config.resolution,
width=config.resolution
).images[0]
# 计算 CLIP Score
with torch.no_grad():
# 处理图像和文本
inputs = clip_processor(
text=[prompt],
images=image,
return_tensors="pt",
padding=True
).to(device)
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
clip_score = logits_per_image.item()
()
clip_scores.append(clip_score)
avg_clip_score = np.mean(clip_scores)
()
avg_clip_score
模型初始化: 训练函数的初始部分负责加载和配置 Stable Diffusion 的核心组件:
gradient checkpointing)牺牲少量计算时间换取显存节省# 2. 训练函数 (包含早停和历史记录)
def train_lora_with_earlystopping(config):
# 初始化模型组件
tokenizer = CLIPTokenizer.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="vae"
)
unet = UNet2DConditionModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="unet"
)
# 添加 LoRA 适配器到 UNet
unet = prepare_unet_for_lora(unet, config.rank, config.lora_alpha)
# 设置噪声调度器
noise_scheduler = DDPMScheduler.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="scheduler"
)
# 启用梯度检查点以节省显存
if config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# 将模型移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_encoder.to(device)
vae.to(device)
unet.to(device)
优化器与数据加载配置 :
AdamW 优化器,配合合理的权重衰减(0.01)防止过拟合AnimalDataset 加载数据DataLoader 实现批量加载和多进程预处理# 设置优化器 (只优化 LoRA 参数)
lora_params = []
for name, param in unet.named_parameters():
if param.requires_grad:
# 只优化需要梯度的参数
lora_params.append(param)
# 使用更稳定的优化器配置
optimizer = torch.optim.AdamW(
lora_params,
lr=config.learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01
)
# 准备数据集和数据加载器
full_dataset = AnimalDataset(
config.data_root,
tokenizer,
size=config.resolution,
center_crop=config.center_crop,
random_flip=config.random_flip,
max_samples_per_class=config.max_samples_per_class
)
# 分割训练集和验证集
val_size = int(len(full_dataset) * config.validation_split)
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_dataloader = DataLoader(
train_dataset,
batch_size=config.train_batch_size,
shuffle=True,
num_workers=2
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config.train_batch_size,
shuffle=False,
num_workers=2
)
训练调度与记录配置:
# 计算总训练步数
num_update_steps_per_epoch = len(train_dataloader) // config.gradient_accumulation_steps
max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
# 学习率调度器
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=max_train_steps
)
# 初始化早停 (使用 CLIP 分数作为指标)
early_stopping = EarlyStopping(
patience=config.early_stopping_patience,
delta=config.early_stopping_delta,
verbose=True
)
# 创建 Excel 工作簿用于记录历史
history_wb = Workbook()
history_ws = history_wb.active
history_ws.title = "Training History"
history_ws.append(["Epoch", "Step", "Train Loss", "Validation Loss", "CLIP Score", "Learning Rate", "Best CLIP Score", "Gradient Norm"])
核心训练循环:
训练循环实现了 Stable Diffusion 的噪声预测训练过程,关键步骤包括:
with torch.no_grad() 冻结 VAE 参数0.18215 缩放因子,这是 Stable Diffusion 的标准处理流程# 训练循环
global_step = 0
best_clip_score = 0.0
# 训练循环部分
for epoch in range(config.num_train_epochs):
unet.train()
total_loss = 0
optimizer.zero_grad()
current_grad_norm = 0.0
# 初始化梯度范数
for step, batch in enumerate(train_dataloader):
# 将批次数据移动到设备
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"].to(device)
# 将图像编码到潜在空间
with torch.no_grad():
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * 0.18215 # 缩放因子
# 采样噪声
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=device).long()
# 向潜在表示添加噪声
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# 获取文本嵌入
with torch.no_grad():
encoder_hidden_states = text_encoder(input_ids)[0]
# 预测噪声残差
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# 计算损失
loss = F.mse_loss(noise_pred, noise, reduction="mean") / config.gradient_accumulation_steps
# 反向传播
loss.backward()
# 梯度累积
if (step + 1) % config.gradient_accumulation_steps == 0:
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(lora_params, config.max_grad_norm)
# 计算梯度范数用于监控
current_grad_norm =
p lora_params:
p.grad :
param_norm = p.grad.data.norm()
current_grad_norm += param_norm.item() **
current_grad_norm = current_grad_norm **
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step +=
total_loss += loss.item() * config.gradient_accumulation_steps
global_step % == :
avg_loss = total_loss / (step + )
current_lr = lr_scheduler.get_last_lr()[]
()
epoch 后处理流程:
每个训练 epoch 结束后执行的关键操作:
# 每个 epoch 结束后计算验证损失和 CLIP 分数
val_loss = compute_validation_loss(unet, vae, text_encoder, val_dataloader, noise_scheduler, device)
avg_train_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch} completed. Train Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss:.4f}")
# 计算 CLIP 分数
clip_score = compute_validation_clip_score(config, unet, text_encoder, vae, tokenizer, device)
# 记录到历史
current_lr = lr_scheduler.get_last_lr()[0]
history_ws.append([epoch, global_step, avg_train_loss, val_loss, clip_score, current_lr, best_clip_score, current_grad_norm])
# 早停检查 (基于 CLIP 分数)
early_stopping(clip_score)
# 保存最佳模型
if clip_score > best_clip_score:
best_clip_score = clip_score
# 保存 LoRA 权重
save_path = os.path.join(config.lora_model_dir, f"lora_weights_epoch_{epoch}.safetensors")
save_lora_weights(unet, save_path)
print(f"Saved best model with CLIP score: {best_clip_score:.4f}")
# 保存训练历史
history_wb.save(config.history_file)
# 检查早停
if early_stopping.early_stop:
print("Early stopping triggered")
break
print("Training completed!")
return unet, text_encoder, vae, tokenizer
训练过程示例输出如下所示:
Starting LoRA training… trainable params: 398,592 || all params: 859,919,556 || trainable%: 0.0464 Epoch 0, Step 0, Loss: 0.0044, LR: 0.000000, Grad Norm: 0.000000 Epoch 0, Step 0, Loss: 0.0077, LR: 0.000000, Grad Norm: 0.000000 Epoch 0, Step 0, Loss: 0.0216, LR: 0.000000, Grad Norm: 0.000000 Epoch 0, Step 50, Loss: 0.2194, LR: 0.000003, Grad Norm: 0.227759 Epoch 0, Step 50, Loss: 0.2198, LR: 0.000003, Grad Norm: 0.227759 … Epoch 0, Step 2250, Loss: 0.1937, LR: 0.000010, Grad Norm: 0.048456 Epoch 0 completed. Train Loss: 0.1937, Validation Loss: 0.1842 Selected animals for validation CLIP score: ['dolphin', 'zebra', 'sandpiper', 'swan', 'pig'] Animal: dolphin, CLIP Score: 29.2128 Animal: zebra, CLIP Score: 32.8276 Animal: sandpiper, CLIP Score: 30.2431 Animal: swan, CLIP Score: 28.6533 Animal: pig, CLIP Score: 29.8067 Average Validation CLIP Score: 30.1487 Saved best model with CLIP score: 30.1487 … Evaluation results saved to /kaggle/working/output/evaluation_results.xlsx All done! Average CLIP Score: 31.1877
验证样本生成:
该函数在训练完成后生成代表性样本用于可视化评估:
100 步)生成高质量样本# 3. 验证和生成样本
def generate_validation_samples(config, unet, text_encoder, vae, tokenizer):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 获取所有动物类别
animal_classes = [f.name for f in os.scandir(config.data_root) if f.is_dir()]
# 随机选择 5 种动物
selected_animals = random.sample(animal_classes, config.num_validation_samples)
print(f"Selected animals for validation: {selected_animals}")
# 创建生成管道
pipe = StableDiffusionPipeline.from_pretrained(
config.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
safety_checker=None, # 禁用安全检查器以加快生成速度
torch_dtype=torch.float16 if config.mixed_precision == "fp16" else torch.float32
).to(device)
# 生成每种动物的图像
all_images = []
all_titles = []
for animal in selected_animals:
prompt = f"a high quality photo of a {animal}"
# 生成图像 (使用更多推理步数)
with torch.autocast(device.type):
image = pipe(
prompt,
num_inference_steps=config.num_final_inference_steps,
guidance_scale=config.guidance_scale,
height=config.resolution,
width=config.resolution
).images[0]
# 保存图像
save_path = os.path.join(config.sample_output_dir, f"{animal}.png")
image.save(save_path)
print(f"Generated image for {animal} saved at ")
all_images.append(image)
all_titles.append(animal)
comparison_path = os.path.join(config.comparison_dir, )
create_comparison_image(all_images, all_titles, comparison_path)
selected_animals
量化评估实现:
该函数提供训练后的全面量化评估:
# 4. 评估函数 - 使用 CLIP Score 评估生成质量
def evaluate_with_clip_score(config, unet, text_encoder, vae, tokenizer):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载 CLIP 模型和处理器
clip_model = CLIPModel.from_pretrained(config.clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(config.clip_model_name)
# 获取所有动物类别
animal_classes = [f.name for f in os.scandir(config.data_root) if f.is_dir()]
# 随机选择评估用的动物
selected_animals = random.sample(animal_classes, config.num_evaluation_samples)
print(f"Selected animals for evaluation: {selected_animals}")
# 创建生成管道
pipe = StableDiffusionPipeline.from_pretrained(
config.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
safety_checker=None,
torch_dtype=torch.float16 if config.mixed_precision == "fp16" else torch.float32
).to(device)
# 存储评估结果
evaluation_results = []
for animal in selected_animals:
prompt = f"a high quality photo of a {animal}"
# 生成图像 (使用更多推理步数)
with torch.autocast(device.type):
image = pipe(
prompt,
num_inference_steps=config.num_final_inference_steps,
guidance_scale=config.guidance_scale,
height=config.resolution,
width=config.resolution
).images[0]
# 保存图像
save_path = os.path.join(config.sample_output_dir, f"eval_{animal}.png")
image.save(save_path)
# 计算 CLIP Score
torch.no_grad():
inputs = clip_processor(
text=[prompt],
images=image,
return_tensors=,
padding=
).to(device)
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
clip_score = logits_per_image.item()
()
evaluation_results.append({
: animal,
: prompt,
: clip_score,
: save_path
})
avg_clip_score = np.mean([result[] result evaluation_results])
()
evaluation_wb = Workbook()
evaluation_ws = evaluation_wb.active
evaluation_ws.title =
evaluation_ws.append([, , , ])
result evaluation_results:
evaluation_ws.append([result[], result[], result[], result[]])
evaluation_ws.append([])
evaluation_ws.append([, avg_clip_score])
evaluation_wb.save(config.evaluation_file)
()
evaluation_results, avg_clip_score
pretrained_model_name_or_path 指定本地 Stable Diffusion 基础模型目录,需包含 text_encoder、unet、vae 等核心组件;contrast/saturation/brightness_factor 为后期图像校正参数,通过 PIL 库实现实时调整,解决扩散模型生成图像可能偏灰、色彩暗淡的问题。# 配置类 - 增加色彩相关参数
class Config:
pretrained_model_name_or_path = "model/LCM-runwayml-stable-diffusion-v1-5" # 本地基础模型路径
resolution = 512 # 生成图像分辨率(默认 512x512)
rank = 2 # LoRA 微调秩(控制微调强度)
lora_alpha = 16 # LoRA 缩放因子
device = "cpu" # 运行设备(cpu/cuda,cuda 需安装 GPU 版本 PyTorch)
num_final_inference_steps = 100 # 默认推理步数(步数越多生成越精细,但耗时更长)
guidance_scale = 5.0 # 引导尺度(控制提示词对生成的影响,值越低色彩越自然)
contrast_factor = 1.0 # 对比度调整因子(1.0 为默认,<1 降低对比度,>1 增强)
saturation_factor = 1.0 # 饱和度调整因子(同上,影响色彩鲜艳度)
brightness_factor = 1.0 # 亮度调整因子(同上,影响图像明暗)
包含 3 个核心工具函数,分别解决 LoRA 权重加载、Tokenizer 加载异常、图像色彩校正 三大关键问题,是模型正常运行与生成效果优化的基础
LoRA(Low-Rank Adaptation)是轻量级微调技术,通过加载预训练的 LoRA 权重,可让基础模型快速适配 '动物图像生成' 场景(无需重新训练整个模型)
# 加载 LoRA 权重的函数
def load_lora_weights(unet, load_path):
# 从本地文件加载 LoRA 权重,指定设备(与模型一致)
lora_state_dict = torch.load(load_path, map_location=torch.device(Config.device))
# 非严格模式加载(LoRA 权重仅覆盖 unet 部分层,无需匹配所有参数)
unet.load_state_dict(lora_state_dict, strict=False)
return unet
Tokenizer(文本分词器)是将 '提示词' 转换为模型可识别向量的组件,该函数解决了 '本地模型 Tokenizer 路径异常' 的常见问题,提供降级加载方案
# 修复 tokenizer 加载问题的函数
def load_tokenizer_with_fix(model_path):
try:
# 尝试正常加载(默认路径:模型目录下的 tokenizer 文件夹)
tokenizer = CLIPTokenizer.from_pretrained(
os.path.join(model_path, "tokenizer")
)
return tokenizer
except Exception as e:
print(f"加载 tokenizer 时出错:{e}")
print("尝试修复 tokenizer 配置...")
# 降级方案:手动指定 vocab.json 和 merges.txt 文件(Tokenizer 核心文件)
from transformers import CLIPTokenizerFast
vocab_file = os.path.join(model_path, "tokenizer", "vocab.json")
merges_file = os.path.join(model_path, "tokenizer", "merges.txt")
if os.path.exists(vocab_file) and os.path.exists(merges_file):
tokenizer = CLIPTokenizerFast(
vocab_file=vocab_file,
merges_file=merges_file,
max_length=77, # CLIP 模型固定输入长度(超过截断,不足补全)
pad_token="!", # 填充 token(统一输入长度)
additional_special_tokens=["<startoftext|>","<endoftext|>"] # 特殊分隔符
)
return tokenizer
else:
raise Exception(f"找不到 tokenizer 文件:{vocab_file} 或 {merges_file}")
扩散模型生成的图像常存在 '对比度不足、色彩暗淡' 问题,该函数通过 PIL 的 ImageEnhance 模块,对生成图像进行后处理,提升视觉效果
# 图像色彩校正函数
def adjust_image_colors(image):
"""调整图像的色彩、对比度和饱和度,使其更自然"""
# 1. 调整对比度(增强细节层次)
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(Config.contrast_factor)
# 2. 调整饱和度(提升色彩鲜艳度,避免偏灰)
enhancer = ImageEnhance.Color(image)
image = enhancer.enhance(Config.saturation_factor)
# 3. 调整亮度(平衡整体明暗,避免过暗/过曝)
enhancer = ImageEnhance.Brightness(image)
image = enhancer.enhance(Config.brightness_factor)
return image
Stable Diffusion 由 Tokenizer + Text Encoder + UNet + VAE + Scheduler 五大组件构成,ModelLoader 类负责将这些组件从本地加载、组装为可直接调用的 StableDiffusionPipeline(生成流水线),并集成 LoRA 权重
Text Encoder 负责文本→向量,UNet 负责向量→latent(隐空间向量),VAE 负责latent→图像,Scheduler 控制去噪步数节奏;UNet 加载 LoRA 权重 —— 因为 UNet 是扩散模型的核心生成层,微调 UNet 即可快速改变生成风格(动物图像),无需调整其他组件;# 模型加载类
class ModelLoader:
def __init__(self, config, lora_model_path):
self.config = config # 全局配置
self.lora_model_path = lora_model_path # LoRA 模型路径
# 初始化各组件(后续加载)
self.tokenizer = None # 文本分词器
self.text_encoder = None # 文本编码器(将分词结果转为向量)
self.vae = None # 变分自编码器(负责图像解码:latent→像素)
self.unet = None # 核心生成网络(扩散过程核心,更新 latent)
self.pipe = None # 最终生成流水线
def load_models(self):
# 1. 加载 Tokenizer(调用修复函数,避免路径异常)
self.tokenizer = load_tokenizer_with_fix(self.config.pretrained_model_name_or_path)
# 2. 加载 Text Encoder(CLIP 模型,将文本转为语义向量)
text_encoder_path = os.path.join(self.config.pretrained_model_name_or_path, "text_encoder")
self.text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
# 3. 加载 VAE(将扩散过程的 latent 向量解码为图像像素)
vae_path = os.path.join(self.config.pretrained_model_name_or_path, "vae")
self.vae = AutoencoderKL.from_pretrained(vae_path)
# 4. 加载 UNet(扩散核心,通过迭代去噪生成 latent)
unet_path = os.path.join(self.config.pretrained_model_name_or_path, "unet")
.unet = UNet2DConditionModel.from_pretrained(unet_path)
.unet = load_lora_weights(.unet, .lora_model_path)
.text_encoder.to(.config.device)
.vae.to(.config.device)
.unet.to(.config.device)
scheduler_path = os.path.join(.config.pretrained_model_name_or_path, )
scheduler = DDPMScheduler.from_pretrained(scheduler_path)
.pipe = StableDiffusionPipeline(
vae=.vae,
text_encoder=.text_encoder,
tokenizer=.tokenizer,
unet=.unet,
scheduler=scheduler,
safety_checker=,
feature_extractor=,
requires_safety_checker=
)
.pipe
图像生成是耗时操作(尤其 CPU 运行时),若直接在主线程执行会导致 UI 卡死。GenerateThread 继承 QThread,将生成逻辑放入子线程,通过信号机制与主线程(UI)交互,实时反馈进度
pyqtSignal 定义 3 类信号,实现子线程与 UI 的 '无阻塞通信'—— 进度更新实时反馈,完成 / 错误信号触发 UI 后续操作;# 生成线程类 - 增加色彩校正步骤
class GenerateThread(QThread):
# 定义信号:生成完成(返回 PIL 图像)、错误(返回错误信息)、进度更新(进度百分比 + 剩余时间)
finished = pyqtSignal(Image.Image)
error = pyqtSignal(str)
progress_updated = pyqtSignal(int, float)
def __init__(self, pipe, animal_name, num_inference_steps, guidance_scale, contrast_factor, saturation_factor, brightness_factor):
super().__init__()
self.pipe = pipe # 生成流水线
self.animal_name = animal_name # 目标动物名称(用户输入)
self.num_inference_steps = num_inference_steps # 推理步数
self.guidance_scale = guidance_scale # 引导尺度
# 色彩调整参数(从 UI 获取,覆盖全局配置)
self.contrast_factor = contrast_factor
self.saturation_factor = saturation_factor
self.brightness_factor = brightness_factor
self.start_time = 0 # 生成开始时间(计算总耗时)
self.step_times = [] # 每步耗时(估算剩余时间)
def run(self):
try:
# 1. 优化提示词(增加环境/光照描述,提升生成质量)
prompt = (
f"a high quality photo of a {self.animal_name}, natural lighting, "
f"realistic colors, in natural habitat, detailed texture"
)
# 2. 文本编码(生成'条件嵌入'和'无条件嵌入',用于引导生成)
with torch.no_grad():
# 禁用梯度计算,减少内存占用
text_inputs = .pipe.tokenizer(
prompt,
padding=,
max_length=.pipe.tokenizer.model_max_length,
truncation=,
return_tensors=,
)
text_input_ids = text_inputs.input_ids
text_embeddings = .pipe.text_encoder(text_input_ids.to(.pipe.device))[]
max_length = text_input_ids.shape[-]
uncond_input = .pipe.tokenizer([], padding=, max_length=max_length, return_tensors=,)
uncond_embeddings = .pipe.text_encoder(uncond_input.input_ids.to(.pipe.device))[]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(, .pipe.unet.config.in_channels, Config.resolution // , Config.resolution // ),
generator=torch.Generator(device=Config.device),
device=Config.device,
)
.pipe.scheduler.set_timesteps(.num_inference_steps, device=Config.device)
.start_time = time.time()
.step_times = []
i, t (.pipe.scheduler.timesteps):
step_start_time = time.time()
latent_model_input = torch.cat([latents]*)
latent_model_input = .pipe.scheduler.scale_model_input(latent_model_input, t)
noise_pred = .pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk()
noise_pred = noise_pred_uncond + .guidance_scale *(noise_pred_text - noise_pred_uncond)
latents = .pipe.scheduler.step(noise_pred, t, latents).prev_sample
step_time = time.time() - step_start_time
.step_times.append(step_time)
progress = ((i + ) / .num_inference_steps * )
steps_remaining = .num_inference_steps -(i + )
(.step_times) >= :
avg_step_time = (.step_times[-:]) /
:
avg_step_time = (.step_times) / (.step_times) .step_times
remaining_time = avg_step_time * steps_remaining
.progress_updated.emit(progress, remaining_time)
latents = /* latents
torch.no_grad():
image = .pipe.vae.decode(latents).sample
image = (image / + ).clamp(, )
image = image.cpu().permute(, , , ).().numpy()
image = (image[] * ).().astype()
image = Image.fromarray(image)
image = adjust_image_colors(image)
image.size != (Config.resolution, Config.resolution):
image = image.resize((Config.resolution, Config.resolution), Image.LANCZOS)
.finished.emit(image)
Exception e:
.error.emit((e))
AnimalGeneratorApp 继承 QMainWindow,是整个工具的 '用户交互中心',负责构建 UI 布局、绑定按钮事件、处理线程信号(显示进度 / 图像 / 错误)。核心分为左侧控制面板和右侧图像显示区两部分,以下重点讲解核心功能逻辑:
class AnimalGeneratorApp(QMainWindow):
def __init__(self):
super().__init__()
self.pipe = None # 生成流水线(加载模型后赋值)
self.current_image = None # 当前生成的图像
self.initUI() # 初始化 UI
def initUI(self):
# 1. 基础设置(字体、窗口标题、尺寸)
font = QFont("SimHei") # 支持中文显示(避免乱码)
font.setPointSize(10)
self.setFont(font)
self.setWindowTitle('动物图像生成器')
self.setGeometry(100, 100, 1100, 800) # 窗口位置与尺寸
# 2. 中心部件与主布局(左右分栏:控制面板 + 图像显示区)
central_widget = QWidget()
self.setCentralWidget(central_widget)
main_layout = QHBoxLayout(central_widget)
main_layout.setContentsMargins(15, 15, 15, 15)
main_layout.setSpacing(20)
# 3. 左侧控制面板(模型设置、生成参数、进度)
control_panel = self.create_control_panel()
main_layout.addWidget(control_panel, 3) # 占 3 份宽度
# 4. 右侧图像显示区(默认图、生成图、水印)
image_panel = self.create_image_panel()
main_layout.addWidget(image_panel, 5) # 占 5 份宽度(图像区更宽,提升体验)
def generate_image(self):
# 前置校验(避免无效操作)
if not self.pipe:
QMessageBox.warning(self, "错误", "请先加载模型")
return
animal_name = self.animal_edit.text().strip()
if not animal_name:
QMessageBox.warning(self, "错误", "请输入动物名称")
return
# 1. 获取 UI 参数(覆盖全局配置)
Config.resolution = self.resolution_spin.value()
num_inference_steps = self.steps_spin.value()
guidance_scale = self.guidance_spin.value()
contrast_factor = self.contrast_spin.value()
saturation_factor = self.saturation_spin.value()
brightness_factor = self.brightness_spin.value()
# 2. UI 状态更新(禁用生成/保存按钮,显示进度条)
self.generate_btn.setEnabled(False)
self.save_btn.setEnabled(False)
self.progress_bar.setVisible(True)
self.progress_bar.setRange(0, 100)
self.progress_bar.setValue(0)
self.progress_label.setText("准备生成 (第一次加载请耐心等待哦)...")
self.statusBar().showMessage("正在生成图像,请稍候...")
# 3. 启动生成线程(传入参数,绑定信号)
self.gen_thread = GenerateThread(
self.pipe, animal_name, num_inference_steps, guidance_scale,
contrast_factor, saturation_factor, brightness_factor
)
.gen_thread.finished.connect(.on_generation_finished)
.gen_thread.error.connect(.on_generation_error)
.gen_thread.progress_updated.connect(.on_progress_updated)
.gen_thread.start()
():
.current_image = image
pixmap = .pil2pixmap(image)
.image_label.setPixmap(pixmap.scaled(
.image_label.width(),
.image_label.height(),
Qt.KeepAspectRatio,
Qt.SmoothTransformation
))
.generate_btn.setEnabled()
.save_btn.setEnabled()
.progress_bar.setValue()
.progress_label.setText()
.statusBar().showMessage()
右侧图像显示区采用三层叠加设计,兼顾 '默认提示'+'生成结果'+'版权水印':
def create_image_panel(self):
panel = QWidget()
layout = QVBoxLayout(panel)
# 图像显示容器(带阴影,提升美观度)
image_container = QWidget()
image_container.setStyleSheet("""
background-color: white;
border-radius: 8px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
padding: 15px;
""")
image_layout = QVBoxLayout(image_container)
# 1. 底层:默认图(70% 透明度)
self.default_image_label = QLabel()
self.default_image_label.setAlignment(Qt.AlignCenter)
self.default_image_label.setMinimumSize(512, 512)
self.load_default_image() # 加载默认提示图(如'请生成动物图像')
# 2. 中层:生成图(初始为空,生成后显示)
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setMinimumSize(512, 512)
self.image_label.setStyleSheet("background-color: transparent;") # 透明背景,避免遮挡底层
# 3. 顶层:水印(右下角对齐)
self.watermark_label = QLabel("制作者:热心市民小周")
self.watermark_label.setStyleSheet("""
color: rgba(100, 100, 100, 150); /* 半透明灰色 */
font-size: 12px;
padding: 5px;
background-color: rgba(255, 255, 255, 100);
border-radius: 2px;
""")
self.watermark_label.setAlignment(Qt.AlignRight | Qt.AlignBottom)
# 网格布局实现层级叠加(同一单元格内,后添加的控件在顶层)
grid_layout = QGridLayout()
grid_layout.addWidget(self.default_image_label, 0, 0) # 底层
grid_layout.addWidget(self.image_label, 0, 0) # 中层
grid_layout.addWidget(.watermark_label, , )
image_layout.addLayout(grid_layout)
layout.addWidget(image_container, )
panel
最终的界面图如下所示:
梯度范数 GN: 衡量参数更新规模
| 推理步数 | 推理平均时间(CPU) | 平均总时间 |
|---|---|---|
| 20 | 68.64s | 456.65s |
| 100 | 349.35s | 823.45s |
| 200 | 683.96s | 1209.65s |
| 400 | 1356.86s | 1863.25s |
系统能够生成多种动物的高质量图像,包括:
更多验证输出样例可见
output/pic
使用不同推理步数得到的 tiger 图像如下所示:
UI 界面的使用实例如下:
100 类动物图像生成

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online