diffusers 源码解析(五十五)
.\diffusers\pipelines\wuerstchen\modeling_wuerstchen_common.py
# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn
# 从指定路径导入 Attention 处理模块
from ...models.attention_processor import Attention
# 定义自定义的层归一化类,继承自 nn.LayerNorm
class WuerstchenLayerNorm(nn.LayerNorm):
# 初始化方法,接收可变参数
def __init__(self, *args, **kwargs):
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 前向传播方法
def forward(self, x):
# 调整输入张量的维度顺序
x = x.permute(0, 2, 3, 1)
# 调用父类的前向传播方法进行归一化
x = super().forward(x)
# 恢复输入张量的维度顺序并返回
return x.permute(0, 3, 1, 2)
# 定义时间步块类,继承自 nn.Module
class TimestepBlock(nn.Module):
# 初始化方法,接收通道数和时间步数
def __init__(self, c, c_timestep):
# 调用父类的初始化方法
super().__init__()
# 定义线性映射层,将时间步数映射到两倍的通道数
self.mapper = nn.Linear(c_timestep, c * 2)
# 前向传播方法
def forward(self, x, t):
# 使用映射层处理时间步,并将结果分割为两个部分
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
# 根据公式更新输入张量并返回
return x * (1 + a) + b
# 定义残差块类,继承自 nn.Module
class ResBlock(nn.Module):
# 初始化方法,接收通道数、跳过连接的通道数、卷积核大小和丢弃率
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
# 调用父类的初始化方法
super().__init__()
# 定义深度可分离卷积层
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
# 定义自定义层归一化
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
# 定义通道处理的顺序模块
self.channelwise = nn.Sequential(
nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
)
# 前向传播方法
def forward(self, x, x_skip=None):
# 保存输入张量以便后续残差连接
x_res = x
# 如果有跳过连接的张量,则进行拼接
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
# 对输入张量进行深度卷积和归一化
x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
# 通过通道处理模块
x = self.channelwise(x).permute(0, 3, 1, 2)
# 返回残差连接后的结果
return x + x_res
# 从外部库导入的全局响应归一化类
class GlobalResponseNorm(nn.Module):
# 初始化方法,接收特征维度
def __init__(self, dim):
# 调用父类的初始化方法
super().__init__()
# 定义可学习参数 gamma 和 beta
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
# 前向传播方法
def forward(self, x):
# 计算输入张量的聚合范数
agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
# 计算标准化范数
stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
# 返回经过归一化后的结果
return self.gamma * (x * stand_div_norm) + self.beta + x
# 定义注意力块类,继承自 nn.Module
class AttnBlock(nn.Module):
# 初始化方法,接收通道数、条件通道数、头数、是否自注意力及丢弃率
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
# 调用父类的初始化方法
super().__init__()
# 设置是否使用自注意力
self.self_attn = self_attn
# 定义归一化层
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
# 定义注意力机制
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
# 定义键值映射层
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
# 前向传播方法
def forward(self, x, kv):
# 使用键值映射层处理 kv
kv = self.kv_mapper(kv)
# 对输入张量进行归一化
norm_x = self.norm(x)
# 如果使用自注意力,则拼接归一化后的 x 和 kv
if self.self_attn:
batch_size, channel, _, _ = x.shape
kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
# 将注意力机制的输出与原输入相加
x = x + self.attention(norm_x, encoder_hidden_states=kv)
# 返回处理后的张量
return x
.\diffusers\pipelines\wuerstchen\modeling_wuerstchen_diffnext.py
# 版权信息,标明该代码的版权所有者及许可证
# Copyright (c) 2023 Dominic Rampas MIT License
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 在 Apache 许可证 2.0("许可证")下获得许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件按"原样"提供,
# 不提供任何形式的保证或条件,无论是明示还是暗示。
# 请参见许可证以获取特定于许可证的权限和限制。
# 导入数学库
import math
# 导入 NumPy 库以进行数组处理
import numpy as np
# 导入 PyTorch 库及其神经网络模块
import torch
import torch.nn as nn
# 从配置工具模块导入 ConfigMixin 和注册配置的方法
from ...configuration_utils import ConfigMixin, register_to_config
# 从模型工具模块导入 ModelMixin
from ...models.modeling_utils import ModelMixin
# 从本地模块导入模型组件
from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm
# 定义 WuerstchenDiffNeXt 类,继承自 ModelMixin 和 ConfigMixin
class WuerstchenDiffNeXt(ModelMixin, ConfigMixin):
# 注册初始化方法到配置
@register_to_config
def __init__(
self,
c_in=4, # 输入通道数,默认为 4
c_out=4, # 输出通道数,默认为 4
c_r=64, # 嵌入维度,默认为 64
patch_size=2, # 补丁大小,默认为 2
c_cond=1024, # 条件通道数,默认为 1024
c_hidden=[320, 640, 1280, 1280], # 隐藏层通道数配置
nhead=[-1, 10, 20, 20], # 注意力头数配置
blocks=[4, 4, 14, 4], # 各级块数配置
level_config=["CT", "CTA", "CTA", "CTA"], # 各级配置
inject_effnet=[False, True, True, True], # 是否注入 EfficientNet
effnet_embd=16, # EfficientNet 嵌入维度
clip_embd=1024, # CLIP 嵌入维度
kernel_size=3, # 卷积核大小
dropout=0.1, # dropout 比率
):
# 初始化权重的方法
def _init_weights(self, m):
# 对卷积层和线性层进行通用初始化
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(m.weight) # 使用 Xavier 均匀分布初始化权重
if m.bias is not None:
nn.init.constant_(m.bias, 0) # 偏置初始化为 0
# 对 EfficientNet 映射器进行初始化
for mapper in self.effnet_mappers:
if mapper is not None:
nn.init.normal_(mapper.weight, std=0.02) # 条件初始化为正态分布
nn.init.normal_(self.clip_mapper.weight, std=0.02) # CLIP 映射器初始化
nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # 输入嵌入初始化
nn.init.constant_(self.clf[1].weight, 0) # 输出分类器初始化为 0
# 初始化块中的权重
for level_block in self.down_blocks + self.up_blocks:
for block in level_block:
if isinstance(block, ResBlockStageB):
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks)) # 权重缩放
elif isinstance(block, TimestepBlock):
nn.init.constant_(block.mapper.weight, 0) # 将时间步映射器的权重初始化为 0
# 生成位置嵌入的方法
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions # 将位置 r 乘以最大位置
half_dim = self.c_r // 2 # 计算半维度
emb = math.log(max_positions) / (half_dim - 1) # 计算嵌入尺度
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() # 生成嵌入
emb = r[:, None] * emb[None, :] # 扩展 r 的维度并进行乘法
emb = torch.cat([emb.sin(), emb.cos()], dim=1) # 计算正弦和余弦嵌入并拼接
if self.c_r % 2 == 1: # 如果 c_r 为奇数,则进行零填充
emb = nn.functional.pad(emb, (0, 1), mode="constant") # 用常数进行填充
return emb.to(dtype=r.dtype) # 返回与 r 数据类型相同的嵌入
# 生成 CLIP 嵌入
def gen_c_embeddings(self, clip):
# 将输入 clip 通过映射转换
clip = self.clip_mapper(clip)
# 对 clip 进行序列归一化处理
clip = self.seq_norm(clip)
# 返回处理后的 clip
return clip
# 下采样编码过程
def _down_encode(self, x, r_embed, effnet, clip=None):
# 初始化层级输出列表
level_outputs = []
# 遍历每一个下采样块
for i, down_block in enumerate(self.down_blocks):
effnet_c = None # 初始化有效网络通道为 None
# 遍历每个下采样块中的组件
for block in down_block:
# 如果是残差块阶段 B
if isinstance(block, ResBlockStageB):
# 检查有效网络通道是否为 None
if effnet_c is None and self.effnet_mappers[i] is not None:
dtype = effnet.dtype # 获取 effnet 的数据类型
# 进行双线性插值并创建有效网络通道
effnet_c = self.effnet_mappers[i](
nn.functional.interpolate(
effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True
).to(dtype)
)
# 设置跳跃连接为有效网络通道
skip = effnet_c if self.effnet_mappers[i] is not None else None
# 通过当前块处理输入 x 和跳跃连接
x = block(x, skip)
# 如果是注意力块
elif isinstance(block, AttnBlock):
# 通过当前块处理输入 x 和 clip
x = block(x, clip)
# 如果是时间步块
elif isinstance(block, TimestepBlock):
# 通过当前块处理输入 x 和 r_embed
x = block(x, r_embed)
else:
# 通过当前块处理输入 x
x = block(x)
# 将当前层输出插入到层级输出列表的开头
level_outputs.insert(0, x)
# 返回所有层级输出
return level_outputs
# 上采样解码过程
def _up_decode(self, level_outputs, r_embed, effnet, clip=None):
# 使用层级输出的第一个元素初始化 x
x = level_outputs[0]
# 遍历每一个上采样块
for i, up_block in enumerate(self.up_blocks):
effnet_c = None # 初始化有效网络通道为 None
# 遍历每个上采样块中的组件
for j, block in enumerate(up_block):
# 如果是残差块阶段 B
if isinstance(block, ResBlockStageB):
# 检查有效网络通道是否为 None
if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None:
dtype = effnet.dtype # 获取 effnet 的数据类型
# 进行双线性插值并创建有效网络通道
effnet_c = self.effnet_mappers[len(self.down_blocks) + i](
nn.functional.interpolate(
effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True
).to(dtype)
)
# 设置跳跃连接为当前层级输出的第 i 个元素
skip = level_outputs[i] if j == 0 and i > 0 else None
# 如果有效网络通道不为 None
if effnet_c is not None:
# 如果跳跃连接不为 None,将其与有效网络通道拼接
if skip is not None:
skip = torch.cat([skip, effnet_c], dim=1)
else:
# 否则直接设置为有效网络通道
skip = effnet_c
# 通过当前块处理输入 x 和跳跃连接
x = block(x, skip)
# 如果是注意力块
elif isinstance(block, AttnBlock):
# 通过当前块处理输入 x 和 clip
x = block(x, clip)
# 如果是时间步块
elif isinstance(block, TimestepBlock):
# 通过当前块处理输入 x 和 r_embed
x = block(x, r_embed)
else:
# 通过当前块处理输入 x
x = block(x)
# 返回最终处理后的 x
return x
# 定义前向传播函数,接受多个输入参数
def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True):
# 如果 x_cat 不为 None,将 x 和 x_cat 沿着维度 1 拼接
if x_cat is not None:
x = torch.cat([x, x_cat], dim=1)
# 处理条件嵌入
r_embed = self.gen_r_embedding(r)
# 如果 clip 不为 None,生成条件嵌入
if clip is not None:
clip = self.gen_c_embeddings(clip)
# 模型块
x_in = x # 保存输入 x 以备后用
x = self.embedding(x) # 将输入 x 转换为嵌入表示
# 下采样编码
level_outputs = self._down_encode(x, r_embed, effnet, clip)
# 上采样解码
x = self._up_decode(level_outputs, r_embed, effnet, clip)
# 将输出分成两个部分 a 和 b
a, b = self.clf(x).chunk(2, dim=1)
# 对 b 进行 sigmoid 激活,并进行缩放
b = b.sigmoid() * (1 - eps * 2) + eps
# 如果返回噪声,计算并返回
if return_noise:
return (x_in - a) / b
else:
return a, b # 否则返回 a 和 b
# 定义一个残差块阶段 B,继承自 nn.Module
class ResBlockStageB(nn.Module):
# 初始化函数,设置输入通道、跳跃连接通道、卷积核大小和丢弃率
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
# 调用父类的初始化方法
super().__init__()
# 创建深度卷积层,使用指定的卷积核大小和填充
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
# 创建层归一化层,设置元素可学习性为 False 和小的 epsilon 值
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
# 创建一个顺序容器,包含线性层、GELU 激活、全局响应归一化、丢弃层和另一线性层
self.channelwise = nn.Sequential(
nn.Linear(c + c_skip, c * 4),
nn.GELU(),
GlobalResponseNorm(c * 4),
nn.Dropout(dropout),
nn.Linear(c * 4, c),
)
# 定义前向传播函数
def forward(self, x, x_skip=None):
# 保存输入以进行残差连接
x_res = x
# 先进行深度卷积和层归一化
x = self.norm(self.depthwise(x))
# 如果有跳跃连接,则将其与当前输出连接
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
# 变换输入维度并通过通道层,最后恢复维度
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
# 返回残差输出
return x + x_res
.\diffusers\pipelines\wuerstchen\modeling_wuerstchen_prior.py
# 版权声明,说明文件的版权所有者和许可证信息
# Copyright (c) 2023 Dominic Rampas MIT License
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache License, Version 2.0 许可证许可使用本文件
# 仅在遵循许可证的情况下使用此文件
# 可以在此获取许可证的副本
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,软件在许可证下按 "现状" 基础提供
# 不提供任何形式的担保或条件
# 查看许可证以获取特定语言的权限和限制
# 导入数学库
import math
# 从 typing 模块导入字典和联合类型
from typing import Dict, Union
# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn
# 导入配置工具和适配器相关的类
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
# 导入注意力处理器相关的类
from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
# 导入模型相关的基类
from ...models.modeling_utils import ModelMixin
# 导入工具函数以检查 PyTorch 版本
from ...utils import is_torch_version
# 导入模型组件
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
# 定义 WuerstchenPrior 类,继承自多个基类
class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
# 设置 UNet 名称为 "prior"
unet_name = "prior"
# 启用梯度检查点功能
_supports_gradient_checkpointing = True
# 注册初始化方法,定义类的构造函数
@register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
# 调用父类的构造函数
super().__init__()
# 设置压缩通道数
self.c_r = c_r
# 定义一个卷积层用于输入到中间通道的映射
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
# 定义条件映射层,由两个线性层和一个激活函数组成
self.cond_mapper = nn.Sequential(
nn.Linear(c_cond, c), # 将条件输入映射到中间通道
nn.LeakyReLU(0.2), # 应用 Leaky ReLU 激活函数
nn.Linear(c, c), # 再次映射到中间通道
)
# 创建一个模块列表用于存储多个块
self.blocks = nn.ModuleList()
# 根据深度参数添加多个残差块、时间步块和注意力块
for _ in range(depth):
self.blocks.append(ResBlock(c, dropout=dropout)) # 添加残差块
self.blocks.append(TimestepBlock(c, c_r)) # 添加时间步块
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) # 添加注意力块
# 定义输出层,由归一化层和卷积层组成
self.out = nn.Sequential(
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6), # 归一化
nn.Conv2d(c, c_in * 2, kernel_size=1), # 输出卷积层
)
# 默认禁用梯度检查点
self.gradient_checkpointing = False
# 设置默认的注意力处理器
self.set_default_attn_processor()
# 定义一个只读属性
@property
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制的属性
# 定义一个返回注意力处理器字典的方法
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
返回值:
`dict` 的注意力处理器: 一个字典,包含模型中使用的所有注意力处理器,并以其权重名称索引。
"""
# 初始化一个空字典以存储处理器
processors = {}
# 定义递归添加处理器的内部函数
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 如果模块具有获取处理器的方法,则将其添加到字典中
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
# 遍历模块的子模块
for sub_name, child in module.named_children():
# 递归调用以添加子模块的处理器
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回更新后的处理器字典
return processors
# 遍历当前模块的所有子模块
for name, module in self.named_children():
# 调用内部函数以添加所有子模块的处理器
fn_recursive_add_processors(name, module, processors)
# 返回所有处理器的字典
return processors
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 复制的方法
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
设置用于计算注意力的注意力处理器。
参数:
processor (`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
实例化的处理器类或将作为所有 `Attention` 层的处理器设置的处理器类字典。
如果 `processor` 是一个字典,则键需要定义对应的交叉注意力处理器的路径。
在设置可训练的注意力处理器时,强烈建议这样做。
"""
# 计算当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 检查传入的处理器字典的大小是否与注意力层数量匹配
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"传入了处理器的字典,但处理器的数量 {len(processor)} 与注意力层的数量: {count} 不匹配。"
f" 请确保传入 {count} 个处理器类。"
)
# 定义递归设置处理器的内部函数
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 如果模块具有设置处理器的方法,则设置处理器
if hasattr(module, "set_processor"):
# 如果处理器不是字典,则直接设置
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 从字典中弹出对应的处理器并设置
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历子模块
for sub_name, child in module.named_children():
# 递归调用以设置子模块的处理器
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前模块的所有子模块
for name, module in self.named_children():
# 调用内部函数以设置所有子模块的处理器
fn_recursive_attn_processor(name, module, processor)
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 复制的方法
# 定义一个方法,用于设置默认的注意力处理器
def set_default_attn_processor(self):
"""
禁用自定义注意力处理器,并设置默认的注意力实现。
"""
# 检查所有注意力处理器是否属于新增的键值注意力处理器
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 如果是,使用新增的键值注意力处理器
processor = AttnAddedKVProcessor()
# 检查所有注意力处理器是否属于交叉注意力处理器
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 如果是,使用标准的注意力处理器
processor = AttnProcessor()
else:
# 否则,抛出一个值错误,说明无法设置默认注意力处理器
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 调用设置方法,将选择的处理器应用于当前对象
self.set_attn_processor(processor)
# 定义一个私有方法,用于设置梯度检查点
def _set_gradient_checkpointing(self, module, value=False):
# 将梯度检查点的值设置为传入的布尔值
self.gradient_checkpointing = value
# 定义生成位置嵌入的方法
def gen_r_embedding(self, r, max_positions=10000):
# 将输入的 r 乘以最大位置数
r = r * max_positions
# 计算嵌入的半维度
half_dim = self.c_r // 2
# 计算嵌入的缩放因子
emb = math.log(max_positions) / (half_dim - 1)
# 创建一个张量,并根据半维度生成指数嵌入
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
# 根据 r 生成最终的嵌入
emb = r[:, None] * emb[None, :]
# 将正弦和余弦嵌入拼接在一起
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
# 如果 c_r 是奇数,则进行零填充
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode="constant")
# 返回嵌入,确保数据类型与 r 一致
return emb.to(dtype=r.dtype)
# 定义前向传播函数,接收输入张量 x、条件 r 和 c
def forward(self, x, r, c):
# 保存输入张量的原始值
x_in = x
# 对输入张量进行投影处理
x = self.projection(x)
# 将条件 c 转换为嵌入表示
c_embed = self.cond_mapper(c)
# 生成条件 r 的嵌入表示
r_embed = self.gen_r_embedding(r)
# 如果处于训练模式并且开启梯度检查点
if self.training and self.gradient_checkpointing:
# 创建自定义前向传播函数的辅助函数
def create_custom_forward(module):
# 定义接受任意输入的自定义前向函数
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 检查 PyTorch 版本是否大于等于 1.11.0
if is_torch_version(">=", "1.11.0"):
# 遍历所有块进行处理
for block in self.blocks:
# 如果块是注意力块
if isinstance(block, AttnBlock):
# 使用检查点来保存内存
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, c_embed, use_reentrant=False
)
# 如果块是时间步块
elif isinstance(block, TimestepBlock):
# 使用检查点来保存内存
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
else:
# 处理其他类型的块
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
else:
# 对于旧版本的 PyTorch
for block in self.blocks:
# 如果块是注意力块
if isinstance(block, AttnBlock):
# 使用检查点来保存内存
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
# 如果块是时间步块
elif isinstance(block, TimestepBlock):
# 使用检查点来保存内存
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
else:
# 处理其他类型的块
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
else:
# 如果不在训练模式下
for block in self.blocks:
# 如果块是注意力块
if isinstance(block, AttnBlock):
# 直接进行前向传播
x = block(x, c_embed)
# 如果块是时间步块
elif isinstance(block, TimestepBlock):
# 直接进行前向传播
x = block(x, r_embed)
else:
# 处理其他类型的块
x = block(x)
# 将输出分割为两个部分 a 和 b
a, b = self.out(x).chunk(2, dim=1)
# 返回经过归一化处理的结果
return (x_in - a) / ((1 - b).abs() + 1e-5)
.\diffusers\pipelines\wuerstchen\pipeline_wuerstchen.py
# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,第 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,按许可证分发的软件
# 是在“按原样”基础上提供的,没有任何形式的保证或条件,
# 明示或暗示。有关许可的特定权限和
# 限制,请参阅许可证。
from typing import Callable, Dict, List, Optional, Union # 从 typing 模块导入类型注解工具
import numpy as np # 导入 NumPy 库,常用于数值计算
import torch # 导入 PyTorch 库,支持深度学习
from transformers import CLIPTextModel, CLIPTokenizer # 从 transformers 库导入 CLIP 模型和分词器
from ...schedulers import DDPMWuerstchenScheduler # 从调度器模块导入 DDPMWuerstchenScheduler
from ...utils import deprecate, logging, replace_example_docstring # 从 utils 模块导入实用工具
from ...utils.torch_utils import randn_tensor # 从 PyTorch 工具模块导入 randn_tensor 函数
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput # 从管道工具模块导入 DiffusionPipeline 和 ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel # 从 Paella VQ 模型模块导入 PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt # 从 Wuerstchen DiffNeXt 模型模块导入 WuerstchenDiffNeXt
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器,禁用 pylint 对无效名称的警告
EXAMPLE_DOC_STRING = """ # 示例文档字符串,提供用法示例
Examples:
```py
>>> import torch # 导入 PyTorch 库
>>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline # 导入 Wuerstchen 管道
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( # 从预训练模型创建 WuerstchenPriorPipeline 实例
... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 # 指定模型名称和数据类型
... ).to("cuda") # 将管道移动到 CUDA 设备
>>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to( # 从预训练模型创建 WuerstchenDecoderPipeline 实例
... "cuda" # 将生成管道移动到 CUDA 设备
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" # 定义生成图像的提示
>>> prior_output = pipe(prompt) # 使用提示生成先前输出
>>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) # 使用生成管道从图像嵌入生成图像
```py
"""
class WuerstchenDecoderPipeline(DiffusionPipeline): # 定义 WuerstchenDecoderPipeline 类,继承自 DiffusionPipeline
"""
Pipeline for generating images from the Wuerstchen model. # 类文档字符串,说明该管道的功能
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) # 说明该模型继承自 DiffusionPipeline,并提醒用户查看父类文档以获取通用方法
# 参数说明
Args:
tokenizer (`CLIPTokenizer`): # CLIP 模型使用的分词器
The CLIP tokenizer.
text_encoder (`CLIPTextModel`): # CLIP 模型使用的文本编码器
The CLIP text encoder.
decoder ([`WuerstchenDiffNeXt`]): # WuerstchenDiffNeXt 解码器
The WuerstchenDiffNeXt unet decoder.
vqgan ([`PaellaVQModel`]): # VQGAN 模型,用于图像生成
The VQGAN model.
scheduler ([`DDPMWuerstchenScheduler`]): # 调度器,用于图像嵌入生成
A scheduler to be used in combination with `prior` to generate image embedding.
latent_dim_scale (float, `optional`, defaults to 10.67): # 用于确定 VQ 潜在空间大小的乘数
Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
width=int(24*10.67)=256 in order to match the training conditions.
"""
# 定义模型的 CPU 卸载顺序
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
# 定义需要回调的张量输入列表
_callback_tensor_inputs = [
"latents", # 潜在变量
"text_encoder_hidden_states", # 文本编码器的隐藏状态
"negative_prompt_embeds", # 负面提示的嵌入
"image_embeddings", # 图像嵌入
]
# 构造函数
def __init__(
self,
tokenizer: CLIPTokenizer, # 初始化时传入的分词器
text_encoder: CLIPTextModel, # 初始化时传入的文本编码器
decoder: WuerstchenDiffNeXt, # 初始化时传入的解码器
scheduler: DDPMWuerstchenScheduler, # 初始化时传入的调度器
vqgan: PaellaVQModel, # 初始化时传入的 VQGAN 模型
latent_dim_scale: float = 10.67, # 可选参数,默认值为 10.67
) -> None:
super().__init__() # 调用父类构造函数
# 注册模型的各个模块
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
decoder=decoder,
scheduler=scheduler,
vqgan=vqgan,
)
# 将潜在维度缩放因子注册到配置中
self.register_to_config(latent_dim_scale=latent_dim_scale)
# 从 diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline 复制的方法,准备潜在变量
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
# 如果潜在变量为 None,则生成随机张量
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 检查传入的潜在变量形状是否与预期形状匹配
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
# 将潜在变量移动到指定设备
latents = latents.to(device)
# 用调度器的初始噪声标准差调整潜在变量
latents = latents * scheduler.init_noise_sigma
# 返回调整后的潜在变量
return latents
# 编码提示的方法
def encode_prompt(
self,
prompt, # 输入的提示
device, # 目标设备
num_images_per_prompt, # 每个提示生成的图像数量
do_classifier_free_guidance, # 是否进行无分类器引导
negative_prompt=None, # 负面提示(可选)
@property
# 获取引导缩放比例的属性
def guidance_scale(self):
return self._guidance_scale
@property
# 判断是否使用无分类器引导
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
# 获取时间步数的属性
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() # 不计算梯度
@replace_example_docstring(EXAMPLE_DOC_STRING) # 替换示例文档字符串
# 定义可调用对象的 __call__ 方法,允许实例像函数一样被调用
def __call__(
self,
# 输入图像的嵌入,支持单个张量或张量列表
image_embeddings: Union[torch.Tensor, List[torch.Tensor]],
# 提示文本,可以是单个字符串或字符串列表
prompt: Union[str, List[str]] = None,
# 推理步骤的数量,默认值为 12
num_inference_steps: int = 12,
# 指定时间步的列表,默认为 None
timesteps: Optional[List[float]] = None,
# 指导比例,控制生成的多样性,默认值为 0.0
guidance_scale: float = 0.0,
# 负提示文本,可以是单个字符串或字符串列表,默认为 None
negative_prompt: Optional[Union[str, List[str]]] = None,
# 每个提示生成的图像数量,默认值为 1
num_images_per_prompt: int = 1,
# 随机数生成器,可选,支持单个或多个生成器
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
# 潜在变量,可选,默认为 None
latents: Optional[torch.Tensor] = None,
# 输出类型,默认值为 "pil"
output_type: Optional[str] = "pil",
# 返回字典标志,默认为 True
return_dict: bool = True,
# 结束步骤回调函数,可选,接收步骤信息
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
# 结束步骤回调函数使用的张量输入列表,默认为 ["latents"]
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
# 其他可选参数,以关键字参数形式传递
**kwargs,
.\diffusers\pipelines\wuerstchen\pipeline_wuerstchen_combined.py
# 版权信息,表明版权所有者和许可信息
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证版本 2.0 进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 本文件只能在遵循许可证的情况下使用
# you may not use this file except in compliance with the License.
# 可以在以下地址获取许可证副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是按“原样”基础分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的保证或条件,无论是明示或暗示
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 请参见许可证以了解管理权限和限制的具体语言
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入所需的类型提示
from typing import Callable, Dict, List, Optional, Union
# 导入 PyTorch 库
import torch
# 从 transformers 库导入 CLIP 文本模型和分词器
from transformers import CLIPTextModel, CLIPTokenizer
# 从自定义调度器导入 DDPMWuerstchenScheduler
from ...schedulers import DDPMWuerstchenScheduler
# 从自定义工具导入去除过时函数和替换示例文档字符串的函数
from ...utils import deprecate, replace_example_docstring
# 从管道工具导入 DiffusionPipeline 基类
from ..pipeline_utils import DiffusionPipeline
# 从模型模块导入 PaellaVQModel
from .modeling_paella_vq_model import PaellaVQModel
# 从模型模块导入 WuerstchenDiffNeXt
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
# 从模型模块导入 WuerstchenPrior
from .modeling_wuerstchen_prior import WuerstchenPrior
# 从管道模块导入 WuerstchenDecoderPipeline
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
# 从管道模块导入 WuerstchenPriorPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
# 文档字符串示例,用于展示如何使用文本转图像的管道
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusions import WuerstchenCombinedPipeline
>>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to(
... "cuda"
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> images = pipe(prompt=prompt)
```
"""
# 定义一个结合文本到图像生成的管道类
class WuerstchenCombinedPipeline(DiffusionPipeline):
"""
使用 Wuerstchen 进行文本到图像生成的组合管道
该模型继承自 [`DiffusionPipeline`]。查看父类文档以了解库为所有管道实现的通用方法
(如下载或保存,运行在特定设备等)。
参数:
tokenizer (`CLIPTokenizer`):
用于文本输入的解码器分词器。
text_encoder (`CLIPTextModel`):
用于文本输入的解码器文本编码器。
decoder (`WuerstchenDiffNeXt`):
用于图像生成管道的解码器模型。
scheduler (`DDPMWuerstchenScheduler`):
用于图像生成管道的调度器。
vqgan (`PaellaVQModel`):
用于图像生成管道的 VQGAN 模型。
prior_tokenizer (`CLIPTokenizer`):
用于文本输入的先前分词器。
prior_text_encoder (`CLIPTextModel`):
用于文本输入的先前文本编码器。
prior_prior (`WuerstchenPrior`):
用于先前管道的先前模型。
prior_scheduler (`DDPMWuerstchenScheduler`):
用于先前管道的调度器。
"""
# 标志,表示是否加载连接的管道
_load_connected_pipes = True
# 初始化类的构造函数,接收多个模型和调度器作为参数
def __init__(
self,
tokenizer: CLIPTokenizer, # 词汇处理器
text_encoder: CLIPTextModel, # 文本编码器
decoder: WuerstchenDiffNeXt, # 解码器模型
scheduler: DDPMWuerstchenScheduler, # 调度器
vqgan: PaellaVQModel, # VQGAN模型
prior_tokenizer: CLIPTokenizer, # 先验词汇处理器
prior_text_encoder: CLIPTextModel, # 先验文本编码器
prior_prior: WuerstchenPrior, # 先验模型
prior_scheduler: DDPMWuerstchenScheduler, # 先验调度器
):
super().__init__() # 调用父类的构造函数
# 注册各个模型和调度器到当前实例
self.register_modules(
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqgan,
prior_prior=prior_prior,
prior_text_encoder=prior_text_encoder,
prior_tokenizer=prior_tokenizer,
prior_scheduler=prior_scheduler,
)
# 初始化先验管道,用于处理先验相关操作
self.prior_pipe = WuerstchenPriorPipeline(
prior=prior_prior, # 先验模型
text_encoder=prior_text_encoder, # 先验文本编码器
tokenizer=prior_tokenizer, # 先验词汇处理器
scheduler=prior_scheduler, # 先验调度器
)
# 初始化解码器管道,用于处理解码相关操作
self.decoder_pipe = WuerstchenDecoderPipeline(
text_encoder=text_encoder, # 文本编码器
tokenizer=tokenizer, # 词汇处理器
decoder=decoder, # 解码器
scheduler=scheduler, # 调度器
vqgan=vqgan, # VQGAN模型
)
# 启用节省内存的高效注意力机制
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
# 在解码器管道中启用高效注意力机制
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
# 启用模型的CPU卸载,减少内存使用
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
使用accelerate将所有模型卸载到CPU,减少内存使用且对性能影响较小。
此方法在调用模型的`forward`方法时将整个模型移到GPU,模型将在下一个模型运行之前保持在GPU上。
相比于`enable_sequential_cpu_offload`,内存节省较少,但性能更佳。
"""
# 在先验管道中启用模型的CPU卸载
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
# 在解码器管道中启用模型的CPU卸载
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
# 启用顺序CPU卸载,显著减少内存使用
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
使用🤗Accelerate将所有模型卸载到CPU,显著减少内存使用。
模型被移动到`torch.device('meta')`,并仅在调用特定子模块的`forward`方法时加载到GPU。
卸载是基于子模块进行的,内存节省比使用`enable_model_cpu_offload`高,但性能较低。
"""
# 在先验管道中启用顺序CPU卸载
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
# 在解码器管道中启用顺序CPU卸载
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
# 定义进度条方法,接受可迭代对象和总计数作为参数
def progress_bar(self, iterable=None, total=None):
# 在 prior_pipe 上更新进度条,传入可迭代对象和总计数
self.prior_pipe.progress_bar(iterable=iterable, total=total)
# 在 decoder_pipe 上更新进度条,传入可迭代对象和总计数
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
# 定义设置进度条配置的方法,接收任意关键字参数
def set_progress_bar_config(self, **kwargs):
# 在 prior_pipe 上设置进度条配置,传入关键字参数
self.prior_pipe.set_progress_bar_config(**kwargs)
# 在 decoder_pipe 上设置进度条配置,传入关键字参数
self.decoder_pipe.set_progress_bar_config(**kwargs)
# 使用 torch.no_grad() 装饰器,表示在此上下文中不计算梯度
@torch.no_grad()
# 替换示例文档字符串的装饰器
@replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
# 定义调用方法,处理文本到图像的转换
def __call__(
# 接受提示文本,支持字符串或字符串列表,默认为 None
prompt: Optional[Union[str, List[str]]] = None,
# 图像高度,默认为 512
height: int = 512,
# 图像宽度,默认为 512
width: int = 512,
# prior 阶段推理步骤数,默认为 60
prior_num_inference_steps: int = 60,
# prior 阶段时间步,默认为 None
prior_timesteps: Optional[List[float]] = None,
# prior 阶段引导比例,默认为 4.0
prior_guidance_scale: float = 4.0,
# decoder 阶段推理步骤数,默认为 12
num_inference_steps: int = 12,
# decoder 阶段时间步,默认为 None
decoder_timesteps: Optional[List[float]] = None,
# decoder 阶段引导比例,默认为 0.0
decoder_guidance_scale: float = 0.0,
# 负提示文本,支持字符串或字符串列表,默认为 None
negative_prompt: Optional[Union[str, List[str]]] = None,
# 提示嵌入,默认为 None
prompt_embeds: Optional[torch.Tensor] = None,
# 负提示嵌入,默认为 None
negative_prompt_embeds: Optional[torch.Tensor] = None,
# 每个提示生成的图像数量,默认为 1
num_images_per_prompt: int = 1,
# 随机数生成器,默认为 None
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
# 潜在表示,默认为 None
latents: Optional[torch.Tensor] = None,
# 输出类型,默认为 "pil"
output_type: Optional[str] = "pil",
# 是否返回字典格式,默认为 True
return_dict: bool = True,
# prior 阶段的回调函数,默认为 None
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
# prior 阶段回调函数输入的张量名称列表,默认为 ["latents"]
prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
# decoder 阶段的回调函数,默认为 None
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
# decoder 阶段回调函数输入的张量名称列表,默认为 ["latents"]
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
# 接受其他任意关键字参数
**kwargs,
.\diffusers\pipelines\wuerstchen\pipeline_wuerstchen_prior.py
# 版权信息,声明此代码归 HuggingFace 团队所有,保留所有权利
# 许可证声明,使用此文件需遵守 Apache 许可证 2.0
# 提供许可证的获取地址
# 许可证说明,未按适用法律或书面协议另行约定的情况下,软件在“按现状”基础上分发
# 提供许可证详细信息的地址
from dataclasses import dataclass # 导入数据类装饰器,用于简化类的定义
from math import ceil # 导入向上取整函数
from typing import Callable, Dict, List, Optional, Union # 导入类型注解
import numpy as np # 导入 NumPy 库,用于数值计算
import torch # 导入 PyTorch 库,用于深度学习
from transformers import CLIPTextModel, CLIPTokenizer # 导入 CLIP 模型和分词器
from ...loaders import StableDiffusionLoraLoaderMixin # 导入加载 LoRA 权重的混合类
from ...schedulers import DDPMWuerstchenScheduler # 导入调度器
from ...utils import BaseOutput, deprecate, logging, replace_example_docstring # 导入工具类和函数
from ...utils.torch_utils import randn_tensor # 导入生成随机张量的工具函数
from ..pipeline_utils import DiffusionPipeline # 导入扩散管道基类
from .modeling_wuerstchen_prior import WuerstchenPrior # 导入 Wuerstchen 先验模型
logger = logging.get_logger(__name__) # 创建日志记录器
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] # 设置默认的时间步,分段线性生成
EXAMPLE_DOC_STRING = """ # 示例文档字符串,提供使用示例
Examples:
```py # Python 代码块开始
>>> import torch # 导入 PyTorch 库
>>> from diffusers import WuerstchenPriorPipeline # 导入 WuerstchenPriorPipeline 类
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( # 从预训练模型加载管道
... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 # 指定模型路径和数据类型
... ).to("cuda") # 将管道移动到 GPU
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" # 定义生成图像的提示
>>> prior_output = pipe(prompt) # 生成图像并返回结果
```py # Python 代码块结束
"""
@dataclass # 使用数据类装饰器定义输出类
class WuerstchenPriorPipelineOutput(BaseOutput): # 定义 WuerstchenPriorPipeline 的输出类
"""
输出类用于 WuerstchenPriorPipeline。
Args:
image_embeddings (`torch.Tensor` or `np.ndarray`) # 图像嵌入数据的类型说明
Prior image embeddings for text prompt # 为文本提示生成的图像嵌入
"""
image_embeddings: Union[torch.Tensor, np.ndarray] # 定义图像嵌入属性
class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): # 定义 WuerstchenPriorPipeline 类,继承自扩散管道和加载器
"""
用于生成 Wuerstchen 图像先验的管道。
此模型继承自 [`DiffusionPipeline`]。查看超类文档以获取库实现的所有管道的通用方法(例如下载、保存、在特定设备上运行等)
该管道还继承以下加载方法:
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用于加载 LoRA 权重
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用于保存 LoRA 权重
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
```py # 文档结束
``` # 文档结束
# 文档字符串,说明构造函数参数及其作用
Args:
prior ([`Prior`]):
# 指定用于从文本嵌入近似图像嵌入的标准 unCLIP 先验
text_encoder ([`CLIPTextModelWithProjection`]):
# 冻结的文本编码器
tokenizer (`CLIPTokenizer`):
# 用于文本处理的标记器,详细信息见 CLIPTokenizer 文档
scheduler ([`DDPMWuerstchenScheduler`]):
# 与 `prior` 结合使用的调度器,用于生成图像嵌入
latent_mean ('float', *optional*, defaults to 42.0):
# 潜在扩散器的均值
latent_std ('float', *optional*, defaults to 1.0):
# 潜在扩散器的标准差
resolution_multiple ('float', *optional*, defaults to 42.67):
# 生成多个图像时的默认分辨率
"""
# 定义 unet 的名称为 "prior"
unet_name = "prior"
# 定义文本编码器的名称
text_encoder_name = "text_encoder"
# 定义模型的 CPU 卸载顺序
model_cpu_offload_seq = "text_encoder->prior"
# 定义回调张量输入的列表
_callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]
# 定义可加载的 LoRA 模块
_lora_loadable_modules = ["prior", "text_encoder"]
# 初始化函数,设置类的属性
def __init__(
self,
# 初始化所需的标记器
tokenizer: CLIPTokenizer,
# 初始化所需的文本编码器
text_encoder: CLIPTextModel,
# 初始化所需的 unCLIP 先验
prior: WuerstchenPrior,
# 初始化所需的调度器
scheduler: DDPMWuerstchenScheduler,
# 设置潜在均值,默认值为 42.0
latent_mean: float = 42.0,
# 设置潜在标准差,默认值为 1.0
latent_std: float = 1.0,
# 设置生成图像的默认分辨率倍数,默认值为 42.67
resolution_multiple: float = 42.67,
) -> None:
# 调用父类的初始化方法
super().__init__()
# 注册所需的模块
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
prior=prior,
scheduler=scheduler,
)
# 将配置注册到类中
self.register_to_config(
latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple
)
# 从指定的管道准备潜在张量
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
# 如果未提供潜在张量,则生成随机张量
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 检查潜在张量的形状是否匹配
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
# 将潜在张量移动到指定设备
latents = latents.to(device)
# 将潜在张量乘以调度器的初始噪声标准差
latents = latents * scheduler.init_noise_sigma
# 返回准备好的潜在张量
return latents
# 编码提示信息,处理正向和负向提示
def encode_prompt(
self,
# 指定设备
device,
# 每个提示生成的图像数量
num_images_per_prompt,
# 是否进行无分类器自由引导
do_classifier_free_guidance,
# 正向提示文本
prompt=None,
# 负向提示文本
negative_prompt=None,
# 提示的嵌入张量,若有则提供
prompt_embeds: Optional[torch.Tensor] = None,
# 负向提示的嵌入张量,若有则提供
negative_prompt_embeds: Optional[torch.Tensor] = None,
# 检查输入的有效性
def check_inputs(
self,
# 正向提示文本
prompt,
# 负向提示文本
negative_prompt,
# 推理步骤的数量
num_inference_steps,
# 是否进行无分类器自由引导
do_classifier_free_guidance,
# 提示的嵌入张量,若有则提供
prompt_embeds=None,
# 负向提示的嵌入张量,若有则提供
negative_prompt_embeds=None,
# 检查 prompt 和 prompt_embeds 是否同时存在
):
if prompt is not None and prompt_embeds is not None:
# 抛出异常,提示不能同时提供 prompt 和 prompt_embeds
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
# 检查 prompt 和 prompt_embeds 是否都未定义
elif prompt is None and prompt_embeds is None:
# 抛出异常,提示必须提供 prompt 或 prompt_embeds
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
# 检查 prompt 是否为有效类型
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
# 抛出异常,提示 prompt 类型不正确
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 检查 negative_prompt 和 negative_prompt_embeds 是否同时存在
if negative_prompt is not None and negative_prompt_embeds is not None:
# 抛出异常,提示不能同时提供 negative_prompt 和 negative_prompt_embeds
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
# 检查 prompt_embeds 和 negative_prompt_embeds 是否同时存在
if prompt_embeds is not None and negative_prompt_embeds is not None:
# 验证这两个张量的形状是否一致
if prompt_embeds.shape != negative_prompt_embeds.shape:
# 抛出异常,提示形状不匹配
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# 检查 num_inference_steps 是否为整数
if not isinstance(num_inference_steps, int):
# 抛出异常,提示 num_inference_steps 类型不正确
raise TypeError(
f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\
In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
)
# 定义属性 guidance_scale,返回该类的 _guidance_scale 值
@property
def guidance_scale(self):
return self._guidance_scale
# 定义属性 do_classifier_free_guidance,判断是否执行无分类器引导
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
# 定义属性 num_timesteps,返回该类的 _num_timesteps 值
@property
def num_timesteps(self):
return self._num_timesteps
# 定义可调用方法,执行主要功能
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
# 定义方法的参数,包括 prompt 和其他配置
self,
prompt: Optional[Union[str, List[str]]] = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 60,
timesteps: List[float] = None,
guidance_scale: float = 8.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pt",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
.\diffusers\pipelines\wuerstchen\__init__.py
# 从 typing 模块导入 TYPE_CHECKING,用于类型检查
from typing import TYPE_CHECKING
# 从 utils 模块导入各种工具和常量
from ...utils import (
DIFFUSERS_SLOW_IMPORT, # 指示是否进行慢速导入
OptionalDependencyNotAvailable, # 处理可选依赖项不可用的异常
_LazyModule, # 用于延迟加载模块
get_objects_from_module, # 从模块中获取对象的函数
is_torch_available, # 检查 PyTorch 是否可用的函数
is_transformers_available, # 检查 Transformers 是否可用的函数
)
# 初始化一个空字典,用于存储虚拟对象
_dummy_objects = {}
# 初始化一个空字典,用于存储导入结构
_import_structure = {}
# 尝试检查依赖项的可用性
try:
# 如果 Transformers 和 Torch 不可用,则抛出异常
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖项不可用的异常
except OptionalDependencyNotAvailable:
# 从 utils 模块导入虚拟对象
from ...utils import dummy_torch_and_transformers_objects
# 更新虚拟对象字典
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果依赖项可用,更新导入结构
else:
_import_structure["modeling_paella_vq_model"] = ["PaellaVQModel"] # 添加 PaellaVQModel
_import_structure["modeling_wuerstchen_diffnext"] = ["WuerstchenDiffNeXt"] # 添加 WuerstchenDiffNeXt
_import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"] # 添加 WuerstchenPrior
_import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"] # 添加 WuerstchenDecoderPipeline
_import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"] # 添加 WuerstchenCombinedPipeline
_import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"] # 添加相关管道
# 根据类型检查或慢速导入的标志进行条件判断
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
# 尝试检查依赖项的可用性
try:
if not (is_transformers_available() and is_torch_available()): # 同样检查可用性
raise OptionalDependencyNotAvailable() # 抛出异常
# 捕获可选依赖项不可用的异常
except OptionalDependencyNotAvailable:
# 从虚拟对象模块导入所有内容
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
# 从各个模块导入必要的类和函数
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
from .modeling_wuerstchen_prior import WuerstchenPrior
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPriorPipeline
else:
# 如果不是类型检查或慢速导入,导入 sys 模块
import sys
# 将当前模块替换为一个延迟加载模块
sys.modules[__name__] = _LazyModule(
__name__, # 模块名称
globals()["__file__"], # 当前文件
_import_structure, # 导入结构
module_spec=__spec__, # 模块规格
)
# 将虚拟对象添加到当前模块
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value) # 设置属性
.\diffusers\pipelines\__init__.py
# 导入类型检查的模块
from typing import TYPE_CHECKING
# 从父级目录的 utils 模块中导入多个对象和函数
from ..utils import (
DIFFUSERS_SLOW_IMPORT, # 导入一个慢加载的功能
OptionalDependencyNotAvailable, # 导入可选依赖不可用的异常
_LazyModule, # 导入懒加载模块的工具
get_objects_from_module, # 导入从模块中获取对象的函数
is_flax_available, # 导入检查 Flax 库是否可用的函数
is_k_diffusion_available, # 导入检查 K-Diffusion 库是否可用的函数
is_librosa_available, # 导入检查 Librosa 库是否可用的函数
is_note_seq_available, # 导入检查 NoteSeq 库是否可用的函数
is_onnx_available, # 导入检查 ONNX 库是否可用的函数
is_sentencepiece_available, # 导入检查 SentencePiece 库是否可用的函数
is_torch_available, # 导入检查 PyTorch 库是否可用的函数
is_torch_npu_available, # 导入检查 NPU 版 PyTorch 是否可用的函数
is_transformers_available, # 导入检查 Transformers 库是否可用的函数
)
# 初始化一个空字典以存储假对象
_dummy_objects = {}
# 定义一个字典以组织导入的模块结构
_import_structure = {
"controlnet": [], # 控制网模块
"controlnet_hunyuandit": [], # 控制网相关模块
"controlnet_sd3": [], # 控制网 SD3 模块
"controlnet_xs": [], # 控制网 XS 模块
"deprecated": [], # 存放弃用模块
"latent_diffusion": [], # 潜在扩散模块
"ledits_pp": [], # LEDITS PP 模块
"marigold": [], # 万寿菊模块
"pag": [], # PAG 模块
"stable_diffusion": [], # 稳定扩散模块
"stable_diffusion_xl": [], # 稳定扩散 XL 模块
}
try:
# 检查 PyTorch 是否可用
if not is_torch_available():
# 如果不可用,抛出可选依赖不可用的异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 如果捕获到异常,从 utils 模块导入假对象(PyTorch 相关)
from ..utils import dummy_pt_objects # noqa F403
# 将获取的假对象更新到字典中
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
# 如果 PyTorch 可用,更新导入结构,添加自动管道类
_import_structure["auto_pipeline"] = [
"AutoPipelineForImage2Image", # 图像到图像的自动管道
"AutoPipelineForInpainting", # 图像修复的自动管道
"AutoPipelineForText2Image", # 文本到图像的自动管道
]
# 添加一致性模型管道
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
# 添加舞蹈扩散管道
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
# 添加 DDIM 管道
_import_structure["ddim"] = ["DDIMPipeline"]
# 添加 DDPM 管道
_import_structure["ddpm"] = ["DDPMPipeline"]
# 添加 DiT 管道
_import_structure["dit"] = ["DiTPipeline"]
# 扩展潜在扩散模块,添加超分辨率管道
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
# 添加管道工具的输出类型
_import_structure["pipeline_utils"] = [
"AudioPipelineOutput", # 音频管道输出
"DiffusionPipeline", # 扩散管道
"StableDiffusionMixin", # 稳定扩散混合类
"ImagePipelineOutput", # 图像管道输出
]
# 扩展弃用模块,添加弃用的管道
_import_structure["deprecated"].extend(
[
"PNDMPipeline", # PNDM 管道
"LDMPipeline", # LDM 管道
"RePaintPipeline", # 重绘管道
"ScoreSdeVePipeline", # Score SDE VE 管道
"KarrasVePipeline", # Karras VE 管道
]
)
try:
# 检查 PyTorch 和 Librosa 是否都可用
if not (is_torch_available() and is_librosa_available()):
# 如果其中一个不可用,抛出可选依赖不可用的异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 捕获异常,从 utils 模块导入假对象(PyTorch 和 Librosa 相关)
from ..utils import dummy_torch_and_librosa_objects # noqa F403
# 更新假对象字典
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
else:
# 如果两个库都可用,扩展弃用模块,添加音频扩散管道和 Mel 类
_import_structure["deprecated"].extend(["AudioDiffusionPipeline", "Mel"])
try:
# 检查 Transformers、PyTorch 和 NoteSeq 是否都可用
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
# 如果其中一个不可用,抛出可选依赖不可用的异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 捕获异常,从 utils 模块导入假对象(Transformers、PyTorch 和 NoteSeq 相关)
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
# 更新假对象字典
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
# 如果三个库都可用,扩展弃用模块,添加 MIDI 处理器和谱图扩散管道
_import_structure["deprecated"].extend(
[
"MidiProcessor", # MIDI 处理器
"SpectrogramDiffusionPipeline", # 谱图扩散管道
]
)
try:
# 检查 PyTorch 和 Transformers 库是否可用
if not (is_torch_available() and is_transformers_available()):
# 如果任一库不可用,抛出异常表示可选依赖不可用
raise OptionalDependencyNotAvailable()
# 捕获未满足可选依赖项的异常
except OptionalDependencyNotAvailable:
# 从上层模块导入虚拟的 Torch 和 Transformers 对象
from ..utils import dummy_torch_and_transformers_objects # noqa F403
# 更新虚拟对象的字典,以获取导入的虚拟对象
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
# 将过时的导入结构中添加一组管道名称
_import_structure["deprecated"].extend(
[
"VQDiffusionPipeline",
"AltDiffusionPipeline",
"AltDiffusionImg2ImgPipeline",
"CycleDiffusionPipeline",
"StableDiffusionInpaintPipelineLegacy",
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionParadigmsPipeline",
"StableDiffusionModelEditingPipeline",
"VersatileDiffusionDualGuidedPipeline",
"VersatileDiffusionImageVariationPipeline",
"VersatileDiffusionPipeline",
"VersatileDiffusionTextToImagePipeline",
]
)
# 为“amused”添加相关管道名称
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
# 为“animatediff”添加相关管道名称
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
"AnimateDiffControlNetPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
]
# 为“flux”添加相关管道名称
_import_structure["flux"] = ["FluxPipeline"]
# 为“audioldm”添加相关管道名称
_import_structure["audioldm"] = ["AudioLDMPipeline"]
# 为“audioldm2”添加相关管道名称
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
]
# 为“blip_diffusion”添加相关管道名称
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
# 为“cogvideo”添加相关管道名称
_import_structure["cogvideo"] = [
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline",
]
# 为“controlnet”扩展相关管道名称
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline",
]
)
# 为“pag”扩展相关管道名称
_import_structure["pag"].extend(
[
"AnimateDiffPAGPipeline",
"KolorsPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
"PixArtSigmaPAGPipeline",
]
)
# 为“controlnet_xs”扩展相关管道名称
_import_structure["controlnet_xs"].extend(
[
"StableDiffusionControlNetXSPipeline",
"StableDiffusionXLControlNetXSPipeline",
]
)
# 为“controlnet_hunyuandit”扩展相关管道名称
_import_structure["controlnet_hunyuandit"].extend(
[
"HunyuanDiTControlNetPipeline",
]
)
# 将 "StableDiffusion3ControlNetPipeline" 添加到 "controlnet_sd3" 的导入结构中
_import_structure["controlnet_sd3"].extend(
[
"StableDiffusion3ControlNetPipeline",
]
)
# 定义 "deepfloyd_if" 的导入结构,包含多个管道
_import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
"IFInpaintingPipeline",
"IFInpaintingSuperResolutionPipeline",
"IFPipeline",
"IFSuperResolutionPipeline",
]
# 设置 "hunyuandit" 的导入结构,仅包含一个管道
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
# 定义 "kandinsky" 的导入结构,包含多个相关管道
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
"KandinskyInpaintCombinedPipeline",
"KandinskyInpaintPipeline",
"KandinskyPipeline",
"KandinskyPriorPipeline",
]
# 定义 "kandinsky2_2" 的导入结构,包含多个管道
_import_structure["kandinsky2_2"] = [
"KandinskyV22CombinedPipeline",
"KandinskyV22ControlnetImg2ImgPipeline",
"KandinskyV22ControlnetPipeline",
"KandinskyV22Img2ImgCombinedPipeline",
"KandinskyV22Img2ImgPipeline",
"KandinskyV22InpaintCombinedPipeline",
"KandinskyV22InpaintPipeline",
"KandinskyV22Pipeline",
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
]
# 定义 "kandinsky3" 的导入结构,包含两个管道
_import_structure["kandinsky3"] = [
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
]
# 定义 "latent_consistency_models" 的导入结构,包含两个管道
_import_structure["latent_consistency_models"] = [
"LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline",
]
# 将 "LDMTextToImagePipeline" 添加到 "latent_diffusion" 的导入结构中
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
# 将稳定扩散相关的管道添加到 "ledits_pp" 的导入结构中
_import_structure["ledits_pp"].extend(
[
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
]
)
# 设置 "latte" 的导入结构,仅包含一个管道
_import_structure["latte"] = ["LattePipeline"]
# 设置 "lumina" 的导入结构,仅包含一个管道
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
# 将 "MarigoldDepthPipeline" 和 "MarigoldNormalsPipeline" 添加到 "marigold" 的导入结构中
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
]
)
# 设置 "musicldm" 的导入结构,仅包含一个管道
_import_structure["musicldm"] = ["MusicLDMPipeline"]
# 设置 "paint_by_example" 的导入结构,仅包含一个管道
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
# 设置 "pia" 的导入结构,仅包含一个管道
_import_structure["pia"] = ["PIAPipeline"]
# 设置 "pixart_alpha" 的导入结构,包含两个管道
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
# 设置 "semantic_stable_diffusion" 的导入结构,仅包含一个管道
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
# 设置 "shap_e" 的导入结构,包含两个管道
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
# 定义 "stable_audio" 的导入结构,包含两个管道
_import_structure["stable_audio"] = [
"StableAudioProjectionModel",
"StableAudioPipeline",
]
# 定义 "stable_cascade" 的导入结构,包含多个管道
_import_structure["stable_cascade"] = [
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
]
# 向 stable_diffusion 的导入结构中添加多个相关的管道名称
_import_structure["stable_diffusion"].extend(
[
# 添加 CLIP 图像投影管道
"CLIPImageProjection",
# 添加稳定扩散深度到图像管道
"StableDiffusionDepth2ImgPipeline",
# 添加稳定扩散图像变体管道
"StableDiffusionImageVariationPipeline",
# 添加稳定扩散图像到图像管道
"StableDiffusionImg2ImgPipeline",
# 添加稳定扩散图像修复管道
"StableDiffusionInpaintPipeline",
# 添加稳定扩散指令图像到图像管道
"StableDiffusionInstructPix2PixPipeline",
# 添加稳定扩散潜在上采样管道
"StableDiffusionLatentUpscalePipeline",
# 添加稳定扩散主管道
"StableDiffusionPipeline",
# 添加稳定扩散上采样管道
"StableDiffusionUpscalePipeline",
# 添加稳定 UnCLIP 图像到图像管道
"StableUnCLIPImg2ImgPipeline",
# 添加稳定 UnCLIP 管道
"StableUnCLIPPipeline",
# 添加稳定扩散 LDM 3D 管道
"StableDiffusionLDM3DPipeline",
]
)
# 为 aura_flow 设置导入结构,包括其管道
_import_structure["aura_flow"] = ["AuraFlowPipeline"]
# 为 stable_diffusion_3 设置相关管道
_import_structure["stable_diffusion_3"] = [
# 添加稳定扩散 3 管道
"StableDiffusion3Pipeline",
# 添加稳定扩散 3 图像到图像管道
"StableDiffusion3Img2ImgPipeline",
# 添加稳定扩散 3 图像修复管道
"StableDiffusion3InpaintPipeline",
]
# 为 stable_diffusion_attend_and_excite 设置导入结构
_import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
# 为 stable_diffusion_safe 设置安全管道
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
# 为 stable_diffusion_sag 设置导入结构
_import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
# 为 stable_diffusion_gligen 设置导入结构
_import_structure["stable_diffusion_gligen"] = [
# 添加稳定扩散 GLIGEN 管道
"StableDiffusionGLIGENPipeline",
# 添加稳定扩散 GLIGEN 文本图像管道
"StableDiffusionGLIGENTextImagePipeline",
]
# 为 stable_video_diffusion 设置导入结构
_import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"]
# 向 stable_diffusion_xl 的导入结构中添加多个管道
_import_structure["stable_diffusion_xl"].extend(
[
# 添加稳定扩散 XL 图像到图像管道
"StableDiffusionXLImg2ImgPipeline",
# 添加稳定扩散 XL 图像修复管道
"StableDiffusionXLInpaintPipeline",
# 添加稳定扩散 XL 指令图像到图像管道
"StableDiffusionXLInstructPix2PixPipeline",
# 添加稳定扩散 XL 主管道
"StableDiffusionXLPipeline",
]
)
# 为 stable_diffusion_diffedit 设置导入结构
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
# 为 stable_diffusion_ldm3d 设置导入结构
_import_structure["stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"]
# 为 stable_diffusion_panorama 设置导入结构
_import_structure["stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
# 为 t2i_adapter 设置导入结构,包括适配管道
_import_structure["t2i_adapter"] = [
# 添加稳定扩散适配器管道
"StableDiffusionAdapterPipeline",
# 添加稳定扩散 XL 适配器管道
"StableDiffusionXLAdapterPipeline",
]
# 为 text_to_video_synthesis 设置多个视频合成相关管道
_import_structure["text_to_video_synthesis"] = [
# 添加文本到视频稳定扩散管道
"TextToVideoSDPipeline",
# 添加文本到视频零管道
"TextToVideoZeroPipeline",
# 添加文本到视频零稳定扩散 XL 管道
"TextToVideoZeroSDXLPipeline",
# 添加视频到视频稳定扩散管道
"VideoToVideoSDPipeline",
]
# 为 i2vgen_xl 设置导入结构
_import_structure["i2vgen_xl"] = ["I2VGenXLPipeline"]
# 为 unclip 设置相关管道
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
# 为 unidiffuser 设置多个管道
_import_structure["unidiffuser"] = [
# 添加图像文本管道输出
"ImageTextPipelineOutput",
# 添加 UniDiffuser 模型
"UniDiffuserModel",
# 添加 UniDiffuser 管道
"UniDiffuserPipeline",
# 添加 UniDiffuser 文本解码器
"UniDiffuserTextDecoder",
]
# 为 wuerstchen 设置多个管道
_import_structure["wuerstchen"] = [
# 添加 Wuerstchen 组合管道
"WuerstchenCombinedPipeline",
# 添加 Wuerstchen 解码器管道
"WuerstchenDecoderPipeline",
# 添加 Wuerstchen 先验管道
"WuerstchenPriorPipeline",
]
# 尝试检查 ONNX 是否可用
try:
# 如果 ONNX 不可用,抛出异常
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从工具模块导入假 ONNX 对象,防止导入错误
from ..utils import dummy_onnx_objects # noqa F403
# 更新虚拟对象字典,添加假 ONNX 对象
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
# 如果 ONNX 可用,更新导入结构
else:
_import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
# 尝试检查 PyTorch、Transformers 和 ONNX 是否都可用
try:
# 如果任何一个不可用,抛出异常
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从工具模块导入假 PyTorch、Transformers 和 ONNX 对象
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
# 更新虚拟对象字典,添加假对象
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
# 如果都可用,扩展导入结构
else:
_import_structure["stable_diffusion"].extend(
[
"OnnxStableDiffusionImg2ImgPipeline",
"OnnxStableDiffusionInpaintPipeline",
"OnnxStableDiffusionPipeline",
"OnnxStableDiffusionUpscalePipeline",
"StableDiffusionOnnxPipeline",
]
)
# 尝试检查 PyTorch、Transformers 和 K-Diffusion 是否都可用
try:
# 如果任何一个不可用,抛出异常
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从工具模块导入假 PyTorch、Transformers 和 K-Diffusion 对象
from ..utils import (
dummy_torch_and_transformers_and_k_diffusion_objects,
)
# 更新虚拟对象字典,添加假对象
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
# 如果都可用,更新导入结构
else:
_import_structure["stable_diffusion_k_diffusion"] = [
"StableDiffusionKDiffusionPipeline",
"StableDiffusionXLKDiffusionPipeline",
]
# 尝试检查 PyTorch、Transformers 和 SentencePiece 是否都可用
try:
# 如果任何一个不可用,抛出异常
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从工具模块导入假 PyTorch、Transformers 和 SentencePiece 对象
from ..utils import (
dummy_torch_and_transformers_and_sentencepiece_objects,
)
# 更新虚拟对象字典,添加假对象
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects))
# 如果都可用,更新导入结构
else:
_import_structure["kolors"] = [
"KolorsPipeline",
"KolorsImg2ImgPipeline",
]
# 尝试检查 Flax 是否可用
try:
# 如果 Flax 不可用,抛出异常
if not is_flax_available():
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从工具模块导入假 Flax 对象,防止导入错误
from ..utils import dummy_flax_objects # noqa F403
# 更新虚拟对象字典,添加假 Flax 对象
_dummy_objects.update(get_objects_from_module(dummy_flax_objects))
# 如果 Flax 可用,更新导入结构
else:
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
# 尝试检查 Flax 和 Transformers 是否都可用
try:
# 如果任何一个不可用,抛出异常
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从工具模块导入假 Flax 和 Transformers 对象
from ..utils import dummy_flax_and_transformers_objects # noqa F403
# 更新虚拟对象字典,添加假对象
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
# 如果都可用,扩展导入结构
else:
_import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
# 将稳定扩散模型相关的类名添加到导入结构中
_import_structure["stable_diffusion"].extend(
[
# 添加图像到图像转换管道类名
"FlaxStableDiffusionImg2ImgPipeline",
# 添加图像修复管道类名
"FlaxStableDiffusionInpaintPipeline",
# 添加基础稳定扩散管道类名
"FlaxStableDiffusionPipeline",
]
)
# 将稳定扩散 XL 模型相关的类名添加到导入结构中
_import_structure["stable_diffusion_xl"].extend(
[
# 添加稳定扩散 XL 管道类名
"FlaxStableDiffusionXLPipeline",
]
)
# 检查是否为类型检查或慢导入条件
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
# 检查是否可用 PyTorch
if not is_torch_available():
# 如果不可用,则引发可选依赖项不可用异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 导入占位符对象以避免运行时错误
from ..utils.dummy_pt_objects import * # noqa F403
else:
# 导入自动图像到图像管道相关类
from .auto_pipeline import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
)
# 导入一致性模型管道
from .consistency_models import ConsistencyModelPipeline
# 导入舞蹈扩散管道
from .dance_diffusion import DanceDiffusionPipeline
# 导入 DDIM 管道
from .ddim import DDIMPipeline
# 导入 DDPM 管道
from .ddpm import DDPMPipeline
# 导入已弃用的管道
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
# 导入 DIT 管道
from .dit import DiTPipeline
# 导入潜在扩散超分辨率管道
from .latent_diffusion import LDMSuperResolutionPipeline
# 导入管道工具类
from .pipeline_utils import (
AudioPipelineOutput,
DiffusionPipeline,
ImagePipelineOutput,
StableDiffusionMixin,
)
try:
# 检查是否可用 PyTorch 和 librosa
if not (is_torch_available() and is_librosa_available()):
# 如果不可用,则引发可选依赖项不可用异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 导入占位符对象以避免运行时错误
from ..utils.dummy_torch_and_librosa_objects import *
else:
# 导入已弃用的音频扩散管道和 Mel 类
from .deprecated import AudioDiffusionPipeline, Mel
try:
# 检查是否可用 PyTorch 和 transformers
if not (is_torch_available() and is_transformers_available()):
# 如果不可用,则引发可选依赖项不可用异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 导入占位符对象以避免运行时错误
from ..utils.dummy_torch_and_transformers_objects import *
else:
# 导入 sys 模块
import sys
# 创建懒加载模块实例
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
# 设置占位符对象到当前模块
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# 版权所有 2024 NVIDIA 和 The HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,版本 2.0("许可证")许可;
# 除非遵循该许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件在许可证下分发是以“按原样”基础进行的,
# 不提供任何形式的明示或暗示的担保或条件。
# 有关许可证下权限和限制的具体语言,请参见许可证。
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入可选类型、元组和联合类型
from typing import Optional, Tuple, Union
# 导入 NumPy 库,通常用于数值计算
import numpy as np
# 导入 PyTorch 库,通常用于深度学习
import torch
# 从配置工具中导入 ConfigMixin 和 register_to_config
from ...configuration_utils import ConfigMixin, register_to_config
# 从 utils 模块导入 BaseOutput 基类
from ...utils import BaseOutput
# 从 utils.torch_utils 导入生成随机张量的函数
from ...utils.torch_utils import randn_tensor
# 从调度工具中导入 SchedulerMixin
from ..scheduling_utils import SchedulerMixin
# 定义 KarrasVeOutput 类,继承自 BaseOutput
@dataclass
class KarrasVeOutput(BaseOutput):
"""
调度器步骤函数输出的输出类。
参数:
prev_sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`,用于图像):
先前时间步的计算样本 (x_{t-1})。`prev_sample` 应作为下一个模型输入使用
在去噪循环中。
derivative (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`,用于图像):
预测的原始图像样本的导数 (x_0)。
pred_original_sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`,用于图像):
基于当前时间步模型输出的预测去噪样本 (x_{0})。
`pred_original_sample` 可用于预览进度或进行引导。
"""
# 先前样本,类型为 torch.Tensor
prev_sample: torch.Tensor
# 导数,类型为 torch.Tensor
derivative: torch.Tensor
# 可选的预测原始样本,类型为 torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# 定义 KarrasVeScheduler 类,继承自 SchedulerMixin 和 ConfigMixin
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
针对方差扩展模型的随机调度器。
该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。有关库为所有调度器实现的通用
方法的详细信息,请查看超类文档,例如加载和保存。
<Tip>
有关参数的更多详细信息,请参见 [附录 E](https://arxiv.org/abs/2206.00364)。用于查找特定模型的
最优 `{s_noise, s_churn, s_min, s_max}` 的网格搜索值在论文的表 5 中进行了描述。
</Tip>
# 参数说明部分,描述每个参数的含义和默认值
Args:
sigma_min (`float`, defaults to 0.02):
# 最小噪声幅度
The minimum noise magnitude.
sigma_max (`float`, defaults to 100):
# 最大噪声幅度
The maximum noise magnitude.
s_noise (`float`, defaults to 1.007):
# 额外噪声量,抵消采样时的细节损失,合理范围为 [1.000, 1.011]
The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
1.011].
s_churn (`float`, defaults to 80):
# 控制整体随机性程度的参数,合理范围为 [0, 100]
The parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100].
s_min (`float`, defaults to 0.05):
# 添加噪声的起始 sigma 范围值,合理范围为 [0, 10]
The start value of the sigma range to add noise (enable stochasticity). A reasonable range is [0, 10].
s_max (`float`, defaults to 50):
# 添加噪声的结束 sigma 范围值,合理范围为 [0.2, 80]
The end value of the sigma range to add noise. A reasonable range is [0.2, 80].
"""
# 定义阶数为 2
order = 2
# 初始化方法,注册到配置
@register_to_config
def __init__(
self,
# 最小噪声幅度,默认值为 0.02
sigma_min: float = 0.02,
# 最大噪声幅度,默认值为 100
sigma_max: float = 100,
# 额外噪声量,默认值为 1.007
s_noise: float = 1.007,
# 随机性控制参数,默认值为 80
s_churn: float = 80,
# sigma 范围起始值,默认值为 0.05
s_min: float = 0.05,
# sigma 范围结束值,默认值为 50
s_max: float = 50,
):
# 设置初始噪声分布的标准差
self.init_noise_sigma = sigma_max
# 可设置值
# 推理步骤的数量,初始为 None
self.num_inference_steps: int = None
# 时间步的张量,初始为 None
self.timesteps: np.IntTensor = None
# sigma(t_i) 的张量,初始为 None
self.schedule: torch.Tensor = None # sigma(t_i)
# 处理模型输入以确保与调度器的互换性
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器的互换性。
Args:
sample (`torch.Tensor`):
# 输入样本
The input sample.
timestep (`int`, *optional*):
# 当前扩散链中的时间步
The current timestep in the diffusion chain.
Returns:
`torch.Tensor`:
# 返回缩放后的输入样本
A scaled input sample.
"""
# 返回未改变的样本
return sample
# 设置扩散链使用的离散时间步
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
设置用于扩散链的离散时间步(在推理之前运行)。
Args:
num_inference_steps (`int`):
# 生成样本时使用的扩散步骤数量
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
# 将时间步移动到的设备,如果为 None,则不移动
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# 设置推理步骤数量
self.num_inference_steps = num_inference_steps
# 创建时间步数组并反转
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
# 将时间步转换为张量并移动到指定设备
self.timesteps = torch.from_numpy(timesteps).to(device)
# 计算调度的 sigma 值
schedule = [
(
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps
]
# 将调度值转换为张量
self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
# 定义添加噪声到输入样本的函数
def add_noise_to_input(
self, sample: torch.Tensor, sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[torch.Tensor, float]:
"""
显式的 Langevin 类似的“搅动”步骤,根据 `gamma_i ≥ 0` 添加噪声,以达到更高的噪声水平 `sigma_hat = sigma_i + gamma_i*sigma_i`。
参数:
sample (`torch.Tensor`):
输入样本。
sigma (`float`):
generator (`torch.Generator`, *可选*):
随机数生成器。
"""
# 检查 sigma 是否在配置的最小值和最大值之间
if self.config.s_min <= sigma <= self.config.s_max:
# 计算 gamma,确保不会超过最大值
gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
else:
# 如果不在范围内,gamma 为 0
gamma = 0
# 从标准正态分布中采样噪声 eps
eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
# 计算新的噪声水平
sigma_hat = sigma + gamma * sigma
# 更新样本,添加噪声
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
# 返回更新后的样本和新的噪声水平
return sample_hat, sigma_hat
# 定义从上一个时间步预测样本的步骤函数
def step(
self,
model_output: torch.Tensor,
sigma_hat: float,
sigma_prev: float,
sample_hat: torch.Tensor,
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
"""
通过反转 SDE 从学习的模型输出中传播扩散过程(通常是预测的噪声)。
参数:
model_output (`torch.Tensor`):
学习扩散模型的直接输出。
sigma_hat (`float`):
sigma_prev (`float`):
sample_hat (`torch.Tensor`):
return_dict (`bool`, *可选*, 默认为 `True`):
是否返回一个 [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] 或 `tuple`。
返回:
[`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] 或 `tuple`:
如果 return_dict 为 `True`,返回 [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`],
否则返回一个元组,第一个元素是样本张量。
"""
# 根据模型输出和 sigma_hat 计算预测的原始样本
pred_original_sample = sample_hat + sigma_hat * model_output
# 计算样本的导数
derivative = (sample_hat - pred_original_sample) / sigma_hat
# 计算上一个时间步的样本
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
# 如果不返回字典,返回样本和导数
if not return_dict:
return (sample_prev, derivative)
# 返回包含样本、导数和预测原始样本的 KarrasVeOutput 对象
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)
# 定义带有修正步骤的函数
def step_correct(
self,
model_output: torch.Tensor,
sigma_hat: float,
sigma_prev: float,
sample_hat: torch.Tensor,
sample_prev: torch.Tensor,
derivative: torch.Tensor,
return_dict: bool = True,
# 处理网络的模型输出,纠正预测样本
) -> Union[KarrasVeOutput, Tuple]:
"""
# 根据网络的模型输出修正预测样本
Args:
model_output (`torch.Tensor`):
# 从学习的扩散模型直接输出的张量
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.Tensor`): TODO
sample_prev (`torch.Tensor`): TODO
derivative (`torch.Tensor`): TODO
return_dict (`bool`, *optional*, defaults to `True`):
# 是否返回 [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] 或 `tuple`
Returns:
prev_sample (TODO): # 在扩散链中的更新样本。 derivative (TODO): TODO
"""
# 通过前一个样本和模型输出计算预测的原始样本
pred_original_sample = sample_prev + sigma_prev * model_output
# 计算修正后的导数
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
# 更新前一个样本,根据当前和预测的导数
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
# 如果不返回字典,则返回更新的样本和导数
if not return_dict:
return (sample_prev, derivative)
# 返回 KarrasVeOutput 对象,包含更新的样本和导数
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)
# 声明未实现的方法,用于添加噪声
def add_noise(self, original_samples, noise, timesteps):
# 引发未实现错误
raise NotImplementedError()