diffusers 源码解析(六十五)
# 该文件由命令 `make fix-copies` 自动生成,请勿编辑。
from ..utils import DummyObject, requires_backends # 从上级目录的 utils 模块导入 DummyObject 和 requires_backends 函数
class SpectrogramDiffusionPipeline(metaclass=DummyObject): # 定义 SpectrogramDiffusionPipeline 类,使用 DummyObject 作为其元类
_backends = ["transformers", "torch", "note_seq"] # 定义类属性 _backends,包含支持的后端列表
def __init__(self, *args, **kwargs): # 初始化方法,接受可变参数
requires_backends(self, ["transformers", "torch", "note_seq"]) # 检查当前实例是否支持指定的后端
@classmethod # 定义类方法
def from_config(cls, *args, **kwargs): # 接受可变参数
requires_backends(cls, ["transformers", "torch", "note_seq"]) # 检查类是否支持指定的后端
@classmethod # 定义类方法
def from_pretrained(cls, *args, **kwargs): # 接受可变参数
requires_backends(cls, ["transformers", "torch", "note_seq"]) # 检查类是否支持指定的后端
.\diffusers\utils\dynamic_modules_utils.py
# coding=utf-8 # 指定文件编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team. # 版权声明,指明版权所有者
#
# Licensed under the Apache License, Version 2.0 (the "License"); # 指明许可证类型
# you may not use this file except in compliance with the License. # 说明使用条件
# You may obtain a copy of the License at # 提供获取许可证的链接
#
# http://www.apache.org/licenses/LICENSE-2.0 # 许可证的具体链接
#
# Unless required by applicable law or agreed to in writing, software # 说明免责条款
# distributed under the License is distributed on an "AS IS" BASIS, # 指出软件按原样分发
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # 声明无任何保证或条件
# See the License for the specific language governing permissions and # 提示查看许可证以获取权限说明
# limitations under the License. # 说明许可证的限制
"""Utilities to dynamically load objects from the Hub.""" # 模块描述,说明功能
import importlib # 导入模块以动态导入其他模块
import inspect # 导入模块以检查对象的类型和属性
import json # 导入模块以处理 JSON 数据
import os # 导入模块以进行操作系统相关的操作
import re # 导入模块以进行正则表达式匹配
import shutil # 导入模块以进行文件和目录操作
import sys # 导入模块以访问解释器使用的变量和函数
from pathlib import Path # 从路径模块导入 Path 类以处理文件路径
from typing import Dict, Optional, Union # 从 typing 模块导入类型注解
from urllib import request # 导入模块以进行 URL 请求
from huggingface_hub import hf_hub_download, model_info # 从 huggingface_hub 导入特定功能
from huggingface_hub.utils import RevisionNotFoundError, validate_hf_hub_args # 导入错误和验证函数
from packaging import version # 导入版本管理模块
from .. import __version__ # 从上级模块导入当前版本
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging # 导入当前包中的常量和日志模块
logger = logging.get_logger(__name__) # pylint: disable=invalid-name # 初始化日志记录器
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror # 提供社区管道镜像的链接
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" # 定义社区管道镜像的 ID
def get_diffusers_versions(): # 定义获取 diffusers 版本的函数
url = "https://pypi.org/pypi/diffusers/json" # 设置获取版本信息的 URL
releases = json.loads(request.urlopen(url).read())["releases"].keys() # 请求 URL 并解析 JSON,获取版本键
return sorted(releases, key=lambda x: version.Version(x)) # 返回按版本排序的版本列表
def init_hf_modules(): # 定义初始化 HF 模块的函数
"""
Creates the cache directory for modules with an init, and adds it to the Python path.
""" # 函数说明,创建缓存目录并添加到 Python 路径
# This function has already been executed if HF_MODULES_CACHE already is in the Python path. # 如果缓存目录已在路径中,直接返回
if HF_MODULES_CACHE in sys.path: # 检查缓存目录是否在 Python 路径中
return # 如果在,则退出函数
sys.path.append(HF_MODULES_CACHE) # 将缓存目录添加到 Python 路径
os.makedirs(HF_MODULES_CACHE, exist_ok=True) # 创建缓存目录,如果已存在则不报错
init_path = Path(HF_MODULES_CACHE) / "__init__.py" # 定义缓存目录中初始化文件的路径
if not init_path.exists(): # 检查初始化文件是否存在
init_path.touch() # 如果不存在,则创建初始化文件
def create_dynamic_module(name: Union[str, os.PathLike]): # 定义创建动态模块的函数,接受字符串或路径对象
"""
Creates a dynamic module in the cache directory for modules.
""" # 函数说明,创建动态模块
init_hf_modules() # 调用初始化函数以确保缓存目录存在
dynamic_module_path = Path(HF_MODULES_CACHE) / name # 定义动态模块的路径
# If the parent module does not exist yet, recursively create it. # 如果父模块不存在,则递归创建
if not dynamic_module_path.parent.exists(): # 检查父模块路径是否存在
create_dynamic_module(dynamic_module_path.parent) # 如果不存在,则递归调用创建函数
os.makedirs(dynamic_module_path, exist_ok=True) # 创建动态模块目录,如果已存在则不报错
init_path = dynamic_module_path / "__init__.py" # 定义动态模块中初始化文件的路径
if not init_path.exists(): # 检查初始化文件是否存在
init_path.touch() # 如果不存在,则创建初始化文件
def get_relative_imports(module_file): # 定义获取相对导入的函数,接受模块文件路径
"""
Get the list of modules that are relatively imported in a module file.
Args:
module_file (`str` or `os.PathLike`): The module file to inspect. # 函数说明,接受字符串或路径对象作为参数
"""
with open(module_file, "r", encoding="utf-8") as f: # 以 UTF-8 编码打开模块文件
content = f.read() # 读取文件内容
# Imports of the form `import .xxx` # 说明以下是相对导入的处理
# 使用正则表达式查找以相对导入形式书写的模块名
relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
# 查找以相对导入形式书写的具体从属模块名
relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
# 将结果去重,确保唯一性
return list(set(relative_imports))
# 获取给定模块所需的所有文件列表,包括相对导入的文件
def get_relative_import_files(module_file):
# 初始化一个标志,用于控制递归循环
no_change = False
# 存储待检查的文件列表,初始为传入的模块文件
files_to_check = [module_file]
# 存储所有找到的相对导入文件
all_relative_imports = []
# 递归遍历所有相对导入文件
while not no_change:
# 存储新发现的导入文件
new_imports = []
# 遍历待检查的文件列表
for f in files_to_check:
# 获取当前文件的相对导入,并添加到新导入列表中
new_imports.extend(get_relative_imports(f))
# 获取当前模块文件的目录路径
module_path = Path(module_file).parent
# 将新导入的模块文件转为绝对路径
new_import_files = [str(module_path / m) for m in new_imports]
# 过滤掉已经找到的导入文件
new_import_files = [f for f in new_import_files if f not in all_relative_imports]
# 更新待检查的文件列表,添加 .py 后缀
files_to_check = [f"{f}.py" for f in new_import_files]
# 检查是否有新导入文件,如果没有,则结束循环
no_change = len(new_import_files) == 0
# 将当前待检查文件加入所有相对导入列表
all_relative_imports.extend(files_to_check)
# 返回所有找到的相对导入文件
return all_relative_imports
# 检查当前 Python 环境是否包含文件中导入的所有库
def check_imports(filename):
# 以 UTF-8 编码打开指定文件并读取内容
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
# 正则表达式查找 `import xxx` 形式的导入
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
# 正则表达式查找 `from xxx import yyy` 形式的导入
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
# 仅保留顶级模块,过滤掉相对导入
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
# 去重并确保导入模块的唯一性
imports = list(set(imports))
# 存储缺失的包列表
missing_packages = []
# 遍历每个导入模块并尝试导入
for imp in imports:
try:
importlib.import_module(imp)
except ImportError:
# 如果导入失败,记录缺失的包
missing_packages.append(imp)
# 如果有缺失的包,抛出 ImportError 异常并提示用户
if len(missing_packages) > 0:
raise ImportError(
"This modeling file requires the following packages that were not found in your environment: "
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
)
# 返回文件的所有相对导入文件
return get_relative_imports(filename)
# 从模块缓存中导入指定的类
def get_class_in_module(class_name, module_path):
# 将模块路径中的分隔符替换为点,以便导入
module_path = module_path.replace(os.path.sep, ".")
# 导入指定模块
module = importlib.import_module(module_path)
# 如果类名为空,查找管道类
if class_name is None:
return find_pipeline_class(module)
# 返回指定类的引用
return getattr(module, class_name)
# 获取继承自 `DiffusionPipeline` 的管道类
def find_pipeline_class(loaded_module):
# 从上级导入 DiffusionPipeline 类
from ..pipelines import DiffusionPipeline
# 获取加载模块中所有的类成员
cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
# 初始化管道类变量
pipeline_class = None
# 遍历 cls_members 字典中的每个类名及其对应的类
for cls_name, cls in cls_members.items():
# 检查类名不是 DiffusionPipeline 的名称,且是其子类,且模块不是 diffusers
if (
cls_name != DiffusionPipeline.__name__
and issubclass(cls, DiffusionPipeline)
and cls.__module__.split(".")[0] != "diffusers"
):
# 如果已经找到一个管道类,则抛出值错误,表示发现多个类
if pipeline_class is not None:
raise ValueError(
# 错误信息,包含找到的多个类的信息
f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
f" {loaded_module}."
)
# 记录找到的管道类
pipeline_class = cls
# 返回找到的管道类
return pipeline_class
# 装饰器,用于验证传入的参数是否符合预期
@validate_hf_hub_args
# 定义获取缓存模块文件的函数,接受多个参数
def get_cached_module_file(
# 预训练模型名称或路径,可以是字符串或路径类型
pretrained_model_name_or_path: Union[str, os.PathLike],
# 模块文件名称,字符串类型
module_file: str,
# 缓存目录,可选参数,路径类型或字符串
cache_dir: Optional[Union[str, os.PathLike]] = None,
# 强制下载标志,可选参数,布尔类型,默认为 False
force_download: bool = False,
# 代理服务器字典,可选参数
proxies: Optional[Dict[str, str]] = None,
# 授权令牌,可选参数,可以是布尔或字符串类型
token: Optional[Union[bool, str]] = None,
# 版本修订信息,可选参数,字符串类型
revision: Optional[str] = None,
# 仅使用本地文件标志,可选参数,布尔类型,默认为 False
local_files_only: bool = False,
):
"""
准备从本地文件夹或远程仓库下载模块,并返回其在缓存中的路径。
参数:
pretrained_model_name_or_path (`str` 或 `os.PathLike`):
可以是预训练模型配置的模型 ID 或包含配置文件的目录路径。
module_file (`str`):
包含要查找的类的模块文件名称。
cache_dir (`str` 或 `os.PathLike`, *可选*):
下载的预训练模型配置的缓存目录路径。
force_download (`bool`, *可选*, 默认值为 `False`):
是否强制重新下载配置文件并覆盖已存在的缓存版本。
proxies (`Dict[str, str]`, *可选*):
代理服务器字典,用于每个请求。
token (`str` 或 *bool*, *可选*):
用作远程文件的 HTTP 授权令牌。
revision (`str`, *可选*, 默认值为 `"main"`):
要使用的特定模型版本。
local_files_only (`bool`, *可选*, 默认值为 `False`):
如果为 `True`,仅尝试从本地文件加载配置。
返回:
`str`: 模块在缓存中的路径。
"""
# 下载并缓存来自 `pretrained_model_name_or_path` 的 module_file,或获取本地文件
pretrained_model_name_or_path = str(pretrained_model_name_or_path) # 将预训练模型路径转换为字符串
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) # 组合路径和文件名形成完整路径
if os.path.isfile(module_file_or_url): # 检查该路径是否指向一个文件
resolved_module_file = module_file_or_url # 如果是文件,保存该路径
submodule = "local" # 标记为本地文件
elif pretrained_model_name_or_path.count("/") == 0: # 如果路径中没有斜杠,表示这是一个模型名称而非路径
available_versions = get_diffusers_versions() # 获取可用版本列表
# 去掉 ".dev0" 部分
latest_version = "v" + ".".join(__version__.split(".")[:3]) # 获取最新的版本号
# 获取匹配的 GitHub 版本
if revision is None: # 如果没有指定修订版本
revision = latest_version if latest_version[1:] in available_versions else "main" # 默认选择最新版本或主分支
logger.info(f"Defaulting to latest_version: {revision}.") # 记录默认选择的版本
elif revision in available_versions: # 如果指定版本在可用列表中
revision = f"v{revision}" # 格式化版本号
elif revision == "main": # 如果指定版本为主分支
revision = revision # 保持不变
else: # 如果指定版本不在可用版本中
raise ValueError(
f"`custom_revision`: {revision} does not exist. Please make sure to choose one of"
f" {', '.join(available_versions + ['main'])}." # 提示可用版本
)
try:
resolved_module_file = hf_hub_download( # 从 Hugging Face Hub 下载指定文件
repo_id=COMMUNITY_PIPELINES_MIRROR_ID, # 设定资源库 ID
repo_type="dataset", # 指定资源类型为数据集
filename=f"{revision}/{pretrained_model_name_or_path}.py", # 构造文件名
cache_dir=cache_dir, # 设置缓存目录
force_download=force_download, # 决定是否强制下载
proxies=proxies, # 设置代理
local_files_only=local_files_only, # 是否只考虑本地文件
)
submodule = "git" # 标记为从 GitHub 下载的文件
module_file = pretrained_model_name_or_path + ".py" # 更新模块文件名
except RevisionNotFoundError as e: # 捕获未找到修订版本的异常
raise EnvironmentError(
f"Revision '{revision}' not found in the community pipelines mirror. Check available revisions on"
" https://huggingface.co/datasets/diffusers/community-pipelines-mirror/tree/main."
" If you don't find the revision you are looking for, please open an issue on https://github.com/huggingface/diffusers/issues."
) from e # 抛出环境错误并提供信息
except EnvironmentError: # 捕获环境错误
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") # 记录错误
raise # 重新抛出异常
else: # 如果不是文件
try:
# 从 URL 加载或从缓存加载(如果已缓存)
resolved_module_file = hf_hub_download( # 从 Hugging Face Hub 下载文件
pretrained_model_name_or_path, # 使用预训练模型名称作为资源路径
module_file, # 指定模块文件名
cache_dir=cache_dir, # 设置缓存目录
force_download=force_download, # 决定是否强制下载
proxies=proxies, # 设置代理
local_files_only=local_files_only, # 是否只考虑本地文件
token=token, # 传递身份验证令牌
)
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) # 构造本地子模块路径
except EnvironmentError: # 捕获环境错误
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") # 记录错误
raise # 重新抛出异常
# 检查环境中是否具备所有所需的模块
modules_needed = check_imports(resolved_module_file)
# 将模块移动到我们的缓存动态模块中
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule) # 创建动态模块
submodule_path = Path(HF_MODULES_CACHE) / full_submodule # 构建子模块路径
if submodule == "local" or submodule == "git": # 检查子模块类型
# 始终复制本地文件(可以通过哈希判断是否有变化)
# 复制的目的是避免将过多文件夹放入 sys.path
shutil.copy(resolved_module_file, submodule_path / module_file) # 复制模块文件到子模块路径
for module_needed in modules_needed: # 遍历所需的模块
if len(module_needed.split(".")) == 2: # 检查模块名称是否有两部分
module_needed = "/".join(module_needed.split(".")) # 将模块名称转换为路径格式
module_folder = module_needed.split("/")[0] # 获取模块文件夹名
if not os.path.exists(submodule_path / module_folder): # 检查文件夹是否存在
os.makedirs(submodule_path / module_folder) # 创建文件夹
module_needed = f"{module_needed}.py" # 添加 .py 后缀
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) # 复制所需模块文件
else:
# 获取提交哈希值
# TODO: 未来将从 etag 获取此信息,而不是这里
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha # 获取模型的提交哈希
# 模块文件将放置在带有 git 哈希的子文件夹中,以便实现版本控制
submodule_path = submodule_path / commit_hash # 更新子模块路径以包含哈希
full_submodule = full_submodule + os.path.sep + commit_hash # 更新完整子模块名称
create_dynamic_module(full_submodule) # 创建新的动态模块
if not (submodule_path / module_file).exists(): # 检查模块文件是否已存在
if len(module_file.split("/")) == 2: # 检查模块文件路径是否包含两个部分
module_folder = module_file.split("/")[0] # 获取模块文件夹名
if not os.path.exists(submodule_path / module_folder): # 检查文件夹是否存在
os.makedirs(submodule_path / module_folder) # 创建文件夹
shutil.copy(resolved_module_file, submodule_path / module_file) # 复制模块文件
# 确保每个相对文件都存在
for module_needed in modules_needed: # 遍历所需模块
if len(module_needed.split(".")) == 2: # 检查模块名称是否有两部分
module_needed = "/".join(module_needed.split(".")) # 将模块名称转换为路径格式
if not (submodule_path / module_needed).exists(): # 检查模块文件是否存在
get_cached_module_file( # 获取缓存的模块文件
pretrained_model_name_or_path,
f"{module_needed}.py", # 模块文件名
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
)
return os.path.join(full_submodule, module_file) # 返回完整子模块路径及模块文件名
# 装饰器,用于验证传入的参数是否符合预期
@validate_hf_hub_args
def get_class_from_dynamic_module(
# 预训练模型的名称或路径,可以是字符串或路径类型
pretrained_model_name_or_path: Union[str, os.PathLike],
# 模块文件的名称,包含要查找的类
module_file: str,
# 要导入的类的名称,默认为 None
class_name: Optional[str] = None,
# 缓存目录的路径,默认为 None
cache_dir: Optional[Union[str, os.PathLike]] = None,
# 是否强制下载配置文件,默认为 False
force_download: bool = False,
# 代理服务器的字典,默认为 None
proxies: Optional[Dict[str, str]] = None,
# 用于远程文件的 HTTP 认证令牌,默认为 None
token: Optional[Union[bool, str]] = None,
# 具体的模型版本,默认为 "main"
revision: Optional[str] = None,
# 是否仅加载本地文件,默认为 False
local_files_only: bool = False,
# 其他可选的关键字参数
**kwargs,
):
"""
从模块文件中提取一个类,该模块文件可以位于本地文件夹或模型的仓库中。
<Tip warning={true}>
调用此函数将执行在本地找到的模块文件或从 Hub 下载的代码。
因此,仅应在可信的仓库上调用。
</Tip>
Args:
# 预训练模型的名称或路径,可以是 huggingface.co 上的模型 id 或本地目录路径
pretrained_model_name_or_path (`str` or `os.PathLike`):
可以是字符串,表示在 huggingface.co 上托管的预训练模型配置的模型 id。
有效的模型 id 可以位于根目录下,例如 `bert-base-uncased`,
或者在用户或组织名称下命名,例如 `dbmdz/bert-base-german-cased`。
也可以是一个目录的路径,包含使用 [`~PreTrainedTokenizer.save_pretrained`] 方法保存的配置文件,
例如 `./my_model_directory/`。
# 模块文件的名称,包含要查找的类
module_file (`str`):
包含要查找的类的模块文件的名称。
# 类的名称,默认为 None
class_name (`str`):
要导入的类的名称。
# 缓存目录的路径,默认为 None
cache_dir (`str` or `os.PathLike`, *optional*):
下载的预训练模型配置应缓存的目录路径,
如果不使用标准缓存的话。
# 是否强制下载配置文件,默认为 False
force_download (`bool`, *optional*, defaults to `False`):
是否强制重新下载配置文件,并覆盖已存在的缓存版本。
# 代理服务器的字典,默认为 None
proxies (`Dict[str, str]`, *optional*):
以协议或端点为基础使用的代理服务器字典,例如 `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`。代理将在每个请求中使用。
# 用于远程文件的 HTTP 认证令牌,默认为 None
token (`str` or `bool`, *optional*):
用作远程文件的 HTTP bearer 认证的令牌。
如果为 True,将使用在运行 `transformers-cli login` 时生成的令牌(存储在 `~/.huggingface` 中)。
# 具体的模型版本,默认为 "main"
revision (`str`, *optional*, defaults to `"main"`):
使用的特定模型版本。可以是分支名、标签名或提交 ID,
因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,
所以 `revision` 可以是 git 允许的任何标识符。
# 是否仅加载本地文件,默认为 False
local_files_only (`bool`, *optional*, defaults to `False`):
如果为 True,仅尝试从本地文件加载标记器配置。
<Tip>
# 如果未登录(`huggingface-cli login`),可以通过 `token` 参数传递令牌,以便使用私有或
# [受限模型](https://huggingface.co/docs/hub/models-gated#gated-models)。
# 返回值:
# `type`: 从模块动态导入的类。
# 示例:
# ```python
# 从 huggingface.co 下载模块 `modeling.py`,缓存并提取类 `MyBertModel`。
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
# ```py"""
# 最后,我们获取新创建模块中的类
final_module = get_cached_module_file(
# 获取预训练模型的名称或路径
pretrained_model_name_or_path,
# 模块文件名称
module_file,
# 缓存目录
cache_dir=cache_dir,
# 强制下载标志
force_download=force_download,
# 代理设置
proxies=proxies,
# 访问令牌
token=token,
# 版本控制
revision=revision,
# 仅本地文件标志
local_files_only=local_files_only,
)
# 从最终模块中获取类名,去掉 ".py" 后缀
return get_class_in_module(class_name, final_module.replace(".py", ""))
.\diffusers\utils\export_utils.py
# 导入所需的模块
import io # 用于处理输入输出
import random # 用于生成随机数
import struct # 用于处理C语言风格的二进制数据
import tempfile # 用于创建临时文件
from contextlib import contextmanager # 用于创建上下文管理器
from typing import List, Union # 用于类型注解
import numpy as np # 用于数值计算
import PIL.Image # 用于处理图像
import PIL.ImageOps # 用于图像操作
from .import_utils import BACKENDS_MAPPING, is_imageio_available, is_opencv_available # 导入工具函数和映射
from .logging import get_logger # 导入日志记录器
# 创建全局随机数生成器
global_rng = random.Random()
# 获取当前模块的日志记录器
logger = get_logger(__name__)
# 定义上下文管理器以便于缓冲写入
@contextmanager
def buffered_writer(raw_f):
# 创建缓冲写入对象
f = io.BufferedWriter(raw_f)
# 生成缓冲写入对象
yield f
# 刷新缓冲区,确保所有数据写入
f.flush()
# 导出图像列表为 GIF 文件
def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None, fps: int = 10) -> str:
# 如果没有提供输出路径,则创建一个临时文件
if output_gif_path is None:
output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name
# 保存图像为 GIF 文件
image[0].save(
output_gif_path,
save_all=True, # 保存所有帧
append_images=image[1:], # 附加后续图像
optimize=False, # 不优化图像
duration=1000 // fps, # 设置帧间隔
loop=0, # 循环次数
)
# 返回生成的 GIF 文件路径
return output_gif_path
# 导出网格数据为 PLY 文件
def export_to_ply(mesh, output_ply_path: str = None):
"""
写入网格的 PLY 文件。
"""
# 如果没有提供输出路径,则创建一个临时文件
if output_ply_path is None:
output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name
# 获取网格顶点坐标并转换为 NumPy 数组
coords = mesh.verts.detach().cpu().numpy()
# 获取网格面信息并转换为 NumPy 数组
faces = mesh.faces.cpu().numpy()
# 获取网格顶点的 RGB 颜色信息并堆叠为数组
rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
# 使用缓冲写入器打开输出文件
with buffered_writer(open(output_ply_path, "wb")) as f:
f.write(b"ply\n") # 写入 PLY 文件头
f.write(b"format binary_little_endian 1.0\n") # 指定文件格式
f.write(bytes(f"element vertex {len(coords)}\n", "ascii")) # 写入顶点数量
f.write(b"property float x\n") # 写入 x 坐标属性
f.write(b"property float y\n") # 写入 y 坐标属性
f.write(b"property float z\n") # 写入 z 坐标属性
if rgb is not None: # 如果有 RGB 颜色信息
f.write(b"property uchar red\n") # 写入红色属性
f.write(b"property uchar green\n") # 写入绿色属性
f.write(b"property uchar blue\n") # 写入蓝色属性
if faces is not None: # 如果有面信息
f.write(bytes(f"element face {len(faces)}\n", "ascii")) # 写入面数量
f.write(b"property list uchar int vertex_index\n") # 写入顶点索引属性
f.write(b"end_header\n") # 写入文件头结束标记
if rgb is not None: # 如果有 RGB 颜色信息
rgb = (rgb * 255.499).round().astype(int) # 将 RGB 值转换为整数
vertices = [
(*coord, *rgb) # 合并坐标和颜色信息
for coord, rgb in zip(
coords.tolist(),
rgb.tolist(),
)
]
format = struct.Struct("<3f3B") # 定义数据格式
for item in vertices: # 写入每个顶点数据
f.write(format.pack(*item)) # 使用打包格式写入数据
else: # 如果没有 RGB 信息
format = struct.Struct("<3f") # 定义仅包含坐标的数据格式
for vertex in coords.tolist(): # 写入每个顶点坐标
f.write(format.pack(*vertex)) # 使用打包格式写入数据
if faces is not None: # 如果有面信息
format = struct.Struct("<B3I") # 定义面数据格式
for tri in faces.tolist(): # 写入每个面数据
f.write(format.pack(len(tri), *tri)) # 使用打包格式写入数据
# 返回生成的 PLY 文件路径
return output_ply_path
# 导出网格数据为 OBJ 文件
def export_to_obj(mesh, output_obj_path: str = None):
# 如果没有提供输出路径,则创建一个临时文件
if output_obj_path is None:
output_obj_path = tempfile.NamedTemporaryFile(suffix=".obj").name
# 获取网格顶点坐标并转换为 NumPy 数组
verts = mesh.verts.detach().cpu().numpy()
# 获取网格面信息并转换为 NumPy 数组
faces = mesh.faces.cpu().numpy()
# 将网格顶点的颜色通道合并成一个多维数组
vertex_colors = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
# 将顶点坐标和颜色组合成格式化字符串,形成顶点列表
vertices = [
"{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(verts.tolist(), vertex_colors.tolist())
]
# 将每个三角形的索引格式化为面定义字符串,索引加1以符合 OBJ 格式
faces = ["f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in faces.tolist()]
# 将顶点和面数据合并为最终输出列表
combined_data = ["v " + vertex for vertex in vertices] + faces
# 打开指定路径的文件以写入数据
with open(output_obj_path, "w") as f:
# 将合并的数据写入文件,每个元素占一行
f.writelines("\n".join(combined_data))
# 导出视频的私有函数,接受视频帧列表、输出视频路径和帧率作为参数
def _legacy_export_to_video(
# 视频帧,可以是 NumPy 数组或 PIL 图像的列表,输出视频的路径,帧率
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
):
# 检查 OpenCV 是否可用
if is_opencv_available():
# 导入 OpenCV 库
import cv2
else:
# 如果不可用,抛出导入错误
raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video"))
# 如果没有提供输出视频路径,则创建一个临时文件
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
# 如果视频帧是 NumPy 数组,则将其值从 [0, 1] 乘以 255 并转换为 uint8 类型
if isinstance(video_frames[0], np.ndarray):
video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
# 如果视频帧是 PIL 图像,则将其转换为 NumPy 数组
elif isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
# 获取视频编码器,指定使用 mp4v 编码
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# 获取视频帧的高度、宽度和通道数
h, w, c = video_frames[0].shape
# 创建视频写入对象,设置输出路径、编码器、帧率和帧尺寸
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
# 遍历每一帧,转换颜色并写入视频
for i in range(len(video_frames)):
# 将帧从 RGB 转换为 BGR 格式
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
# 将转换后的帧写入视频文件
video_writer.write(img)
# 返回输出视频的路径
return output_video_path
# 导出视频的公共函数,接受视频帧列表、输出视频路径和帧率作为参数
def export_to_video(
# 视频帧,可以是 NumPy 数组或 PIL 图像的列表,输出视频的路径,帧率
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
) -> str:
# TODO: Dhruv. 在 Diffusers 版本 0.33.0 发布时删除
# 为了防止现有代码中断而添加的
# 检查 imageio 是否可用
if not is_imageio_available():
# 记录警告信息,提示用户建议使用 imageio 和 imageio-ffmpeg 作为后端
logger.warning(
(
"It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n"
"These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n"
"Support for the OpenCV backend will be deprecated in a future Diffusers version"
)
)
# 如果不使用 imageio,则调用旧版导出函数
return _legacy_export_to_video(video_frames, output_video_path, fps)
# 如果 imageio 可用,则导入它
if is_imageio_available():
import imageio
else:
# 如果不可用,抛出导入错误
raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video"))
# 尝试获取 imageio ffmpeg 插件的执行文件
try:
imageio.plugins.ffmpeg.get_exe()
except AttributeError:
# 如果未找到兼容的 ffmpeg 安装,抛出属性错误
raise AttributeError(
(
"Found an existing imageio backend in your environment. Attempting to export video with imageio. \n"
"Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg"
)
)
# 如果没有提供输出视频路径,则创建一个临时文件
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
# 如果视频帧是 NumPy 数组,则将其值从 [0, 1] 乘以 255 并转换为 uint8 类型
if isinstance(video_frames[0], np.ndarray):
video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
# 如果视频帧是 PIL 图像,则将其转换为 NumPy 数组
elif isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
# 使用 imageio 创建视频写入器,指定输出路径和帧率
with imageio.get_writer(output_video_path, fps=fps) as writer:
# 遍历每一帧,将其附加到视频写入器中
for frame in video_frames:
writer.append_data(frame)
# 返回输出视频的路径
return output_video_path
.\diffusers\utils\hub_utils.py
# coding=utf-8 # 指定文件的编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team. # 文件版权信息
#
# Licensed under the Apache License, Version 2.0 (the "License"); # 许可证声明
# you may not use this file except in compliance with the License. # 使用许可证的条件说明
# You may obtain a copy of the License at # 许可证获取链接
#
# http://www.apache.org/licenses/LICENSE-2.0 # 许可证链接
#
# Unless required by applicable law or agreed to in writing, software # 免责条款
# distributed under the License is distributed on an "AS IS" BASIS, # 免责条款的进一步说明
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # 不提供任何明示或暗示的保证
# See the License for the specific language governing permissions and # 查看许可证以获取特定权限
# limitations under the License. # 以及在许可证下的限制
import json # 导入 json 模块,用于处理 JSON 数据
import os # 导入 os 模块,用于与操作系统交互
import re # 导入 re 模块,用于正则表达式操作
import sys # 导入 sys 模块,用于访问与 Python 解释器相关的信息
import tempfile # 导入 tempfile 模块,用于创建临时文件
import traceback # 导入 traceback 模块,用于处理异常的跟踪信息
import warnings # 导入 warnings 模块,用于发出警告
from pathlib import Path # 从 pathlib 导入 Path 类,用于路径操作
from typing import Dict, List, Optional, Union # 导入类型提示相关的类
from uuid import uuid4 # 从 uuid 导入 uuid4 函数,用于生成唯一标识符
# 从 huggingface_hub 导入所需的模块和函数
from huggingface_hub import (
ModelCard, # 导入 ModelCard 类,用于模型卡片管理
ModelCardData, # 导入 ModelCardData 类,用于处理模型卡片数据
create_repo, # 导入 create_repo 函数,用于创建模型仓库
hf_hub_download, # 导入 hf_hub_download 函数,用于下载模型
model_info, # 导入 model_info 函数,用于获取模型信息
snapshot_download, # 导入 snapshot_download 函数,用于下载快照
upload_folder, # 导入 upload_folder 函数,用于上传文件夹
)
# 导入 huggingface_hub 的常量
from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
# 从 huggingface_hub.file_download 导入正则表达式相关内容
from huggingface_hub.file_download import REGEX_COMMIT_HASH
# 导入 huggingface_hub.utils 的多个异常处理和实用函数
from huggingface_hub.utils import (
EntryNotFoundError, # 导入找不到条目的异常
RepositoryNotFoundError, # 导入找不到仓库的异常
RevisionNotFoundError, # 导入找不到修订版本的异常
is_jinja_available, # 导入检查 Jinja 模板是否可用的函数
validate_hf_hub_args, # 导入验证 Hugging Face Hub 参数的函数
)
from packaging import version # 导入 version 模块,用于版本处理
from requests import HTTPError # 从 requests 导入 HTTPError 异常,用于处理 HTTP 错误
from .. import __version__ # 导入当前包的版本信息
from .constants import (
DEPRECATED_REVISION_ARGS, # 导入已弃用的修订参数常量
HUGGINGFACE_CO_RESOLVE_ENDPOINT, # 导入 Hugging Face 解析端点常量
SAFETENSORS_WEIGHTS_NAME, # 导入安全张量权重名称常量
WEIGHTS_NAME, # 导入权重名称常量
)
from .import_utils import (
ENV_VARS_TRUE_VALUES, # 导入环境变量真值集合
_flax_version, # 导入 Flax 版本
_jax_version, # 导入 JAX 版本
_onnxruntime_version, # 导入 ONNX 运行时版本
_torch_version, # 导入 PyTorch 版本
is_flax_available, # 导入检查 Flax 是否可用的函数
is_onnx_available, # 导入检查 ONNX 是否可用的函数
is_torch_available, # 导入检查 PyTorch 是否可用的函数
)
from .logging import get_logger # 从 logging 模块导入获取日志记录器的函数
logger = get_logger(__name__) # 获取当前模块的日志记录器实例
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" # 设置模型卡片模板文件的路径
SESSION_ID = uuid4().hex # 生成一个唯一的会话 ID
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: # 定义一个格式化用户代理字符串的函数
"""
Formats a user-agent string with basic info about a request. # 函数说明,格式化用户代理字符串
"""
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" # 构建基本用户代理字符串
if HF_HUB_DISABLE_TELEMETRY or HF_HUB_OFFLINE: # 检查是否禁用遥测或处于离线状态
return ua + "; telemetry/off" # 返回禁用遥测的用户代理字符串
if is_torch_available(): # 检查 PyTorch 是否可用
ua += f"; torch/{_torch_version}" # 将 PyTorch 版本信息添加到用户代理字符串
if is_flax_available(): # 检查 Flax 是否可用
ua += f"; jax/{_jax_version}" # 将 JAX 版本信息添加到用户代理字符串
ua += f"; flax/{_flax_version}" # 将 Flax 版本信息添加到用户代理字符串
if is_onnx_available(): # 检查 ONNX 是否可用
ua += f"; onnxruntime/{_onnxruntime_version}" # 将 ONNX 运行时版本信息添加到用户代理字符串
# CI will set this value to True # CI 会将此值设置为 True
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: # 检查环境变量是否指示在 CI 中运行
ua += "; is_ci/true" # 如果是 CI,添加相关信息到用户代理字符串
if isinstance(user_agent, dict): # 检查用户代理是否为字典类型
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) # 将字典项格式化为字符串并添加到用户代理
elif isinstance(user_agent, str): # 检查用户代理是否为字符串类型
ua += "; " + user_agent # 直接添加用户代理字符串
return ua # 返回最终的用户代理字符串
def load_or_create_model_card( # 定义加载或创建模型卡片的函数
repo_id_or_path: str = None, # 仓库 ID 或路径,默认为 None
token: Optional[str] = None, # 访问令牌,默认为 None
is_pipeline: bool = False, # 是否为管道模型,默认为 False
from_training: bool = False, # 是否从训练中加载,默认为 False
# 定义模型描述,类型为可选的字符串,默认值为 None
model_description: Optional[str] = None,
# 定义基础模型,类型为字符串,默认值为 None
base_model: str = None,
# 定义提示信息,类型为可选的字符串,默认值为 None
prompt: Optional[str] = None,
# 定义许可证信息,类型为可选的字符串,默认值为 None
license: Optional[str] = None,
# 定义小部件列表,类型为可选的字典列表,默认值为 None
widget: Optional[List[dict]] = None,
# 定义推理标志,类型为可选的布尔值,默认值为 None
inference: Optional[bool] = None,
# 定义一个函数,返回类型为 ModelCard
) -> ModelCard:
"""
加载或创建模型卡片。
参数:
repo_id_or_path (`str`):
仓库 ID(例如 "runwayml/stable-diffusion-v1-5")或查找模型卡片的本地路径。
token (`str`, *可选*):
认证令牌。默认为存储的令牌。详细信息见 https://huggingface.co/settings/token。
is_pipeline (`bool`):
布尔值,指示是否为 [`DiffusionPipeline`] 添加标签。
from_training: (`bool`): 布尔标志,表示模型卡片是否是从训练脚本创建的。
model_description (`str`, *可选*): 要添加到模型卡片的模型描述。在从训练脚本使用 `load_or_create_model_card` 时有用。
base_model (`str`): 基础模型标识符(例如 "stabilityai/stable-diffusion-xl-base-1.0")。对类似 DreamBooth 的训练有用。
prompt (`str`, *可选*): 用于训练的提示。对类似 DreamBooth 的训练有用。
license: (`str`, *可选*): 输出工件的许可证。在从训练脚本使用 `load_or_create_model_card` 时有用。
widget (`List[dict]`, *可选*): 附带画廊模板的部件。
inference: (`bool`, *可选*): 是否开启推理部件。在从训练脚本使用 `load_or_create_model_card` 时有用。
"""
# 检查是否安装了 Jinja 模板引擎
if not is_jinja_available():
# 如果未安装,抛出一个值错误,并提供安装建议
raise ValueError(
"Modelcard 渲染基于 Jinja 模板。"
" 请确保在使用 `load_or_create_model_card` 之前安装了 `jinja`."
" 要安装它,请运行 `pip install Jinja2`."
)
try:
# 检查远程仓库中是否存在模型卡片
model_card = ModelCard.load(repo_id_or_path, token=token)
except (EntryNotFoundError, RepositoryNotFoundError):
# 如果模型卡片不存在,则根据模板创建一个模型卡片
if from_training:
# 从模板创建模型卡片,并使用卡片数据作为 YAML 块
model_card = ModelCard.from_template(
card_data=ModelCardData( # 卡片元数据对象
license=license,
library_name="diffusers", # 指定库名
inference=inference, # 指定推理设置
base_model=base_model, # 指定基础模型
instance_prompt=prompt, # 指定实例提示
widget=widget, # 指定部件
),
template_path=MODEL_CARD_TEMPLATE_PATH, # 模板路径
model_description=model_description, # 模型描述
)
else:
# 创建一个空的模型卡片数据对象
card_data = ModelCardData()
# 根据 is_pipeline 变量确定组件类型
component = "pipeline" if is_pipeline else "model"
# 如果没有提供模型描述,则生成默认描述
if model_description is None:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
# 从模板创建模型卡片
model_card = ModelCard.from_template(card_data, model_description=model_description)
# 返回模型卡片的内容
return model_card
# 定义一个函数,用于填充模型卡片的库名称和可选标签
def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard:
# 如果模型卡片的库名称为空,则设置为 "diffusers"
if model_card.data.library_name is None:
model_card.data.library_name = "diffusers"
# 如果标签不为空
if tags is not None:
# 如果标签是字符串,则转换为列表
if isinstance(tags, str):
tags = [tags]
# 如果模型卡片的标签为空,则初始化为空列表
if model_card.data.tags is None:
model_card.data.tags = []
# 遍历所有标签,将它们添加到模型卡片的标签中
for tag in tags:
model_card.data.tags.append(tag)
# 返回更新后的模型卡片
return model_card
# 定义一个函数,从已解析的文件名中提取提交哈希
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
# 提取提交哈希,优先使用提供的提交哈希
if resolved_file is None or commit_hash is not None:
return commit_hash
# 将解析后的文件路径转换为 POSIX 格式
resolved_file = str(Path(resolved_file).as_posix())
# 在文件路径中搜索提交哈希的模式
search = re.search(r"snapshots/([^/]+)/", resolved_file)
# 如果未找到模式,则返回 None
if search is None:
return None
# 从搜索结果中提取提交哈希
commit_hash = search.groups()[0]
# 如果提交哈希符合规定格式,则返回它,否则返回 None
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
# 定义旧的默认缓存路径,可能需要迁移
# 该逻辑大体来源于 `transformers`,并有如下不同之处:
# - Diffusers 不使用自定义环境变量来指定缓存路径。
# - 无需迁移缓存格式,只需将文件移动到新位置。
hf_cache_home = os.path.expanduser(
# 获取环境变量 HF_HOME,默认路径为 ~/.cache/huggingface
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
# 定义旧的 diffusers 缓存路径
old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")
# 定义一个函数,用于移动缓存目录
def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
# 如果新缓存目录为空,则设置为 HF_HUB_CACHE
if new_cache_dir is None:
new_cache_dir = HF_HUB_CACHE
# 如果旧缓存目录为空,则使用旧的 diffusers 缓存路径
if old_cache_dir is None:
old_cache_dir = old_diffusers_cache
# 扩展用户目录路径
old_cache_dir = Path(old_cache_dir).expanduser()
new_cache_dir = Path(new_cache_dir).expanduser()
# 遍历旧缓存目录中的所有 blob 文件
for old_blob_path in old_cache_dir.glob("**/blobs/*"):
# 如果路径是文件且不是符号链接
if old_blob_path.is_file() and not old_blob_path.is_symlink():
# 计算新 blob 文件的路径
new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
# 创建新路径的父目录
new_blob_path.parent.mkdir(parents=True, exist_ok=True)
# 替换旧的 blob 文件为新的 blob 文件
os.replace(old_blob_path, new_blob_path)
# 尝试在旧路径和新路径之间创建符号链接
try:
os.symlink(new_blob_path, old_blob_path)
except OSError:
# 如果无法创建符号链接,发出警告
logger.warning(
"Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded."
)
# 现在,old_cache_dir 包含指向新缓存的符号链接(仍然可以使用)
# 定义缓存版本文件的路径
cache_version_file = os.path.join(HF_HUB_CACHE, "version_diffusers_cache.txt")
# 如果缓存版本文件不存在,则设置缓存版本为 0
if not os.path.isfile(cache_version_file):
cache_version = 0
else:
# 打开文件以读取缓存版本
with open(cache_version_file) as f:
try:
# 尝试将读取内容转换为整数
cache_version = int(f.read())
except ValueError:
# 如果转换失败,则设置缓存版本为 0
cache_version = 0
# 如果缓存版本小于 1
if cache_version < 1:
# 检查旧的缓存目录是否存在且非空
old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0
# 如果旧缓存不为空,则记录警告信息
if old_cache_is_not_empty:
logger.warning(
"The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your "
"existing cached models. This is a one-time operation, you can interrupt it or run it "
"later by calling `diffusers.utils.hub_utils.move_cache()`."
)
# 尝试移动缓存
try:
move_cache()
# 捕获任何异常并处理
except Exception as e:
# 获取异常的追踪信息并格式化为字符串
trace = "\n".join(traceback.format_tb(e.__traceback__))
# 记录错误信息,建议用户在 GitHub 提交问题
logger.error(
f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
"file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole "
"message and we will do our best to help."
)
# 检查缓存版本是否小于1
if cache_version < 1:
# 尝试创建缓存目录
try:
os.makedirs(HF_HUB_CACHE, exist_ok=True) # 创建目录,如果已存在则不报错
# 打开缓存版本文件以写入版本号
with open(cache_version_file, "w") as f:
f.write("1") # 写入版本号1
except Exception: # 捕获异常
# 记录警告信息,提示用户可能存在的问题
logger.warning(
f"There was a problem when trying to write in your cache folder ({HF_HUB_CACHE}). Please, ensure "
"the directory exists and can be written to."
)
# 定义函数以添加变体到权重名称
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
# 如果变体不为 None
if variant is not None:
# 按 '.' 分割权重名称
splits = weights_name.split(".")
# 确定分割索引
split_index = -2 if weights_name.endswith(".index.json") else -1
# 更新权重名称的分割部分,插入变体
splits = splits[:-split_index] + [variant] + splits[-split_index:]
# 重新连接分割部分为完整的权重名称
weights_name = ".".join(splits)
# 返回更新后的权重名称
return weights_name
# 装饰器用于验证 HF Hub 的参数
@validate_hf_hub_args
def _get_model_file(
pretrained_model_name_or_path: Union[str, Path], # 预训练模型的名称或路径
*,
weights_name: str, # 权重文件的名称
subfolder: Optional[str] = None, # 子文件夹,默认为 None
cache_dir: Optional[str] = None, # 缓存目录,默认为 None
force_download: bool = False, # 强制下载标志,默认为 False
proxies: Optional[Dict] = None, # 代理设置,默认为 None
local_files_only: bool = False, # 仅使用本地文件的标志,默认为 False
token: Optional[str] = None, # 访问令牌,默认为 None
user_agent: Optional[Union[Dict, str]] = None, # 用户代理设置,默认为 None
revision: Optional[str] = None, # 修订版本,默认为 None
commit_hash: Optional[str] = None, # 提交哈希值,默认为 None
):
# 将预训练模型路径转换为字符串
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
# 如果路径指向一个文件
if os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path # 直接返回文件路径
# 如果路径指向一个目录
elif os.path.isdir(pretrained_model_name_or_path):
# 检查目录中是否存在权重文件
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# 从 PyTorch 检查点加载模型文件
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file # 返回模型文件路径
# 如果有子文件夹且子文件夹中存在权重文件
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
return model_file # 返回子文件夹中的模型文件路径
else:
# 抛出环境错误,指示未找到权重文件
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
# 检查本地是否存在分片文件的函数
def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames):
# 构造分片文件的路径
shards_path = os.path.join(local_dir, subfolder)
# 获取所有分片文件的完整路径
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
# 遍历每个分片文件
for shard_file in shard_filenames:
# 检查分片文件是否存在
if not os.path.exists(shard_file):
# 如果不存在,抛出错误提示
raise ValueError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
# 获取检查点分片文件的函数定义
def _get_checkpoint_shard_files(
pretrained_model_name_or_path, # 预训练模型的名称或路径
index_filename, # 索引文件名
cache_dir=None, # 缓存目录,默认为 None
proxies=None, # 代理设置,默认为 None
# 设置是否仅使用本地文件,默认为 False
local_files_only=False,
# 设置访问令牌,默认为 None,表示不使用令牌
token=None,
# 设置用户代理字符串,默认为 None
user_agent=None,
# 设置修订版号,默认为 None
revision=None,
# 设置子文件夹路径,默认为空字符串
subfolder="",
):
"""
对于给定的模型:
- 如果 `pretrained_model_name_or_path` 是 Hub 上的模型 ID,则下载并缓存所有分片的检查点
- 返回所有分片的路径列表,以及一些元数据。
有关每个参数的描述,请参见 [`PreTrainedModel.from_pretrained`]。 `index_filename` 是索引的完整路径
(如果 `pretrained_model_name_or_path` 是 Hub 上的模型 ID,则下载并缓存)。
"""
# 检查索引文件是否存在,如果不存在则抛出错误
if not os.path.isfile(index_filename):
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
# 打开索引文件并读取内容,解析为 JSON 格式
with open(index_filename, "r") as f:
index = json.loads(f.read())
# 获取权重映射中的所有原始分片文件名,并去重后排序
original_shard_filenames = sorted(set(index["weight_map"].values()))
# 获取分片元数据
sharded_metadata = index["metadata"]
# 将所有检查点键的列表添加到元数据中
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
# 复制权重映射到元数据中
sharded_metadata["weight_map"] = index["weight_map"].copy()
# 构建分片的路径
shards_path = os.path.join(pretrained_model_name_or_path, subfolder)
# 首先处理本地文件夹
if os.path.isdir(pretrained_model_name_or_path):
# 检查本地是否存在所需的分片
_check_if_shards_exist_locally(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
# 返回分片路径和分片元数据
return shards_path, sharded_metadata
# 此时 pretrained_model_name_or_path 是 Hub 上的模型标识符
# 设置允许的文件模式为原始分片文件名
allow_patterns = original_shard_filenames
# 如果提供了子文件夹,则更新允许的文件模式
if subfolder is not None:
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
# 定义需要忽略的文件模式
ignore_patterns = ["*.json", "*.md"]
# 如果不是仅使用本地文件
if not local_files_only:
# `model_info` 调用必须受到上述条件的保护
model_files_info = model_info(pretrained_model_name_or_path, revision=revision)
# 遍历原始分片文件名
for shard_file in original_shard_filenames:
# 检查当前分片文件是否在模型文件信息中存在
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
# 如果分片文件不存在,则抛出环境错误
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} 不存在名为 {shard_file} 的文件,这是根据检查点索引所需的。"
)
try:
# 从 URL 加载
cached_folder = snapshot_download(
pretrained_model_name_or_path, # 要下载的模型路径
cache_dir=cache_dir, # 缓存目录
proxies=proxies, # 代理设置
local_files_only=local_files_only, # 是否仅使用本地文件
token=token, # 授权令牌
revision=revision, # 版本信息
allow_patterns=allow_patterns, # 允许的文件模式
ignore_patterns=ignore_patterns, # 忽略的文件模式
user_agent=user_agent, # 用户代理信息
)
# 如果指定了子文件夹,则更新缓存文件夹路径
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)
# 已经在获取索引时处理了 RepositoryNotFoundError 和 RevisionNotFoundError,
# 所以这里不需要捕获它们。也处理了 EntryNotFoundError。
except HTTPError as e:
# 如果无法连接到指定的端点,则抛出环境错误
raise EnvironmentError(
f"我们无法连接到 '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' 来加载 {pretrained_model_name_or_path}。请检查您的互联网连接后重试。"
) from e
# 如果 `local_files_only=True`,则 `cached_folder` 可能不包含所有分片文件
elif local_files_only:
# 检查本地是否存在所有分片
_check_if_shards_exist_locally(
local_dir=cache_dir, # 本地目录
subfolder=subfolder, # 子文件夹
original_shard_filenames=original_shard_filenames # 原始分片文件名列表
)
# 如果指定了子文件夹,则更新缓存文件夹路径
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)
# 返回缓存文件夹和分片元数据
return cached_folder, sharded_metadata
# 定义一个混合类,用于将模型、调度器或管道推送到 Hugging Face Hub
class PushToHubMixin:
"""
A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
"""
# 定义一个私有方法,用于上传指定文件夹中的所有文件
def _upload_folder(
self,
working_dir: Union[str, os.PathLike], # 工作目录,包含待上传的文件
repo_id: str, # 目标仓库的 ID
token: Optional[str] = None, # 可选的认证令牌
commit_message: Optional[str] = None, # 可选的提交信息
create_pr: bool = False, # 是否创建拉取请求
):
"""
Uploads all files in `working_dir` to `repo_id`.
"""
# 如果未提供提交信息,则根据类名生成默认提交信息
if commit_message is None:
if "Model" in self.__class__.__name__:
commit_message = "Upload model" # 如果是模型类,设置默认信息
elif "Scheduler" in self.__class__.__name__:
commit_message = "Upload scheduler" # 如果是调度器类,设置默认信息
else:
commit_message = f"Upload {self.__class__.__name__}" # 否则,使用类名作为提交信息
# 记录上传文件的日志信息
logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
# 调用 upload_folder 函数上传文件,并返回其结果
return upload_folder(
repo_id=repo_id, # 目标仓库 ID
folder_path=working_dir, # 待上传的文件夹路径
token=token, # 认证令牌
commit_message=commit_message, # 提交信息
create_pr=create_pr # 是否创建拉取请求
)
# 定义一个公共方法,用于将文件推送到 Hugging Face Hub
def push_to_hub(
self,
repo_id: str, # 目标仓库的 ID
commit_message: Optional[str] = None, # 可选的提交信息
private: Optional[bool] = None, # 可选,是否将仓库设置为私有
token: Optional[str] = None, # 可选的认证令牌
create_pr: bool = False, # 是否创建拉取请求
safe_serialization: bool = True, # 是否安全序列化
variant: Optional[str] = None, # 可选的变体参数
) -> str: # 定义函数返回值类型为字符串
"""
Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub. # 函数文档字符串,说明功能
Parameters: # 参数说明部分
repo_id (`str`): # 仓库 ID,类型为字符串
The name of the repository you want to push your model, scheduler, or pipeline files to. It should
contain your organization name when pushing to an organization. `repo_id` can also be a path to a local
directory. # 描述 repo_id 的用途和格式
commit_message (`str`, *optional*): # 可选参数,类型为字符串
Message to commit while pushing. Default to `"Upload {object}".` # 提交消息的默认值
private (`bool`, *optional*): # 可选参数,类型为布尔值
Whether or not the repository created should be private. # 是否创建私有仓库
token (`str`, *optional*): # 可选参数,类型为字符串
The token to use as HTTP bearer authorization for remote files. The token generated when running
`huggingface-cli login` (stored in `~/.huggingface`). # 说明 token 的用途
create_pr (`bool`, *optional*, defaults to `False`): # 可选参数,类型为布尔值,默认值为 False
Whether or not to create a PR with the uploaded files or directly commit. # 是否创建 PR
safe_serialization (`bool`, *optional*, defaults to `True`): # 可选参数,类型为布尔值,默认值为 True
Whether or not to convert the model weights to the `safetensors` format. # 是否使用安全序列化格式
variant (`str`, *optional*): # 可选参数,类型为字符串
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. # 权重保存格式
Examples: # 示例说明部分
```py
from diffusers import UNet2DConditionModel # 从 diffusers 导入 UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet") # 从预训练模型加载 UNet
# Push the `unet` to your namespace with the name "my-finetuned-unet". # 推送到个人命名空间
unet.push_to_hub("my-finetuned-unet") # 将 unet 推送到指定名称的仓库
# Push the `unet` to an organization with the name "my-finetuned-unet". # 推送到组织
unet.push_to_hub("your-org/my-finetuned-unet") # 将 unet 推送到指定组织的仓库
```
""" # 结束文档字符串
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id # 创建仓库并获取仓库 ID
# Create a new empty model card and eventually tag it # 创建新的模型卡片并可能添加标签
model_card = load_or_create_model_card(repo_id, token=token) # 加载或创建模型卡片
model_card = populate_model_card(model_card) # 填充模型卡片信息
# Save all files. # 保存所有文件
save_kwargs = {"safe_serialization": safe_serialization} # 设置保存文件的参数
if "Scheduler" not in self.__class__.__name__: # 检查当前类名是否包含 "Scheduler"
save_kwargs.update({"variant": variant}) # 如果不包含,则添加 variant 参数
with tempfile.TemporaryDirectory() as tmpdir: # 创建临时目录
self.save_pretrained(tmpdir, **save_kwargs) # 将模型保存到临时目录
# Update model card if needed: # 如果需要,更新模型卡片
model_card.save(os.path.join(tmpdir, "README.md")) # 将模型卡片保存为 README.md 文件
return self._upload_folder( # 上传临时目录中的文件
tmpdir, # 临时目录路径
repo_id, # 仓库 ID
token=token, # 认证 token
commit_message=commit_message, # 提交消息
create_pr=create_pr, # 是否创建 PR
) # 返回上传结果
.\diffusers\utils\import_utils.py
# 版权信息,标明此文件的版权所有者及相关许可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache License, Version 2.0 许可,使用该文件需遵循该许可
# 许可的使用条件;可以在以下地址获取许可
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,按照许可分发的软件以 "现状" 方式提供,
# 不附带任何形式的保证或条件,明示或暗示。
# 详见许可中对权限和限制的具体规定
"""
导入工具:与导入和懒初始化相关的工具函数
"""
# 导入模块,提供动态导入和模块相关功能
import importlib.util
# 导入操作符模块,方便使用比较操作符
import operator as op
# 导入操作系统模块,提供与操作系统交互的功能
import os
# 导入系统模块,提供对 Python 解释器的访问
import sys
# 从 collections 模块导入有序字典
from collections import OrderedDict
# 从 itertools 模块导入链式迭代工具
from itertools import chain
# 导入模块类型
from types import ModuleType
# 导入类型注释工具
from typing import Any, Union
# 从 huggingface_hub.utils 导入检查 Jinja 是否可用的工具
from huggingface_hub.utils import is_jinja_available # noqa: F401
# 导入版本控制工具
from packaging import version
# 从 packaging.version 导入版本和解析函数
from packaging.version import Version, parse
# 导入当前目录下的 logging 模块
from . import logging
# 根据 Python 版本选择合适的 importlib_metadata 模块导入方式
if sys.version_info < (3, 8):
# 如果 Python 版本低于 3.8,导入 importlib_metadata
import importlib_metadata
else:
# 如果 Python 版本为 3.8 或更高,导入 importlib.metadata
import importlib.metadata as importlib_metadata
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定义环境变量的真值集合
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
# 定义环境变量的真值和自动值集合
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
# 从环境变量中获取是否使用 TensorFlow,默认为 "AUTO"
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
# 从环境变量中获取是否使用 PyTorch,默认为 "AUTO"
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
# 从环境变量中获取是否使用 JAX,默认为 "AUTO"
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
# 从环境变量中获取是否使用 SafeTensors,默认为 "AUTO"
USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
# 从环境变量中获取是否进行慢导入,默认为 "FALSE"
DIFFUSERS_SLOW_IMPORT = os.environ.get("DIFFUSERS_SLOW_IMPORT", "FALSE").upper()
# 将慢导入的环境变量值转换为布尔值
DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
# 定义操作符与对应函数的映射
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
# 初始化 PyTorch 版本为 "N/A"
_torch_version = "N/A"
# 检查是否启用 PyTorch,并且 TensorFlow 未被启用
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
# 尝试查找 PyTorch 模块是否可用
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
try:
# 获取 PyTorch 的版本信息
_torch_version = importlib_metadata.version("torch")
# 记录可用的 PyTorch 版本
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
# 如果 PyTorch 模块未找到,将其标记为不可用
_torch_available = False
else:
# 如果 USE_TORCH 被设置,则禁用 PyTorch
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False
# 检查 PyTorch XLA 是否可用
_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
if _torch_xla_available:
try:
# 获取 PyTorch XLA 的版本信息
_torch_xla_version = importlib_metadata.version("torch_xla")
# 记录可用的 PyTorch XLA 版本
logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
except ImportError:
# 如果 PyTorch XLA 导入失败,将其标记为不可用
_torch_xla_available = False
# 检查 torch_npu 是否可用
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
if _torch_npu_available:
# 尝试获取 "torch_npu" 包的版本信息
try:
# 使用 importlib_metadata 获取 "torch_npu" 的版本
_torch_npu_version = importlib_metadata.version("torch_npu")
# 记录可用的 "torch_npu" 版本信息到日志
logger.info(f"torch_npu version {_torch_npu_version} available.")
# 捕获导入错误,表示 "torch_npu" 包不可用
except ImportError:
# 设置标志,表示 "torch_npu" 不可用
_torch_npu_available = False
# 初始化 JAX 版本为 "N/A"
_jax_version = "N/A"
# 初始化 Flax 版本为 "N/A"
_flax_version = "N/A"
# 检查 USE_JAX 是否在环境变量的真值或自动值列表中
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
# 检查 JAX 和 Flax 是否可用
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
# 如果 Flax 可用
if _flax_available:
try:
# 获取 JAX 的版本
_jax_version = importlib_metadata.version("jax")
# 获取 Flax 的版本
_flax_version = importlib_metadata.version("flax")
# 记录 JAX 和 Flax 的版本信息
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
except importlib_metadata.PackageNotFoundError:
# 如果找不到包,则设置 Flax 不可用
_flax_available = False
else:
# 如果 USE_JAX 不在环境变量的真值或自动值列表中,设置 Flax 不可用
_flax_available = False
# 检查 USE_SAFETENSORS 是否在环境变量的真值或自动值列表中
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
# 检查 safetensors 是否可用
_safetensors_available = importlib.util.find_spec("safetensors") is not None
# 如果 safetensors 可用
if _safetensors_available:
try:
# 获取 safetensors 的版本
_safetensors_version = importlib_metadata.version("safetensors")
# 记录 safetensors 的版本信息
logger.info(f"Safetensors version {_safetensors_version} available.")
except importlib_metadata.PackageNotFoundError:
# 如果找不到包,则设置 safetensors 不可用
_safetensors_available = False
else:
# 如果 USE_SAFETENSORS 不在环境变量的真值或自动值列表中,记录信息并设置 safetensors 不可用
logger.info("Disabling Safetensors because USE_TF is set")
_safetensors_available = False
# 检查 transformers 是否可用
_transformers_available = importlib.util.find_spec("transformers") is not None
try:
# 获取 transformers 的版本
_transformers_version = importlib_metadata.version("transformers")
# 记录 transformers 的版本信息
logger.debug(f"Successfully imported transformers version {_transformers_version}")
except importlib_metadata.PackageNotFoundError:
# 如果找不到包,则设置 transformers 不可用
_transformers_available = False
# 检查 inflect 是否可用
_inflect_available = importlib.util.find_spec("inflect") is not None
try:
# 获取 inflect 的版本
_inflect_version = importlib_metadata.version("inflect")
# 记录 inflect 的版本信息
logger.debug(f"Successfully imported inflect version {_inflect_version}")
except importlib_metadata.PackageNotFoundError:
# 如果找不到包,则设置 inflect 不可用
_inflect_available = False
# 初始化 onnxruntime 版本为 "N/A"
_onnxruntime_version = "N/A"
# 检查 onnxruntime 是否可用
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
# 如果 onnxruntime 可用
if _onnx_available:
# 可能的 onnxruntime 包候选列表
candidates = (
"onnxruntime",
"onnxruntime-gpu",
"ort_nightly_gpu",
"onnxruntime-directml",
"onnxruntime-openvino",
"ort_nightly_directml",
"onnxruntime-rocm",
"onnxruntime-training",
)
# 初始化 onnxruntime 版本为 None
_onnxruntime_version = None
# 对于元数据,我们需要查找 onnxruntime 和 onnxruntime-gpu
for pkg in candidates:
try:
# 获取当前候选包的版本
_onnxruntime_version = importlib_metadata.version(pkg)
break # 如果找到版本,跳出循环
except importlib_metadata.PackageNotFoundError:
pass # 如果找不到包,继续下一个候选
# 如果找到了 onnxruntime 版本,则设置可用状态
_onnx_available = _onnxruntime_version is not None
# 如果 onnxruntime 可用,记录其版本信息
if _onnx_available:
logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
# (sayakpaul): importlib.util.find_spec("opencv-python") 返回 None 即使它已经被安装。
# 判断 OpenCV 是否可用,若未找到则返回 False
# _opencv_available = importlib.util.find_spec("opencv-python") is not None
try:
# 定义可能的 OpenCV 包候选名称
candidates = (
"opencv-python",
"opencv-contrib-python",
"opencv-python-headless",
"opencv-contrib-python-headless",
)
# 初始化 OpenCV 版本变量
_opencv_version = None
# 遍历候选包名称
for pkg in candidates:
try:
# 尝试获取当前包的版本
_opencv_version = importlib_metadata.version(pkg)
# 成功获取版本后退出循环
break
except importlib_metadata.PackageNotFoundError:
# 如果包未找到,则继续尝试下一个候选包
pass
# 检查是否成功获取到 OpenCV 版本
_opencv_available = _opencv_version is not None
# 如果 OpenCV 可用,记录调试信息
if _opencv_available:
logger.debug(f"Successfully imported cv2 version {_opencv_version}")
except importlib_metadata.PackageNotFoundError:
# 如果没有找到 OpenCV 包,标记为不可用
_opencv_available = False
# 判断 SciPy 是否可用,若未找到则返回 False
_scipy_available = importlib.util.find_spec("scipy") is not None
try:
# 尝试获取 SciPy 包的版本
_scipy_version = importlib_metadata.version("scipy")
# 如果成功,记录调试信息
logger.debug(f"Successfully imported scipy version {_scipy_version}")
except importlib_metadata.PackageNotFoundError:
# 如果包未找到,则标记为不可用
_scipy_available = False
# 判断 librosa 是否可用,若未找到则返回 False
_librosa_available = importlib.util.find_spec("librosa") is not None
try:
# 尝试获取 librosa 包的版本
_librosa_version = importlib_metadata.version("librosa")
# 如果成功,记录调试信息
logger.debug(f"Successfully imported librosa version {_librosa_version}")
except importlib_metadata.PackageNotFoundError:
# 如果包未找到,则标记为不可用
_librosa_available = False
# 判断 accelerate 是否可用,若未找到则返回 False
_accelerate_available = importlib.util.find_spec("accelerate") is not None
try:
# 尝试获取 accelerate 包的版本
_accelerate_version = importlib_metadata.version("accelerate")
# 如果成功,记录调试信息
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
except importlib_metadata.PackageNotFoundError:
# 如果包未找到,则标记为不可用
_accelerate_available = False
# 判断 xformers 是否可用,若未找到则返回 False
_xformers_available = importlib.util.find_spec("xformers") is not None
try:
# 尝试获取 xformers 包的版本
_xformers_version = importlib_metadata.version("xformers")
# 如果 torch 可用,获取其版本并进行版本比较
if _torch_available:
_torch_version = importlib_metadata.version("torch")
# 如果 torch 版本小于 1.12,抛出错误
if version.Version(_torch_version) < version.Version("1.12"):
raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12")
# 如果成功,记录调试信息
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
# 如果包未找到,则标记为不可用
_xformers_available = False
# 判断 k_diffusion 是否可用,若未找到则返回 False
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
try:
# 尝试获取 k_diffusion 包的版本
_k_diffusion_version = importlib_metadata.version("k_diffusion")
# 如果成功,记录调试信息
logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
except importlib_metadata.PackageNotFoundError:
# 如果包未找到,则标记为不可用
_k_diffusion_available = False
# 判断 note_seq 是否可用,若未找到则返回 False
_note_seq_available = importlib.util.find_spec("note_seq") is not None
try:
# 尝试获取 note_seq 包的版本
_note_seq_version = importlib_metadata.version("note_seq")
# 如果成功,记录调试信息
logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
except importlib_metadata.PackageNotFoundError:
# 如果包未找到,则标记为不可用
_note_seq_available = False
# 判断 wandb 是否可用,若未找到则返回 False
_wandb_available = importlib.util.find_spec("wandb") is not None
try:
# 尝试获取 wandb 包的版本
_wandb_version = importlib_metadata.version("wandb")
# 记录调试信息,指示成功导入 wandb 的版本
logger.debug(f"Successfully imported wandb version {_wandb_version }")
# 捕获导入错误,表示没有找到 'wandb' 包
except importlib_metadata.PackageNotFoundError:
# 设置 'wandb' 不可用标志为 False
_wandb_available = False
# 检查 'tensorboard' 包是否可用
_tensorboard_available = importlib.util.find_spec("tensorboard")
try:
# 获取 'tensorboard' 的版本信息
_tensorboard_version = importlib_metadata.version("tensorboard")
# 记录成功导入 'tensorboard' 的版本
logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
except importlib_metadata.PackageNotFoundError:
# 设置 'tensorboard' 不可用标志为 False
_tensorboard_available = False
# 检查 'compel' 包是否可用
_compel_available = importlib.util.find_spec("compel")
try:
# 获取 'compel' 的版本信息
_compel_version = importlib_metadata.version("compel")
# 记录成功导入 'compel' 的版本
logger.debug(f"Successfully imported compel version {_compel_version}")
except importlib_metadata.PackageNotFoundError:
# 设置 'compel' 不可用标志为 False
_compel_available = False
# 检查 'ftfy' 包是否可用
_ftfy_available = importlib.util.find_spec("ftfy") is not None
try:
# 获取 'ftfy' 的版本信息
_ftfy_version = importlib_metadata.version("ftfy")
# 记录成功导入 'ftfy' 的版本
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
except importlib_metadata.PackageNotFoundError:
# 设置 'ftfy' 不可用标志为 False
_ftfy_available = False
# 检查 'bs4' 包是否可用
_bs4_available = importlib.util.find_spec("bs4") is not None
try:
# importlib metadata 以不同名称获取
_bs4_version = importlib_metadata.version("beautifulsoup4")
# 记录成功导入 'beautifulsoup4' 的版本
logger.debug(f"Successfully imported ftfy version {_bs4_version}")
except importlib_metadata.PackageNotFoundError:
# 设置 'bs4' 不可用标志为 False
_bs4_available = False
# 检查 'torchsde' 包是否可用
_torchsde_available = importlib.util.find_spec("torchsde") is not None
try:
# 获取 'torchsde' 的版本信息
_torchsde_version = importlib_metadata.version("torchsde")
# 记录成功导入 'torchsde' 的版本
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
except importlib_metadata.PackageNotFoundError:
# 设置 'torchsde' 不可用标志为 False
_torchsde_available = False
# 检查 'imwatermark' 包是否可用
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
try:
# 获取 'invisible-watermark' 的版本信息
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
# 记录成功导入 'invisible-watermark' 的版本
logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}")
except importlib_metadata.PackageNotFoundError:
# 设置 'invisible-watermark' 不可用标志为 False
_invisible_watermark_available = False
# 检查 'peft' 包是否可用
_peft_available = importlib.util.find_spec("peft") is not None
try:
# 获取 'peft' 的版本信息
_peft_version = importlib_metadata.version("peft")
# 记录成功导入 'peft' 的版本
logger.debug(f"Successfully imported peft version {_peft_version}")
except importlib_metadata.PackageNotFoundError:
# 设置 'peft' 不可用标志为 False
_peft_available = False
# 检查 'torchvision' 包是否可用
_torchvision_available = importlib.util.find_spec("torchvision") is not None
try:
# 获取 'torchvision' 的版本信息
_torchvision_version = importlib_metadata.version("torchvision")
# 记录成功导入 'torchvision' 的版本
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
except importlib_metadata.PackageNotFoundError:
# 设置 'torchvision' 不可用标志为 False
_torchvision_available = False
# 检查 'sentencepiece' 包是否可用
_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None
try:
# 获取 'sentencepiece' 的版本信息
_sentencepiece_version = importlib_metadata.version("sentencepiece")
# 记录成功导入 'sentencepiece' 的版本
logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}")
except importlib_metadata.PackageNotFoundError:
# 设置 'sentencepiece' 不可用标志为 False
_sentencepiece_available = False
# 检查 'matplotlib' 包是否可用
_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
try:
# 导入 matplotlib 的版本信息
_matplotlib_version = importlib_metadata.version("matplotlib")
# 记录成功导入 matplotlib 的版本信息到调试日志
logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
# 捕获导入库时的包未找到异常
except importlib_metadata.PackageNotFoundError:
# 如果未找到 matplotlib,设置其可用状态为 False
_matplotlib_available = False
# 检查 "timm" 库是否可用,返回结果为布尔值
_timm_available = importlib.util.find_spec("timm") is not None
# 如果 "timm" 库可用
if _timm_available:
try:
# 获取 "timm" 库的版本信息
_timm_version = importlib_metadata.version("timm")
# 记录可用的 "timm" 版本信息
logger.info(f"Timm version {_timm_version} available.")
# 捕获导入库时的包未找到异常
except importlib_metadata.PackageNotFoundError:
# 如果未找到 "timm",设置其可用状态为 False
_timm_available = False
# 定义函数以返回 "timm" 库的可用状态
def is_timm_available():
return _timm_available
# 检查 "bitsandbytes" 库是否可用,返回结果为布尔值
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
try:
# 获取 "bitsandbytes" 库的版本信息
_bitsandbytes_version = importlib_metadata.version("bitsandbytes")
# 记录成功导入的 "bitsandbytes" 版本信息
logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
# 捕获导入库时的包未找到异常
except importlib_metadata.PackageNotFoundError:
# 如果未找到 "bitsandbytes",设置其可用状态为 False
_bitsandbytes_available = False
# 检查当前是否在 Google Colab 环境中
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
# 检查 "imageio" 库是否可用,返回结果为布尔值
_imageio_available = importlib.util.find_spec("imageio") is not None
# 如果 "imageio" 库可用
if _imageio_available:
try:
# 获取 "imageio" 库的版本信息
_imageio_version = importlib_metadata.version("imageio")
# 记录成功导入的 "imageio" 版本信息
logger.debug(f"Successfully imported imageio version {_imageio_version}")
# 捕获导入库时的包未找到异常
except importlib_metadata.PackageNotFoundError:
# 如果未找到 "imageio",设置其可用状态为 False
_imageio_available = False
# 定义函数以返回 "torch" 库的可用状态
def is_torch_available():
return _torch_available
# 定义函数以返回 "torch_xla" 库的可用状态
def is_torch_xla_available():
return _torch_xla_available
# 定义函数以返回 "torch_npu" 库的可用状态
def is_torch_npu_available():
return _torch_npu_available
# 定义函数以返回 "flax" 库的可用状态
def is_flax_available():
return _flax_available
# 定义函数以返回 "transformers" 库的可用状态
def is_transformers_available():
return _transformers_available
# 定义函数以返回 "inflect" 库的可用状态
def is_inflect_available():
return _inflect_available
# 定义函数以返回 "unidecode" 库的可用状态
def is_unidecode_available():
return _unidecode_available
# 定义函数以返回 "onnx" 库的可用状态
def is_onnx_available():
return _onnx_available
# 定义函数以返回 "opencv" 库的可用状态
def is_opencv_available():
return _opencv_available
# 定义函数以返回 "scipy" 库的可用状态
def is_scipy_available():
return _scipy_available
# 定义函数以返回 "librosa" 库的可用状态
def is_librosa_available():
return _librosa_available
# 定义函数以返回 "xformers" 库的可用状态
def is_xformers_available():
return _xformers_available
# 定义函数以返回 "accelerate" 库的可用状态
def is_accelerate_available():
return _accelerate_available
# 定义函数以返回 "k_diffusion" 库的可用状态
def is_k_diffusion_available():
return _k_diffusion_available
# 定义函数以返回 "note_seq" 库的可用状态
def is_note_seq_available():
return _note_seq_available
# 定义函数以返回 "wandb" 库的可用状态
def is_wandb_available():
return _wandb_available
# 定义函数以返回 "tensorboard" 库的可用状态
def is_tensorboard_available():
return _tensorboard_available
# 定义函数以返回 "compel" 库的可用状态
def is_compel_available():
return _compel_available
# 定义函数以返回 "ftfy" 库的可用状态
def is_ftfy_available():
return _ftfy_available
# 定义函数以返回 "bs4" 库的可用状态
def is_bs4_available():
return _bs4_available
# 定义函数以返回 "torchsde" 库的可用状态
def is_torchsde_available():
return _torchsde_available
# 定义函数以返回 "invisible_watermark" 库的可用状态
def is_invisible_watermark_available():
return _invisible_watermark_available
# 定义函数以返回 "peft" 库的可用状态
def is_peft_available():
return _peft_available
# 定义函数以返回 "torchvision" 库的可用状态
def is_torchvision_available():
return _torchvision_available
# 定义函数以返回 "matplotlib" 库的可用状态
def is_matplotlib_available():
return _matplotlib_available
# 定义函数以返回 "safetensors" 库的可用状态
def is_safetensors_available():
return _safetensors_available
# 定义函数以返回 "bitsandbytes" 库的可用状态
def is_bitsandbytes_available():
return _bitsandbytes_available
# 定义函数以返回是否在 Google Colab 环境中
def is_google_colab():
return _is_google_colab
# 定义函数以返回 "sentencepiece" 库的可用状态
def is_sentencepiece_available():
# 返回句子分割工具的可用性状态
return _sentencepiece_available
# 定义一个函数来检查 imageio 库是否可用
def is_imageio_available():
# 返回 _imageio_available 的值,表示 imageio 库的可用性
return _imageio_available
# docstyle-ignore
# 定义一个字符串,提示用户缺少 FLAX 库并提供安装链接
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 inflect 库并提供安装命令
INFLECT_IMPORT_ERROR = """
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
inflect`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 PyTorch 库并提供安装链接
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 onnxruntime 库并提供安装命令
ONNX_IMPORT_ERROR = """
{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
install onnxruntime`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 OpenCV 库并提供安装命令
OPENCV_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip
install opencv-python`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 scipy 库并提供安装命令
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
scipy`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 librosa 库并提供安装链接
LIBROSA_IMPORT_ERROR = """
{0} requires the librosa library but it was not found in your environment. Checkout the instructions on the
installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment.
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 transformers 库并提供安装命令
TRANSFORMERS_IMPORT_ERROR = """
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
install transformers`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 unidecode 库并提供安装命令
UNIDECODE_IMPORT_ERROR = """
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
Unidecode`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 k-diffusion 库并提供安装命令
K_DIFFUSION_IMPORT_ERROR = """
{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
install k-diffusion`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 note-seq 库并提供安装命令
NOTE_SEQ_IMPORT_ERROR = """
{0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip
install note-seq`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 wandb 库并提供安装命令
WANDB_IMPORT_ERROR = """
{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip
install wandb`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 tensorboard 库并提供安装命令
TENSORBOARD_IMPORT_ERROR = """
{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip
install tensorboard`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 compel 库并提供安装命令
COMPEL_IMPORT_ERROR = """
{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
"""
# docstyle-ignore
# 定义一个字符串,提示用户缺少 Beautiful Soup 库并提供安装命令
BS4_IMPORT_ERROR = """
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 ftfy 库
FTFY_IMPORT_ERROR = """
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 torchsde 库
TORCHSDE_IMPORT_ERROR = """
{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde`
"""
# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 invisible-watermark 库
INVISIBLE_WATERMARK_IMPORT_ERROR = """
{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0`
"""
# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 peft 库
PEFT_IMPORT_ERROR = """
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft`
"""
# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 safetensors 库
SAFETENSORS_IMPORT_ERROR = """
{0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors`
"""
# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 sentencepiece 库
SENTENCEPIECE_IMPORT_ERROR = """
{0} requires the sentencepiece library but it was not found in your environment. You can install it with pip: `pip install sentencepiece`
"""
# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 bitsandbytes 库
BITSANDBYTES_IMPORT_ERROR = """
{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
"""
# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 imageio 库和 ffmpeg
IMAGEIO_IMPORT_ERROR = """
{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
"""
# 定义一个有序字典,用于映射后端
BACKENDS_MAPPING = OrderedDict(
# 创建一个包含库名称及其可用性检查和导入错误的元组列表
[
# BS4 库的可用性检查及对应的导入错误信息
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
# Flax 库的可用性检查及对应的导入错误信息
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
# Inflect 库的可用性检查及对应的导入错误信息
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
# ONNX 库的可用性检查及对应的导入错误信息
("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
# OpenCV 库的可用性检查及对应的导入错误信息
("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)),
# SciPy 库的可用性检查及对应的导入错误信息
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
# PyTorch 库的可用性检查及对应的导入错误信息
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
# Transformers 库的可用性检查及对应的导入错误信息
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
# Unidecode 库的可用性检查及对应的导入错误信息
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
# Librosa 库的可用性检查及对应的导入错误信息
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
# K-Diffusion 库的可用性检查及对应的导入错误信息
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
# Note Seq 库的可用性检查及对应的导入错误信息
("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)),
# WandB 库的可用性检查及对应的导入错误信息
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
# TensorBoard 库的可用性检查及对应的导入错误信息
("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
# Compel 库的可用性检查及对应的导入错误信息
("compel", (is_compel_available, COMPEL_IMPORT_ERROR)),
# FTFY 库的可用性检查及对应的导入错误信息
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
# TorchSDE 库的可用性检查及对应的导入错误信息
("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
# Invisible Watermark 库的可用性检查及对应的导入错误信息
("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
# PEFT 库的可用性检查及对应的导入错误信息
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
# SafeTensors 库的可用性检查及对应的导入错误信息
("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
# BitsAndBytes 库的可用性检查及对应的导入错误信息
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
# SentencePiece 库的可用性检查及对应的导入错误信息
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
# ImageIO 库的可用性检查及对应的导入错误信息
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
]
# 定义一个装饰器函数,检查所需后端是否可用
def requires_backends(obj, backends):
# 如果后端参数不是列表或元组,则将其转换为列表
if not isinstance(backends, (list, tuple)):
backends = [backends]
# 获取对象的名称,如果对象有 __name__ 属性则使用该属性,否则使用类名
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
# 根据后端列表生成检查器元组
checks = (BACKENDS_MAPPING[backend] for backend in backends)
# 遍历检查器,生成未通过检查的错误信息列表
failed = [msg.format(name) for available, msg in checks if not available()]
# 如果有检查失败,则抛出导入错误,显示所有错误信息
if failed:
raise ImportError("".join(failed))
# 如果对象名称在特定管道列表中,并且 transformers 版本小于 4.25.0,则抛出错误
if name in [
"VersatileDiffusionTextToImagePipeline",
"VersatileDiffusionPipeline",
"VersatileDiffusionDualGuidedPipeline",
"StableDiffusionImageVariationPipeline",
"UnCLIPPipeline",
] and is_transformers_version("<", "4.25.0"):
raise ImportError(
f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install"
" --upgrade transformers \n```py"
)
# 如果对象名称在另一个特定管道列表中,并且 transformers 版本小于 4.26.0,则抛出错误
if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
"<", "4.26.0"
):
raise ImportError(
f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
" --upgrade transformers \n```py"
)
# 定义一个元类,用于生成虚拟对象
class DummyObject(type):
"""
Dummy 对象的元类。任何继承自它的类在用户尝试访问其任何方法时都会返回由
`requires_backend` 生成的 ImportError。
"""
# 重写 __getattr__ 方法,处理属性访问
def __getattr__(cls, key):
# 如果属性以 "_" 开头且不是特定的两个属性,则返回默认实现
if key.startswith("_") and key not in ["_load_connected_pipes", "_is_onnx"]:
return super().__getattr__(cls, key)
# 检查类是否满足后端要求
requires_backends(cls, cls._backends)
# 该函数用于比较库版本与要求版本的关系,来源于指定的 GitHub 链接
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
"""
参数:
比较库版本与某个要求之间的关系,使用给定的操作符。
library_or_version (`str` 或 `packaging.version.Version`):
要检查的库名或版本。
operation (`str`):
操作符的字符串表示,如 `">"` 或 `"<="`。
requirement_version (`str`):
要与库版本比较的版本
"""
# 检查操作符是否在预定义的操作符字典中
if operation not in STR_OPERATION_TO_FUNC.keys():
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
# 获取对应的操作函数
operation = STR_OPERATION_TO_FUNC[operation]
# 如果传入的库或版本是字符串,则获取其版本信息
if isinstance(library_or_version, str):
library_or_version = parse(importlib_metadata.version(library_or_version))
# 使用操作函数比较版本,并返回结果
return operation(library_or_version, parse(requirement_version))
# 该函数用于检查 PyTorch 版本与给定要求的关系,来源于指定的 GitHub 链接
def is_torch_version(operation: str, version: str):
"""
参数:
# 比较当前 PyTorch 版本与给定参考版本及操作符
# operation (`str`): 操作符的字符串表示,例如 `">"` 或 `"<="`
# version (`str`): PyTorch 的字符串版本
"""
# 返回解析后的当前 PyTorch 版本与指定版本的比较结果
return compare_versions(parse(_torch_version), operation, version)
# 定义一个函数,用于比较当前 Transformers 版本与给定版本之间的关系
def is_transformers_version(operation: str, version: str):
"""
Args:
Compares the current Transformers version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
# 检查 Transformers 是否可用,如果不可用则返回 False
if not _transformers_available:
return False
# 将当前 Transformers 版本与给定版本进行比较,并返回比较结果
return compare_versions(parse(_transformers_version), operation, version)
# 定义一个函数,用于比较当前 Accelerate 版本与给定版本之间的关系
def is_accelerate_version(operation: str, version: str):
"""
Args:
Compares the current Accelerate version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
# 检查 Accelerate 是否可用,如果不可用则返回 False
if not _accelerate_available:
return False
# 将当前 Accelerate 版本与给定版本进行比较,并返回比较结果
return compare_versions(parse(_accelerate_version), operation, version)
# 定义一个函数,用于比较当前 PEFT 版本与给定版本之间的关系
def is_peft_version(operation: str, version: str):
"""
Args:
Compares the current PEFT version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
# 检查 PEFT 版本是否可用,如果不可用则返回 False
if not _peft_version:
return False
# 将当前 PEFT 版本与给定版本进行比较,并返回比较结果
return compare_versions(parse(_peft_version), operation, version)
# 定义一个函数,用于比较当前 k-diffusion 版本与给定版本之间的关系
def is_k_diffusion_version(operation: str, version: str):
"""
Args:
Compares the current k-diffusion version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
# 检查 k-diffusion 是否可用,如果不可用则返回 False
if not _k_diffusion_available:
return False
# 将当前 k-diffusion 版本与给定版本进行比较,并返回比较结果
return compare_versions(parse(_k_diffusion_version), operation, version)
# 定义一个函数,用于从指定模块中提取对象
def get_objects_from_module(module):
"""
Args:
Returns a dict of object names and values in a module, while skipping private/internal objects
module (ModuleType):
Module to extract the objects from.
Returns:
dict: Dictionary of object names and corresponding values
"""
# 创建一个空字典,用于存储对象名称及其对应的值
objects = {}
# 遍历模块中的所有对象名称
for name in dir(module):
# 跳过以 "_" 开头的私有对象
if name.startswith("_"):
continue
# 获取对象并将名称和值存入字典
objects[name] = getattr(module, name)
# 返回包含对象名称及值的字典
return objects
# 定义一个异常类,用于表示可选依赖未在环境中找到的错误
class OptionalDependencyNotAvailable(BaseException):
"""An error indicating that an optional dependency of Diffusers was not found in the environment."""
# 定义一个懒加载模块类,只有在请求对象时才会执行相关导入
class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# 受到 optuna.integration._IntegrationModule 的启发
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
# 初始化方法,接收模块的名称、文件、导入结构及可选参数
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
# 调用父类初始化方法
super().__init__(name)
# 存储导入结构中模块的集合
self._modules = set(import_structure.keys())
# 存储类与模块的映射字典
self._class_to_module = {}
# 遍历导入结构,将每个类与其对应模块关联
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# 为 IDE 的自动补全准备的属性
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
# 记录模块文件的路径
self.__file__ = module_file
# 记录模块的规格
self.__spec__ = module_spec
# 存储模块路径
self.__path__ = [os.path.dirname(module_file)]
# 存储额外对象,如果没有则初始化为空字典
self._objects = {} if extra_objects is None else extra_objects
# 存储模块名称
self._name = name
# 存储导入结构
self._import_structure = import_structure
# 为 IDE 的自动补全准备的方法
def __dir__(self):
# 获取当前对象的所有属性
result = super().__dir__()
# 将未包含在当前属性中的__all__元素添加到结果中
for attr in self.__all__:
if attr not in result:
result.append(attr)
# 返回包含所有属性的列表
return result
def __getattr__(self, name: str) -> Any:
# 如果对象中包含该名称,返回对应对象
if name in self._objects:
return self._objects[name]
# 如果名称在模块中,获取模块
if name in self._modules:
value = self._get_module(name)
# 如果名称在类到模块的映射中,获取对应模块的属性
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
# 如果名称不存在,抛出属性错误
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
# 将获取到的值设置为对象的属性
setattr(self, name, value)
# 返回值
return value
def _get_module(self, module_name: str):
# 尝试导入指定模块
try:
return importlib.import_module("." + module_name, self.__name__)
# 捕获异常,抛出运行时错误
except Exception as e:
raise RuntimeError(
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
f" traceback):\n{e}"
) from e
def __reduce__(self):
# 返回对象的序列化信息
return (self.__class__, (self._name, self.__file__, self._import_structure))
.\diffusers\utils\loading_utils.py
# 导入操作系统模块
import os
# 导入临时文件模块
import tempfile
# 导入类型相关的类型提示
from typing import Callable, List, Optional, Union
# 导入 PIL 库中的图像模块
import PIL.Image
# 导入 PIL 库中的图像处理模块
import PIL.ImageOps
# 导入请求库
import requests
# 从本地模块中导入一些实用工具
from .import_utils import BACKENDS_MAPPING, is_imageio_available
# 定义加载图像的函数
def load_image(
# 接受字符串或 PIL 图像作为输入
image: Union[str, PIL.Image.Image],
# 可选的转换方法,用于加载后处理图像
convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
) -> PIL.Image.Image:
"""
加载给定的 `image` 为 PIL 图像。
参数:
image (`str` 或 `PIL.Image.Image`):
要转换为 PIL 图像格式的图像。
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
加载后应用于图像的转换方法。如果为 `None`,则图像将被转换为 "RGB"。
返回:
`PIL.Image.Image`:
一个 PIL 图像。
"""
# 检查 image 是否为字符串类型
if isinstance(image, str):
# 如果字符串以 http 或 https 开头,则认为是 URL
if image.startswith("http://") or image.startswith("https://"):
# 通过请求获取图像并打开为 PIL 图像
image = PIL.Image.open(requests.get(image, stream=True).raw)
# 检查给定路径是否为有效文件
elif os.path.isfile(image):
# 打开本地文件为 PIL 图像
image = PIL.Image.open(image)
else:
# 如果路径无效,抛出错误
raise ValueError(
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
)
# 检查 image 是否为 PIL 图像对象
elif isinstance(image, PIL.Image.Image):
# 如果是 PIL 图像,保持不变
image = image
else:
# 如果格式不正确,抛出错误
raise ValueError(
"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
)
# 应用 EXIF 转换,调整图像方向
image = PIL.ImageOps.exif_transpose(image)
# 如果提供了转换方法,则应用该方法
if convert_method is not None:
image = convert_method(image)
else:
# 否则将图像转换为 RGB 格式
image = image.convert("RGB")
# 返回处理后的图像
return image
# 定义加载视频的函数
def load_video(
# 接受视频的字符串路径或 URL
video: str,
# 可选的转换方法,用于加载后处理图像列表
convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
) -> List[PIL.Image.Image]:
"""
加载给定的 `video` 为 PIL 图像列表。
参数:
video (`str`):
视频的 URL 或路径,转换为 PIL 图像列表。
convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
加载后应用于视频的转换方法。如果为 `None`,则图像将被转换为 "RGB"。
返回:
`List[PIL.Image.Image]`:
视频作为 PIL 图像列表。
"""
# 检查视频是否为 URL
is_url = video.startswith("http://") or video.startswith("https://")
# 检查视频路径是否为有效文件
is_file = os.path.isfile(video)
# 标记是否创建了临时文件
was_tempfile_created = False
# 如果既不是 URL 也不是文件,抛出错误
if not (is_url or is_file):
raise ValueError(
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
)
# 如果是 URL,获取视频数据
if is_url:
video_data = requests.get(video, stream=True).raw
# 获取视频的文件后缀,如果没有则默认使用 .mp4
suffix = os.path.splitext(video)[1] or ".mp4"
# 创建一个带后缀的临时文件
video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
# 标记已创建临时文件
was_tempfile_created = True
# 将视频数据写入临时文件
with open(video_path, "wb") as f:
f.write(video_data.read())
# 更新视频变量为临时文件路径
video = video_path
# 初始化一个空列表,用于存储 PIL 图像
pil_images = []
# 检查视频文件是否为 GIF 格式
if video.endswith(".gif"):
# 打开 GIF 文件
gif = PIL.Image.open(video)
try:
# 持续读取 GIF 的每一帧
while True:
# 将当前帧的副本添加到列表中
pil_images.append(gif.copy())
# 移动到下一帧
gif.seek(gif.tell() + 1)
# 捕捉到 EOFError 时停止循环
except EOFError:
pass
else:
# 检查 imageio 库是否可用
if is_imageio_available():
# 导入 imageio 库
import imageio
else:
# 如果不可用,抛出导入错误
raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video"))
try:
# 尝试获取 ffmpeg 执行文件
imageio.plugins.ffmpeg.get_exe()
except AttributeError:
# 如果未找到 ffmpeg,抛出属性错误
raise AttributeError(
"`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg"
)
# 使用 imageio 创建视频读取器
with imageio.get_reader(video) as reader:
# 读取所有帧
for frame in reader:
# 将每一帧转换为 PIL 图像并添加到列表中
pil_images.append(PIL.Image.fromarray(frame))
# 如果创建了临时文件,删除它
if was_tempfile_created:
os.remove(video_path)
# 如果提供了转换方法,应用该方法到 PIL 图像列表
if convert_method is not None:
pil_images = convert_method(pil_images)
# 返回 PIL 图像列表
return pil_images