diffusers 源码解析(五十八)
# Copyright 2024 FLAIR Lab and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 你可以在遵循许可证的情况下使用此文件。
# 可以通过以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意,软件在“按原样”基础上分发,
# 不提供任何形式的明示或暗示的保证或条件。
# 有关许可证的具体权限和限制,请参见许可证文档。
# 声明:请查看 https://arxiv.org/abs/2204.13902 和 https://github.com/qsh-zh/deis 以获取更多信息
# 此代码库是基于 https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py 进行修改的
import math # 导入数学模块,用于数学计算
from typing import List, Optional, Tuple, Union # 导入类型提示相关的类型
import numpy as np # 导入 NumPy 库,用于数组和数学运算
import torch # 导入 PyTorch 库,用于张量操作和深度学习
from ..configuration_utils import ConfigMixin, register_to_config # 从配置工具导入混合类和注册函数
from ..utils import deprecate # 从工具模块导入弃用标记函数
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # 从调度工具模块导入调度相关的类
# 从 diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 复制的函数
def betas_for_alpha_bar(
num_diffusion_timesteps, # 传入参数:扩散时间步数
max_beta=0.999, # 传入参数:使用的最大 beta 值,默认值为 0.999
alpha_transform_type="cosine", # 传入参数:alpha 转换类型,默认为 "cosine"
):
"""
创建一个 beta 调度,以离散化给定的 alpha_t_bar 函数,该函数定义了时间 t = [0,1] 上
(1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接受参数 t 并将其转换为扩散过程中该部分的 (1-beta) 的累积乘积。
参数:
num_diffusion_timesteps (`int`): 生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用小于 1 的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认值为 `cosine`): alpha_bar 的噪声调度类型。
可选择 `cosine` 或 `exp`
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的 beta 值
"""
if alpha_transform_type == "cosine": # 如果选择的转换类型为 "cosine"
def alpha_bar_fn(t): # 定义 alpha_bar 函数
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 # 计算并返回余弦值的平方
elif alpha_transform_type == "exp": # 如果选择的转换类型为 "exp"
def alpha_bar_fn(t): # 定义 alpha_bar 函数
return math.exp(t * -12.0) # 计算并返回指数衰减值
else: # 如果选择的转换类型不支持
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") # 抛出值错误
betas = [] # 初始化一个空列表用于存储 beta 值
for i in range(num_diffusion_timesteps): # 遍历每个扩散时间步
t1 = i / num_diffusion_timesteps # 计算当前时间步 t1
t2 = (i + 1) / num_diffusion_timesteps # 计算下一个时间步 t2
# 计算 beta 值并添加到列表中,限制最大值为 max_beta
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32) # 将 beta 列表转换为 PyTorch 张量并返回
class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): # 定义 DEISMultistepScheduler 类,继承自调度器和配置混合类
"""
`DEISMultistepScheduler` 是一个快速高阶解算器,用于扩散常微分方程(ODEs)。
# 该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。请查阅父类文档以获取库为所有调度程序实现的通用方法,例如加载和保存。
# 参数说明:
# num_train_timesteps (`int`, defaults to 1000):
# 用于训练模型的扩散步骤数量。
# beta_start (`float`, defaults to 0.0001):
# 推断的起始 `beta` 值。
# beta_end (`float`, defaults to 0.02):
# 最终的 `beta` 值。
# beta_schedule (`str`, defaults to `"linear"`):
# beta 计划,从 beta 范围映射到一系列用于模型步骤的 betas。可选择 `linear`、`scaled_linear` 或 `squaredcos_cap_v2`。
# trained_betas (`np.ndarray`, *optional*):
# 直接传递 beta 数组给构造函数,以绕过 `beta_start` 和 `beta_end`。
# solver_order (`int`, defaults to 2):
# DEIS 顺序,可以是 `1`、`2` 或 `3`。建议使用 `solver_order=2` 进行引导采样,使用 `solver_order=3` 进行无条件采样。
# prediction_type (`str`, defaults to `epsilon`):
# 调度程序函数的预测类型;可以是 `epsilon`(预测扩散过程的噪声)、`sample`(直接预测噪声样本)或 `v_prediction`(见 [Imagen Video](https://imagen.research.google/video/paper.pdf) 论文第 2.4 节)。
# thresholding (`bool`, defaults to `False`):
# 是否使用“动态阈值”方法。这对于如稳定扩散的潜空间扩散模型不适用。
# dynamic_thresholding_ratio (`float`, defaults to 0.995):
# 动态阈值方法的比率。仅在 `thresholding=True` 时有效。
# sample_max_value (`float`, defaults to 1.0):
# 动态阈值的阈值值。仅在 `thresholding=True` 时有效。
# algorithm_type (`str`, defaults to `deis`):
# 求解器的算法类型。
# lower_order_final (`bool`, defaults to `True`):
# 是否在最后步骤中使用低阶求解器。仅在推理步骤小于 15 时有效。
# use_karras_sigmas (`bool`, *optional*, defaults to `False`):
# 是否在采样过程中使用 Karras sigmas 作为噪声计划中的步长。如果为 `True`,则 sigmas 根据噪声水平序列 {σi} 确定。
# timestep_spacing (`str`, defaults to `"linspace"`):
# 时间步的缩放方式。有关更多信息,请参考 [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) 的表 2。
# steps_offset (`int`, defaults to 0):
# 添加到推理步骤的偏移量,根据某些模型系列的要求。
# 创建一个包含所有 KarrasDiffusionSchedulers 名称的列表
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
# 设置默认的求解器阶数为 1
order = 1
# 注册到配置中,定义初始化函数
@register_to_config
def __init__(
# 设置训练时间步的数量,默认值为 1000
num_train_timesteps: int = 1000,
# 设置 beta 的起始值,默认值为 0.0001
beta_start: float = 0.0001,
# 设置 beta 的结束值,默认值为 0.02
beta_end: float = 0.02,
# 设置 beta 的调度方式,默认值为 "linear"
beta_schedule: str = "linear",
# 可选参数,设置训练的 beta 数组,默认值为 None
trained_betas: Optional[np.ndarray] = None,
# 设置求解器的阶数,默认值为 2
solver_order: int = 2,
# 设置预测类型,默认值为 "epsilon"
prediction_type: str = "epsilon",
# 设置是否使用阈值处理,默认值为 False
thresholding: bool = False,
# 设置动态阈值比例,默认值为 0.995
dynamic_thresholding_ratio: float = 0.995,
# 设置样本的最大值,默认值为 1.0
sample_max_value: float = 1.0,
# 设置算法类型,默认值为 "deis"
algorithm_type: str = "deis",
# 设置求解器类型,默认值为 "logrho"
solver_type: str = "logrho",
# 设置是否在最后阶段使用较低的阶数,默认值为 True
lower_order_final: bool = True,
# 可选参数,设置是否使用 Karras sigma,默认值为 False
use_karras_sigmas: Optional[bool] = False,
# 设置时间步的间距类型,默认值为 "linspace"
timestep_spacing: str = "linspace",
# 设置步数偏移,默认值为 0
steps_offset: int = 0,
):
# 检查已训练的 beta 值是否为 None
if trained_betas is not None:
# 将训练的 beta 值转换为浮点型张量
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
# 检查 beta 调度类型是否为线性
elif beta_schedule == "linear":
# 生成从 beta_start 到 beta_end 的线性序列
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
# 检查 beta 调度类型是否为缩放线性
elif beta_schedule == "scaled_linear":
# 该调度特定于潜在扩散模型
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
# 检查 beta 调度类型是否为平方余弦 cap v2
elif beta_schedule == "squaredcos_cap_v2":
# Glide 余弦调度
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
# 如果不支持的调度类型,抛出未实现错误
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
# 计算 alphas,等于 1 减去 betas
self.alphas = 1.0 - self.betas
# 计算 alphas 的累积乘积
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# 当前仅支持 VP 类型噪声调度
self.alpha_t = torch.sqrt(self.alphas_cumprod)
# 计算 sigma_t,等于 1 减去 alphas_cumprod 的平方根
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
# 计算 lambda_t,等于 alpha_t 和 sigma_t 的对数差
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
# 计算 sigmas,等于 (1 - alphas_cumprod) 除以 alphas_cumprod 的平方根
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# 设置初始噪声分布的标准差
self.init_noise_sigma = 1.0
# DEIS 设置
if algorithm_type not in ["deis"]:
# 如果算法类型是 dpmsolver 或 dpmsolver++
if algorithm_type in ["dpmsolver", "dpmsolver++"]:
# 注册算法类型到配置
self.register_to_config(algorithm_type="deis")
else:
# 抛出未实现错误
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
# 检查求解器类型是否为 logrho
if solver_type not in ["logrho"]:
# 如果求解器类型是 midpoint, heun, bh1, bh2
if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
# 注册求解器类型到配置
self.register_to_config(solver_type="logrho")
else:
# 抛出未实现错误
raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}")
# 可设置的值
self.num_inference_steps = None
# 生成从 0 到 num_train_timesteps - 1 的时间步,反转顺序
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
# 将时间步转换为张量
self.timesteps = torch.from_numpy(timesteps)
# 初始化模型输出列表,长度为 solver_order
self.model_outputs = [None] * solver_order
# 记录低阶数
self.lower_order_nums = 0
# 初始化步索引
self._step_index = None
# 初始化开始索引
self._begin_index = None
# 将 sigmas 移到 CPU,以避免过多的 CPU/GPU 通信
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后增加 1。
"""
return self._step_index
@property
def begin_index(self):
"""
第一个时间步的索引。应通过 `set_begin_index` 方法从管道中设置。
"""
return self._begin_index
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 复制
# 设置调度器的起始索引,默认值为0
def set_begin_index(self, begin_index: int = 0):
# 文档字符串,说明函数的用途和参数
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
# 将传入的起始索引值存储到实例变量中
self._begin_index = begin_index
# 从 diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample 复制的函数
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# 文档字符串,描述动态阈值处理的原理和效果
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
# 获取输入样本的数值类型
dtype = sample.dtype
# 获取样本的批次大小、通道数及剩余维度
batch_size, channels, *remaining_dims = sample.shape
# 检查数据类型,如果不是浮点数,则转换为浮点数
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# 将样本扁平化以进行量化计算
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
# 计算样本的绝对值
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
# 计算每个图像的动态阈值
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
# 限制阈值在指定范围内
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
# 扩展维度以适应广播
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
# 将样本限制在[-s, s]范围内并归一化
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
# 恢复样本的原始形状
sample = sample.reshape(batch_size, channels, *remaining_dims)
# 将样本转换回原始数据类型
sample = sample.to(dtype)
# 返回处理后的样本
return sample
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t 复制的函数
def _sigma_to_t(self, sigma, log_sigmas):
# 计算对数sigma值,确保不小于1e-10
log_sigma = np.log(np.maximum(sigma, 1e-10))
# 计算对数sigma的分布
dists = log_sigma - log_sigmas[:, np.newaxis]
# 找到sigma的范围
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
# 获取低和高的对数sigma值
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# 进行sigma的插值
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# 将插值转换为时间范围
t = (1 - w) * low_idx + w * high_idx
# 重新调整形状以匹配sigma的形状
t = t.reshape(sigma.shape)
# 返回时间值
return t
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep 导入的函数,用于将 sigma 转换为 alpha 和 sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
# 计算 alpha_t,公式为 1 / sqrt(sigma^2 + 1)
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
# 计算 sigma_t,公式为 sigma * alpha_t
sigma_t = sigma * alpha_t
# 返回计算得到的 alpha_t 和 sigma_t
return alpha_t, sigma_t
# 从 diffusers.schedulers.scheduling_euler_discrete 导入的函数,用于将输入 sigma 转换为 Karras 的格式
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构建 Karras 等人 (2022) 的噪声调度。"""
# 确保其他调度器复制此函数时不会出错的黑客方案
# TODO: 将此逻辑添加到其他调度器中
if hasattr(self.config, "sigma_min"):
# 获取 sigma_min,如果配置中存在
sigma_min = self.config.sigma_min
else:
# 如果配置中不存在,设置为 None
sigma_min = None
if hasattr(self.config, "sigma_max"):
# 获取 sigma_max,如果配置中存在
sigma_max = self.config.sigma_max
else:
# 如果配置中不存在,设置为 None
sigma_max = None
# 设置 sigma_min 为输入 sigmas 的最后一个值,如果它是 None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
# 设置 sigma_max 为输入 sigmas 的第一个值,如果它是 None
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
# 定义 rho 的值为 7.0,引用文献中使用的值
rho = 7.0 # 7.0 is the value used in the paper
# 生成从 0 到 1 的 ramp 数组,长度为 num_inference_steps
ramp = np.linspace(0, 1, num_inference_steps)
# 计算 min_inv_rho 和 max_inv_rho
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
# 根据公式生成 sigmas 数组
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# 返回生成的 sigmas
return sigmas
# 定义 convert_model_output 函数,用于处理模型输出
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
将模型输出转换为 DEIS 算法所需的对应类型。
参数:
model_output (`torch.Tensor`):
来自学习的扩散模型的直接输出。
timestep (`int`):
当前扩散链中的离散时间步。
sample (`torch.Tensor`):
扩散过程中创建的当前样本实例。
返回:
`torch.Tensor`:
转换后的模型输出。
"""
# 从 args 中提取 timestep,如果没有则从 kwargs 中提取,默认为 None
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
# 如果 sample 为 None,尝试从 args 中提取
if sample is None:
if len(args) > 1:
sample = args[1]
else:
# 如果没有提供 sample,则抛出错误
raise ValueError("missing `sample` as a required keyward argument")
# 如果 timestep 不是 None,发出弃用警告
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 获取当前步的 sigma 值
sigma = self.sigmas[self.step_index]
# 将 sigma 转换为 alpha_t 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
# 根据配置类型进行不同的模型输出处理
if self.config.prediction_type == "epsilon":
# 计算基于 epsilon 的预测
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
# 直接将模型输出作为预测
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
# 计算基于 v 的预测
x0_pred = alpha_t * sample - sigma_t * model_output
else:
# 如果 prediction_type 不符合要求,抛出错误
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DEISMultistepScheduler."
)
# 如果开启阈值处理,则对预测值进行阈值处理
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
# 如果算法类型为 deis,返回转换后的样本
if self.config.algorithm_type == "deis":
return (sample - alpha_t * x0_pred) / sigma_t
else:
# 抛出未实现错误,表明仅支持 log-rho multistep deis
raise NotImplementedError("only support log-rho multistep deis now")
# 定义 deis_first_order_update 函数,接受模型输出和可变参数
def deis_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor: # 定义函数返回类型为 torch.Tensor
""" # 开始函数的文档字符串
One step for the first-order DEIS (equivalent to DDIM). # 描述该函数为一阶 DEIS 步骤(等同于 DDIM)
Args: # 参数说明开始
model_output (`torch.Tensor`): # 参数 model_output,类型为 torch.Tensor
The direct output from the learned diffusion model. # 描述为从学习到的扩散模型获得的直接输出
timestep (`int`): # 参数 timestep,类型为 int
The current discrete timestep in the diffusion chain. # 描述为扩散链中的当前离散时间步
prev_timestep (`int`): # 参数 prev_timestep,类型为 int
The previous discrete timestep in the diffusion chain. # 描述为扩散链中的前一个离散时间步
sample (`torch.Tensor`): # 参数 sample,类型为 torch.Tensor
A current instance of a sample created by the diffusion process. # 描述为扩散过程创建的当前样本实例
Returns: # 返回值说明开始
`torch.Tensor`: # 返回类型为 torch.Tensor
The sample tensor at the previous timestep. # 描述为在前一个时间步的样本张量
""" # 结束函数的文档字符串
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) # 获取当前时间步,如果没有则从关键字参数中提取
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) # 获取前一个时间步,如果没有则从关键字参数中提取
if sample is None: # 检查 sample 是否为 None
if len(args) > 2: # 如果 args 的长度大于 2
sample = args[2] # 从 args 中获取 sample
else: # 否则
raise ValueError(" missing `sample` as a required keyward argument") # 抛出缺少 sample 的异常
if timestep is not None: # 如果当前时间步不为 None
deprecate( # 调用 deprecate 函数以发出弃用警告
"timesteps", # 被弃用的参数名称
"1.0.0", # 版本号
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", # 弃用说明
)
if prev_timestep is not None: # 如果前一个时间步不为 None
deprecate( # 调用 deprecate 函数以发出弃用警告
"prev_timestep", # 被弃用的参数名称
"1.0.0", # 版本号
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", # 弃用说明
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # 获取当前和前一个时间步的 sigma 值
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) # 将 sigma_t 转换为 alpha_t 和 sigma_t
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) # 将 sigma_s 转换为 alpha_s 和 sigma_s
lambda_t = torch.log(alpha_t) - torch.log(sigma_t) # 计算 lambda_t 为 alpha_t 和 sigma_t 的对数差
lambda_s = torch.log(alpha_s) - torch.log(sigma_s) # 计算 lambda_s 为 alpha_s 和 sigma_s 的对数差
h = lambda_t - lambda_s # 计算 h 为 lambda_t 和 lambda_s 的差
if self.config.algorithm_type == "deis": # 检查算法类型是否为 "deis"
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output # 计算当前样本 x_t
else: # 否则
raise NotImplementedError("only support log-rho multistep deis now") # 抛出不支持的算法类型异常
return x_t # 返回计算得到的样本 x_t
def multistep_deis_second_order_update( # 定义 multistep_deis_second_order_update 函数
self, # 类实例
model_output_list: List[torch.Tensor], # 参数 model_output_list,类型为 torch.Tensor 的列表
*args, # 可变位置参数
sample: torch.Tensor = None, # 参数 sample,默认为 None
**kwargs, # 可变关键字参数
# 定义一个函数,返回类型为 torch.Tensor
) -> torch.Tensor:
"""
第二阶多步 DEIS 的一步计算。
参数:
model_output_list (`List[torch.Tensor]`):
当前和后续时间步的学习扩散模型直接输出。
sample (`torch.Tensor`):
扩散过程生成的当前样本实例。
返回:
`torch.Tensor`:
上一时间步的样本张量。
"""
# 获取时间步列表,如果没有则从关键字参数中获取
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
# 获取前一个时间步,如果没有则从关键字参数中获取
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
# 如果样本为 None,则尝试从参数中获取样本
if sample is None:
if len(args) > 2:
sample = args[2]
else:
# 如果样本仍然为 None,则引发错误
raise ValueError(" missing `sample` as a required keyward argument")
# 如果时间步列表不为 None,则发出弃用警告
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 如果前一个时间步不为 None,则发出弃用警告
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 获取当前和前后时间步的 sigma 值
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
# 将 sigma 转换为 alpha 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
# 获取最后两个模型输出
m0, m1 = model_output_list[-1], model_output_list[-2]
# 计算 rho 值
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
# 检查算法类型是否为 "deis"
if self.config.algorithm_type == "deis":
# 定义积分函数
def ind_fn(t, b, c):
# Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}]
return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c))
# 计算系数
coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1)
coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0)
# 计算 x_t
x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1)
# 返回计算结果
return x_t
else:
# 如果算法类型不支持,则引发未实现的错误
raise NotImplementedError("only support log-rho multistep deis now")
# 定义一个多步 DEIS 第三阶更新的函数
def multistep_deis_third_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
# 当前样本实例,默认为 None
sample: torch.Tensor = None,
**kwargs,
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep 复制
# 根据时间步初始化索引
def index_for_timestep(self, timestep, schedule_timesteps=None):
# 如果未提供时间调度步,则使用默认时间步
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# 找到与当前时间步匹配的候选索引
index_candidates = (schedule_timesteps == timestep).nonzero()
# 如果没有找到匹配的候选索引
if len(index_candidates) == 0:
# 将步骤索引设置为时间步的最后一个索引
step_index = len(self.timesteps) - 1
# 如果找到多个候选索引
# 第一个步骤的 sigma 索引总是第二个索引(如果只有一个则是最后一个)
# 这样可以确保在去噪调度中不会意外跳过 sigma
elif len(index_candidates) > 1:
# 使用第二个候选索引作为步骤索引
step_index = index_candidates[1].item()
else:
# 否则,使用第一个候选索引作为步骤索引
step_index = index_candidates[0].item()
# 返回最终步骤索引
return step_index
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index 中复制
def _init_step_index(self, timestep):
"""
初始化调度器的步骤索引计数器。
"""
# 如果开始索引为 None
if self.begin_index is None:
# 如果时间步是张量类型,则将其转移到相应设备
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# 使用 index_for_timestep 方法初始化步骤索引
self._step_index = self.index_for_timestep(timestep)
else:
# 否则使用预设的开始索引
self._step_index = self._begin_index
# 执行一步计算
def step(
self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
从前一个时间步预测样本,通过反转 SDE。此函数使用多步 DEIS 传播样本。
参数:
model_output (`torch.Tensor`):
从学习的扩散模型直接输出的张量。
timestep (`int`):
扩散链中当前离散时间步。
sample (`torch.Tensor`):
通过扩散过程创建的当前样本实例。
return_dict (`bool`):
是否返回 [`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`。
返回:
[`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`:
如果 return_dict 为 `True`,则返回 [`~schedulers.scheduling_utils.SchedulerOutput`],否则返回一个元组,
其中第一个元素是样本张量。
"""
# 检查推理步骤数量是否为 None,若是则抛出异常
if self.num_inference_steps is None:
raise ValueError(
"推理步骤数量为 'None',您需要在创建调度器后运行 'set_timesteps'"
)
# 检查当前步骤索引是否为 None,若是则初始化步骤索引
if self.step_index is None:
self._init_step_index(timestep)
# 判断是否为较低阶最终更新的条件
lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
# 判断是否为较低阶第二更新的条件
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)
# 转换模型输出为适合当前样本的格式
model_output = self.convert_model_output(model_output, sample=sample)
# 更新模型输出缓存,将当前模型输出存储到最后一个位置
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
# 根据配置选择合适的更新方法
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
# 使用一阶更新方法计算前一个样本
prev_sample = self.deis_first_order_update(model_output, sample=sample)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
# 使用二阶更新方法计算前一个样本
prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
else:
# 使用三阶更新方法计算前一个样本
prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
# 更新较低阶次数计数器
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# 完成后将步骤索引加一
self._step_index += 1
# 如果不返回字典,则返回包含前一个样本的元组
if not return_dict:
return (prev_sample,)
# 返回前一个样本的调度输出
return SchedulerOutput(prev_sample=prev_sample)
# 定义一个方法,用于根据当前时间步缩放去噪模型输入
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器之间的互换性。
Args:
sample (`torch.Tensor`):
输入样本。
Returns:
`torch.Tensor`:
缩放后的输入样本。
"""
# 返回未修改的输入样本
return sample
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise 复制的代码
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# 确保 sigmas 和 timesteps 与 original_samples 具有相同的设备和数据类型
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
# 检查设备类型,如果是 MPS 且 timesteps 是浮点型
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# MPS 不支持 float64 数据类型
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
# 将调度时间步转换为与原始样本相同的设备
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# 如果 begin_index 为 None,表示调度器用于训练或管道未实现 set_begin_index
if self.begin_index is None:
# 根据时间步获取步索引
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# 在第一次去噪步骤后调用 add_noise(用于修补)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# 在第一次去噪步骤之前调用 add_noise 以创建初始潜在图像(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
# 根据步索引获取 sigma,并将其扁平化
sigma = sigmas[step_indices].flatten()
# 如果 sigma 的形状小于原始样本的形状,则在最后一个维度添加一个维度
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
# 将 sigma 转换为 alpha_t 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
# 生成带噪声的样本
noisy_samples = alpha_t * original_samples + sigma_t * noise
# 返回带噪声的样本
return noisy_samples
# 定义方法以返回训练时间步的数量
def __len__(self):
return self.config.num_train_timesteps
# 版权声明,指明该文件的版权所有者和使用条款
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# 按照 Apache 2.0 许可证使用本文件的声明
# Licensed under the Apache License, Version 2.0 (the "License");
# 只能在遵守许可证的情况下使用该文件
# you may not use this file except in compliance with the License.
# 可以在以下网址获取许可证副本
# http://www.apache.org/licenses/LICENSE-2.0
#
# 如果没有法律适用或书面协议的情况下,该文件按 "AS IS" 基础分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的保证或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解具体的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 免责声明,说明此文件受到特定项目的影响
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
# 导入数学模块
import math
# 从 typing 模块导入 List, Optional, Tuple, Union 类型
from typing import List, Optional, Tuple, Union
# 导入 numpy 模块并使用 np 别名
import numpy as np
# 导入 torch 模块
import torch
# 从配置工具中导入 ConfigMixin 和 register_to_config
from ..configuration_utils import ConfigMixin, register_to_config
# 从 utils 中导入 deprecate
from ..utils import deprecate
# 从 torch_utils 中导入 randn_tensor
from ..utils.torch_utils import randn_tensor
# 从调度工具中导入 KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# 从 diffusers 中复制的函数,用于生成 beta 调度
def betas_for_alpha_bar(
num_diffusion_timesteps, # 指定生成 beta 数量的参数
max_beta=0.999, # 设置最大 beta 值的默认参数
alpha_transform_type="cosine", # 指定 alpha 转换类型的默认参数
):
"""
创建一个 beta 调度,它离散化给定的 alpha_t_bar 函数,定义时间 t = [0,1] 上 (1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接受参数 t 并将其转换为该部分扩散过程的 (1-beta) 的累积乘积。
参数:
num_diffusion_timesteps (`int`): 生成 beta 的数量。
max_beta (`float`): 使用的最大 beta 值;使用小于 1 的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认为 `cosine`): alpha_bar 的噪声调度类型。
从 `cosine` 或 `exp` 中选择
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的 betas
"""
# 检查 alpha_transform_type 是否为 "cosine"
if alpha_transform_type == "cosine":
# 定义 alpha_bar_fn 函数,计算 cos 函数的平方
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
# 检查 alpha_transform_type 是否为 "exp"
elif alpha_transform_type == "exp":
# 定义 alpha_bar_fn 函数,计算指数函数
def alpha_bar_fn(t):
return math.exp(t * -12.0)
# 如果 alpha_transform_type 不是支持的类型,则引发错误
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
# 初始化 betas 列表
betas = []
# 遍历每个扩散时间步
for i in range(num_diffusion_timesteps):
# 计算当前时间步 t1
t1 = i / num_diffusion_timesteps
# 计算下一个时间步 t2
t2 = (i + 1) / num_diffusion_timesteps
# 计算 beta 值并添加到 betas 列表中,确保不超过 max_beta
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
# 将 betas 列表转换为张量并返回
return torch.tensor(betas, dtype=torch.float32)
# 从 diffusers 中复制的函数,用于重新调整 beta
def rescale_zero_terminal_snr(betas):
"""
根据 https://arxiv.org/pdf/2305.08891.pdf (算法 1) 重新调整 beta 以具有零终端 SNR
参数:
betas (`torch.Tensor`):
用于初始化调度器的 beta。
# 返回 rescaled betas,且终端信噪比为零
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# 将 betas 转换为 alphas_bar_sqrt
alphas = 1.0 - betas # 计算 alphas
alphas_cumprod = torch.cumprod(alphas, dim=0) # 计算 alphas 的累积乘积
alphas_bar_sqrt = alphas_cumprod.sqrt() # 计算累积乘积的平方根
# 存储旧值
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() # 克隆第一个 alphas_bar_sqrt 的值
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # 克隆最后一个 alphas_bar_sqrt 的值
# 平移,使最后一个时间步为零
alphas_bar_sqrt -= alphas_bar_sqrt_T # 从每个元素中减去最后一个值
# 缩放,使第一个时间步恢复到旧值
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # 缩放 alphas_bar_sqrt
# 将 alphas_bar_sqrt 转换为 betas
alphas_bar = alphas_bar_sqrt**2 # 还原平方根
alphas = alphas_bar[1:] / alphas_bar[:-1] # 还原累积乘积
alphas = torch.cat([alphas_bar[0:1], alphas]) # 将第一个 alphas_bar 的值添加到 alphas 前面
betas = 1 - alphas # 计算 betas
# 返回计算后的 betas
return betas
# 定义一个多步调度器类,专用于快速高阶求解扩散常微分方程
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
`DPMSolverMultistepScheduler` 是一个快速的专用高阶求解器,用于扩散 ODE。
该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。请查看父类文档以了解该库为所有调度器实现的通用方法,例如加载和保存。
"""
# 存储与 KarrasDiffusionSchedulers 兼容的调度器名称列表
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
# 设置默认的求解器阶数为 1
order = 1
@register_to_config
def __init__(
# 训练的时间步数,默认为 1000
num_train_timesteps: int = 1000,
# 初始 beta 值,默认为 0.0001
beta_start: float = 0.0001,
# 最终 beta 值,默认为 0.02
beta_end: float = 0.02,
# beta 的调度类型,默认为线性
beta_schedule: str = "linear",
# 经过训练的 beta 值,默认为 None
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
# 求解器阶数,默认为 2
solver_order: int = 2,
# 预测类型,默认为 "epsilon"
prediction_type: str = "epsilon",
# 是否启用阈值处理,默认为 False
thresholding: bool = False,
# 动态阈值处理比例,默认为 0.995
dynamic_thresholding_ratio: float = 0.995,
# 样本的最大值,默认为 1.0
sample_max_value: float = 1.0,
# 算法类型,默认为 "dpmsolver++"
algorithm_type: str = "dpmsolver++",
# 求解器类型,默认为 "midpoint"
solver_type: str = "midpoint",
# 是否在最后阶段使用低阶方法,默认为 True
lower_order_final: bool = True,
# 最后阶段是否使用欧拉法,默认为 False
euler_at_final: bool = False,
# 是否使用 Karras 的 sigma 值,默认为 None
use_karras_sigmas: Optional[bool] = False,
# 是否使用 LU 的 lambda 值,默认为 None
use_lu_lambdas: Optional[bool] = False,
# 最终 sigma 类型,默认为 "zero"
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
# 最小 lambda 值被裁剪,默认为 -inf
lambda_min_clipped: float = -float("inf"),
# 方差类型,默认为 None
variance_type: Optional[str] = None,
# 时间步的间隔类型,默认为 "linspace"
timestep_spacing: str = "linspace",
# 步骤偏移量,默认为 0
steps_offset: int = 0,
# 是否在 SNR 为零时重新缩放 beta,默认为 False
rescale_betas_zero_snr: bool = False,
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后会增加 1。
"""
return self._step_index
@property
def begin_index(self):
"""
第一个时间步的索引。应通过 `set_begin_index` 方法从管道设置。
"""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的起始索引。该函数应在推理前通过管道运行。
参数:
begin_index (`int`):
调度器的起始索引。
"""
# 更新调度器的起始索引
self._begin_index = begin_index
def set_timesteps(
# 设置推理步骤的数量,默认为 None
num_inference_steps: int = None,
# 设备类型,可以是字符串或 torch.device,默认为 None
device: Union[str, torch.device] = None,
# 指定时间步,默认为 None
timesteps: Optional[List[int]] = None,
# 从 diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample 中复制
# 定义一个私有方法,进行动态阈值采样,输入为一个张量,返回处理后的张量
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"动态阈值处理:在每个采样步骤中,我们将 s 设置为 xt0(在时间步 t 对 x_0 的预测)中的某个百分位绝对像素值,如果 s > 1,则将 xt0 阈值化到范围 [-s, s],然后除以 s。动态阈值处理将饱和像素(接近 -1 和 1 的像素)推入内部,从而在每一步主动防止像素饱和。我们发现动态阈值处理可以显著提高照片真实感以及图像与文本的对齐,尤其是在使用非常大的引导权重时。"
https://arxiv.org/abs/2205.11487
"""
# 获取输入样本的数值类型
dtype = sample.dtype
# 解包样本的形状,获取批量大小、通道数及其余维度
batch_size, channels, *remaining_dims = sample.shape
# 如果数据类型不是 float32 或 float64,则将样本上升为 float 类型
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # 为量化计算上升数据类型,且 cpu half 不支持 clamp
# 将样本展平,以便对每幅图像进行量化计算
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
# 计算样本的绝对值,用于获取"某个百分位绝对像素值"
abs_sample = sample.abs() # "某个百分位绝对像素值"
# 计算绝对样本的指定百分位值
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
# 将 s 的值限制在 [1, sample_max_value] 范围内
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # 当 clamped 到 min=1 时,相当于标准的裁剪到 [-1, 1]
# 在第一维增加一个维度,以便后续操作中能够正确广播
s = s.unsqueeze(1) # (batch_size, 1) 因为 clamp 将在 dim=0 上广播
# 将样本的值限制在 [-s, s] 范围内,并将其归一化
sample = torch.clamp(sample, -s, s) / s # "我们将 xt0 阈值化到范围 [-s, s],然后除以 s"
# 恢复样本的原始形状
sample = sample.reshape(batch_size, channels, *remaining_dims)
# 将样本转换回原始的数据类型
sample = sample.to(dtype)
# 返回处理后的样本
return sample
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t 拷贝的方法
def _sigma_to_t(self, sigma, log_sigmas):
# 计算 log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
# 计算分布
dists = log_sigma - log_sigmas[:, np.newaxis]
# 获取 sigma 的范围
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
# 获取低和高的 log sigma 值
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# 进行 sigma 的插值
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# 将插值转换为时间范围
t = (1 - w) * low_idx + w * high_idx
# 将 t 的形状恢复为 sigma 的形状
t = t.reshape(sigma.shape)
# 返回时间值
return t
# 定义一个方法,将 sigma 转换为 alpha_t 和 sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
# 计算 alpha_t,使用 sigma 的平方根
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
# 计算 sigma_t,通过 sigma 和 alpha_t 进行计算
sigma_t = sigma * alpha_t
# 返回 alpha_t 和 sigma_t
return alpha_t, sigma_t
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras 拷贝的方法
# 定义一个私有方法,用于将输入的 sigma 值转换为 Karras 的噪声调度
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构建 Karras 等人 (2022) 的噪声调度。"""
# 确保其他调度器复制此函数时不会出错的 Hack
# TODO: 将此逻辑添加到其他调度器中
# 检查配置中是否存在 sigma_min 属性
if hasattr(self.config, "sigma_min"):
# 如果存在,则获取其值
sigma_min = self.config.sigma_min
else:
# 否则,将 sigma_min 设置为 None
sigma_min = None
# 检查配置中是否存在 sigma_max 属性
if hasattr(self.config, "sigma_max"):
# 如果存在,则获取其值
sigma_max = self.config.sigma_max
else:
# 否则,将 sigma_max 设置为 None
sigma_max = None
# 如果 sigma_min 为 None,则使用 in_sigmas 中最后一个值
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
# 如果 sigma_max 为 None,则使用 in_sigmas 中第一个值
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
# rho 的值为论文中使用的 7.0
rho = 7.0 # 7.0 是论文中使用的值
# 创建一个从 0 到 1 的线性 ramp,长度为 num_inference_steps
ramp = np.linspace(0, 1, num_inference_steps)
# 计算 sigma_min 的倒数 rho
min_inv_rho = sigma_min ** (1 / rho)
# 计算 sigma_max 的倒数 rho
max_inv_rho = sigma_max ** (1 / rho)
# 根据 ramp 计算出对应的 sigma 值
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# 返回计算得到的 sigma 值
return sigmas
# 定义一个私有方法,用于将输入的 lambda 值转换为 Lu 的噪声调度
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构建 Lu 等人 (2022) 的噪声调度。"""
# 获取输入 lambdas 的最小值(最后一个值)
lambda_min: float = in_lambdas[-1].item()
# 获取输入 lambdas 的最大值(第一个值)
lambda_max: float = in_lambdas[0].item()
# rho 的值为论文中使用的 1.0
rho = 1.0 # 1.0 是论文中使用的值
# 创建一个从 0 到 1 的线性 ramp,长度为 num_inference_steps
ramp = np.linspace(0, 1, num_inference_steps)
# 计算 lambda_min 的倒数 rho
min_inv_rho = lambda_min ** (1 / rho)
# 计算 lambda_max 的倒数 rho
max_inv_rho = lambda_max ** (1 / rho)
# 根据 ramp 计算出对应的 lambda 值
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# 返回计算得到的 lambda 值
return lambdas
# 定义一个方法,用于转换模型输出
def convert_model_output(
self,
model_output: torch.Tensor, # 模型的输出张量
*args, # 额外的位置参数
sample: torch.Tensor = None, # 可选的样本张量
**kwargs, # 额外的关键字参数
# 定义一个方法,用于 DPM 求解器的一阶更新
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor, # 模型的输出张量
*args, # 额外的位置参数
sample: torch.Tensor = None, # 可选的样本张量
noise: Optional[torch.Tensor] = None, # 可选的噪声张量
**kwargs, # 额外的关键字参数
# 返回一个张量,表示第一阶 DPMSolver 的一步(等效于 DDIM)
) -> torch.Tensor:
"""
一步用于第一阶 DPMSolver(等效于 DDIM)。
参数:
model_output (`torch.Tensor`):
从学习的扩散模型直接输出的张量。
sample (`torch.Tensor`):
扩散过程中创建的当前样本实例。
返回:
`torch.Tensor`:
上一个时间步的样本张量。
"""
# 从参数中获取当前时间步,若没有则从关键字参数中获取
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
# 从参数中获取上一个时间步,若没有则从关键字参数中获取
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
# 如果样本为空,尝试从参数中获取样本
if sample is None:
if len(args) > 2:
sample = args[2]
else:
# 抛出错误,样本为必需的关键字参数
raise ValueError(" missing `sample` as a required keyward argument")
# 如果时间步不为空,发出弃用警告
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 如果上一个时间步不为空,发出弃用警告
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 获取当前和前一个时间步的 sigma 值
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
# 将 sigma 转换为 alpha 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
# 将前一个 sigma 转换为 alpha 和 sigma_s
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
# 计算 lambda_t 和 lambda_s
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
# 计算 h 值
h = lambda_t - lambda_s
# 根据配置的算法类型进行不同的计算
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None # 确保噪声不为空
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None # 确保噪声不为空
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
# 返回计算得到的样本张量
return x_t
# 定义多步 DPM 求解器的二阶更新方法
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor], # 模型输出列表
*args, # 额外的参数
sample: torch.Tensor = None, # 当前样本
noise: Optional[torch.Tensor] = None, # 可选噪声
**kwargs, # 额外的关键字参数
# 定义一个三阶更新的多步 DPM 求解器方法
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor], # 输入的模型输出列表,包含多个张量
*args, # 其他位置参数
sample: torch.Tensor = None, # 输入样本,默认为 None
**kwargs, # 其他关键字参数
# 定义一个用于获取时间步索引的方法
def index_for_timestep(self, timestep, schedule_timesteps=None):
# 如果未提供调度时间步,则使用默认时间步
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# 查找与给定时间步相等的调度时间步的索引
index_candidates = (schedule_timesteps == timestep).nonzero()
# 如果没有找到匹配的索引
if len(index_candidates) == 0:
# 设置步索引为时间步列表的最后一个索引
step_index = len(self.timesteps) - 1
# 对于多个匹配的情况,选择第二个索引(或最后一个索引)
# 以确保不会意外跳过 sigma
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
# 如果只找到一个匹配,取第一个索引
step_index = index_candidates[0].item()
# 返回计算得到的步索引
return step_index
# 初始化调度器的步索引计数器的方法
def _init_step_index(self, timestep):
"""
初始化调度器的步索引计数器。
"""
# 如果开始索引为 None
if self.begin_index is None:
# 如果时间步是张量类型,则将其转移到时间步的设备上
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# 根据时间步获取步索引
self._step_index = self.index_for_timestep(timestep)
else:
# 如果开始索引已定义,则直接使用它
self._step_index = self._begin_index
# 定义一个步骤方法,执行一个更新步骤
def step(
self,
model_output: torch.Tensor, # 模型输出的张量
timestep: Union[int, torch.Tensor], # 当前的时间步,可以是整数或张量
sample: torch.Tensor, # 输入样本
generator=None, # 可选的生成器
variance_noise: Optional[torch.Tensor] = None, # 可选的方差噪声张量
return_dict: bool = True, # 是否返回字典格式的结果
# 定义一个缩放模型输入的方法
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器互换。
参数:
sample (`torch.Tensor`):
输入样本。
返回:
`torch.Tensor`:
缩放后的输入样本。
"""
# 返回输入样本,当前没有进行缩放
return sample
# 定义一个添加噪声的方法
def add_noise(
self,
original_samples: torch.Tensor, # 原始样本的张量
noise: torch.Tensor, # 要添加的噪声张量
timesteps: torch.IntTensor, # 时间步的整数张量
) -> torch.Tensor:
# 确保 sigmas 和 timesteps 与 original_samples 在同一设备和数据类型上
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
# 检查设备是否为 mps 且 timesteps 是否为浮点数
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps 不支持 float64 类型
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
# 将 timesteps 转换到与 original_samples 相同的设备
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# 当 scheduler 用于训练或 pipeline 未实现 set_begin_index 时,begin_index 为 None
if self.begin_index is None:
# 根据 timesteps 计算 step_indices
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# 在第一次去噪步骤后调用 add_noise(用于修复)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# 在第一次去噪步骤之前调用 add_noise 以创建初始潜在图像(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
# 根据 step_indices 获取对应的 sigmas,并将其展平
sigma = sigmas[step_indices].flatten()
# 通过增加维度来匹配 sigma 的形状与 original_samples 的形状
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
# 将 sigma 转换为 alpha_t 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
# 生成带噪声的样本
noisy_samples = alpha_t * original_samples + sigma_t * noise
# 返回带噪声的样本
return noisy_samples
def __len__(self):
# 返回训练时间步的数量
return self.config.num_train_timesteps
# 版权所有 2024 TSAIL 团队和 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版("许可证")许可;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,按“原样”分发的软件不附带任何明示或暗示的担保或条件。
# 有关许可证所适用权限和限制的具体说明,请参见许可证。
# 免责声明:此文件受到 https://github.com/LuChengTHU/dpm-solver 的强烈影响
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 导入类型相关的模块
from typing import List, Optional, Tuple, Union
# 导入 flax 库
import flax
# 导入 jax 库
import jax
# 导入 jax 数组处理模块
import jax.numpy as jnp
# 从配置工具导入配置混合类和注册函数
from ..configuration_utils import ConfigMixin, register_to_config
# 从调度工具导入常用的调度器状态和类
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
)
# 定义 DPMSolverMultistepSchedulerState 数据类,表示调度器的状态
@flax.struct.dataclass
class DPMSolverMultistepSchedulerState:
# 调度器的通用状态
common: CommonSchedulerState
# 当前时间步的 alpha 值
alpha_t: jnp.ndarray
# 当前时间步的 sigma 值
sigma_t: jnp.ndarray
# 当前时间步的 lambda 值
lambda_t: jnp.ndarray
# 可设置的值
init_noise_sigma: jnp.ndarray # 初始化噪声标准差
timesteps: jnp.ndarray # 时间步数组
num_inference_steps: Optional[int] = None # 推理步骤数(可选)
# 运行时值
model_outputs: Optional[jnp.ndarray] = None # 模型输出(可选)
lower_order_nums: Optional[jnp.int32] = None # 较低阶数(可选)
prev_timestep: Optional[jnp.int32] = None # 上一个时间步(可选)
cur_sample: Optional[jnp.ndarray] = None # 当前样本(可选)
# 定义类方法以创建调度器状态
@classmethod
def create(
cls,
common: CommonSchedulerState,
alpha_t: jnp.ndarray,
sigma_t: jnp.ndarray,
lambda_t: jnp.ndarray,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
# 返回新的调度器状态实例
return cls(
common=common,
alpha_t=alpha_t,
sigma_t=sigma_t,
lambda_t=lambda_t,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
# 定义 FlaxDPMSolverMultistepSchedulerOutput 数据类,表示调度器输出
@dataclass
class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput):
# 调度器的状态
state: DPMSolverMultistepSchedulerState
# 定义 FlaxDPMSolverMultistepScheduler 类,继承调度器混合类和配置混合类
class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
DPM-Solver(以及改进版 DPM-Solver++)是一个快速的专用高阶求解器,用于扩散 ODE,并提供收敛阶数保证。
实证表明,使用 DPM-Solver 仅 20 步就能生成高质量样本,即使仅用 10 步也能生成相当不错的样本。
有关更多详细信息,请参见原始论文: https://arxiv.org/abs/2206.00927 和 https://arxiv.org/abs/2211.01095
目前,我们支持多步 DPM-Solver 适用于噪声预测模型和数据预测模型。
我们建议使用 `solver_order=2` 进行引导采样,使用 `solver_order=3` 进行无条件采样。
# 支持 Imagen 中的“动态阈值”方法,参考文献:https://arxiv.org/abs/2205.11487
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487).
# 对于像素空间扩散模型,可以同时设置 `algorithm_type="dpmsolver++"` 和 `thresholding=True` 来使用动态阈值
For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
# 注意,阈值方法不适合于潜空间扩散模型(如 stable-diffusion)
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
# `ConfigMixin` 负责存储在调度器的 `__init__` 函数中传递的所有配置属性,例如 `num_train_timesteps`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
# 这些属性可以通过 `scheduler.config.num_train_timesteps` 访问
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
# `SchedulerMixin` 提供通用的加载和保存功能,通过 [`SchedulerMixin.save_pretrained`] 和 [`~SchedulerMixin.from_pretrained`] 函数
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
# 有关更多详细信息,请参见原始论文: https://arxiv.org/abs/2206.00927 和 https://arxiv.org/abs/2211.01095
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
# 兼容的调度器列表,从 FlaxKarrasDiffusionSchedulers 中提取名称
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
# 数据类型变量
dtype: jnp.dtype
# 属性,返回是否有状态
@property
def has_state(self):
return True
# 注册到配置的初始化函数,定义多个参数的默认值
@register_to_config
def __init__(
# 训练时间步数,默认为 1000
num_train_timesteps: int = 1000,
# beta 的起始值,默认为 0.0001
beta_start: float = 0.0001,
# beta 的结束值,默认为 0.02
beta_end: float = 0.02,
# beta 的调度方式,默认为 "linear"
beta_schedule: str = "linear",
# 已训练的 beta 值,可选
trained_betas: Optional[jnp.ndarray] = None,
# 解算器阶数,默认为 2
solver_order: int = 2,
# 预测类型,默认为 "epsilon"
prediction_type: str = "epsilon",
# 是否启用阈值处理,默认为 False
thresholding: bool = False,
# 动态阈值比例,默认为 0.995
dynamic_thresholding_ratio: float = 0.995,
# 采样最大值,默认为 1.0
sample_max_value: float = 1.0,
# 算法类型,默认为 "dpmsolver++"
algorithm_type: str = "dpmsolver++",
# 解算器类型,默认为 "midpoint"
solver_type: str = "midpoint",
# 最后阶段是否降低阶数,默认为 True
lower_order_final: bool = True,
# 时间步的间隔类型,默认为 "linspace"
timestep_spacing: str = "linspace",
# 数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32,
):
# 将数据类型赋值给实例变量
self.dtype = dtype
# 创建状态的方法,接受一个可选的公共调度状态参数,返回 DPM 求解器多步调度状态
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
# 如果没有提供公共调度状态,则创建一个新的实例
if common is None:
common = CommonSchedulerState.create(self)
# 当前仅支持 VP 类型的噪声调度
alpha_t = jnp.sqrt(common.alphas_cumprod) # 计算累积 alpha 的平方根
sigma_t = jnp.sqrt(1 - common.alphas_cumprod) # 计算 1 减去累积 alpha 的平方根
lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) # 计算 alpha_t 和 sigma_t 的对数差
# DPM 求解器的设置
if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]:
# 如果算法类型不在支持的列表中,则抛出未实现异常
raise NotImplementedError(f"{self.config.algorithm_type} is not implemented for {self.__class__}")
if self.config.solver_type not in ["midpoint", "heun"]:
# 如果求解器类型不在支持的列表中,则抛出未实现异常
raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}")
# 初始化噪声分布的标准差
init_noise_sigma = jnp.array(1.0, dtype=self.dtype) # 创建一个值为 1.0 的数组,类型为实例的 dtype
# 生成时间步的数组,从 0 到 num_train_timesteps,取整后反转
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
# 创建并返回 DPM 求解器多步调度状态
return DPMSolverMultistepSchedulerState.create(
common=common, # 传入公共调度状态
alpha_t=alpha_t, # 传入计算得到的 alpha_t
sigma_t=sigma_t, # 传入计算得到的 sigma_t
lambda_t=lambda_t, # 传入计算得到的 lambda_t
init_noise_sigma=init_noise_sigma, # 传入初始化噪声的标准差
timesteps=timesteps, # 传入时间步数组
)
# 设置时间步的方法,接受当前状态、推理步骤数和形状作为参数
def set_timesteps(
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
) -> DPMSolverMultistepSchedulerState: # 定义返回类型为 DPMSolverMultistepSchedulerState
""" # 文档字符串,描述该函数的功能
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. # 设置用于扩散链的离散时间步,支持在推断前运行的函数
Args: # 参数说明
state (`DPMSolverMultistepSchedulerState`): # 状态参数,类型为 DPMSolverMultistepSchedulerState
the `FlaxDPMSolverMultistepScheduler` state data class instance. # FlaxDPMSolverMultistepScheduler 的状态数据类实例
num_inference_steps (`int`): # 推断步骤数量参数,类型为 int
the number of diffusion steps used when generating samples with a pre-trained model. # 生成样本时使用的扩散步骤数量
shape (`Tuple`): # 样本形状参数,类型为元组
the shape of the samples to be generated. # 要生成的样本的形状
""" # 文档字符串结束
last_timestep = self.config.num_train_timesteps # 获取训练时的最后时间步
if self.config.timestep_spacing == "linspace": # 检查时间步间距配置是否为线性空间
timesteps = ( # 生成线性空间的时间步
jnp.linspace(0, last_timestep - 1, num_inference_steps + 1) # 生成从0到最后时间步的线性间隔
.round()[::-1][:-1] # 取反并去掉最后一个元素
.astype(jnp.int32) # 转换为整型
)
elif self.config.timestep_spacing == "leading": # 检查时间步间距配置是否为前导
step_ratio = last_timestep // (num_inference_steps + 1) # 计算步骤比率
# creates integer timesteps by multiplying by ratio # 通过乘以比率创建整数时间步
# casting to int to avoid issues when num_inference_step is power of 3 # 强制转换为整数以避免在 num_inference_step 为 3 的幂时的问题
timesteps = ( # 生成前导时间步
(jnp.arange(0, num_inference_steps + 1) * step_ratio) # 创建范围并乘以步骤比率
.round()[::-1][:-1] # 取反并去掉最后一个元素
.copy().astype(jnp.int32) # 复制并转换为整型
)
timesteps += self.config.steps_offset # 加上步骤偏移量
elif self.config.timestep_spacing == "trailing": # 检查时间步间距配置是否为后置
step_ratio = self.config.num_train_timesteps / num_inference_steps # 计算步骤比率
# creates integer timesteps by multiplying by ratio # 通过乘以比率创建整数时间步
# casting to int to avoid issues when num_inference_step is power of 3 # 强制转换为整数以避免在 num_inference_step 为 3 的幂时的问题
timesteps = jnp.arange(last_timestep, 0, -step_ratio) # 从最后时间步到0生成时间步
.round().copy().astype(jnp.int32) # 四舍五入、复制并转换为整型
timesteps -= 1 # 时间步减去1
else: # 如果没有匹配的时间步间距配置
raise ValueError( # 抛出值错误
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." # 提示用户选择有效的时间步间距
)
# initial running values # 初始化运行值
model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) # 创建模型输出数组,初始化为零
lower_order_nums = jnp.int32(0) # 初始化低阶数字为0
prev_timestep = jnp.int32(-1) # 初始化前一个时间步为-1
cur_sample = jnp.zeros(shape, dtype=self.dtype) # 创建当前样本数组,初始化为零
return state.replace( # 返回更新后的状态
num_inference_steps=num_inference_steps, # 更新推断步骤数量
timesteps=timesteps, # 更新时间步
model_outputs=model_outputs, # 更新模型输出
lower_order_nums=lower_order_nums, # 更新低阶数字
prev_timestep=prev_timestep, # 更新前一个时间步
cur_sample=cur_sample, # 更新当前样本
)
def convert_model_output( # 定义转换模型输出的函数
self, # 实例对象
state: DPMSolverMultistepSchedulerState, # 状态参数,类型为 DPMSolverMultistepSchedulerState
model_output: jnp.ndarray, # 模型输出参数,类型为 jnp.ndarray
timestep: int, # 当前时间步参数,类型为 int
sample: jnp.ndarray, # 样本参数,类型为 jnp.ndarray
def dpm_solver_first_order_update( # 定义一阶更新的扩散模型求解器函数
self, # 实例对象
state: DPMSolverMultistepSchedulerState, # 状态参数,类型为 DPMSolverMultistepSchedulerState
model_output: jnp.ndarray, # 模型输出参数,类型为 jnp.ndarray
timestep: int, # 当前时间步参数,类型为 int
prev_timestep: int, # 前一个时间步参数,类型为 int
sample: jnp.ndarray, # 样本参数,类型为 jnp.ndarray
# 函数返回一个一阶DPM求解器的步骤结果,等效于DDIM
) -> jnp.ndarray:
# 文档字符串,说明函数的用途及详细推导链接
"""
One step for the first-order DPM-Solver (equivalent to DDIM).
See https://arxiv.org/abs/2206.00927 for the detailed derivation.
Args:
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
Returns:
`jnp.ndarray`: the sample tensor at the previous timestep.
"""
# 将前一个时间步和当前时间步赋值给变量
t, s0 = prev_timestep, timestep
# 获取模型输出
m0 = model_output
# 获取当前和前一个时间步的lambda值
lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0]
# 获取当前和前一个时间步的alpha值
alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0]
# 获取当前和前一个时间步的sigma值
sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0]
# 计算h值,表示lambda_t与lambda_s的差异
h = lambda_t - lambda_s
# 根据配置的算法类型选择相应的计算公式
if self.config.algorithm_type == "dpmsolver++":
# 计算当前样本的更新值,使用dpmsolver++公式
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
elif self.config.algorithm_type == "dpmsolver":
# 计算当前样本的更新值,使用dpmsolver公式
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0
# 返回更新后的样本
return x_t
# 定义一个多步骤DPM求解器的二阶更新函数
def multistep_dpm_solver_second_order_update(
# 接受当前状态作为参数
self,
state: DPMSolverMultistepSchedulerState,
# 接受模型输出列表作为参数
model_output_list: jnp.ndarray,
# 接受时间步列表作为参数
timestep_list: List[int],
# 接受前一个时间步作为参数
prev_timestep: int,
# 接受当前样本作为参数
sample: jnp.ndarray,
# 返回上一个时间步的样本张量
) -> jnp.ndarray:
# DPM-Solver的二阶多步一步
# 参数说明:
# model_output_list:当前和后续时间步的扩散模型直接输出的列表
# timestep:当前和后续离散时间步
# prev_timestep:前一个离散时间步
# sample:当前扩散过程中的样本实例
# 返回值为上一个时间步的样本张量
"""
# 从前一个时间步获取时间步
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
# 从模型输出列表中获取当前和前一个时间步的输出
m0, m1 = model_output_list[-1], model_output_list[-2]
# 获取状态中当前和前两个时间步的lambda值
lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
# 获取状态中当前和前两个时间步的alpha值
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
# 获取状态中当前和前一个时间步的sigma值
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
# 计算h和h_0的差值
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
# 计算r0值
r0 = h_0 / h
# D0和D1的值
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
# 根据算法类型判断处理方式
if self.config.algorithm_type == "dpmsolver++":
# 参考详细推导文献
if self.config.solver_type == "midpoint":
# 使用中点法计算x_t
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
- 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
# 使用Heun法计算x_t
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config.algorithm_type == "dpmsolver":
# 参考详细推导文献
if self.config.solver_type == "midpoint":
# 使用中点法计算x_t
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
- 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
# 使用Heun法计算x_t
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
- (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
)
# 返回计算得到的x_t
return x_t
# 定义三阶更新的多步DPM求解器
def multistep_dpm_solver_third_order_update(
# 传入状态
state: DPMSolverMultistepSchedulerState,
# 传入模型输出列表
model_output_list: jnp.ndarray,
# 传入时间步列表
timestep_list: List[int],
# 传入前一个时间步
prev_timestep: int,
# 传入样本实例
sample: jnp.ndarray,
) -> jnp.ndarray: # 定义函数返回类型为 jnp.ndarray
""" # 开始文档字符串
One step for the third-order multistep DPM-Solver. # 描述该函数为三阶多步 DPM 求解器的一步
Args: # 开始参数说明
model_output_list (`List[jnp.ndarray]`): # 定义模型输出列表参数
direct outputs from learned diffusion model at current and latter timesteps. # 描述该参数为当前及后续时间步的扩散模型直接输出
timestep (`int`): # 定义当前时间步参数
current and latter discrete timestep in the diffusion chain. # 描述该参数为扩散链中当前及后续离散时间步
prev_timestep (`int`): # 定义前一个时间步参数
previous discrete timestep in the diffusion chain. # 描述该参数为扩散链中前一个离散时间步
sample (`jnp.ndarray`): # 定义样本参数
current instance of sample being created by diffusion process. # 描述该参数为当前通过扩散过程创建的样本实例
Returns: # 开始返回值说明
`jnp.ndarray`: the sample tensor at the previous timestep. # 描述返回值为前一个时间步的样本张量
""" # 结束文档字符串
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] # 获取当前和最近的四个时间步
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] # 获取最近三个模型输出
lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( # 从状态中提取对应时间步的 lambda 值
state.lambda_t[t], # 当前时间步的 lambda 值
state.lambda_t[s0], # 最近时间步的 lambda 值
state.lambda_t[s1], # 倒数第二个时间步的 lambda 值
state.lambda_t[s2], # 倒数第三个时间步的 lambda 值
) # 结束 lambda 值提取
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] # 提取当前和最近时间步的 alpha 值
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] # 提取当前和最近时间步的 sigma 值
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 # 计算 h 相关变量
r0, r1 = h_0 / h, h_1 / h # 计算 r0 和 r1
D0 = m0 # 将最近的模型输出赋值给 D0
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) # 计算 D1_0 和 D1_1
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) # 计算 D1
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) # 计算 D2
if self.config.algorithm_type == "dpmsolver++": # 检查算法类型是否为 "dpmsolver++"
# See https://arxiv.org/abs/2206.00927 for detailed derivations # 引用文献以获取详细推导
x_t = ( # 计算 x_t
(sigma_t / sigma_s0) * sample # 计算与 sigma_t 相关的项
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0 # 计算与 D0 相关的项
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 # 计算与 D1 相关的项
- (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 # 计算与 D2 相关的项
) # 结束 x_t 计算
elif self.config.algorithm_type == "dpmsolver": # 检查算法类型是否为 "dpmsolver"
# See https://arxiv.org/abs/2206.00927 for detailed derivations # 引用文献以获取详细推导
x_t = ( # 计算 x_t
(alpha_t / alpha_s0) * sample # 计算与 alpha_t 相关的项
- (sigma_t * (jnp.exp(h) - 1.0)) * D0 # 计算与 D0 相关的项
- (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 # 计算与 D1 相关的项
- (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 # 计算与 D2 相关的项
) # 结束 x_t 计算
return x_t # 返回计算出的 x_t
def step( # 定义 step 函数
self, # 传入自身引用
state: DPMSolverMultistepSchedulerState, # 定义状态参数
model_output: jnp.ndarray, # 定义模型输出参数
timestep: int, # 定义时间步参数
sample: jnp.ndarray, # 定义样本参数
return_dict: bool = True, # 定义是否返回字典的参数,默认为 True
def scale_model_input( # 定义 scale_model_input 函数
self, # 传入自身引用
state: DPMSolverMultistepSchedulerState, # 定义状态参数
sample: jnp.ndarray, # 定义样本参数
timestep: Optional[int] = None # 定义时间步参数,可选,默认为 None
) -> jnp.ndarray: # 指定函数返回类型为 jnp.ndarray
""" # 文档字符串,描述函数的作用及参数
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. # 确保与需要根据当前时间步缩放去噪模型输入的调度器的可互换性
Args: # 参数说明部分
state (`DPMSolverMultistepSchedulerState`): # state 参数,类型为 DPMSolverMultistepSchedulerState
the `FlaxDPMSolverMultistepScheduler` state data class instance. # FlaxDPMSolverMultistepScheduler 的状态数据类实例
sample (`jnp.ndarray`): input sample # sample 参数,类型为 jnp.ndarray,表示输入样本
timestep (`int`, optional): current timestep # timestep 参数,类型为 int,可选,表示当前时间步
Returns: # 返回值说明部分
`jnp.ndarray`: scaled input sample # 返回一个 jnp.ndarray,表示缩放后的输入样本
"""
return sample # 返回输入样本,不做任何修改
def add_noise( # 定义 add_noise 函数
self, # 对象自身引用
state: DPMSolverMultistepSchedulerState, # state 参数,类型为 DPMSolverMultistepSchedulerState
original_samples: jnp.ndarray, # original_samples 参数,类型为 jnp.ndarray,表示原始样本
noise: jnp.ndarray, # noise 参数,类型为 jnp.ndarray,表示要添加的噪声
timesteps: jnp.ndarray, # timesteps 参数,类型为 jnp.ndarray,表示时间步
) -> jnp.ndarray: # 指定函数返回类型为 jnp.ndarray
return add_noise_common(state.common, original_samples, noise, timesteps) # 调用 add_noise_common 函数,返回添加噪声后的样本
def __len__(self): # 定义获取对象长度的方法
return self.config.num_train_timesteps # 返回配置中定义的训练时间步数
# 版权声明,说明文件归属及授权信息
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 2.0 许可证进行授权;用户必须遵循许可证使用本文件。
# 可以在以下网址获取许可证的副本:
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,软件按“原样”提供,不附带任何形式的保证或条件。
# 请查看许可证以了解特定的权限和限制。
# 本文件受到 https://github.com/LuChengTHU/dpm-solver 的强烈影响
# 导入数学库
import math
# 导入类型提示
from typing import List, Optional, Tuple, Union
# 导入 numpy 库
import numpy as np
# 导入 torch 库
import torch
# 从配置工具导入混合配置类和注册功能
from ..configuration_utils import ConfigMixin, register_to_config
# 导入弃用工具
from ..utils import deprecate
# 导入随机张量工具
from ..utils.torch_utils import randn_tensor
# 导入调度工具
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# 从 diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 复制的函数
def betas_for_alpha_bar(
num_diffusion_timesteps, # 设置扩散时间步数
max_beta=0.999, # 设置最大 beta 值
alpha_transform_type="cosine", # 设置 alpha 转换类型
):
"""
创建一个 beta 调度,离散化给定的 alpha_t_bar 函数,该函数定义了随时间变化的 (1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接受参数 t 并将其转换为扩散过程的累积乘积。
参数:
num_diffusion_timesteps (`int`): 生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用低于 1 的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认为 `cosine`): alpha_bar 的噪声调度类型。
可选值为 `cosine` 或 `exp`
返回:
betas (`np.ndarray`): 调度器用来更新模型输出的 betas。
"""
# 根据 alpha_transform_type 的类型定义 alpha_bar 函数
if alpha_transform_type == "cosine":
# 定义余弦类型的 alpha_bar 函数
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
# 定义指数类型的 alpha_bar 函数
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
# 如果 alpha_transform_type 不被支持,抛出异常
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = [] # 初始化空列表用于存储 beta 值
# 遍历每一个扩散时间步
for i in range(num_diffusion_timesteps):
# 计算当前时间步的 t1 和 t2
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
# 计算 beta 值并添加到列表中,确保不超过 max_beta
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
# 返回一个张量形式的 beta 值
return torch.tensor(betas, dtype=torch.float32)
# DPMSolverMultistepInverseScheduler 类定义,继承调度混合类和配置混合类
class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"""
`DPMSolverMultistepInverseScheduler` 是 [`DPMSolverMultistepScheduler`] 的反向调度器。
该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。有关通用的信息,请查看父类文档。
# 文档字符串,描述库为所有调度程序实现的方法,例如加载和保存功能。
methods the library implements for all schedulers such as loading and saving.
"""
# 定义兼容的调度器名称列表,来源于 KarrasDiffusionSchedulers 枚举
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
# 定义默认的顺序参数
order = 1
@register_to_config
# 初始化方法,接受多个超参数
def __init__(
# 训练时间步数,默认值为 1000
num_train_timesteps: int = 1000,
# β 值的起始值,默认值为 0.0001
beta_start: float = 0.0001,
# β 值的结束值,默认值为 0.02
beta_end: float = 0.02,
# β 值的调度方式,默认值为 "linear"
beta_schedule: str = "linear",
# 训练后的 β 值,默认为 None,可以是数组或列表
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
# 求解器的顺序,默认值为 2
solver_order: int = 2,
# 预测类型,默认值为 "epsilon"
prediction_type: str = "epsilon",
# 是否使用阈值处理,默认值为 False
thresholding: bool = False,
# 动态阈值处理的比例,默认值为 0.995
dynamic_thresholding_ratio: float = 0.995,
# 采样的最大值,默认值为 1.0
sample_max_value: float = 1.0,
# 算法类型,默认值为 "dpmsolver++"
algorithm_type: str = "dpmsolver++",
# 求解器类型,默认值为 "midpoint"
solver_type: str = "midpoint",
# 最后阶数是否较低,默认值为 True
lower_order_final: bool = True,
# 最后一步是否使用欧拉法,默认值为 False
euler_at_final: bool = False,
# 是否使用 Karras 的 sigma 值,默认值为 None
use_karras_sigmas: Optional[bool] = False,
# λ 最小值裁剪,默认值为负无穷
lambda_min_clipped: float = -float("inf"),
# 方差类型,默认值为 None
variance_type: Optional[str] = None,
# 时间步间距类型,默认值为 "linspace"
timestep_spacing: str = "linspace",
# 步骤偏移量,默认值为 0
steps_offset: int = 0,
):
# 检查算法类型是否为已弃用类型
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
# 构建弃用信息消息
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
# 调用弃用函数,传递相关信息
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
# 如果训练的beta值不为None,则初始化self.betas
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
# 如果beta_schedule为"linear",则生成线性beta值
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
# 如果beta_schedule为"scaled_linear",生成特定的beta值
elif beta_schedule == "scaled_linear":
# 此调度特定于潜在扩散模型
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
# 如果beta_schedule为"squaredcos_cap_v2",生成Glide余弦调度的beta值
elif beta_schedule == "squaredcos_cap_v2":
# Glide余弦调度
self.betas = betas_for_alpha_bar(num_train_timesteps)
# 如果beta_schedule不在已实现的调度中,则抛出未实现异常
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
# 计算alphas值
self.alphas = 1.0 - self.betas
# 计算alphas的累积乘积
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# 目前只支持VP类型噪声调度
self.alpha_t = torch.sqrt(self.alphas_cumprod)
# 计算sigma_t值
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
# 计算lambda_t值
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
# 计算sigmas值
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# 设置初始噪声分布的标准差
self.init_noise_sigma = 1.0
# DPM-Solver的设置
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
# 如果算法类型为"deis",则注册为"dpmsolver++"
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
# 否则抛出未实现异常
else:
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
# 检查solver_type是否合法
if solver_type not in ["midpoint", "heun"]:
# 如果solver_type在特定类型中,则注册为"midpoint"
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
# 否则抛出未实现异常
else:
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
# 可设置的值
self.num_inference_steps = None
# 创建时间步长的线性数组
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy()
# 将numpy数组转换为torch张量
self.timesteps = torch.from_numpy(timesteps)
# 初始化模型输出列表
self.model_outputs = [None] * solver_order
# 初始化低阶数字
self.lower_order_nums = 0
# 初始化步骤索引
self._step_index = None
# 将sigmas转移到CPU,减少CPU/GPU通信
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# 设置是否使用Karras sigmas
self.use_karras_sigmas = use_karras_sigmas
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后增加1。
"""
return self._step_index
# 从 diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample 拷贝而来
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"动态阈值处理:在每个采样步骤中,我们将 s 设置为 xt0(在时间步 t 预测的 x_0)中的某个百分位绝对像素值,
如果 s > 1,则我们将 xt0 阈值处理到范围 [-s, s],然后除以 s。动态阈值处理将饱和像素(接近 -1 和 1 的像素)
向内推,以主动防止每一步的饱和。我们发现动态阈值处理显著提高了照片真实感以及图像-文本对齐,特别是在使用非常大的引导权重时。"
https://arxiv.org/abs/2205.11487
"""
# 获取输入样本的数据类型
dtype = sample.dtype
# 解包输入样本的形状信息,得到批量大小、通道数及其他维度
batch_size, channels, *remaining_dims = sample.shape
# 如果数据类型不是 float32 或 float64,则将样本转换为 float
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # 为了计算分位数而上升精度,并且 CPU 半精度的 clamp 未实现
# 将样本扁平化,以便在每幅图像上进行分位数计算
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
# 计算样本的绝对值,以获得“某个百分位绝对像素值”
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
# 在每个图像的维度上计算绝对样本的分位数
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
# 限制 s 的范围,最小为 1,最大为配置中的 sample_max_value
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # 当最小值限制为 1 时,相当于标准剪切到 [-1, 1]
# 增加维度以便后续广播处理
s = s.unsqueeze(1) # (batch_size, 1) 因为 clamp 会在维度 0 上广播
# 将样本限制在范围 [-s, s] 内,并除以 s
sample = torch.clamp(sample, -s, s) / s # "我们将 xt0 阈值处理到范围 [-s, s] 并除以 s"
# 恢复样本的原始形状
sample = sample.reshape(batch_size, channels, *remaining_dims)
# 将样本转换回原始数据类型
sample = sample.to(dtype)
# 返回处理后的样本
return sample
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t 拷贝而来
def _sigma_to_t(self, sigma, log_sigmas):
# 获取 log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
# 计算分布
dists = log_sigma - log_sigmas[:, np.newaxis]
# 获取 sigma 的范围
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
# 获取 low 和 high 的值
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# 对 sigma 进行插值
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# 将插值转换为时间范围
t = (1 - w) * low_idx + w * high_idx
# 将 t 形状恢复为 sigma 的形状
t = t.reshape(sigma.shape)
# 返回计算得到的 t
return t
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t 拷贝而来
def _sigma_to_alpha_sigma_t(self, sigma):
# 根据 sigma 计算 alpha_t
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
# 计算 sigma_t
sigma_t = sigma * alpha_t
# 返回 alpha_t 和 sigma_t
return alpha_t, sigma_t
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras 复制的代码
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构建 Karras 等人(2022年)的噪声调度。"""
# 确保其他复制此函数的调度器不会出错的临时处理
# TODO: 将此逻辑添加到其他调度器
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min # 获取配置中的最小 sigma 值
else:
sigma_min = None # 如果没有定义,则设置为 None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max # 获取配置中的最大 sigma 值
else:
sigma_max = None # 如果没有定义,则设置为 None
# 如果 sigma_min 为空,则使用输入 sigma 的最后一个值
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
# 如果 sigma_max 为空,则使用输入 sigma 的第一个值
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 论文中使用的 rho 值为 7.0
ramp = np.linspace(0, 1, num_inference_steps) # 创建线性 ramp 从 0 到 1
min_inv_rho = sigma_min ** (1 / rho) # 计算最小 sigma 的倒数
max_inv_rho = sigma_max ** (1 / rho) # 计算最大 sigma 的倒数
# 根据线性 ramp 和倒数值计算 sigmas
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas # 返回计算得到的 sigmas
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output 复制的代码
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
):
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update 复制的代码
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
):
# 返回前一时间步的样本张量
) -> torch.Tensor:
"""
对第一阶 DPMSolver 执行一步(相当于 DDIM)。
参数:
model_output (`torch.Tensor`):
从学习的扩散模型直接输出的张量。
sample (`torch.Tensor`):
扩散过程中生成的当前样本实例。
返回:
`torch.Tensor`:
前一时间步的样本张量。
"""
# 从位置参数或关键字参数获取当前时间步
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
# 从位置参数或关键字参数获取前一个时间步
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
# 检查样本是否为 None
if sample is None:
# 如果存在第三个位置参数,则将其赋值给 sample
if len(args) > 2:
sample = args[2]
# 如果没有样本,则引发错误
else:
raise ValueError(" missing `sample` as a required keyward argument")
# 如果当前时间步不为 None
if timestep is not None:
# 警告用户时间步已弃用
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 如果前一个时间步不为 None
if prev_timestep is not None:
# 警告用户前一个时间步已弃用
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 获取当前和前一个时间步的 sigma 值
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
# 将 sigma 转换为 alpha 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
# 将前一个 sigma 转换为 alpha 和 sigma_s
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
# 计算 lambda_t 和 lambda_s
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
# 计算 h 值
h = lambda_t - lambda_s
# 根据算法类型计算 x_t
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None # 确保噪声不为 None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None # 确保噪声不为 None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
# 返回计算得到的 x_t
return x_t
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update 复制
# 定义一个多步 DPM 求解器的二阶更新方法
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor], # 模型输出的张量列表
*args, # 可变位置参数
sample: torch.Tensor = None, # 输入样本,默认为 None
noise: Optional[torch.Tensor] = None, # 噪声张量,默认为 None
**kwargs, # 可变关键字参数
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler 复制的三阶更新方法
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor], # 模型输出的张量列表
*args, # 可变位置参数
sample: torch.Tensor = None, # 输入样本,默认为 None
**kwargs, # 可变关键字参数
# 定义初始化步骤索引的方法
def _init_step_index(self, timestep):
# 检查 timestep 是否为张量,如果是,则移动到当前时刻的设备
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# 查找与当前 timestep 匹配的时间索引候选
index_candidates = (self.timesteps == timestep).nonzero()
# 如果没有找到匹配的索引,则使用最后一个时间索引
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# 如果找到多个匹配的索引,选择第二个索引(确保不会跳过 sigma)
# 这样可以在去噪调度中间开始时,保证不会意外跳过 sigma
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
# 如果只找到一个匹配的索引,则使用该索引
else:
step_index = index_candidates[0].item()
# 将计算出的步骤索引保存到实例变量中
self._step_index = step_index
# 定义步骤方法
def step(
self,
model_output: torch.Tensor, # 模型输出的张量
timestep: Union[int, torch.Tensor], # 当前的时间步,整数或张量
sample: torch.Tensor, # 输入样本
generator=None, # 随机数生成器,默认为 None
variance_noise: Optional[torch.Tensor] = None, # 可选的方差噪声张量
return_dict: bool = True, # 是否返回字典,默认为 True
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler 复制的模型输入缩放方法
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器互换性。
参数:
sample (`torch.Tensor`):
输入样本。
返回:
`torch.Tensor`:
一个缩放后的输入样本。
"""
# 返回原样本,未进行缩放
return sample
# 定义添加噪声的方法
def add_noise(
self,
original_samples: torch.Tensor, # 原始样本张量
noise: torch.Tensor, # 噪声张量
timesteps: torch.IntTensor, # 时间步张量
) -> torch.Tensor: # 定义返回类型为 torch.Tensor 的函数,函数的开始部分
# 确保 sigmas 和 timesteps 的设备和数据类型与 original_samples 相同
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
# 如果设备类型为 "mps" 并且 timesteps 为浮点数
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps 不支持 float64 类型
# 将时间步长转换为与 original_samples 相同设备和 float32 类型
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
# 将时间步长转换为与 original_samples 相同设备
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [] # 初始化一个空列表,用于存储步骤索引
for timestep in timesteps: # 遍历每个时间步
# 找到与当前时间步相等的调度时间步的索引候选
index_candidates = (schedule_timesteps == timestep).nonzero()
# 如果没有找到索引候选
if len(index_candidates) == 0:
# 使用最后一个调度时间步的索引
step_index = len(schedule_timesteps) - 1
# 如果找到多个索引候选
elif len(index_candidates) > 1:
# 取第二个候选的索引作为步骤索引
step_index = index_candidates[1].item()
else:
# 只有一个候选,取它的索引
step_index = index_candidates[0].item()
# 将步骤索引添加到列表中
step_indices.append(step_index)
# 根据步骤索引获取对应的 sigma 值,并将其展平
sigma = sigmas[step_indices].flatten()
# 如果 sigma 的维度小于 original_samples 的维度,则在最后增加一个维度
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
# 调用 _sigma_to_alpha_sigma_t 方法计算 alpha_t 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
# 生成带噪声的样本
noisy_samples = alpha_t * original_samples + sigma_t * noise
# 返回带噪声的样本
return noisy_samples
def __len__(self): # 定义 __len__ 方法
# 返回训练时间步的数量
return self.config.num_train_timesteps
# 版权信息,声明版权所有者及许可证信息
# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# 按照 Apache 许可证第 2.0 版授权
# 该文件只能在遵循许可证的情况下使用
# 可以在以下网址获取许可证副本
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用的法律另有规定或书面达成协议,否则按 "现状" 分发软件
# 不提供任何形式的担保或条件
# 查看许可证以了解有关权限和限制的具体内容
# 导入数学模块
import math
# 从 typing 导入多个类型
from typing import List, Optional, Tuple, Union
# 导入 numpy 库
import numpy as np
# 导入 torch 库
import torch
# 导入 torchsde 库
import torchsde
# 从配置工具模块导入类和方法
from ..configuration_utils import ConfigMixin, register_to_config
# 从调度工具模块导入类
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class BatchedBrownianTree:
"""封装 torchsde.BrownianTree 以支持批量熵的类。"""
def __init__(self, x, t0, t1, seed=None, **kwargs):
# 对 t0 和 t1 进行排序,并获取符号信息
t0, t1, self.sign = self.sort(t0, t1)
# 从关键字参数中获取 w0,默认初始化为与 x 形状相同的零张量
w0 = kwargs.get("w0", torch.zeros_like(x))
# 如果未提供种子,随机生成一个种子
if seed is None:
seed = torch.randint(0, 2**63 - 1, []).item()
# 设置批量处理标志
self.batched = True
try:
# 确保种子的长度与 x 的第一个维度匹配
assert len(seed) == x.shape[0]
w0 = w0[0] # 取第一个 w0
except TypeError:
# 如果种子类型不匹配,将其转为列表,并设置为单个种子
seed = [seed]
self.batched = False
# 根据种子创建一组 BrownianTree 实例
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
@staticmethod
def sort(a, b):
# 返回排序后的值及其符号
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
# 对 t0 和 t1 进行排序
t0, t1, sign = self.sort(t0, t1)
# 调用每棵树并将结果堆叠起来,考虑符号
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
# 如果不是批量处理,返回单一结果
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""基于 torchsde.BrownianTree 的噪声采样器。
参数:
x (Tensor): 用于生成随机样本的张量,其形状、设备和数据类型将被使用。
sigma_min (float): 有效区间的下限。
sigma_max (float): 有效区间的上限。
seed (int 或 List[int]): 随机种子。如果提供了种子列表而不是单个整数,
则噪声采样器将为每个批量项目使用一个 BrownianTree,每个都有自己的种子。
transform (callable): 一个函数,将 sigma 映射到采样器的内部时间步。
"""
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
# 保存变换函数
self.transform = transform
# 变换 sigma_min 和 sigma_max,获得 t0 和 t1
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
# 创建 BatchedBrownianTree 实例
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
# 变换 sigma 和 sigma_next,获得 t0 和 t1
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
# 返回计算的噪声,进行归一化
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
# 从 diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 复制的代码
def betas_for_alpha_bar(
num_diffusion_timesteps, # 生成的 beta 数量
max_beta=0.999, # 使用的最大 beta 值,值低于 1 可防止奇点
alpha_transform_type="cosine", # alpha_bar 的噪声调度类型,默认为 cosine
):
"""
创建一个 beta 调度,离散化给定的 alpha_t_bar 函数,该函数定义了时间 t = [0,1] 的
(1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接收 t 参数并将其转换为 (1-beta) 的累积乘积
直至扩散过程的该部分。
Args:
num_diffusion_timesteps (`int`): 生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用低于 1 的值来防止奇点。
alpha_transform_type (`str`, *可选*, 默认为 `cosine`): alpha_bar 的噪声调度类型。
从 `cosine` 或 `exp` 中选择
Returns:
betas (`np.ndarray`): 调度程序用于步骤模型输出的 betas
"""
# 检查 alpha_transform_type 是否为 cosine
if alpha_transform_type == "cosine":
# 定义 alpha_bar_fn 函数,使用 cosine 计算 alpha_bar
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
# 检查 alpha_transform_type 是否为 exp
elif alpha_transform_type == "exp":
# 定义 alpha_bar_fn 函数,使用指数计算 alpha_bar
def alpha_bar_fn(t):
return math.exp(t * -12.0)
# 抛出不支持的 alpha_transform_type 错误
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = [] # 初始化 beta 列表
# 遍历每个扩散时间步
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps # 当前时间步
t2 = (i + 1) / num_diffusion_timesteps # 下一个时间步
# 计算 beta 值并添加到 betas 列表,确保不超过 max_beta
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
# 返回 beta 的张量表示
return torch.tensor(betas, dtype=torch.float32)
class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
"""
DPMSolverSDEScheduler 实现了 [Elucidating the Design Space of Diffusion-Based
Generative Models](https://huggingface.co/papers/2206.00364) 论文中的随机采样器。
此模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。请查看超类文档以获取库为所有调度器实现的通用
方法,例如加载和保存。
# 定义初始化方法的参数及其默认值
Args:
num_train_timesteps (`int`, defaults to 1000): # 训练模型的扩散步骤数量,默认为1000
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.00085): # 推理的起始 beta 值,默认为0.00085
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.012): # 推理的最终 beta 值,默认为0.012
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`): # beta 计划,定义 beta 范围到模型步骤的映射,默认为线性
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*): # 可选参数,直接传递 beta 数组以跳过 beta_start 和 beta_end
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
prediction_type (`str`, defaults to `epsilon`, *optional*): # 调度函数的预测类型,默认为预测扩散过程的噪声
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
use_karras_sigmas (`bool`, *optional*, defaults to `False`): # 是否在采样过程中使用 Karras sigmas 来调整噪声调度的步长,默认为 False
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
noise_sampler_seed (`int`, *optional*, defaults to `None`): # 噪声采样器使用的随机种子,默认为 None 时生成随机种子
The random seed to use for the noise sampler. If `None`, a random seed is generated.
timestep_spacing (`str`, defaults to `"linspace"`): # 定义时间步的缩放方式,默认为线性间隔
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): # 推理步骤的偏移量,某些模型家族所需,默认为0
An offset added to the inference steps, as required by some model families.
"""
# 创建一个兼容的调度器列表,包含 KarrasDiffusionSchedulers 中的所有调度器名称
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
# 设置调度器的顺序,默认为2
order = 2
# 注册到配置中的初始化函数
@register_to_config
def __init__(
# 初始化函数参数及其默认值
num_train_timesteps: int = 1000, # 训练模型的扩散步骤数量,默认为1000
beta_start: float = 0.00085, # 推理的起始 beta 值,默认为0.00085
beta_end: float = 0.012, # 推理的最终 beta 值,默认为0.012
beta_schedule: str = "linear", # beta 计划,默认为线性
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, # 可选参数,跳过 beta_start 和 beta_end
prediction_type: str = "epsilon", # 默认预测类型为噪声
use_karras_sigmas: Optional[bool] = False, # 是否使用 Karras sigmas,默认为 False
noise_sampler_seed: Optional[int] = None, # 噪声采样器的随机种子,默认为 None
timestep_spacing: str = "linspace", # 时间步的缩放方式,默认为线性间隔
steps_offset: int = 0, # 推理步骤的偏移量,默认为0
):
# 检查训练的 beta 是否为 None
if trained_betas is not None:
# 将训练的 beta 转换为张量,数据类型为 float32
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
# 如果 beta_schedule 为线性
elif beta_schedule == "linear":
# 生成从 beta_start 到 beta_end 的线性间隔值
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
# 如果 beta_schedule 为 scaled_linear
elif beta_schedule == "scaled_linear":
# 此调度非常特定于潜在扩散模型
# 生成 beta_start 和 beta_end 的平方根线性间隔值,并平方
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
# 如果 beta_schedule 为 squaredcos_cap_v2
elif beta_schedule == "squaredcos_cap_v2":
# Glide 余弦调度
self.betas = betas_for_alpha_bar(num_train_timesteps)
# 如果 beta_schedule 不在以上选项中
else:
# 抛出未实现的错误
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
# 计算 alphas,等于 1 减去 betas
self.alphas = 1.0 - self.betas
# 计算 alphas 的累积乘积
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# 设置所有值
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
# 记录是否使用 Karras sigma
self.use_karras_sigmas = use_karras_sigmas
# 初始化噪声采样器为 None
self.noise_sampler = None
# 记录噪声采样器种子
self.noise_sampler_seed = noise_sampler_seed
# 初始化步索引为 None
self._step_index = None
# 初始化开始索引为 None
self._begin_index = None
# 将 sigmas 移动到 CPU,以避免过多的 CPU/GPU 通信
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep 复制
def index_for_timestep(self, timestep, schedule_timesteps=None):
# 如果没有提供 schedule_timesteps,则使用当前的 timesteps
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# 查找与 timestep 相等的索引
indices = (schedule_timesteps == timestep).nonzero()
# 确保在调度开始时不会意外跳过 sigma
# 如果 indices 长度大于 1,则 pos 设为 1,否则设为 0
pos = 1 if len(indices) > 1 else 0
# 返回相应位置的索引
return indices[pos].item()
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index 复制
def _init_step_index(self, timestep):
# 如果 begin_index 为 None
if self.begin_index is None:
# 如果 timestep 是张量,移动到 timesteps 的设备上
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# 通过索引方法获取当前步索引
self._step_index = self.index_for_timestep(timestep)
else:
# 将步索引设为开始索引
self._step_index = self._begin_index
@property
def init_noise_sigma(self):
# 返回初始噪声分布的标准差
if self.config.timestep_spacing in ["linspace", "trailing"]:
# 返回 sigmas 的最大值
return self.sigmas.max()
# 计算并返回噪声的标准差
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度步骤后增加 1。
"""
# 返回当前步索引
return self._step_index
@property
# 定义获取初始时间步的函数
def begin_index(self):
"""
返回初始时间步索引,应该通过 `set_begin_index` 方法设置。
"""
# 返回当前的初始时间步索引
return self._begin_index
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler 复制而来,设置初始时间步的函数
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的初始时间步。此函数应在推断前从管道运行。
Args:
begin_index (`int`):
调度器的初始时间步。
"""
# 将传入的初始时间步索引保存到类属性
self._begin_index = begin_index
# 定义缩放模型输入的函数
def scale_model_input(
self,
sample: torch.Tensor,
timestep: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器的可互换性。
Args:
sample (`torch.Tensor`):
输入样本。
timestep (`int`, *optional*):
扩散链中的当前时间步。
Returns:
`torch.Tensor`:
缩放后的输入样本。
"""
# 如果步索引为空,则初始化步索引
if self.step_index is None:
self._init_step_index(timestep)
# 获取当前步的 sigma 值
sigma = self.sigmas[self.step_index]
# 根据状态选择 sigma 值
sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma
# 将样本缩放
sample = sample / ((sigma_input**2 + 1) ** 0.5)
# 返回缩放后的样本
return sample
# 定义设置时间步的函数
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
def _second_order_timesteps(self, sigmas, log_sigmas):
# 定义 sigma 的函数
def sigma_fn(_t):
return np.exp(-_t)
# 定义时间的函数
def t_fn(_sigma):
return -np.log(_sigma)
# 设置中点比例
midpoint_ratio = 0.5
# 获取时间步
t = t_fn(sigmas)
# 计算时间间隔
delta_time = np.diff(t)
# 提出新时间步
t_proposed = t[:-1] + delta_time * midpoint_ratio
# 计算提出的 sigma 值
sig_proposed = sigma_fn(t_proposed)
# 将 sigma 转换为时间步
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed])
# 返回时间步数组
return timesteps
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler 复制而来,将 sigma 转换为时间的函数
def _sigma_to_t(self, sigma, log_sigmas):
# 获取 sigma 的对数
log_sigma = np.log(np.maximum(sigma, 1e-10))
# 计算分布
dists = log_sigma - log_sigmas[:, np.newaxis]
# 获取 sigma 范围的索引
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
# 获取低高边界的对数 sigma 值
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# 对 sigma 进行插值
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# 将插值转换为时间范围
t = (1 - w) * low_idx + w * high_idx
# 调整形状以匹配输入 sigma 的形状
t = t.reshape(sigma.shape)
# 返回时间步
return t
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras 复制而来
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""构建 Karras 等人(2022)提出的噪声调度。"""
# 获取输入 sigmas 的最小值
sigma_min: float = in_sigmas[-1].item()
# 获取输入 sigmas 的最大值
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 论文中使用的值
# 创建从 0 到 1 的线性 ramp
ramp = np.linspace(0, 1, self.num_inference_steps)
# 计算最小和最大倒数 sigma
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
# 计算最终的 sigmas 值
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# 返回计算得到的 sigmas
return sigmas
@property
def state_in_first_order(self):
# 判断当前 sample 是否为 None,以确定状态
return self.sample is None
def step(
self,
model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
s_noise: float = 1.0,
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise 复制而来
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# 确保 sigmas 和 timesteps 与 original_samples 具有相同的设备和数据类型
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps 不支持 float64 类型
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
# 将调度时间步转换为原始样本的设备
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# 当 scheduler 用于训练时,self.begin_index 为 None,或管道未实现 set_begin_index
if self.begin_index is None:
# 根据时间步计算步骤索引
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# 在第一个去噪步骤之后调用 add_noise(用于修复)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# 在第一个去噪步骤之前调用 add_noise,以创建初始潜在图像(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
# 根据步骤索引提取 sigma,并展平
sigma = sigmas[step_indices].flatten()
# 如果 sigma 的形状小于原始样本,则在最后添加一个维度
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
# 计算添加噪声后的样本
noisy_samples = original_samples + noise * sigma
# 返回添加噪声后的样本
return noisy_samples
def __len__(self):
# 返回训练时间步的数量
return self.config.num_train_timesteps