diffusers 源码解析(六十一)
# 版权声明,包含版权年份、作者及其团队信息
# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# 根据 Apache License 2.0 进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 只有在遵守许可证的情况下才能使用此文件
# you may not use this file except in compliance with the License.
# 可以在以下网址获取许可证副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有约定,软件按照 "AS IS" 原则分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的保证或条件,包括明示或暗示
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 请参阅许可证以了解具体的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入数学库
import math
# 从 typing 模块导入类型提示
from typing import List, Optional, Tuple, Union
# 导入 NumPy 库
import numpy as np
# 导入 PyTorch 库
import torch
# 从配置实用工具中导入 ConfigMixin 和 register_to_config
from ..configuration_utils import ConfigMixin, register_to_config
# 从调度实用工具中导入 KarrasDiffusionSchedulers、SchedulerMixin 和 SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# 从 diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 复制的函数
def betas_for_alpha_bar(
num_diffusion_timesteps, # 定义扩散时间步数
max_beta=0.999, # 定义最大 beta 值,默认为 0.999
alpha_transform_type="cosine", # 定义 alpha 转换类型,默认为 'cosine'
):
"""
创建一个 beta 调度,以离散化给定的 alpha_t_bar 函数,该函数定义了时间 t = [0,1] 的
(1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接受一个参数 t,并将其转换为扩散过程的该部分
(1-beta) 的累积乘积。
参数:
num_diffusion_timesteps (`int`): 生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用小于 1 的值以防止奇点。
alpha_transform_type (`str`, *可选*, 默认值为 `cosine`): alpha_bar 的噪声调度类型。
从 `cosine` 或 `exp` 中选择
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的 betas
"""
# 如果 alpha 转换类型为 'cosine'
if alpha_transform_type == "cosine":
# 定义 alpha_bar 函数,基于余弦函数计算
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
# 如果 alpha 转换类型为 'exp'
elif alpha_transform_type == "exp":
# 定义 alpha_bar 函数,基于指数函数计算
def alpha_bar_fn(t):
return math.exp(t * -12.0)
# 如果 alpha 转换类型不受支持,抛出错误
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = [] # 初始化一个空列表用于存储 beta 值
# 遍历每个扩散时间步
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps # 当前时间步
t2 = (i + 1) / num_diffusion_timesteps # 下一个时间步
# 计算 beta 值并添加到列表,限制在最大 beta 范围内
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
# 返回 beta 值的张量
return torch.tensor(betas, dtype=torch.float32)
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
KDPM2DiscreteScheduler 的灵感来自 DPMSolver2 和论文
[Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364)。
此模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。查看超类文档以了解库为所有调度器实现的通用方法,如加载和保存。
"""
# 参数列表说明
Args:
num_train_timesteps (`int`, defaults to 1000): # 模型训练的扩散步骤数,默认为1000
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.00085): # 推理的起始 `beta` 值,默认为0.00085
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.012): # 最终的 `beta` 值,默认为0.012
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`): # beta 的调度方式,默认为线性,可以选择线性或缩放线性
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*): # 可选参数,直接传入 beta 数组以绕过 beta_start 和 beta_end
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`): # 可选参数,指示是否使用 Karras sigmas 进行噪声调度
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*): # 预测类型,可选值包括 epsilon、sample 或 v_prediction
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
timestep_spacing (`str`, defaults to `"linspace"`): # 时间步长的缩放方式,默认为线性空间
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): # 推理步骤的偏移量,默认为0
An offset added to the inference steps, as required by some model families.
""" # 参数说明文档结束
_compatibles = [e.name for e in KarrasDiffusionSchedulers] # 从 KarrasDiffusionSchedulers 中提取兼容的名称列表
order = 2 # 设置调度的顺序为2
@register_to_config # 将此方法注册到配置中
def __init__( # 初始化方法
self,
num_train_timesteps: int = 1000, # 默认训练步骤数为1000
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012, # 默认最终 beta 值
beta_schedule: str = "linear", # 默认调度方式为线性
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, # 可选的训练 beta 数组
use_karras_sigmas: Optional[bool] = False, # 默认不使用 Karras sigmas
prediction_type: str = "epsilon", # 默认预测类型为 epsilon
timestep_spacing: str = "linspace", # 默认时间步长缩放方式为线性空间
steps_offset: int = 0, # 默认步骤偏移量为0
):
# 检查是否有训练好的 beta 值
if trained_betas is not None:
# 将训练好的 beta 值转换为张量,数据类型为 float32
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
# 检查 beta 调度是否为线性
elif beta_schedule == "linear":
# 生成从 beta_start 到 beta_end 的线性序列,长度为 num_train_timesteps
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
# 检查 beta 调度是否为缩放线性
elif beta_schedule == "scaled_linear":
# 该调度特定于潜在扩散模型
# 生成从 beta_start 的平方根到 beta_end 的平方根的线性序列,再平方
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
# 检查 beta 调度是否为平方余弦
elif beta_schedule == "squaredcos_cap_v2":
# Glide 余弦调度
# 使用 betas_for_alpha_bar 函数生成 beta 值
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
# 如果 beta 调度不在已实现的范围内,抛出未实现错误
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
# 计算 alpha 值,等于 1 减去 beta 值
self.alphas = 1.0 - self.betas
# 计算 alpha 值的累积乘积
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# 设置所有时间步的值
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
# 初始化步骤索引和开始索引
self._step_index = None
self._begin_index = None
# 将 sigma 值移动到 CPU,避免过多的 CPU/GPU 通信
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
# 返回初始噪声分布的标准差
if self.config.timestep_spacing in ["linspace", "trailing"]:
# 返回 sigma 的最大值
return self.sigmas.max()
# 返回 sigma 最大值的平方加 1 的平方根
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后增加 1。
"""
return self._step_index
@property
def begin_index(self):
"""
第一个时间步的索引。应该通过 `set_begin_index` 方法从管道设置。
"""
return self._begin_index
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 复制的
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的开始索引。此函数应在推理之前从管道运行。
参数:
begin_index (`int`):
调度器的开始索引。
"""
# 设置调度器的开始索引
self._begin_index = begin_index
def scale_model_input(
self,
# 输入的样本张量
sample: torch.Tensor,
# 当前时间步,可以是浮点数或张量
timestep: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
确保与需要根据当前时间步调整去噪模型输入的调度器互换性。
参数:
sample (`torch.Tensor`):
输入样本。
timestep (`int`, *可选*):
当前扩散链中的时间步。
返回:
`torch.Tensor`:
一个经过缩放的输入样本。
"""
# 如果步骤索引尚未初始化,则根据时间步初始化它
if self.step_index is None:
self._init_step_index(timestep)
# 根据状态决定使用哪个 sigma 值
if self.state_in_first_order:
sigma = self.sigmas[self.step_index]
else:
sigma = self.sigmas_interpol[self.step_index]
# 将输入样本除以 sigma 的平方加一的平方根,进行缩放
sample = sample / ((sigma**2 + 1) ** 0.5)
# 返回经过缩放的样本
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
@property
# 判断是否处于一阶状态,即样本是否为 None
def state_in_first_order(self):
return self.sample is None
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep 复制而来
def index_for_timestep(self, timestep, schedule_timesteps=None):
# 如果没有提供调度时间步,则使用默认时间步
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# 找到与当前时间步相匹配的索引
indices = (schedule_timesteps == timestep).nonzero()
# 对于第一个 `step`,选择第二个索引(或只有一个时选择最后一个索引)
pos = 1 if len(indices) > 1 else 0
# 返回对应的索引值
return indices[pos].item()
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index 复制而来
def _init_step_index(self, timestep):
# 如果开始索引为 None,则初始化步骤索引
if self.begin_index is None:
# 如果时间步是张量,则转移到相同设备
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# 根据时间步索引初始化步骤索引
self._step_index = self.index_for_timestep(timestep)
else:
# 否则将步骤索引设置为开始索引
self._step_index = self._begin_index
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t 复制而来
def _sigma_to_t(self, sigma, log_sigmas):
# 计算 sigma 的对数
log_sigma = np.log(np.maximum(sigma, 1e-10))
# 计算对数 sigma 与给定对数的差异
dists = log_sigma - log_sigmas[:, np.newaxis]
# 确定 sigma 的范围
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
# 获取低高对数 sigma 值
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# 进行 sigma 的插值
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# 将插值转换为时间范围
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
# 返回时间 t
return t
# 复制自 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构造 Karras 等人(2022)提出的噪声调度。"""
# 确保其他调度器复制此函数时不会出现问题的黑客修复
# TODO: 将此逻辑添加到其他调度器中
if hasattr(self.config, "sigma_min"):
# 如果配置中有 sigma_min,则使用它
sigma_min = self.config.sigma_min
else:
# 否则设置为 None
sigma_min = None
if hasattr(self.config, "sigma_max"):
# 如果配置中有 sigma_max,则使用它
sigma_max = self.config.sigma_max
else:
# 否则设置为 None
sigma_max = None
# 如果 sigma_min 为 None,则使用输入信号的最后一个值
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
# 如果 sigma_max 为 None,则使用输入信号的第一个值
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 论文中使用的值 7.0
# 生成一个从 0 到 1 的线性 ramp,长度为 num_inference_steps
ramp = np.linspace(0, 1, num_inference_steps)
# 计算 sigma_min 的逆 rho 次方
min_inv_rho = sigma_min ** (1 / rho)
# 计算 sigma_max 的逆 rho 次方
max_inv_rho = sigma_max ** (1 / rho)
# 根据最大和最小的逆值以及 ramp 生成 sigma 序列
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# 返回生成的 sigma 序列
return sigmas
def step(
self,
model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
# 复制自 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# 确保 sigmas 和 timesteps 与 original_samples 具有相同的设备和数据类型
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
# 检查设备是否为 MPS 且 timesteps 是否为浮点数
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# MPS 不支持 float64 数据类型
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
# 将 timesteps 转换为相同设备和 float32 数据类型
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
# 将 schedule_timesteps 转换为 original_samples 设备
schedule_timesteps = self.timesteps.to(original_samples.device)
# 将 timesteps 转换为 original_samples 设备
timesteps = timesteps.to(original_samples.device)
# 当 scheduler 用于训练时,self.begin_index 为 None,或者管道未实现 set_begin_index
if self.begin_index is None:
# 根据 timesteps 计算步索引
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# 在第一个去噪步骤后调用 add_noise(用于图像修复)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# 在第一个去噪步骤之前调用 add_noise 以创建初始潜在图像(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
# 根据步索引提取 sigma,并展平为一维
sigma = sigmas[step_indices].flatten()
# 如果 sigma 的维度少于 original_samples,则在最后一个维度添加维度
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
# 生成带噪声的样本,通过原始样本与噪声和 sigma 的乘积相加
noisy_samples = original_samples + noise * sigma
# 返回带噪声的样本
return noisy_samples
# 定义 __len__ 方法以返回训练时间步的数量
def __len__(self):
return self.config.num_train_timesteps
# 版权所有 2024 斯坦福大学团队与 HuggingFace 团队,保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件按“原样”提供,
# 不附带任何明示或暗示的担保或条件。
# 有关许可证所涵盖的权限和限制的具体信息,请参见许可证。
# 免责声明:此代码受 https://github.com/pesser/pytorch_diffusion
# 和 https://github.com/hojonathanho/diffusion 的强烈影响
import math # 导入数学库以执行数学运算
from dataclasses import dataclass # 从数据类模块导入数据类装饰器
from typing import List, Optional, Tuple, Union # 从 typing 模块导入类型提示
import numpy as np # 导入 NumPy 库以进行数值计算
import torch # 导入 PyTorch 库以使用张量和深度学习功能
from ..configuration_utils import ConfigMixin, register_to_config # 从配置工具导入配置混合和注册函数
from ..utils import BaseOutput, logging # 从工具模块导入基础输出类和日志功能
from ..utils.torch_utils import randn_tensor # 从工具模块导入随机张量生成函数
from .scheduling_utils import SchedulerMixin # 从调度工具导入调度混合器
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
@dataclass
class LCMSchedulerOutput(BaseOutput):
"""
调度器 `step` 函数输出的输出类。
参数:
prev_sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)` 的图像):
先前时间步的计算样本 `(x_{t-1})`。`prev_sample` 应作为下一个模型输入用于
去噪循环。
pred_original_sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)` 的图像):
基于当前时间步的模型输出的预测去噪样本 `(x_{0})`。
`pred_original_sample` 可用于预览进度或指导。
"""
prev_sample: torch.Tensor # 存储先前样本的张量
denoised: Optional[torch.Tensor] = None # 可选的去噪样本,默认为 None
# 从 diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 复制的函数
def betas_for_alpha_bar(
num_diffusion_timesteps, # 参数:扩散时间步的数量
max_beta=0.999, # 参数:使用的最大 beta 值,避免奇异性应低于 1
alpha_transform_type="cosine", # 参数:alpha_bar 的噪声调度类型,默认为“余弦”
):
"""
创建一个 beta 调度,离散化给定的 alpha_t_bar 函数,该函数定义了
(1-beta) 随时间的累积乘积,范围从 t = [0,1]。
包含一个 alpha_bar 函数,该函数接受参数 t,并将其转换为
扩散过程的 (1-beta) 累积乘积。
参数:
num_diffusion_timesteps (`int`): 要生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用小于 1 的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认为 `cosine`): alpha_bar 的噪声调度类型。
可选择 `cosine` 或 `exp`
返回:
betas (`np.ndarray`): 调度器用于步骤模型输出的 beta 值
"""
# 检查给定的 alpha 转换类型是否为“cosine”
if alpha_transform_type == "cosine":
# 定义用于计算 alpha_bar 的函数,基于余弦函数
def alpha_bar_fn(t):
# 计算并返回余弦函数值的平方
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
# 检查 alpha 转换类型是否为“exp”
elif alpha_transform_type == "exp":
# 定义用于计算 alpha_bar 的函数,基于指数函数
def alpha_bar_fn(t):
# 计算并返回指数衰减值
return math.exp(t * -12.0)
# 如果 alpha 转换类型不支持,则抛出异常
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
# 初始化一个空列表,用于存储 beta 值
betas = []
# 遍历每个扩散时间步
for i in range(num_diffusion_timesteps):
# 计算当前时间步的归一化值
t1 = i / num_diffusion_timesteps
# 计算下一个时间步的归一化值
t2 = (i + 1) / num_diffusion_timesteps
# 计算 beta 值并添加到列表中,确保不超过 max_beta
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
# 返回 beta 值的张量,数据类型为 float32
return torch.tensor(betas, dtype=torch.float32)
# 从 diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr 中复制
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
将 betas 重新缩放为零终端 SNR,基于 https://arxiv.org/pdf/2305.08891.pdf (算法 1)
参数:
betas (`torch.Tensor`):
初始化调度器时使用的 betas。
返回:
`torch.Tensor`: 具有零终端 SNR 的重新缩放的 betas
"""
# 将 betas 转换为 alphas_bar_sqrt
alphas = 1.0 - betas # 计算 alphas,即 1 减去 betas
alphas_cumprod = torch.cumprod(alphas, dim=0) # 计算 alphas 的累积乘积
alphas_bar_sqrt = alphas_cumprod.sqrt() # 计算累积乘积的平方根
# 存储旧值。
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() # 记录第一个 alphas_bar_sqrt 的值
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # 记录最后一个 alphas_bar_sqrt 的值
# 将最后一个时间步的值移为零。
alphas_bar_sqrt -= alphas_bar_sqrt_T # 从所有值中减去最后一个值,使其为零
# 缩放以将第一个时间步恢复到旧值。
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # 进行缩放以恢复初始值
# 将 alphas_bar_sqrt 转换回 betas
alphas_bar = alphas_bar_sqrt**2 # 将平方根恢复为平方
alphas = alphas_bar[1:] / alphas_bar[:-1] # 通过累积乘积的逆运算恢复 alphas
alphas = torch.cat([alphas_bar[0:1], alphas]) # 将第一个 alphas_bar 的值添加到结果中
betas = 1 - alphas # 计算 betas,即 1 减去 alphas
return betas # 返回重新缩放的 betas
class LCMScheduler(SchedulerMixin, ConfigMixin):
"""
`LCMScheduler` 扩展了在去噪扩散概率模型 (DDPM) 中引入的去噪程序,并实现了
非马尔可夫引导。
此模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。[`~ConfigMixin`] 负责存储在调度器的
`__init__` 函数中传入的所有配置属性,例如 `num_train_timesteps`。它们可以通过
`scheduler.config.num_train_timesteps` 访问。[`SchedulerMixin`] 提供通用的加载和保存
功能,通过 [`SchedulerMixin.save_pretrained`] 和 [`~SchedulerMixin.from_pretrained`] 函数。
"""
order = 1 # 定义调度器的顺序
@register_to_config # 注册到配置中,允许该函数的参数在配置中存储
def __init__(
self,
num_train_timesteps: int = 1000, # 训练时间步的数量,默认为 1000
beta_start: float = 0.00085, # beta 的起始值,默认为 0.00085
beta_end: float = 0.012, # beta 的结束值,默认为 0.012
beta_schedule: str = "scaled_linear", # beta 的调度方式,默认为 "scaled_linear"
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, # 经过训练的 betas,默认为 None
original_inference_steps: int = 50, # 原始推理步骤数,默认为 50
clip_sample: bool = False, # 是否剪裁样本,默认为 False
clip_sample_range: float = 1.0, # 剪裁样本的范围,默认为 1.0
set_alpha_to_one: bool = True, # 是否将 alpha 设置为 1,默认为 True
steps_offset: int = 0, # 步骤偏移量,默认为 0
prediction_type: str = "epsilon", # 预测类型,默认为 "epsilon"
thresholding: bool = False, # 是否应用阈值处理,默认为 False
dynamic_thresholding_ratio: float = 0.995, # 动态阈值处理的比例,默认为 0.995
sample_max_value: float = 1.0, # 样本最大值,默认为 1.0
timestep_spacing: str = "leading", # 时间步间距,默认为 "leading"
timestep_scaling: float = 10.0, # 时间步缩放因子,默认为 10.0
rescale_betas_zero_snr: bool = False, # 是否重新缩放 betas 以实现零 SNR,默认为 False
):
# 检查已训练的 beta 值是否存在
if trained_betas is not None:
# 将训练好的 beta 值转换为张量,数据类型为浮点32
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
# 检查 beta 调度是否为线性
elif beta_schedule == "linear":
# 生成从 beta_start 到 beta_end 的线性间隔的 beta 值
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
# 检查 beta 调度是否为 scaled_linear
elif beta_schedule == "scaled_linear":
# 该调度非常特定于潜在扩散模型
# 生成 beta_start 和 beta_end 的平方根线性间隔的 beta 值,然后平方
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
# 检查 beta 调度是否为 squaredcos_cap_v2
elif beta_schedule == "squaredcos_cap_v2":
# Glide 余弦调度
# 使用 betas_for_alpha_bar 函数生成 beta 值
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
# 抛出未实现错误,如果 beta_schedule 未被实现
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
# 如果需要,重新缩放为零 SNR
if rescale_betas_zero_snr:
# 调用 rescale_zero_terminal_snr 函数重新缩放 beta 值
self.betas = rescale_zero_terminal_snr(self.betas)
# 计算 alpha 值,alpha = 1 - beta
self.alphas = 1.0 - self.betas
# 计算 alpha 的累积乘积
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# 在每个 ddim 步骤中,查看之前的 alphas_cumprod
# 最后一步时,因为已经在 0,所以没有前一个 alphas_cumprod
# set_alpha_to_one 决定是否将该参数简单地设置为 1,或使用“非前一个”的最终 alpha
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# 初始化噪声分布的标准差
self.init_noise_sigma = 1.0
# 可设置值
self.num_inference_steps = None
# 生成反向的时间步长张量
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
# 标记是否自定义时间步长
self.custom_timesteps = False
# 初始化步骤索引和开始索引
self._step_index = None
self._begin_index = None
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep 复制而来
def index_for_timestep(self, timestep, schedule_timesteps=None):
# 如果未提供调度时间步长,则使用默认时间步长
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# 找到与给定时间步长相等的索引
indices = (schedule_timesteps == timestep).nonzero()
# 对于第一个“步骤”,选择的 sigma 索引始终是第二个索引
# (如果只有一个,则为最后一个索引),确保不会意外跳过 sigma
pos = 1 if len(indices) > 1 else 0
# 返回指定位置的索引值
return indices[pos].item()
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index 复制而来
def _init_step_index(self, timestep):
# 如果 begin_index 为空,初始化步骤索引
if self.begin_index is None:
# 如果时间步长是张量,将其转换为相同设备
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# 使用 index_for_timestep 方法设置步骤索引
self._step_index = self.index_for_timestep(timestep)
else:
# 否则将步骤索引设置为开始索引
self._step_index = self._begin_index
# 定义一个属性
# 定义一个方法,返回当前的步骤索引
def step_index(self):
# 返回私有属性 _step_index
return self._step_index
# 定义一个只读属性,表示起始时间步的索引
@property
def begin_index(self):
"""
起始时间步的索引。应通过管道使用 `set_begin_index` 方法设置。
"""
# 返回私有属性 _begin_index
return self._begin_index
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 复制的
# 定义一个方法,设置调度器的起始索引
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的起始索引。此函数应在推理前通过管道运行。
参数:
begin_index (`int`):
调度器的起始索引。
"""
# 将传入的起始索引值赋给私有属性 _begin_index
self._begin_index = begin_index
# 定义一个方法,确保与需要根据当前时间步缩放去噪模型输入的调度器的互换性
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器的互换性。
参数:
sample (`torch.Tensor`):
输入样本。
timestep (`int`, *可选*):
扩散链中的当前时间步。
返回:
`torch.Tensor`:
缩放后的输入样本。
"""
# 直接返回输入样本,未进行任何缩放处理
return sample
# 从 diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample 复制的
# 定义私有方法 _threshold_sample,输入为一个张量 sample,返回处理后的张量
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"动态阈值处理:在每个采样步骤中,我们将 s 设置为 xt0(在时间步 t 时 x_0 的预测)的某个百分位绝对像素值,
如果 s > 1,则将 xt0 阈值化到范围 [-s, s],然后除以 s。动态阈值处理将饱和像素(接近 -1 和 1 的像素)向内推,
从而主动防止每个步骤的像素饱和。我们发现动态阈值处理显著改善了照片真实感以及图像-文本对齐,特别是在使用非常大的
引导权重时。"
https://arxiv.org/abs/2205.11487
"""
# 获取样本的数据类型
dtype = sample.dtype
# 获取样本的批量大小、通道数及其余维度
batch_size, channels, *remaining_dims = sample.shape
# 如果样本的类型不是 float32 或 float64,则将其转换为 float
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # 为了进行分位数计算而上升精度,且 CPU 半精度不支持 clamping
# 将样本展平,以便沿每个图像进行分位数计算
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
# 计算样本的绝对值
abs_sample = sample.abs() # "某个百分位绝对像素值"
# 计算每个样本的动态阈值 s
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
# 将 s 限制在指定范围内
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # 当下限为 1 时,相当于标准裁剪到 [-1, 1]
# 将 s 的维度扩展为 (batch_size, 1),以便在维度 0 上广播
s = s.unsqueeze(1) # (batch_size, 1)
# 对样本进行裁剪并标准化
sample = torch.clamp(sample, -s, s) / s # "将 xt0 阈值化到范围 [-s, s],然后除以 s"
# 将样本的形状还原为原来的维度
sample = sample.reshape(batch_size, channels, *remaining_dims)
# 将样本的类型转换回原来的数据类型
sample = sample.to(dtype)
# 返回处理后的样本
return sample
# 定义设置时间步的函数
def set_timesteps(
self,
num_inference_steps: Optional[int] = None, # 可选参数:推理步骤数
device: Union[str, torch.device] = None, # 可选参数:设备类型
original_inference_steps: Optional[int] = None, # 可选参数:原始推理步骤数
timesteps: Optional[List[int]] = None, # 可选参数:时间步列表
strength: int = 1.0, # 强度参数,默认为 1.0
# 定义用于边界条件离散的缩放函数
def get_scalings_for_boundary_condition_discrete(self, timestep):
# 设置默认的 sigma 数据值
self.sigma_data = 0.5 # 默认值:0.5
# 根据时间步和配置进行缩放
scaled_timestep = timestep * self.config.timestep_scaling
# 计算跳过的系数
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
# 计算输出的系数
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
# 返回跳过的系数和输出的系数
return c_skip, c_out
# 定义步进方法
def step(
self,
model_output: torch.Tensor, # 模型输出的张量
timestep: int, # 当前的时间步
sample: torch.Tensor, # 当前的样本张量
generator: Optional[torch.Generator] = None, # 可选的随机数生成器
return_dict: bool = True, # 是否返回字典格式的结果
# 从 diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise 复制的方法
def add_noise(
self,
original_samples: torch.Tensor, # 原始样本的张量
noise: torch.Tensor, # 噪声张量
timesteps: torch.IntTensor, # 时间步的张量
) -> torch.Tensor:
# 确保 alphas_cumprod 和 timestep 的设备和数据类型与 original_samples 一致
# 将 self.alphas_cumprod 移动到目标设备,以避免后续 add_noise 调用时重复的 CPU 到 GPU 数据移动
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
# 将 alphas_cumprod 转换为与 original_samples 相同的数据类型
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
# 将 timesteps 移动到与 original_samples 相同的设备
timesteps = timesteps.to(original_samples.device)
# 计算 sqrt_alpha_prod 为 alphas_cumprod 在 timesteps 索引处的平方根
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
# 将 sqrt_alpha_prod 展平为一维数组
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
# 如果 sqrt_alpha_prod 的维度少于 original_samples 的维度,添加新的维度
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
# 计算 sqrt_one_minus_alpha_prod 为 (1 - alphas_cumprod 在 timesteps 索引处) 的平方根
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
# 将 sqrt_one_minus_alpha_prod 展平为一维数组
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
# 如果 sqrt_one_minus_alpha_prod 的维度少于 original_samples 的维度,添加新的维度
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
# 根据 sqrt_alpha_prod、original_samples 和 sqrt_one_minus_alpha_prod 计算 noisy_samples
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
# 返回噪声样本
return noisy_samples
# 从 diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity 复制的代码
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
# 确保 alphas_cumprod 和 timestep 的设备和数据类型与 sample 一致
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
# 将 alphas_cumprod 转换为与 sample 相同的数据类型
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
# 将 timesteps 移动到与 sample 相同的设备
timesteps = timesteps.to(sample.device)
# 计算 sqrt_alpha_prod 为 alphas_cumprod 在 timesteps 索引处的平方根
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
# 将 sqrt_alpha_prod 展平为一维数组
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
# 如果 sqrt_alpha_prod 的维度少于 sample 的维度,添加新的维度
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
# 计算 sqrt_one_minus_alpha_prod 为 (1 - alphas_cumprod 在 timesteps 索引处) 的平方根
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
# 将 sqrt_one_minus_alpha_prod 展平为一维数组
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
# 如果 sqrt_one_minus_alpha_prod 的维度少于 sample 的维度,添加新的维度
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
# 根据 sqrt_alpha_prod 和 noise 计算 velocity,并减去 sqrt_one_minus_alpha_prod 和 sample 的乘积
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
# 返回速度
return velocity
# 返回训练时间步的数量
def __len__(self):
return self.config.num_train_timesteps
# 从 diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep 复制的代码
# 定义一个方法,用于获取给定时间步的前一个时间步
def previous_timestep(self, timestep):
# 如果有自定义的时间步
if self.custom_timesteps:
# 找到当前时间步在时间步数组中的索引位置
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
# 如果当前索引是最后一个时间步的索引
if index == self.timesteps.shape[0] - 1:
# 设置前一个时间步为 -1,表示没有前一个时间步
prev_t = torch.tensor(-1)
else:
# 否则,取当前索引后一个时间步的值作为前一个时间步
prev_t = self.timesteps[index + 1]
else:
# 如果没有自定义时间步,计算推理步骤的数量
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
# 计算前一个时间步,基于当前时间步和推理步骤
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
# 返回计算得到的前一个时间步
return prev_t
# 版权声明,列出版权所有者及其保留的权利
# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版("许可证")进行许可;
# 你不得在未遵守许可证的情况下使用此文件。
# 你可以在以下网址获得许可证的副本:
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,软件根据许可证分发,按"原样"提供,
# 不提供任何形式的明示或暗示的担保或条件。
# 详见许可证中关于权限和限制的具体条款。
import math # 导入数学模块以使用数学函数
import warnings # 导入警告模块以发出警告信息
from dataclasses import dataclass # 导入数据类装饰器以简化类的定义
from typing import List, Optional, Tuple, Union # 导入类型注解以增强代码可读性
import numpy as np # 导入NumPy库以处理数组和数值计算
import torch # 导入PyTorch库以进行张量计算
from scipy import integrate # 从SciPy库导入积分功能
from ..configuration_utils import ConfigMixin, register_to_config # 从配置工具导入混合类和注册功能
from ..utils import BaseOutput # 从工具模块导入基础输出类
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin # 从调度工具导入相关类
@dataclass
# 从diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput复制,重命名为LMSDiscrete
class LMSDiscreteSchedulerOutput(BaseOutput):
"""
调度器`step`函数输出的输出类。
参数:
prev_sample (`torch.Tensor`形状为`(batch_size, num_channels, height, width)`的图像):
上一时间步的计算样本`(x_{t-1})`。`prev_sample`应作为下一个模型输入用于去噪循环。
pred_original_sample (`torch.Tensor`形状为`(batch_size, num_channels, height, width)`的图像):
基于当前时间步模型输出的预测去噪样本`(x_{0})`。
`pred_original_sample`可用于预览进度或进行指导。
"""
prev_sample: torch.Tensor # 存储上一时间步的样本
pred_original_sample: Optional[torch.Tensor] = None # 存储预测的原始样本,默认为None
# 从diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar复制
def betas_for_alpha_bar(
num_diffusion_timesteps, # 传入的扩散时间步数
max_beta=0.999, # 最大的beta值,防止奇异性使用小于1的值
alpha_transform_type="cosine", # 噪声调度的类型,默认为"cosine"
):
"""
创建一个beta调度器,该调度器离散化给定的alpha_t_bar函数,
该函数定义了从t = [0,1]开始的(1-beta)的累积乘积。
包含一个alpha_bar函数,该函数接受一个参数t,并将其转换为在扩散过程的这一部分
中(1-beta)的累积乘积。
参数:
num_diffusion_timesteps (`int`): 要生成的beta数量。
max_beta (`float`): 使用的最大beta值;使用小于1的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认为`cosine`): alpha_bar的噪声调度类型。
可选择`cosine`或`exp`
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的beta值
"""
if alpha_transform_type == "cosine": # 检查alpha调度类型是否为"cosine"
def alpha_bar_fn(t): # 定义alpha_bar函数
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 # 返回经过变换的alpha值
# 检查 alpha_transform_type 是否为 "exp"
elif alpha_transform_type == "exp":
# 定义 alpha_bar_fn 函数,接受参数 t
def alpha_bar_fn(t):
# 计算指数衰减函数,返回 e 的 (t * -12.0) 次方
return math.exp(t * -12.0)
# 如果 alpha_transform_type 不是预期的类型,则抛出异常
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
# 初始化一个空列表 betas,用于存储计算出的 beta 值
betas = []
# 遍历从 0 到 num_diffusion_timesteps - 1 的每个整数
for i in range(num_diffusion_timesteps):
# 计算当前时间步 t1
t1 = i / num_diffusion_timesteps
# 计算下一个时间步 t2
t2 = (i + 1) / num_diffusion_timesteps
# 计算 beta 值并添加到列表,确保不超过 max_beta
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
# 将 betas 列表转换为 PyTorch 的浮点张量并返回
return torch.tensor(betas, dtype=torch.float32)
# 定义一个线性多步调度器类,继承自 SchedulerMixin 和 ConfigMixin
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
一个用于离散 beta 计划的线性多步调度器。
该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。请查看超类文档以了解库为所有调度器实现的通用方法,如加载和保存。
参数:
num_train_timesteps (`int`, 默认值为 1000):
用于训练模型的扩散步骤数量。
beta_start (`float`, 默认值为 0.0001):
推断的起始 `beta` 值。
beta_end (`float`, 默认值为 0.02):
最终的 `beta` 值。
beta_schedule (`str`, 默认值为 `"linear"`):
beta 计划,将 beta 范围映射到一系列用于模型步进的 betas。可以选择 `linear` 或 `scaled_linear`。
trained_betas (`np.ndarray`, *可选*):
直接将 beta 数组传递给构造函数,以绕过 `beta_start` 和 `beta_end`。
use_karras_sigmas (`bool`, *可选*, 默认值为 `False`):
是否在采样过程中使用 Karras sigmas 作为噪声计划中的步长。如果为 `True`,则根据噪声水平序列 {σi} 确定 sigmas。
prediction_type (`str`, 默认值为 `epsilon`, *可选*):
调度器函数的预测类型;可以是 `epsilon`(预测扩散过程的噪声)、`sample`(直接预测带噪声的样本)或 `v_prediction`(参见 [Imagen Video](https://imagen.research.google/video/paper.pdf) 论文的第 2.4 节)。
timestep_spacing (`str`, 默认值为 `"linspace"`):
时间步的缩放方式。有关更多信息,请参阅 [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) 的表 2。
steps_offset (`int`, 默认值为 0):
添加到推断步骤的偏移量,某些模型系列需要该偏移量。
"""
# 兼容的调度器名称列表,从 KarrasDiffusionSchedulers 获取
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
# 设置调度器的顺序
order = 1
# 用于注册配置的初始化方法
@register_to_config
def __init__(
# 训练步骤数量,默认为 1000
self,
num_train_timesteps: int = 1000,
# 起始 beta 值,默认为 0.0001
beta_start: float = 0.0001,
# 结束 beta 值,默认为 0.02
beta_end: float = 0.02,
# beta 计划,默认为 "linear"
beta_schedule: str = "linear",
# 经过训练的 beta 值,默认为 None
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
# 是否使用 Karras sigmas,默认为 False
use_karras_sigmas: Optional[bool] = False,
# 预测类型,默认为 "epsilon"
prediction_type: str = "epsilon",
# 时间步缩放方式,默认为 "linspace"
timestep_spacing: str = "linspace",
# 步骤偏移量,默认为 0
steps_offset: int = 0,
):
# 检查是否提供了训练好的贝塔值
if trained_betas is not None:
# 将训练好的贝塔值转换为 PyTorch 张量,数据类型为 float32
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
# 检查贝塔调度类型是否为线性
elif beta_schedule == "linear":
# 生成一个从 beta_start 到 beta_end 的线性序列,包含 num_train_timesteps 个值
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
# 检查贝塔调度类型是否为 scaled_linear
elif beta_schedule == "scaled_linear":
# 此调度特定于潜在扩散模型
# 生成从 beta_start^0.5 到 beta_end^0.5 的线性序列,然后平方
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
# 检查贝塔调度类型是否为 squaredcos_cap_v2
elif beta_schedule == "squaredcos_cap_v2":
# Glide 余弦调度
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
# 如果提供的调度类型未实现,抛出未实现错误
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
# 计算 alphas,等于 1 减去 betas
self.alphas = 1.0 - self.betas
# 计算 alphas 的累积乘积
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# 计算 sigmas,基于 alphas_cumprod 的公式
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
# 反转 sigmas 数组并添加一个 0.0 值
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
# 将 sigmas 转换为 PyTorch 张量
self.sigmas = torch.from_numpy(sigmas)
# 可设置的值
self.num_inference_steps = None # 推理步骤数初始化为 None
self.use_karras_sigmas = use_karras_sigmas # 使用 Karras sigmas 的标志
# 设置时间步长
self.set_timesteps(num_train_timesteps, None)
self.derivatives = [] # 初始化导数列表
self.is_scale_input_called = False # 标志,表示是否调用过缩放输入
self._step_index = None # 当前步骤索引初始化为 None
self._begin_index = None # 起始步骤索引初始化为 None
# 将 sigmas 移动到 CPU,避免过多的 CPU/GPU 通信
self.sigmas = self.sigmas.to("cpu")
@property
def init_noise_sigma(self):
# 返回初始噪声分布的标准差
if self.config.timestep_spacing in ["linspace", "trailing"]:
# 返回 sigmas 的最大值
return self.sigmas.max()
# 返回 sigmas 最大值的平方加 1 再开平方
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后增加 1。
"""
return self._step_index
@property
def begin_index(self):
"""
第一个时间步的索引。应该通过 `set_begin_index` 方法从管道设置。
"""
return self._begin_index
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 复制的代码
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的起始索引。此函数应该在推理之前从管道运行。
参数:
begin_index (`int`):
调度器的起始索引。
"""
self._begin_index = begin_index # 设置起始索引
# 定义一个方法,用于根据当前时间步缩放模型输入
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器的互换性。
参数:
sample (`torch.Tensor`):
输入样本。
timestep (`float` or `torch.Tensor`):
扩散链中的当前时间步。
返回:
`torch.Tensor`:
缩放后的输入样本。
"""
# 如果步索引为空,则初始化步索引
if self.step_index is None:
self._init_step_index(timestep)
# 根据当前步索引获取相应的 sigma 值
sigma = self.sigmas[self.step_index]
# 将样本缩放,缩放因子为 sqrt(sigma^2 + 1)
sample = sample / ((sigma**2 + 1) ** 0.5)
# 标记输入缩放方法已被调用
self.is_scale_input_called = True
# 返回缩放后的样本
return sample
# 定义一个方法,用于计算线性多步系数
def get_lms_coefficient(self, order, t, current_order):
"""
计算线性多步系数。
参数:
order ():
t ():
current_order ():
"""
# 定义一个内部函数,用于计算 LMS 导数
def lms_derivative(tau):
prod = 1.0
# 遍历到 order 的每一个步骤
for k in range(order):
# 跳过当前阶数
if current_order == k:
continue
# 计算导数的乘积
prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
# 返回导数的结果
return prod
# 通过数值积分计算集成系数,范围从 self.sigmas[t] 到 self.sigmas[t + 1]
integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
# 返回集成后的系数
return integrated_coeff
# 定义设置离散时间步长的方法,接受推理步骤数和设备参数
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
设置用于扩散链的离散时间步长(在推理之前运行)。
参数:
num_inference_steps (`int`):
生成样本时使用的扩散步骤数。
device (`str` 或 `torch.device`, *可选*):
要将时间步长移动到的设备。如果为 `None`,则不移动时间步长。
"""
# 将推理步骤数赋值给实例变量
self.num_inference_steps = num_inference_steps
# 根据配置的时间步长间隔选择不同的处理方式
if self.config.timestep_spacing == "linspace":
# 创建线性间隔的时间步长,并反转顺序
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1
].copy()
elif self.config.timestep_spacing == "leading":
# 计算步长比率
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# 创建整数时间步长,通过乘以比率生成
# 转换为整数以避免 num_inference_step 为 3 的幂时出现问题
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset # 添加偏移量
elif self.config.timestep_spacing == "trailing":
# 计算步长比率
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# 创建整数时间步长,通过乘以比率生成
# 转换为整数以避免 num_inference_step 为 3 的幂时出现问题
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1 # 减去1以调整步长
else:
# 如果选择的时间步长间隔不被支持,抛出异常
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
# 根据累积 alpha 值计算 sigma 值
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas) # 计算 sigma 的对数
# 插值计算 sigma 值以对应时间步长
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
# 如果使用 Karras sigma,则转换 sigma
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas)
# 根据 sigma 计算对应的时间步长
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
# 将 sigma 数组与 0.0 连接,并转换为浮点数类型
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
# 将 sigma 和时间步长转换为张量并移动到指定设备
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None # 初始化步骤索引为 None
self._begin_index = None # 初始化开始索引为 None
# 将 sigma 移动到 CPU,以避免过多的 CPU/GPU 通信
self.sigmas = self.sigmas.to("cpu")
# 初始化导数列表
self.derivatives = []
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep 复制
# 根据给定的时间步获取对应的索引
def index_for_timestep(self, timestep, schedule_timesteps=None):
# 如果没有提供时间步调度,则使用类的时间步
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# 找到与指定时间步相等的所有索引
indices = (schedule_timesteps == timestep).nonzero()
# 对于第一个步骤,选取的 sigma 索引总是第二个索引
# 如果只有一个索引,则选择最后一个
# 这样可以确保在从去噪调度的中间开始时不会意外跳过 sigma
pos = 1 if len(indices) > 1 else 0
# 返回所选索引的项
return indices[pos].item()
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index 复制
# 初始化步骤索引
def _init_step_index(self, timestep):
# 如果开始索引为 None,则计算步骤索引
if self.begin_index is None:
# 如果时间步是一个张量,则将其转换到相应设备
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# 调用 index_for_timestep 方法获取步骤索引
self._step_index = self.index_for_timestep(timestep)
else:
# 如果有开始索引,则使用它
self._step_index = self._begin_index
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t 复制
# 将 sigma 转换为时间 t
def _sigma_to_t(self, sigma, log_sigmas):
# 获取 sigma 的对数值,避免对数为负
log_sigma = np.log(np.maximum(sigma, 1e-10))
# 计算 log_sigma 与 log_sigmas 之间的分布
dists = log_sigma - log_sigmas[:, np.newaxis]
# 获取 sigma 范围的索引
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
# 获取低和高的 log_sigma 值
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# 插值计算 sigma
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# 将插值转换为时间范围
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape) # 确保形状与 sigma 一致
return t
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras 复制
# 将 sigma 转换为 Karras 的噪声调度
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""构建 Karras 等人 (2022) 的噪声调度。"""
# 获取输入 sigma 的最小和最大值
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# 设置 rho 的值
rho = 7.0 # 论文中使用的值
# 创建从 0 到 1 的线性变化
ramp = np.linspace(0, 1, self.num_inference_steps)
# 根据 rho 计算最小和最大反 rho
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
# 根据反 rho 和线性变化计算 sigmas
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# 定义步骤函数,处理模型输出、时间步和样本
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
order: int = 4,
return_dict: bool = True,
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise 复制
# 添加噪声到原始样本
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# 确保 sigmas 和 timesteps 的设备和数据类型与 original_samples 相同
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
# 检查设备类型,如果是 mps 并且 timesteps 是浮点数
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps 不支持 float64 类型
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
# 否则,将 timesteps 转换为原样本设备的类型
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# 当 scheduler 用于训练时,self.begin_index 为 None,或者管道未实现 set_begin_index
if self.begin_index is None:
# 计算每个时间步的索引
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# 在第一个去噪步骤后调用 add_noise(用于修复)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# 在第一个去噪步骤之前调用 add_noise,以创建初始潜在图像(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
# 获取对应时间步的 sigma 值并展平
sigma = sigmas[step_indices].flatten()
# 如果 sigma 的维度小于 original_samples,进行维度扩展
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
# 添加噪声到原样本,生成噪声样本
noisy_samples = original_samples + noise * sigma
# 返回生成的噪声样本
return noisy_samples
# 返回训练时间步的数量
def __len__(self):
return self.config.num_train_timesteps
# 版权声明,说明文件的版权归2024年Katherine Crowson及HuggingFace团队所有
# 授权信息,指出该文件受Apache许可证2.0版的保护
# 用户只能在遵守许可证的情况下使用该文件
# 用户可以在以下网址获取许可证副本
# http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面协议,软件在"现状"基础上分发
# 不提供任何形式的明示或暗示的担保或条件
# 查看许可证以获取特定权限和限制信息
# 从dataclass模块导入dataclass装饰器
from dataclasses import dataclass
# 从typing模块导入Optional、Tuple和Union类型
from typing import Optional, Tuple, Union
# 导入flax库
import flax
# 从jax.numpy导入jnp模块,用于数值计算
import jax.numpy as jnp
# 从scipy库导入integrate模块,用于数值积分
from scipy import integrate
# 从configuration_utils模块导入ConfigMixin和register_to_config
from ..configuration_utils import ConfigMixin, register_to_config
# 从scheduling_utils_flax模块导入相关的调度器类
from .scheduling_utils_flax import (
CommonSchedulerState, # 通用调度器状态
FlaxKarrasDiffusionSchedulers, # Karras扩散调度器
FlaxSchedulerMixin, # 调度器混合类
FlaxSchedulerOutput, # 调度器输出类
broadcast_to_shape_from_left, # 从左侧广播形状的函数
)
# 定义LMSDiscreteSchedulerState类,表示调度器的状态
@flax.struct.dataclass
class LMSDiscreteSchedulerState:
common: CommonSchedulerState # 通用调度器状态的实例
# 可设置的属性
init_noise_sigma: jnp.ndarray # 初始化噪声标准差
timesteps: jnp.ndarray # 时间步数组
sigmas: jnp.ndarray # 噪声标准差数组
num_inference_steps: Optional[int] = None # 可选的推理步骤数量
# 运行时的属性
derivatives: Optional[jnp.ndarray] = None # 可选的导数数组
# 类方法,用于创建LMSDiscreteSchedulerState实例
@classmethod
def create(
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
):
# 返回一个新的LMSDiscreteSchedulerState实例
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
# 定义FlaxLMSSchedulerOutput类,表示调度器的输出
@dataclass
class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
state: LMSDiscreteSchedulerState # LMSDiscreteSchedulerState的实例
# 定义FlaxLMSDiscreteScheduler类,表示线性多步调度器
class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
线性多步调度器,用于离散beta调度。基于Katherine Crowson的原始k-diffusion实现:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
[`~ConfigMixin`]负责存储传递给调度器`__init__`函数的所有配置属性,例如`num_train_timesteps`。
可以通过`scheduler.config.num_train_timesteps`访问。
[`SchedulerMixin`]提供通过[`SchedulerMixin.save_pretrained`]和
[`~SchedulerMixin.from_pretrained`]函数进行的通用加载和保存功能。
"""
# 参数说明
Args:
num_train_timesteps (`int`): 训练模型时使用的扩散步骤数。
beta_start (`float`): 推理时的起始 `beta` 值。
beta_end (`float`): 最终 `beta` 值。
beta_schedule (`str`):
beta 调度,表示从 beta 范围到一系列 beta 的映射,用于模型的步进。可选择
`linear` 或 `scaled_linear`。
trained_betas (`jnp.ndarray`, optional):
直接传递 beta 数组到构造函数的选项,以绕过 `beta_start`、`beta_end` 等。
prediction_type (`str`, default `epsilon`, optional):
调度函数的预测类型,可能值有 `epsilon`(预测扩散过程的噪声)、`sample`(直接预测带噪声的样本)或 `v_prediction`(见第 2.4 节
https://imagen.research.google/video/paper.pdf)。
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
用于参数和计算的 `dtype` 类型。
"""
# 创建一个包含 FlaxKarrasDiffusionSchedulers 中每个调度器名称的列表
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
# 定义一个数据类型属性
dtype: jnp.dtype
# 定义属性,指示是否有状态
@property
def has_state(self):
return True
# 注册构造函数到配置
@register_to_config
def __init__(
# 初始化时的参数,设定默认值
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
# 将传入的数据类型参数赋值给实例变量
self.dtype = dtype
# 创建状态的方法,接受一个可选的公共调度器状态
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
# 如果没有传入公共状态,则创建一个新的公共状态
if common is None:
common = CommonSchedulerState.create(self)
# 生成一个从 0 到 num_train_timesteps 的时间步数组,并反转
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
# 计算每个时间步的标准差,使用公式
sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
# 初始噪声分布的标准差
init_noise_sigma = sigmas.max()
# 创建并返回一个 LMSDiscreteSchedulerState 实例,传入相关参数
return LMSDiscreteSchedulerState.create(
common=common,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
sigmas=sigmas,
)
# 定义一个方法用于缩放模型输入以匹配 K-LMS 算法
def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
"""
通过 `(sigma**2 + 1) ** 0.5` 缩放去噪模型输入以匹配 K-LMS 算法。
参数:
state (`LMSDiscreteSchedulerState`):
`FlaxLMSDiscreteScheduler` 状态数据类实例。
sample (`jnp.ndarray`):
当前由扩散过程创建的样本实例。
timestep (`int`):
扩散链中的当前离散时间步。
返回:
`jnp.ndarray`: 缩放后的输入样本
"""
# 找到与当前时间步相等的索引
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
# 获取索引的第一个值
step_index = step_index[0]
# 获取当前时间步对应的 sigma 值
sigma = state.sigmas[step_index]
# 将样本按缩放因子进行缩放
sample = sample / ((sigma**2 + 1) ** 0.5)
# 返回缩放后的样本
return sample
# 定义一个方法用于计算线性多步系数
def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order):
"""
计算线性多步系数。
参数:
order (TODO):
t (TODO):
current_order (TODO):
"""
# 定义一个内部函数用于计算 LMS 导数
def lms_derivative(tau):
prod = 1.0
# 遍历所有步长,计算导数的乘积
for k in range(order):
# 跳过当前的阶数
if current_order == k:
continue
# 计算导数乘积
prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
# 返回导数值
return prod
# 使用数值积分计算集成系数
integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]
# 返回集成系数
return integrated_coeff
# 定义一个方法用于设置扩散链使用的时间步
def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> LMSDiscreteSchedulerState:
"""
设置用于扩散链的时间步。在推理之前运行的辅助函数。
参数:
state (`LMSDiscreteSchedulerState`):
`FlaxLMSDiscreteScheduler` 状态数据类实例。
num_inference_steps (`int`):
在生成样本时使用的扩散步骤数。
"""
# 生成从最大训练时间步到 0 的线性时间步数组
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
# 计算时间步的低索引和高索引
low_idx = jnp.floor(timesteps).astype(jnp.int32)
high_idx = jnp.ceil(timesteps).astype(jnp.int32)
# 计算时间步的分数部分
frac = jnp.mod(timesteps, 1.0)
# 计算 sigma 值
sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
# 插值计算 sigma 值
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
# 在 sigma 数组末尾添加 0.0
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
# 将时间步转换为整型
timesteps = timesteps.astype(jnp.int32)
# 初始化导数的值
derivatives = jnp.zeros((0,) + shape, dtype=self.dtype)
# 返回更新后的状态
return state.replace(
timesteps=timesteps,
sigmas=sigmas,
num_inference_steps=num_inference_steps,
derivatives=derivatives,
)
# 定义一个方法,用于在扩散过程中预测上一个时间步的样本
def step(
self,
state: LMSDiscreteSchedulerState, # 当前调度器状态实例
model_output: jnp.ndarray, # 从学习到的扩散模型得到的直接输出
timestep: int, # 当前扩散链中的离散时间步
sample: jnp.ndarray, # 当前正在通过扩散过程生成的样本实例
order: int = 4, # 多步推理的系数
return_dict: bool = True, # 是否返回元组而非 FlaxLMSSchedulerOutput 类
) -> Union[FlaxLMSSchedulerOutput, Tuple]:
"""
通过逆转 SDE 预测上一个时间步的样本。核心函数从学习到的模型输出(通常是预测噪声)传播扩散过程。
Args:
state (`LMSDiscreteSchedulerState`): FlaxLMSDiscreteScheduler 的状态数据类实例。
model_output (`jnp.ndarray`): 从学习到的扩散模型直接输出。
timestep (`int`): 当前离散时间步。
sample (`jnp.ndarray`):
当前通过扩散过程创建的样本实例。
order: 多步推理的系数。
return_dict (`bool`): 是否返回元组而非 FlaxLMSSchedulerOutput 类。
Returns:
[`FlaxLMSSchedulerOutput`] or `tuple`: 如果 `return_dict` 为 True,返回 [`FlaxLMSSchedulerOutput`],否则返回一个元组。当返回元组时,第一个元素是样本张量。
"""
# 检查推理步骤是否为 None,如果是则抛出错误
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# 获取当前时间步的 sigma 值
sigma = state.sigmas[timestep]
# 1. 从 sigma 缩放的预测噪声计算预测的原始样本 (x_0)
if self.config.prediction_type == "epsilon":
# 计算预测的原始样本
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# 使用 v 预测公式计算预测的原始样本
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
# 如果 prediction_type 不符合预期,抛出错误
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. 转换为 ODE 导数
derivative = (sample - pred_original_sample) / sigma # 计算导数
# 将新的导数添加到状态中
state = state.replace(derivatives=jnp.append(state.derivatives, derivative))
# 如果导数长度超过了设定的 order,删除最早的导数
if len(state.derivatives) > order:
state = state.replace(derivatives=jnp.delete(state.derivatives, 0))
# 3. 计算线性多步系数
order = min(timestep + 1, order) # 确保 order 不超过当前时间步
# 生成多步系数
lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]
# 4. 基于导数路径计算上一个样本
prev_sample = sample + sum(
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
) # 计算上一个样本
# 如果不需要返回字典,返回元组
if not return_dict:
return (prev_sample, state)
# 返回 FlaxLMSSchedulerOutput 类实例
return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
# 定义添加噪声的函数,接受调度状态、原始样本、噪声和时间步
def add_noise(
self,
state: LMSDiscreteSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
# 从调度状态中获取指定时间步的 sigma 值,并扁平化
sigma = state.sigmas[timesteps].flatten()
# 将 sigma 的形状广播到噪声的形状
sigma = broadcast_to_shape_from_left(sigma, noise.shape)
# 将噪声与原始样本结合,生成带噪声的样本
noisy_samples = original_samples + noise * sigma
# 返回带噪声的样本
return noisy_samples
# 定义获取对象长度的方法
def __len__(self):
# 返回训练时间步的数量
return self.config.num_train_timesteps
# 版权信息,声明版权所有者和使用条款
# Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;在遵循许可证的前提下使用此文件。
# 您可以在以下网址获得许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,软件按“原样”分发,没有任何形式的保证或条件,
# 无论是明示还是暗示。
# 请参阅许可证以获取有关权限和限制的具体说明。
# 声明:该文件受到 https://github.com/ermongroup/ddim 的强烈影响
# 导入数学模块
import math
# 从类型提示导入必要的类型
from typing import List, Optional, Tuple, Union
# 导入 numpy 和 torch 库
import numpy as np
import torch
# 从配置工具导入所需的混合类和注册函数
from ..configuration_utils import ConfigMixin, register_to_config
# 从调度工具导入调度器和输出类
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# 定义生成 beta 调度的函数,基于 alpha_t_bar 函数
def betas_for_alpha_bar(
num_diffusion_timesteps, # 生成的 beta 数量
max_beta=0.999, # 最大 beta 值
alpha_transform_type="cosine", # alpha 转换类型
):
"""
创建一个 beta 调度,离散化给定的 alpha_t_bar 函数,定义了时间 t=[0,1] 上 (1-beta) 的累积乘积。
包含一个 alpha_bar 函数,接受 t 参数并将其转换为扩散过程中的 (1-beta) 的累积乘积。
参数:
num_diffusion_timesteps (`int`): 要生成的 beta 数量。
max_beta (`float`): 要使用的最大 beta 值;使用小于 1 的值以
防止奇点。
alpha_transform_type (`str`, *可选*,默认为 `cosine`): alpha_bar 的噪声调度类型。
选择 `cosine` 或 `exp`
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的 betas
"""
# 根据选择的 alpha 转换类型定义 alpha_bar 函数
if alpha_transform_type == "cosine":
def alpha_bar_fn(t): # 定义基于余弦的 alpha_bar 函数
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t): # 定义基于指数的 alpha_bar 函数
return math.exp(t * -12.0)
else:
# 如果 alpha 转换类型不被支持,抛出错误
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = [] # 初始化空列表以存储 beta 值
# 遍历每个扩散时间步,计算相应的 beta 值
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps # 当前时间点
t2 = (i + 1) / num_diffusion_timesteps # 下一个时间点
# 计算 beta 值并添加到列表中
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
# 返回作为 PyTorch 张量的 beta 值
return torch.tensor(betas, dtype=torch.float32)
# 定义 PNDMScheduler 类,使用伪数值方法进行扩散模型调度
class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
`PNDMScheduler` 使用伪数值方法进行扩散模型的调度,如龙格-库塔和线性多步方法。
此模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。有关所有调度器的通用方法的库文档,请检查超类文档,如加载和保存。
# 参数说明
Args:
num_train_timesteps (`int`, defaults to 1000):
# 模型训练的扩散步骤数量,默认为 1000
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
# 推理的起始 `beta` 值,默认为 0.0001
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
# 最终的 `beta` 值,默认为 0.02
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
# beta 调度策略,从 beta 范围到模型步进的 beta 序列的映射。可选值包括 `linear`、`scaled_linear` 或 `squaredcos_cap_v2`
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
# 直接将 beta 数组传递给构造函数,以绕过 `beta_start` 和 `beta_end`
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
skip_prk_steps (`bool`, defaults to `False`):
# 允许调度器跳过原始论文中定义的 Runge-Kutta 步骤,这些步骤在 PLMS 步骤之前是必需的
Allows the scheduler to skip the Runge-Kutta steps defined in the original paper as being required before
PLMS steps.
set_alpha_to_one (`bool`, defaults to `False`):
# 每个扩散步骤使用该步骤和前一步的 alpha 乘积值。对于最后一步没有前一个 alpha。当选项为 `True` 时,前一个 alpha 乘积固定为 1, 否则使用第 0 步的 alpha 值
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
prediction_type (`str`, defaults to `epsilon`, *optional*):
# 调度函数的预测类型;可以是 `epsilon`(预测扩散过程的噪声)或 `v_prediction`(参见 [Imagen Video](https://imagen.research.google/video/paper.pdf) 论文的 2.4 节)
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
paper).
timestep_spacing (`str`, defaults to `"leading"`):
# 时间步的缩放方式。有关更多信息,请参见 [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) 的表 2
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
# 添加到推理步骤的偏移量,一些模型家族需要这个偏移
An offset added to the inference steps, as required by some model families.
"""
# 定义与 KarrasDiffusionSchedulers 兼容的调度器名称列表
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
# 设置默认的调度器顺序
order = 1
# 装饰器,用于将初始化函数注册到配置
@register_to_config
# 初始化函数
def __init__(
# 训练扩散步骤数量,默认为 1000
self,
num_train_timesteps: int = 1000,
# 起始 beta 值,默认为 0.0001
beta_start: float = 0.0001,
# 最终 beta 值,默认为 0.02
beta_end: float = 0.02,
# beta 调度策略,默认为 "linear"
beta_schedule: str = "linear",
# 可选参数,直接传递的 beta 数组
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
# 是否跳过 Runge-Kutta 步骤,默认为 False
skip_prk_steps: bool = False,
# 是否将 alpha 设置为 1,默认为 False
set_alpha_to_one: bool = False,
# 预测类型,默认为 "epsilon"
prediction_type: str = "epsilon",
# 时间步的缩放方式,默认为 "leading"
timestep_spacing: str = "leading",
# 推理步骤的偏移量,默认为 0
steps_offset: int = 0,
# 该部分代码为类的方法的一部分
):
# 检查已训练的 beta 参数是否为 None
if trained_betas is not None:
# 将训练好的 beta 参数转换为浮点型张量
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
# 如果 beta_schedule 为线性调度
elif beta_schedule == "linear":
# 生成从 beta_start 到 beta_end 的线性空间,数量为 num_train_timesteps
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
# 如果 beta_schedule 为 scaled_linear
elif beta_schedule == "scaled_linear":
# 该调度特定于潜在扩散模型
# 生成从 beta_start**0.5 到 beta_end**0.5 的线性空间,并平方
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
# 如果 beta_schedule 为 squaredcos_cap_v2
elif beta_schedule == "squaredcos_cap_v2":
# 使用 Glide 的余弦调度生成 beta
self.betas = betas_for_alpha_bar(num_train_timesteps)
# 如果以上条件都不满足,抛出未实现错误
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
# 计算 alphas,等于 1 减去 betas
self.alphas = 1.0 - self.betas
# 计算 alphas 的累积乘积
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# 如果 set_alpha_to_one 为真,则 final_alpha_cumprod 为 1.0,否则取 alphas_cumprod 的第一个值
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# 初始化噪声分布的标准差
self.init_noise_sigma = 1.0
# 当前只支持 F-PNDM,即龙格-库塔方法
# 更多算法信息请参考论文:https://arxiv.org/pdf/2202.09778.pdf
# 主要关注公式 (9), (12), (13) 和算法 2
self.pndm_order = 4
# 运行时的值初始化
self.cur_model_output = 0
self.counter = 0
self.cur_sample = None
self.ets = []
# 可设置的值初始化
self.num_inference_steps = None
# 创建倒序的时间步数组
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.prk_timesteps = None
self.plms_timesteps = None
self.timesteps = None
# 定义 step 方法
def step(
# 接收模型输出张量
model_output: torch.Tensor,
# 当前时间步
timestep: int,
# 当前样本张量
sample: torch.Tensor,
# 返回字典标志,默认为 True
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
预测前一个时间步的样本,通过逆向 SDE 进行。这一函数从学习模型的输出(通常是预测的噪声)中传播扩散过程,
并根据内部变量 `counter` 调用 [`~PNDMScheduler.step_prk`] 或 [`~PNDMScheduler.step_plms`]。
参数:
model_output (`torch.Tensor`):
来自学习扩散模型的直接输出。
timestep (`int`):
当前扩散链中的离散时间步。
sample (`torch.Tensor`):
当前通过扩散过程生成的样本实例。
return_dict (`bool`):
是否返回 [`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`。
返回:
[`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`:
如果 return_dict 为 `True`,返回 [`~schedulers.scheduling_utils.SchedulerOutput`],否则返回一个
元组,其中第一个元素是样本张量。
"""
# 检查当前计数器是否小于 PRK 时间步的长度,且配置中是否跳过 PRK 步骤
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
# 调用 step_prk 方法,传递模型输出、时间步、样本和返回字典标志
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
else:
# 调用 step_plms 方法,传递模型输出、时间步、样本和返回字典标志
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
# 定义 step_prk 方法,接收模型输出、时间步、样本和可选的返回字典标志
def step_prk(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
return_dict: bool = True,
# 返回一个调度输出或元组,表示通过逆向SDE预测样本
) -> Union[SchedulerOutput, Tuple]:
"""
通过逆向SDE预测前一个时间步的样本。该函数使用Runge-Kutta方法传播样本。
进行四次前向传递以逼近微分方程的解。
参数:
model_output (`torch.Tensor`):
来自学习的扩散模型的直接输出。
timestep (`int`):
扩散链中的当前离散时间步。
sample (`torch.Tensor`):
通过扩散过程创建的当前样本实例。
return_dict (`bool`):
是否返回一个[`~schedulers.scheduling_utils.SchedulerOutput`]或元组。
返回:
[`~schedulers.scheduling_utils.SchedulerOutput`]或`tuple`:
如果return_dict为`True`,返回[`~schedulers.scheduling_utils.SchedulerOutput`],否则返回一个
元组,其第一个元素是样本张量。
"""
# 检查推断步骤是否被设置
if self.num_inference_steps is None:
raise ValueError(
"推断步骤数为'None',创建调度器后需要运行'set_timesteps'"
)
# 计算到前一个时间步的差异
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
# 确定前一个时间步
prev_timestep = timestep - diff_to_prev
# 从预先计算的时间步中获取当前时间步
timestep = self.prk_timesteps[self.counter // 4 * 4]
# 根据counter的值更新当前模型输出和样本
if self.counter % 4 == 0:
self.cur_model_output += 1 / 6 * model_output
self.ets.append(model_output)
self.cur_sample = sample
elif (self.counter - 1) % 4 == 0:
self.cur_model_output += 1 / 3 * model_output
elif (self.counter - 2) % 4 == 0:
self.cur_model_output += 1 / 3 * model_output
elif (self.counter - 3) % 4 == 0:
# 更新模型输出并重置当前模型输出
model_output = self.cur_model_output + 1 / 6 * model_output
self.cur_model_output = 0
# 确保cur_sample不为`None`
cur_sample = self.cur_sample if self.cur_sample is not None else sample
# 获取前一个样本
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
# 增加计数器
self.counter += 1
# 根据return_dict返回不同的结果
if not return_dict:
return (prev_sample,)
# 返回调度输出
return SchedulerOutput(prev_sample=prev_sample)
# 定义PLMS步骤方法
def step_plms(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
预测从上一个时间步生成的样本,通过逆转SDE。该函数使用线性多步法传播样本。
它多次执行一次前向传递以近似解决方案。
参数:
model_output (`torch.Tensor`):
学习的扩散模型的直接输出。
timestep (`int`):
当前扩散链中的离散时间步。
sample (`torch.Tensor`):
通过扩散过程生成的当前样本实例。
return_dict (`bool`):
是否返回 [`~schedulers.scheduling_utils.SchedulerOutput`] 或元组。
返回:
[`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`:
如果 return_dict 为 `True`,返回 [`~schedulers.scheduling_utils.SchedulerOutput`],否则返回一个元组,元组的第一个元素是样本张量。
"""
# 检查推理步骤数是否为 None
if self.num_inference_steps is None:
# 抛出错误提示需要设置时间步
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# 检查是否跳过 PRK 步骤,并确保 ETS 列表至少有 3 个元素
if not self.config.skip_prk_steps and len(self.ets) < 3:
# 抛出错误提示需要进行至少 12 次迭代
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
"for more information."
)
# 计算前一个时间步
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 如果计数器不为 1
if self.counter != 1:
# 只保留最近 3 个 ETS 值
self.ets = self.ets[-3:]
# 添加当前模型输出到 ETS 列表
self.ets.append(model_output)
else:
# 如果计数器为 1,设置时间步为当前时间步
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
# 如果 ETS 列表只有 1 个元素且计数器为 0
if len(self.ets) == 1 and self.counter == 0:
# 模型输出不变,当前样本为输入样本
model_output = model_output
self.cur_sample = sample
# 如果 ETS 列表只有 1 个元素且计数器为 1
elif len(self.ets) == 1 and self.counter == 1:
# 取当前模型输出和最后一个 ETS 的平均值
model_output = (model_output + self.ets[-1]) / 2
sample = self.cur_sample
self.cur_sample = None
# 如果 ETS 列表有 2 个元素
elif len(self.ets) == 2:
# 根据 ETS 计算模型输出
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
# 如果 ETS 列表有 3 个元素
elif len(self.ets) == 3:
# 使用更复杂的公式计算模型输出
model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
# 如果 ETS 列表有 4 个或更多元素
else:
# 使用公式计算模型输出,考虑更多的 ETS 值
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
# 获取前一个样本
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
# 增加计数器
self.counter += 1
# 如果不返回字典
if not return_dict:
# 返回前一个样本的元组
return (prev_sample,)
# 返回包含前一个样本的调度器输出
return SchedulerOutput(prev_sample=prev_sample)
# 定义缩放模型输入的函数,接收样本和其他参数,返回缩放后的样本
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# 文档字符串,说明该函数的作用、参数和返回值
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
# 直接返回输入样本,不做任何处理
return sample
# 定义获取前一个样本的函数,基于当前样本、时间步及模型输出
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# 参考 PNDM 论文公式 (9),计算 x_(t−δ)
# 该函数使用公式 (9) 计算前一个样本
# 注意需要在方程两边加上 x_t
# 变量注释映射到论文中的符号
# alpha_prod_t -> α_t
# alpha_prod_t_prev -> α_(t−δ)
# beta_prod_t -> (1 - α_t)
# beta_prod_t_prev -> (1 - α_(t−δ))
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
# 获取当前时间步的累计 alpha 值
alpha_prod_t = self.alphas_cumprod[timestep]
# 获取前一时间步的累计 alpha 值,若无,则使用最终的累计 alpha 值
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
# 计算当前时间步的 beta 值
beta_prod_t = 1 - alpha_prod_t
# 计算前一时间步的 beta 值
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 如果预测类型为 "v_prediction",根据公式更新模型输出
if self.config.prediction_type == "v_prediction":
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
# 若预测类型不为 "epsilon",则抛出异常
elif self.config.prediction_type != "epsilon":
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
)
# 计算样本系数,对应于公式 (9) 的分子部分
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
# 计算模型输出的分母系数,对应于公式 (9) 的分母部分
model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
) ** (0.5)
# 根据公式 (9) 计算前一个样本
prev_sample = (
sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
)
# 返回计算得到的前一个样本
return prev_sample
# 从 diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise 复制的函数
def add_noise(
# 定义添加噪声的函数,接收原始样本、噪声和时间步
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# 确保 alphas_cumprod 和 timestep 具有与 original_samples 相同的设备和数据类型
# 将 self.alphas_cumprod 移动到指定设备,以避免后续 add_noise 调用中的 CPU 到 GPU 的冗余数据移动
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
# 将 alphas_cumprod 转换为与 original_samples 相同的数据类型
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
# 将 timesteps 移动到与 original_samples 相同的设备
timesteps = timesteps.to(original_samples.device)
# 计算 alphas_cumprod 在 timesteps 位置的平方根
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
# 将 sqrt_alpha_prod 扁平化为一维张量
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
# 如果 sqrt_alpha_prod 的维度少于 original_samples,则在最后一个维度增加一个维度
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
# 计算 1 - alphas_cumprod 在 timesteps 位置的平方根
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
# 将 sqrt_one_minus_alpha_prod 扁平化为一维张量
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
# 如果 sqrt_one_minus_alpha_prod 的维度少于 original_samples,则在最后一个维度增加一个维度
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
# 根据加权公式生成带噪声的样本
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
# 返回带噪声的样本
return noisy_samples
# 定义获取对象长度的方法
def __len__(self):
# 返回配置中训练时间步的数量
return self.config.num_train_timesteps
# 版权所有 2024 浙江大学团队与 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;
# 除非符合许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,
# 否则根据许可证分发的软件按“原样”提供,
# 不附带任何形式的保证或条件,无论是明示或暗示的。
# 请参阅许可证以获取有关权限和
# 限制的具体语言。
# 免责声明:此文件受到 https://github.com/ermongroup/ddim 的强烈影响
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入可选类型、元组和联合类型
from typing import Optional, Tuple, Union
# 导入 flax 库
import flax
# 导入 jax 库
import jax
# 导入 jax.numpy 模块并重命名为 jnp
import jax.numpy as jnp
# 从配置工具中导入 ConfigMixin 和 register_to_config
from ..configuration_utils import ConfigMixin, register_to_config
# 从调度工具中导入多个调度相关的类和函数
from .scheduling_utils_flax import (
CommonSchedulerState, # 导入通用调度器状态
FlaxKarrasDiffusionSchedulers, # 导入 Flax Karras 扩散调度器
FlaxSchedulerMixin, # 导入 Flax 调度混合器
FlaxSchedulerOutput, # 导入 Flax 调度输出类
add_noise_common, # 导入通用添加噪声函数
)
# 定义 PNDMSchedulerState 类,使用 flax 的数据类装饰器
@flax.struct.dataclass
class PNDMSchedulerState:
common: CommonSchedulerState # 公共调度状态
final_alpha_cumprod: jnp.ndarray # 最终 alpha 的累积乘积
# 可设置的值
init_noise_sigma: jnp.ndarray # 初始噪声标准差
timesteps: jnp.ndarray # 时间步数组
num_inference_steps: Optional[int] = None # 可选的推理步骤数
prk_timesteps: Optional[jnp.ndarray] = None # 可选的 Runge-Kutta 时间步
plms_timesteps: Optional[jnp.ndarray] = None # 可选的 PLMS 时间步
# 运行时值
cur_model_output: Optional[jnp.ndarray] = None # 当前模型输出
counter: Optional[jnp.int32] = None # 计数器
cur_sample: Optional[jnp.ndarray] = None # 当前样本
ets: Optional[jnp.ndarray] = None # 可选的扩散状态数组
# 定义一个类方法,用于创建 PNDMSchedulerState 实例
@classmethod
def create(
cls, # 类本身
common: CommonSchedulerState, # 传入的公共调度状态
final_alpha_cumprod: jnp.ndarray, # 传入的最终 alpha 累积乘积
init_noise_sigma: jnp.ndarray, # 传入的初始噪声标准差
timesteps: jnp.ndarray, # 传入的时间步数组
):
# 返回一个 PNDMSchedulerState 实例
return cls(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
# 定义 FlaxPNDMSchedulerOutput 类,继承 FlaxSchedulerOutput
@dataclass
class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
state: PNDMSchedulerState # PNDMSchedulerState 状态
# 定义 FlaxPNDMScheduler 类,继承 FlaxSchedulerMixin 和 ConfigMixin
class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
namely Runge-Kutta method and a linear multi-step method.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
""" # 类的文档字符串,描述了调度器的功能和来源
# 参数说明
Args:
num_train_timesteps (`int`): 训练模型所使用的扩散步骤数量。
beta_start (`float`): 推理的起始 `beta` 值。
beta_end (`float`): 最终的 `beta` 值。
beta_schedule (`str`):
beta 调度,表示从一个 beta 范围到一系列 beta 的映射,用于模型的步骤选择。可选值为
`linear`、`scaled_linear` 或 `squaredcos_cap_v2`。
trained_betas (`jnp.ndarray`, optional):
可选参数,直接将 beta 数组传递给构造函数,以跳过 `beta_start`、`beta_end` 等设置。
skip_prk_steps (`bool`):
允许调度器跳过原论文中定义的 Runge-Kutta 步骤,这些步骤在 plms 步骤之前是必要的;默认为 `False`。
set_alpha_to_one (`bool`, default `False`):
每个扩散步骤使用该步骤和前一个步骤的 alpha 乘积的值。对于最后一步没有前一个 alpha。当此选项为 `True` 时,前一个 alpha 乘积固定为 `1`,否则使用步骤 0 的 alpha 值。
steps_offset (`int`, default `0`):
添加到推理步骤的偏移量,某些模型系列需要此偏移。
prediction_type (`str`, default `epsilon`, optional):
调度函数的预测类型,选项包括 `epsilon`(预测扩散过程中的噪声)、`sample`(直接预测带噪声的样本)或 `v_prediction`(见文献 2.4 https://imagen.research.google/video/paper.pdf)。
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
用于参数和计算的 `dtype` 类型。
"""
# 获取 FlaxKarrasDiffusionSchedulers 中所有兼容的名称
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
# 定义数据类型
dtype: jnp.dtype
# 定义 PNDM 的阶数
pndm_order: int
# 定义属性以检查是否具有状态
@property
def has_state(self):
# 返回 True,表示该对象具有状态
return True
# 注册到配置中,并定义初始化函数
@register_to_config
def __init__(
# 设置训练时的扩散步骤数量,默认为 1000
num_train_timesteps: int = 1000,
# 设置推理的起始 beta 值,默认为 0.0001
beta_start: float = 0.0001,
# 设置最终的 beta 值,默认为 0.02
beta_end: float = 0.02,
# 设置 beta 调度类型,默认为 "linear"
beta_schedule: str = "linear",
# 可选参数,传递已训练的 beta 数组
trained_betas: Optional[jnp.ndarray] = None,
# 设置是否跳过 Runge-Kutta 步骤,默认为 False
skip_prk_steps: bool = False,
# 设置是否将 alpha 固定为 1,默认为 False
set_alpha_to_one: bool = False,
# 设置推理步骤的偏移量,默认为 0
steps_offset: int = 0,
# 设置预测类型,默认为 "epsilon"
prediction_type: str = "epsilon",
# 设置数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32,
):
# 将数据类型赋值给实例变量
self.dtype = dtype
# 当前仅支持 F-PNDM,即 Runge-Kutta 方法
# 有关算法的更多信息,请参见论文:https://arxiv.org/pdf/2202.09778.pdf
# 主要查看公式 (9)、(12)、(13) 和算法 2。
# 将 PNDM 阶数设置为 4
self.pndm_order = 4
# 创建状态的方法,接受一个可选的 CommonSchedulerState 参数
def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState:
# 如果 common 参数为 None,则创建一个新的 CommonSchedulerState 实例
if common is None:
common = CommonSchedulerState.create(self)
# 在每个 ddim 步骤中,我们查看前一个 alphas_cumprod
# 对于最后一步,由于我们已经处于 0,因此没有前一个 alphas_cumprod
# `set_alpha_to_one` 决定我们是否将该参数简单设置为 1,还是
# 使用“非前一个”的最终 alpha。
final_alpha_cumprod = (
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
)
# 初始噪声分布的标准差
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
# 创建一个反向的时间步数组,从 num_train_timesteps 开始
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
# 返回一个新的 PNDMSchedulerState 实例,包含 common、final_alpha_cumprod、init_noise_sigma 和 timesteps
return PNDMSchedulerState.create(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
# 设置用于扩散链的离散时间步,推理前运行的辅助函数
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
"""
设置用于扩散链的离散时间步,推理前运行的辅助函数。
参数:
state (`PNDMSchedulerState`):
`FlaxPNDMScheduler` 状态数据类实例。
num_inference_steps (`int`):
生成样本时使用的扩散步骤数量。
shape (`Tuple`):
要生成的样本形状。
"""
# 计算每个推理步骤的步长比
step_ratio = self.config.num_train_timesteps // num_inference_steps
# 通过乘以比率生成整数时间步
# 四舍五入以避免 num_inference_step 为 3 的幂时出现问题
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset
if self.config.skip_prk_steps:
# 对于某些模型(如稳定扩散),可以/应该跳过 prk 步骤以产生更好的结果。
# 使用 PNDM 时,如果配置跳过 prk 步骤,基于 crowsonkb 的 PLMS 采样实现
prk_timesteps = jnp.array([], dtype=jnp.int32)
# 生成 plms 时间步,将最后的时间步反转并添加到前面
plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1]
else:
# 生成 prk 时间步,重复并添加偏移
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
self.pndm_order,
)
# 反转并去掉边界的 prk 时间步
prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1]
# 反转 plms 时间步
plms_timesteps = _timesteps[:-3][::-1]
# 合并 prk 和 plms 时间步
timesteps = jnp.concatenate([prk_timesteps, plms_timesteps])
# 初始化运行值
# 创建当前模型输出的零数组,形状为传入的 shape
cur_model_output = jnp.zeros(shape, dtype=self.dtype)
# 初始化计数器为 0
counter = jnp.int32(0)
# 创建当前样本的零数组,形状为传入的 shape
cur_sample = jnp.zeros(shape, dtype=self.dtype)
# 创建一个额外的数组,用于存储中间结果
ets = jnp.zeros((4,) + shape, dtype=self.dtype)
# 返回更新后的状态,包含新的时间步和运行值
return state.replace(
timesteps=timesteps,
num_inference_steps=num_inference_steps,
prk_timesteps=prk_timesteps,
plms_timesteps=plms_timesteps,
cur_model_output=cur_model_output,
counter=counter,
cur_sample=cur_sample,
ets=ets,
)
# 定义缩放模型输入的函数
def scale_model_input(
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
# 声明函数返回类型为 jnp.ndarray(JAX 的 ndarray)
"""
# 函数文档字符串,说明该函数的用途
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
# 参数 state,类型为 PNDMSchedulerState,表示调度器的状态数据类实例
sample (`jnp.ndarray`): input sample
# 参数 sample,类型为 jnp.ndarray,表示输入样本
timestep (`int`, optional): current timestep
# 可选参数 timestep,类型为 int,表示当前时间步
Returns:
`jnp.ndarray`: scaled input sample
# 返回类型为 jnp.ndarray,表示缩放后的输入样本
"""
return sample
# 返回输入样本,当前未进行任何处理
def step(
# 定义 step 方法
self,
state: PNDMSchedulerState,
# 参数 state,类型为 PNDMSchedulerState,表示调度器的状态数据类实例
model_output: jnp.ndarray,
# 参数 model_output,类型为 jnp.ndarray,表示模型的输出
timestep: int,
# 参数 timestep,类型为 int,表示当前时间步
sample: jnp.ndarray,
# 参数 sample,类型为 jnp.ndarray,表示输入样本
return_dict: bool = True,
# 参数 return_dict,类型为 bool,默认为 True,表示是否返回字典格式的结果
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
预测在上一个时间步的样本,通过反转 SDE。核心功能是从学习的模型输出传播扩散过程
(通常是预测的噪声)。
此函数根据内部变量 `counter` 调用 `step_prk()` 或 `step_plms()`。
Args:
state (`PNDMSchedulerState`): `FlaxPNDMScheduler` 状态数据类实例。
model_output (`jnp.ndarray`): 来自学习扩散模型的直接输出。
timestep (`int`): 当前扩散链中的离散时间步。
sample (`jnp.ndarray`):
正在通过扩散过程创建的当前样本实例。
return_dict (`bool`): 返回元组而不是 `FlaxPNDMSchedulerOutput` 类的选项。
Returns:
[`FlaxPNDMSchedulerOutput`] 或 `tuple`: 如果 `return_dict` 为 True,则返回 [`FlaxPNDMSchedulerOutput`],
否则返回 `tuple`。返回元组时,第一个元素是样本张量。
"""
# 检查推理步骤数量是否为 None,抛出错误以提醒用户
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# 如果配置跳过 PRK 步骤,调用 PLMS 步骤
if self.config.skip_prk_steps:
prev_sample, state = self.step_plms(state, model_output, timestep, sample)
else:
# 否则,首先执行 PRK 步骤
prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample)
# 然后执行 PLMS 步骤
plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample)
# 检查当前计数器是否小于 PRK 时间步的长度
cond = state.counter < len(state.prk_timesteps)
# 根据条件选择前一个样本
prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample)
# 更新状态,选择相应的当前模型输出和其他状态变量
state = state.replace(
cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output),
ets=jax.lax.select(cond, prk_state.ets, plms_state.ets),
cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample),
counter=jax.lax.select(cond, prk_state.counter, plms_state.counter),
)
# 如果不返回字典,则返回前一个样本和状态的元组
if not return_dict:
return (prev_sample, state)
# 否则返回 FlaxPNDMSchedulerOutput 对象
return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state)
# 定义 step_prk 方法,用于执行 PRK 步骤
def step_prk(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns:
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 检查推理步骤数量是否为 None,如果是则抛出异常
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# 根据当前计数器决定与上一步的差值,计算上一步的时间步
diff_to_prev = jnp.where(
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
)
prev_timestep = timestep - diff_to_prev # 计算前一个时间步
timestep = state.prk_timesteps[state.counter // 4 * 4] # 更新当前时间步
# 选择当前模型输出,基于计数器的余数决定逻辑
model_output = jax.lax.select(
(state.counter % 4) != 3,
model_output, # 余数为 0, 1, 2
state.cur_model_output + 1 / 6 * model_output, # 余数为 3
)
# 更新状态,替换当前模型输出、ets 和当前样本
state = state.replace(
cur_model_output=jax.lax.select_n(
state.counter % 4,
state.cur_model_output + 1 / 6 * model_output, # 余数为 0
state.cur_model_output + 1 / 3 * model_output, # 余数为 1
state.cur_model_output + 1 / 3 * model_output, # 余数为 2
jnp.zeros_like(state.cur_model_output), # 余数为 3
),
ets=jax.lax.select(
(state.counter % 4) == 0,
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # 余数为 0
state.ets, # 余数为 1, 2, 3
),
cur_sample=jax.lax.select(
(state.counter % 4) == 0,
sample, # 余数为 0
state.cur_sample, # 余数为 1, 2, 3
),
)
cur_sample = state.cur_sample # 获取当前样本
# 获取前一个样本,基于当前状态和模型输出
prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output)
# 更新状态计数器
state = state.replace(counter=state.counter + 1)
# 返回前一个样本和更新后的状态
return (prev_sample, state)
# 定义 step_plms 函数,参数包括状态、模型输出、时间步和样本
def step_plms(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
# 计算前一个样本,使用 PNDM 算法中的公式 (9)
def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output):
# 查看 PNDM 论文中的公式 (9)
# 此函数使用公式 (9) 计算 x_(t−δ)
# 注意:需要将 x_t 加到方程的两边
# 符号约定 (<变量名> -> <论文中的名称>
# alpha_prod_t -> α_t
# alpha_prod_t_prev -> α_(t−δ)
# beta_prod_t -> (1 - α_t)
# beta_prod_t_prev -> (1 - α_(t−δ))
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
# 获取当前时间步的累积 α 值
alpha_prod_t = state.common.alphas_cumprod[timestep]
# 如果 prev_timestep 大于等于 0,获取前一个时间步的累积 α 值,否则使用最终的累积 α 值
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
)
# 计算当前时间步的 β 值
beta_prod_t = 1 - alpha_prod_t
# 计算前一个时间步的 β 值
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 根据预测类型进行不同的处理
if self.config.prediction_type == "v_prediction":
# 使用公式调整模型输出
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
elif self.config.prediction_type != "epsilon":
# 如果预测类型不符合要求,则抛出异常
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
)
# 计算样本系数,对应公式 (9) 中的分母部分加 1
# 注意:公式简化后可得 (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
# sqrt(α_(t−δ)) / sqrt(α_t)
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
# 计算模型输出的分母系数,对应公式 (9) 中 e_θ(x_t, t) 的分母
model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
) ** (0.5)
# 根据公式 (9) 计算前一个样本
prev_sample = (
sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
)
# 返回计算得到的前一个样本
return prev_sample
# 添加噪声到样本中
def add_noise(
self,
state: PNDMSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
# 调用公共函数添加噪声
return add_noise_common(state.common, original_samples, noise, timesteps)
# 返回训练时间步的数量
def __len__(self):
return self.config.num_train_timesteps