diffusers-源码解析-六十五-

龙哥盟 / 2024-11-09 / 原文

diffusers 源码解析(六十五)

.\diffusers\utils\dummy_transformers_and_torch_and_note_seq_objects.py

# 该文件由命令 `make fix-copies` 自动生成,请勿编辑。
from ..utils import DummyObject, requires_backends  # 从上级目录的 utils 模块导入 DummyObject 和 requires_backends 函数


class SpectrogramDiffusionPipeline(metaclass=DummyObject):  # 定义 SpectrogramDiffusionPipeline 类,使用 DummyObject 作为其元类
    _backends = ["transformers", "torch", "note_seq"]  # 定义类属性 _backends,包含支持的后端列表

    def __init__(self, *args, **kwargs):  # 初始化方法,接受可变参数
        requires_backends(self, ["transformers", "torch", "note_seq"])  # 检查当前实例是否支持指定的后端

    @classmethod  # 定义类方法
    def from_config(cls, *args, **kwargs):  # 接受可变参数
        requires_backends(cls, ["transformers", "torch", "note_seq"])  # 检查类是否支持指定的后端

    @classmethod  # 定义类方法
    def from_pretrained(cls, *args, **kwargs):  # 接受可变参数
        requires_backends(cls, ["transformers", "torch", "note_seq"])  # 检查类是否支持指定的后端

.\diffusers\utils\dynamic_modules_utils.py

# coding=utf-8  # 指定文件编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team.  # 版权声明,指明版权所有者
#
# Licensed under the Apache License, Version 2.0 (the "License");  # 指明许可证类型
# you may not use this file except in compliance with the License.  # 说明使用条件
# You may obtain a copy of the License at  # 提供获取许可证的链接
#
#     http://www.apache.org/licenses/LICENSE-2.0  # 许可证的具体链接
#
# Unless required by applicable law or agreed to in writing, software  # 说明免责条款
# distributed under the License is distributed on an "AS IS" BASIS,  # 指出软件按原样分发
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  # 声明无任何保证或条件
# See the License for the specific language governing permissions and  # 提示查看许可证以获取权限说明
# limitations under the License.  # 说明许可证的限制
"""Utilities to dynamically load objects from the Hub."""  # 模块描述,说明功能

import importlib  # 导入模块以动态导入其他模块
import inspect  # 导入模块以检查对象的类型和属性
import json  # 导入模块以处理 JSON 数据
import os  # 导入模块以进行操作系统相关的操作
import re  # 导入模块以进行正则表达式匹配
import shutil  # 导入模块以进行文件和目录操作
import sys  # 导入模块以访问解释器使用的变量和函数
from pathlib import Path  # 从路径模块导入 Path 类以处理文件路径
from typing import Dict, Optional, Union  # 从 typing 模块导入类型注解
from urllib import request  # 导入模块以进行 URL 请求

from huggingface_hub import hf_hub_download, model_info  # 从 huggingface_hub 导入特定功能
from huggingface_hub.utils import RevisionNotFoundError, validate_hf_hub_args  # 导入错误和验证函数
from packaging import version  # 导入版本管理模块

from .. import __version__  # 从上级模块导入当前版本
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging  # 导入当前包中的常量和日志模块


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name  # 初始化日志记录器
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror  # 提供社区管道镜像的链接
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"  # 定义社区管道镜像的 ID


def get_diffusers_versions():  # 定义获取 diffusers 版本的函数
    url = "https://pypi.org/pypi/diffusers/json"  # 设置获取版本信息的 URL
    releases = json.loads(request.urlopen(url).read())["releases"].keys()  # 请求 URL 并解析 JSON,获取版本键
    return sorted(releases, key=lambda x: version.Version(x))  # 返回按版本排序的版本列表


def init_hf_modules():  # 定义初始化 HF 模块的函数
    """
    Creates the cache directory for modules with an init, and adds it to the Python path.
    """  # 函数说明,创建缓存目录并添加到 Python 路径
    # This function has already been executed if HF_MODULES_CACHE already is in the Python path.  # 如果缓存目录已在路径中,直接返回
    if HF_MODULES_CACHE in sys.path:  # 检查缓存目录是否在 Python 路径中
        return  # 如果在,则退出函数

    sys.path.append(HF_MODULES_CACHE)  # 将缓存目录添加到 Python 路径
    os.makedirs(HF_MODULES_CACHE, exist_ok=True)  # 创建缓存目录,如果已存在则不报错
    init_path = Path(HF_MODULES_CACHE) / "__init__.py"  # 定义缓存目录中初始化文件的路径
    if not init_path.exists():  # 检查初始化文件是否存在
        init_path.touch()  # 如果不存在,则创建初始化文件


def create_dynamic_module(name: Union[str, os.PathLike]):  # 定义创建动态模块的函数,接受字符串或路径对象
    """
    Creates a dynamic module in the cache directory for modules.
    """  # 函数说明,创建动态模块
    init_hf_modules()  # 调用初始化函数以确保缓存目录存在
    dynamic_module_path = Path(HF_MODULES_CACHE) / name  # 定义动态模块的路径
    # If the parent module does not exist yet, recursively create it.  # 如果父模块不存在,则递归创建
    if not dynamic_module_path.parent.exists():  # 检查父模块路径是否存在
        create_dynamic_module(dynamic_module_path.parent)  # 如果不存在,则递归调用创建函数
    os.makedirs(dynamic_module_path, exist_ok=True)  # 创建动态模块目录,如果已存在则不报错
    init_path = dynamic_module_path / "__init__.py"  # 定义动态模块中初始化文件的路径
    if not init_path.exists():  # 检查初始化文件是否存在
        init_path.touch()  # 如果不存在,则创建初始化文件


def get_relative_imports(module_file):  # 定义获取相对导入的函数,接受模块文件路径
    """
    Get the list of modules that are relatively imported in a module file.

    Args:
        module_file (`str` or `os.PathLike`): The module file to inspect.  # 函数说明,接受字符串或路径对象作为参数
    """
    with open(module_file, "r", encoding="utf-8") as f:  # 以 UTF-8 编码打开模块文件
        content = f.read()  # 读取文件内容

    # Imports of the form `import .xxx`  # 说明以下是相对导入的处理
    # 使用正则表达式查找以相对导入形式书写的模块名
        relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
        # 查找以相对导入形式书写的具体从属模块名
        relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
        # 将结果去重,确保唯一性
        return list(set(relative_imports))
# 获取给定模块所需的所有文件列表,包括相对导入的文件
def get_relative_import_files(module_file):
    # 初始化一个标志,用于控制递归循环
    no_change = False
    # 存储待检查的文件列表,初始为传入的模块文件
    files_to_check = [module_file]
    # 存储所有找到的相对导入文件
    all_relative_imports = []

    # 递归遍历所有相对导入文件
    while not no_change:
        # 存储新发现的导入文件
        new_imports = []
        # 遍历待检查的文件列表
        for f in files_to_check:
            # 获取当前文件的相对导入,并添加到新导入列表中
            new_imports.extend(get_relative_imports(f))

        # 获取当前模块文件的目录路径
        module_path = Path(module_file).parent
        # 将新导入的模块文件转为绝对路径
        new_import_files = [str(module_path / m) for m in new_imports]
        # 过滤掉已经找到的导入文件
        new_import_files = [f for f in new_import_files if f not in all_relative_imports]
        # 更新待检查的文件列表,添加 .py 后缀
        files_to_check = [f"{f}.py" for f in new_import_files]

        # 检查是否有新导入文件,如果没有,则结束循环
        no_change = len(new_import_files) == 0
        # 将当前待检查文件加入所有相对导入列表
        all_relative_imports.extend(files_to_check)

    # 返回所有找到的相对导入文件
    return all_relative_imports


# 检查当前 Python 环境是否包含文件中导入的所有库
def check_imports(filename):
    # 以 UTF-8 编码打开指定文件并读取内容
    with open(filename, "r", encoding="utf-8") as f:
        content = f.read()

    # 正则表达式查找 `import xxx` 形式的导入
    imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
    # 正则表达式查找 `from xxx import yyy` 形式的导入
    imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
    # 仅保留顶级模块,过滤掉相对导入
    imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]

    # 去重并确保导入模块的唯一性
    imports = list(set(imports))
    # 存储缺失的包列表
    missing_packages = []
    # 遍历每个导入模块并尝试导入
    for imp in imports:
        try:
            importlib.import_module(imp)
        except ImportError:
            # 如果导入失败,记录缺失的包
            missing_packages.append(imp)

    # 如果有缺失的包,抛出 ImportError 异常并提示用户
    if len(missing_packages) > 0:
        raise ImportError(
            "This modeling file requires the following packages that were not found in your environment: "
            f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
        )

    # 返回文件的所有相对导入文件
    return get_relative_imports(filename)


# 从模块缓存中导入指定的类
def get_class_in_module(class_name, module_path):
    # 将模块路径中的分隔符替换为点,以便导入
    module_path = module_path.replace(os.path.sep, ".")
    # 导入指定模块
    module = importlib.import_module(module_path)

    # 如果类名为空,查找管道类
    if class_name is None:
        return find_pipeline_class(module)
    # 返回指定类的引用
    return getattr(module, class_name)


# 获取继承自 `DiffusionPipeline` 的管道类
def find_pipeline_class(loaded_module):
    # 从上级导入 DiffusionPipeline 类
    from ..pipelines import DiffusionPipeline

    # 获取加载模块中所有的类成员
    cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))

    # 初始化管道类变量
    pipeline_class = None
    # 遍历 cls_members 字典中的每个类名及其对应的类
    for cls_name, cls in cls_members.items():
        # 检查类名不是 DiffusionPipeline 的名称,且是其子类,且模块不是 diffusers
        if (
            cls_name != DiffusionPipeline.__name__
            and issubclass(cls, DiffusionPipeline)
            and cls.__module__.split(".")[0] != "diffusers"
        ):
            # 如果已经找到一个管道类,则抛出值错误,表示发现多个类
            if pipeline_class is not None:
                raise ValueError(
                    # 错误信息,包含找到的多个类的信息
                    f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
                    f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
                    f" {loaded_module}."
                )
            # 记录找到的管道类
            pipeline_class = cls

    # 返回找到的管道类
    return pipeline_class
# 装饰器,用于验证传入的参数是否符合预期
@validate_hf_hub_args
# 定义获取缓存模块文件的函数,接受多个参数
def get_cached_module_file(
    # 预训练模型名称或路径,可以是字符串或路径类型
    pretrained_model_name_or_path: Union[str, os.PathLike],
    # 模块文件名称,字符串类型
    module_file: str,
    # 缓存目录,可选参数,路径类型或字符串
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    # 强制下载标志,可选参数,布尔类型,默认为 False
    force_download: bool = False,
    # 代理服务器字典,可选参数
    proxies: Optional[Dict[str, str]] = None,
    # 授权令牌,可选参数,可以是布尔或字符串类型
    token: Optional[Union[bool, str]] = None,
    # 版本修订信息,可选参数,字符串类型
    revision: Optional[str] = None,
    # 仅使用本地文件标志,可选参数,布尔类型,默认为 False
    local_files_only: bool = False,
):
    """
    准备从本地文件夹或远程仓库下载模块,并返回其在缓存中的路径。
    
    参数:
        pretrained_model_name_or_path (`str` 或 `os.PathLike`):
            可以是预训练模型配置的模型 ID 或包含配置文件的目录路径。
        module_file (`str`):
            包含要查找的类的模块文件名称。
        cache_dir (`str` 或 `os.PathLike`, *可选*):
            下载的预训练模型配置的缓存目录路径。
        force_download (`bool`, *可选*, 默认值为 `False`):
            是否强制重新下载配置文件并覆盖已存在的缓存版本。
        proxies (`Dict[str, str]`, *可选*):
            代理服务器字典,用于每个请求。
        token (`str` 或 *bool*, *可选*):
            用作远程文件的 HTTP 授权令牌。
        revision (`str`, *可选*, 默认值为 `"main"`):
            要使用的特定模型版本。
        local_files_only (`bool`, *可选*, 默认值为 `False`):
            如果为 `True`,仅尝试从本地文件加载配置。
    
    返回:
        `str`: 模块在缓存中的路径。
    """
    # 下载并缓存来自 `pretrained_model_name_or_path` 的 module_file,或获取本地文件
    pretrained_model_name_or_path = str(pretrained_model_name_or_path)  # 将预训练模型路径转换为字符串

    module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)  # 组合路径和文件名形成完整路径

    if os.path.isfile(module_file_or_url):  # 检查该路径是否指向一个文件
        resolved_module_file = module_file_or_url  # 如果是文件,保存该路径
        submodule = "local"  # 标记为本地文件
    elif pretrained_model_name_or_path.count("/") == 0:  # 如果路径中没有斜杠,表示这是一个模型名称而非路径
        available_versions = get_diffusers_versions()  # 获取可用版本列表
        # 去掉 ".dev0" 部分
        latest_version = "v" + ".".join(__version__.split(".")[:3])  # 获取最新的版本号

        # 获取匹配的 GitHub 版本
        if revision is None:  # 如果没有指定修订版本
            revision = latest_version if latest_version[1:] in available_versions else "main"  # 默认选择最新版本或主分支
            logger.info(f"Defaulting to latest_version: {revision}.")  # 记录默认选择的版本
        elif revision in available_versions:  # 如果指定版本在可用列表中
            revision = f"v{revision}"  # 格式化版本号
        elif revision == "main":  # 如果指定版本为主分支
            revision = revision  # 保持不变
        else:  # 如果指定版本不在可用版本中
            raise ValueError(
                f"`custom_revision`: {revision} does not exist. Please make sure to choose one of"
                f" {', '.join(available_versions + ['main'])}."  # 提示可用版本
            )

        try:
            resolved_module_file = hf_hub_download(  # 从 Hugging Face Hub 下载指定文件
                repo_id=COMMUNITY_PIPELINES_MIRROR_ID,  # 设定资源库 ID
                repo_type="dataset",  # 指定资源类型为数据集
                filename=f"{revision}/{pretrained_model_name_or_path}.py",  # 构造文件名
                cache_dir=cache_dir,  # 设置缓存目录
                force_download=force_download,  # 决定是否强制下载
                proxies=proxies,  # 设置代理
                local_files_only=local_files_only,  # 是否只考虑本地文件
            )
            submodule = "git"  # 标记为从 GitHub 下载的文件
            module_file = pretrained_model_name_or_path + ".py"  # 更新模块文件名
        except RevisionNotFoundError as e:  # 捕获未找到修订版本的异常
            raise EnvironmentError(
                f"Revision '{revision}' not found in the community pipelines mirror. Check available revisions on"
                " https://huggingface.co/datasets/diffusers/community-pipelines-mirror/tree/main."
                " If you don't find the revision you are looking for, please open an issue on https://github.com/huggingface/diffusers/issues."
            ) from e  # 抛出环境错误并提供信息
        except EnvironmentError:  # 捕获环境错误
            logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")  # 记录错误
            raise  # 重新抛出异常
    else:  # 如果不是文件
        try:
            # 从 URL 加载或从缓存加载(如果已缓存)
            resolved_module_file = hf_hub_download(  # 从 Hugging Face Hub 下载文件
                pretrained_model_name_or_path,  # 使用预训练模型名称作为资源路径
                module_file,  # 指定模块文件名
                cache_dir=cache_dir,  # 设置缓存目录
                force_download=force_download,  # 决定是否强制下载
                proxies=proxies,  # 设置代理
                local_files_only=local_files_only,  # 是否只考虑本地文件
                token=token,  # 传递身份验证令牌
            )
            submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))  # 构造本地子模块路径
        except EnvironmentError:  # 捕获环境错误
            logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")  # 记录错误
            raise  # 重新抛出异常
    # 检查环境中是否具备所有所需的模块
    modules_needed = check_imports(resolved_module_file)

    # 将模块移动到我们的缓存动态模块中
    full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
    create_dynamic_module(full_submodule)  # 创建动态模块
    submodule_path = Path(HF_MODULES_CACHE) / full_submodule  # 构建子模块路径
    if submodule == "local" or submodule == "git":  # 检查子模块类型
        # 始终复制本地文件(可以通过哈希判断是否有变化)
        # 复制的目的是避免将过多文件夹放入 sys.path
        shutil.copy(resolved_module_file, submodule_path / module_file)  # 复制模块文件到子模块路径
        for module_needed in modules_needed:  # 遍历所需的模块
            if len(module_needed.split(".")) == 2:  # 检查模块名称是否有两部分
                module_needed = "/".join(module_needed.split("."))  # 将模块名称转换为路径格式
                module_folder = module_needed.split("/")[0]  # 获取模块文件夹名
                if not os.path.exists(submodule_path / module_folder):  # 检查文件夹是否存在
                    os.makedirs(submodule_path / module_folder)  # 创建文件夹
            module_needed = f"{module_needed}.py"  # 添加 .py 后缀
            shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)  # 复制所需模块文件
    else:
        # 获取提交哈希值
        # TODO: 未来将从 etag 获取此信息,而不是这里
        commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha  # 获取模型的提交哈希

        # 模块文件将放置在带有 git 哈希的子文件夹中,以便实现版本控制
        submodule_path = submodule_path / commit_hash  # 更新子模块路径以包含哈希
        full_submodule = full_submodule + os.path.sep + commit_hash  # 更新完整子模块名称
        create_dynamic_module(full_submodule)  # 创建新的动态模块

        if not (submodule_path / module_file).exists():  # 检查模块文件是否已存在
            if len(module_file.split("/")) == 2:  # 检查模块文件路径是否包含两个部分
                module_folder = module_file.split("/")[0]  # 获取模块文件夹名
                if not os.path.exists(submodule_path / module_folder):  # 检查文件夹是否存在
                    os.makedirs(submodule_path / module_folder)  # 创建文件夹
            shutil.copy(resolved_module_file, submodule_path / module_file)  # 复制模块文件

        # 确保每个相对文件都存在
        for module_needed in modules_needed:  # 遍历所需模块
            if len(module_needed.split(".")) == 2:  # 检查模块名称是否有两部分
                module_needed = "/".join(module_needed.split("."))  # 将模块名称转换为路径格式
            if not (submodule_path / module_needed).exists():  # 检查模块文件是否存在
                get_cached_module_file(  # 获取缓存的模块文件
                    pretrained_model_name_or_path,
                    f"{module_needed}.py",  # 模块文件名
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    token=token,
                    revision=revision,
                    local_files_only=local_files_only,
                )
    return os.path.join(full_submodule, module_file)  # 返回完整子模块路径及模块文件名
# 装饰器,用于验证传入的参数是否符合预期
@validate_hf_hub_args
def get_class_from_dynamic_module(
    # 预训练模型的名称或路径,可以是字符串或路径类型
    pretrained_model_name_or_path: Union[str, os.PathLike],
    # 模块文件的名称,包含要查找的类
    module_file: str,
    # 要导入的类的名称,默认为 None
    class_name: Optional[str] = None,
    # 缓存目录的路径,默认为 None
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    # 是否强制下载配置文件,默认为 False
    force_download: bool = False,
    # 代理服务器的字典,默认为 None
    proxies: Optional[Dict[str, str]] = None,
    # 用于远程文件的 HTTP 认证令牌,默认为 None
    token: Optional[Union[bool, str]] = None,
    # 具体的模型版本,默认为 "main"
    revision: Optional[str] = None,
    # 是否仅加载本地文件,默认为 False
    local_files_only: bool = False,
    # 其他可选的关键字参数
    **kwargs,
):
    """
    从模块文件中提取一个类,该模块文件可以位于本地文件夹或模型的仓库中。

    <Tip warning={true}>

    调用此函数将执行在本地找到的模块文件或从 Hub 下载的代码。
    因此,仅应在可信的仓库上调用。

    </Tip>

    Args:
        # 预训练模型的名称或路径,可以是 huggingface.co 上的模型 id 或本地目录路径
        pretrained_model_name_or_path (`str` or `os.PathLike`):
            可以是字符串,表示在 huggingface.co 上托管的预训练模型配置的模型 id。
            有效的模型 id 可以位于根目录下,例如 `bert-base-uncased`,
            或者在用户或组织名称下命名,例如 `dbmdz/bert-base-german-cased`。
            也可以是一个目录的路径,包含使用 [`~PreTrainedTokenizer.save_pretrained`] 方法保存的配置文件,
            例如 `./my_model_directory/`。

        # 模块文件的名称,包含要查找的类
        module_file (`str`):
            包含要查找的类的模块文件的名称。
        # 类的名称,默认为 None
        class_name (`str`):
            要导入的类的名称。
        # 缓存目录的路径,默认为 None
        cache_dir (`str` or `os.PathLike`, *optional*):
            下载的预训练模型配置应缓存的目录路径,
            如果不使用标准缓存的话。
        # 是否强制下载配置文件,默认为 False
        force_download (`bool`, *optional*, defaults to `False`):
            是否强制重新下载配置文件,并覆盖已存在的缓存版本。
        # 代理服务器的字典,默认为 None
        proxies (`Dict[str, str]`, *optional*):
            以协议或端点为基础使用的代理服务器字典,例如 `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}`。代理将在每个请求中使用。
        # 用于远程文件的 HTTP 认证令牌,默认为 None
        token (`str` or `bool`, *optional*):
            用作远程文件的 HTTP bearer 认证的令牌。
            如果为 True,将使用在运行 `transformers-cli login` 时生成的令牌(存储在 `~/.huggingface` 中)。
        # 具体的模型版本,默认为 "main"
        revision (`str`, *optional*, defaults to `"main"`):
            使用的特定模型版本。可以是分支名、标签名或提交 ID,
            因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,
            所以 `revision` 可以是 git 允许的任何标识符。
        # 是否仅加载本地文件,默认为 False
        local_files_only (`bool`, *optional*, defaults to `False`):
            如果为 True,仅尝试从本地文件加载标记器配置。

    <Tip>
    # 如果未登录(`huggingface-cli login`),可以通过 `token` 参数传递令牌,以便使用私有或
    # [受限模型](https://huggingface.co/docs/hub/models-gated#gated-models)。
        
    # 返回值:
    #     `type`: 从模块动态导入的类。
    
    # 示例:
    
    # ```python
    # 从 huggingface.co 下载模块 `modeling.py`,缓存并提取类 `MyBertModel`。
    cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
    # ```py"""
    # 最后,我们获取新创建模块中的类
    final_module = get_cached_module_file(
        # 获取预训练模型的名称或路径
        pretrained_model_name_or_path,
        # 模块文件名称
        module_file,
        # 缓存目录
        cache_dir=cache_dir,
        # 强制下载标志
        force_download=force_download,
        # 代理设置
        proxies=proxies,
        # 访问令牌
        token=token,
        # 版本控制
        revision=revision,
        # 仅本地文件标志
        local_files_only=local_files_only,
    )
    # 从最终模块中获取类名,去掉 ".py" 后缀
    return get_class_in_module(class_name, final_module.replace(".py", ""))

.\diffusers\utils\export_utils.py

# 导入所需的模块
import io  # 用于处理输入输出
import random  # 用于生成随机数
import struct  # 用于处理C语言风格的二进制数据
import tempfile  # 用于创建临时文件
from contextlib import contextmanager  # 用于创建上下文管理器
from typing import List, Union  # 用于类型注解

import numpy as np  # 用于数值计算
import PIL.Image  # 用于处理图像
import PIL.ImageOps  # 用于图像操作

from .import_utils import BACKENDS_MAPPING, is_imageio_available, is_opencv_available  # 导入工具函数和映射
from .logging import get_logger  # 导入日志记录器

# 创建全局随机数生成器
global_rng = random.Random()

# 获取当前模块的日志记录器
logger = get_logger(__name__)

# 定义上下文管理器以便于缓冲写入
@contextmanager
def buffered_writer(raw_f):
    # 创建缓冲写入对象
    f = io.BufferedWriter(raw_f)
    # 生成缓冲写入对象
    yield f
    # 刷新缓冲区,确保所有数据写入
    f.flush()

# 导出图像列表为 GIF 文件
def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None, fps: int = 10) -> str:
    # 如果没有提供输出路径,则创建一个临时文件
    if output_gif_path is None:
        output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name

    # 保存图像为 GIF 文件
    image[0].save(
        output_gif_path,
        save_all=True,  # 保存所有帧
        append_images=image[1:],  # 附加后续图像
        optimize=False,  # 不优化图像
        duration=1000 // fps,  # 设置帧间隔
        loop=0,  # 循环次数
    )
    # 返回生成的 GIF 文件路径
    return output_gif_path

# 导出网格数据为 PLY 文件
def export_to_ply(mesh, output_ply_path: str = None):
    """
    写入网格的 PLY 文件。
    """
    # 如果没有提供输出路径,则创建一个临时文件
    if output_ply_path is None:
        output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name

    # 获取网格顶点坐标并转换为 NumPy 数组
    coords = mesh.verts.detach().cpu().numpy()
    # 获取网格面信息并转换为 NumPy 数组
    faces = mesh.faces.cpu().numpy()
    # 获取网格顶点的 RGB 颜色信息并堆叠为数组
    rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)

    # 使用缓冲写入器打开输出文件
    with buffered_writer(open(output_ply_path, "wb")) as f:
        f.write(b"ply\n")  # 写入 PLY 文件头
        f.write(b"format binary_little_endian 1.0\n")  # 指定文件格式
        f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))  # 写入顶点数量
        f.write(b"property float x\n")  # 写入 x 坐标属性
        f.write(b"property float y\n")  # 写入 y 坐标属性
        f.write(b"property float z\n")  # 写入 z 坐标属性
        if rgb is not None:  # 如果有 RGB 颜色信息
            f.write(b"property uchar red\n")  # 写入红色属性
            f.write(b"property uchar green\n")  # 写入绿色属性
            f.write(b"property uchar blue\n")  # 写入蓝色属性
        if faces is not None:  # 如果有面信息
            f.write(bytes(f"element face {len(faces)}\n", "ascii"))  # 写入面数量
            f.write(b"property list uchar int vertex_index\n")  # 写入顶点索引属性
        f.write(b"end_header\n")  # 写入文件头结束标记

        if rgb is not None:  # 如果有 RGB 颜色信息
            rgb = (rgb * 255.499).round().astype(int)  # 将 RGB 值转换为整数
            vertices = [
                (*coord, *rgb)  # 合并坐标和颜色信息
                for coord, rgb in zip(
                    coords.tolist(),
                    rgb.tolist(),
                )
            ]
            format = struct.Struct("<3f3B")  # 定义数据格式
            for item in vertices:  # 写入每个顶点数据
                f.write(format.pack(*item))  # 使用打包格式写入数据
        else:  # 如果没有 RGB 信息
            format = struct.Struct("<3f")  # 定义仅包含坐标的数据格式
            for vertex in coords.tolist():  # 写入每个顶点坐标
                f.write(format.pack(*vertex))  # 使用打包格式写入数据

        if faces is not None:  # 如果有面信息
            format = struct.Struct("<B3I")  # 定义面数据格式
            for tri in faces.tolist():  # 写入每个面数据
                f.write(format.pack(len(tri), *tri))  # 使用打包格式写入数据

    # 返回生成的 PLY 文件路径
    return output_ply_path

# 导出网格数据为 OBJ 文件
def export_to_obj(mesh, output_obj_path: str = None):
    # 如果没有提供输出路径,则创建一个临时文件
    if output_obj_path is None:
        output_obj_path = tempfile.NamedTemporaryFile(suffix=".obj").name

    # 获取网格顶点坐标并转换为 NumPy 数组
    verts = mesh.verts.detach().cpu().numpy()
    # 获取网格面信息并转换为 NumPy 数组
    faces = mesh.faces.cpu().numpy()
    # 将网格顶点的颜色通道合并成一个多维数组
        vertex_colors = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
        # 将顶点坐标和颜色组合成格式化字符串,形成顶点列表
        vertices = [
            "{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(verts.tolist(), vertex_colors.tolist())
        ]
    
        # 将每个三角形的索引格式化为面定义字符串,索引加1以符合 OBJ 格式
        faces = ["f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in faces.tolist()]
    
        # 将顶点和面数据合并为最终输出列表
        combined_data = ["v " + vertex for vertex in vertices] + faces
    
        # 打开指定路径的文件以写入数据
        with open(output_obj_path, "w") as f:
            # 将合并的数据写入文件,每个元素占一行
            f.writelines("\n".join(combined_data))
# 导出视频的私有函数,接受视频帧列表、输出视频路径和帧率作为参数
def _legacy_export_to_video(
    # 视频帧,可以是 NumPy 数组或 PIL 图像的列表,输出视频的路径,帧率
    video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
):
    # 检查 OpenCV 是否可用
    if is_opencv_available():
        # 导入 OpenCV 库
        import cv2
    else:
        # 如果不可用,抛出导入错误
        raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video"))
    
    # 如果没有提供输出视频路径,则创建一个临时文件
    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name

    # 如果视频帧是 NumPy 数组,则将其值从 [0, 1] 乘以 255 并转换为 uint8 类型
    if isinstance(video_frames[0], np.ndarray):
        video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]

    # 如果视频帧是 PIL 图像,则将其转换为 NumPy 数组
    elif isinstance(video_frames[0], PIL.Image.Image):
        video_frames = [np.array(frame) for frame in video_frames]

    # 获取视频编码器,指定使用 mp4v 编码
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    # 获取视频帧的高度、宽度和通道数
    h, w, c = video_frames[0].shape
    # 创建视频写入对象,设置输出路径、编码器、帧率和帧尺寸
    video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
    
    # 遍历每一帧,转换颜色并写入视频
    for i in range(len(video_frames)):
        # 将帧从 RGB 转换为 BGR 格式
        img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
        # 将转换后的帧写入视频文件
        video_writer.write(img)

    # 返回输出视频的路径
    return output_video_path


# 导出视频的公共函数,接受视频帧列表、输出视频路径和帧率作为参数
def export_to_video(
    # 视频帧,可以是 NumPy 数组或 PIL 图像的列表,输出视频的路径,帧率
    video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
) -> str:
    # TODO: Dhruv. 在 Diffusers 版本 0.33.0 发布时删除
    # 为了防止现有代码中断而添加的
    # 检查 imageio 是否可用
    if not is_imageio_available():
        # 记录警告信息,提示用户建议使用 imageio 和 imageio-ffmpeg 作为后端
        logger.warning(
            (
                "It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n"
                "These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n"
                "Support for the OpenCV backend will be deprecated in a future Diffusers version"
            )
        )
        # 如果不使用 imageio,则调用旧版导出函数
        return _legacy_export_to_video(video_frames, output_video_path, fps)

    # 如果 imageio 可用,则导入它
    if is_imageio_available():
        import imageio
    else:
        # 如果不可用,抛出导入错误
        raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video"))

    # 尝试获取 imageio ffmpeg 插件的执行文件
    try:
        imageio.plugins.ffmpeg.get_exe()
    except AttributeError:
        # 如果未找到兼容的 ffmpeg 安装,抛出属性错误
        raise AttributeError(
            (
                "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n"
                "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg"
            )
        )

    # 如果没有提供输出视频路径,则创建一个临时文件
    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name

    # 如果视频帧是 NumPy 数组,则将其值从 [0, 1] 乘以 255 并转换为 uint8 类型
    if isinstance(video_frames[0], np.ndarray):
        video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]

    # 如果视频帧是 PIL 图像,则将其转换为 NumPy 数组
    elif isinstance(video_frames[0], PIL.Image.Image):
        video_frames = [np.array(frame) for frame in video_frames]

    # 使用 imageio 创建视频写入器,指定输出路径和帧率
    with imageio.get_writer(output_video_path, fps=fps) as writer:
        # 遍历每一帧,将其附加到视频写入器中
        for frame in video_frames:
            writer.append_data(frame)

    # 返回输出视频的路径
    return output_video_path

.\diffusers\utils\hub_utils.py

# coding=utf-8  # 指定文件的编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team.  # 文件版权信息
#
# Licensed under the Apache License, Version 2.0 (the "License");  # 许可证声明
# you may not use this file except in compliance with the License.  # 使用许可证的条件说明
# You may obtain a copy of the License at  # 许可证获取链接
#
#     http://www.apache.org/licenses/LICENSE-2.0  # 许可证链接
#
# Unless required by applicable law or agreed to in writing, software  # 免责条款
# distributed under the License is distributed on an "AS IS" BASIS,  # 免责条款的进一步说明
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  # 不提供任何明示或暗示的保证
# See the License for the specific language governing permissions and  # 查看许可证以获取特定权限
# limitations under the License.  # 以及在许可证下的限制


import json  # 导入 json 模块,用于处理 JSON 数据
import os  # 导入 os 模块,用于与操作系统交互
import re  # 导入 re 模块,用于正则表达式操作
import sys  # 导入 sys 模块,用于访问与 Python 解释器相关的信息
import tempfile  # 导入 tempfile 模块,用于创建临时文件
import traceback  # 导入 traceback 模块,用于处理异常的跟踪信息
import warnings  # 导入 warnings 模块,用于发出警告
from pathlib import Path  # 从 pathlib 导入 Path 类,用于路径操作
from typing import Dict, List, Optional, Union  # 导入类型提示相关的类
from uuid import uuid4  # 从 uuid 导入 uuid4 函数,用于生成唯一标识符

# 从 huggingface_hub 导入所需的模块和函数
from huggingface_hub import (
    ModelCard,  # 导入 ModelCard 类,用于模型卡片管理
    ModelCardData,  # 导入 ModelCardData 类,用于处理模型卡片数据
    create_repo,  # 导入 create_repo 函数,用于创建模型仓库
    hf_hub_download,  # 导入 hf_hub_download 函数,用于下载模型
    model_info,  # 导入 model_info 函数,用于获取模型信息
    snapshot_download,  # 导入 snapshot_download 函数,用于下载快照
    upload_folder,  # 导入 upload_folder 函数,用于上传文件夹
)
# 导入 huggingface_hub 的常量
from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
# 从 huggingface_hub.file_download 导入正则表达式相关内容
from huggingface_hub.file_download import REGEX_COMMIT_HASH
# 导入 huggingface_hub.utils 的多个异常处理和实用函数
from huggingface_hub.utils import (
    EntryNotFoundError,  # 导入找不到条目的异常
    RepositoryNotFoundError,  # 导入找不到仓库的异常
    RevisionNotFoundError,  # 导入找不到修订版本的异常
    is_jinja_available,  # 导入检查 Jinja 模板是否可用的函数
    validate_hf_hub_args,  # 导入验证 Hugging Face Hub 参数的函数
)
from packaging import version  # 导入 version 模块,用于版本处理
from requests import HTTPError  # 从 requests 导入 HTTPError 异常,用于处理 HTTP 错误

from .. import __version__  # 导入当前包的版本信息
from .constants import (
    DEPRECATED_REVISION_ARGS,  # 导入已弃用的修订参数常量
    HUGGINGFACE_CO_RESOLVE_ENDPOINT,  # 导入 Hugging Face 解析端点常量
    SAFETENSORS_WEIGHTS_NAME,  # 导入安全张量权重名称常量
    WEIGHTS_NAME,  # 导入权重名称常量
)
from .import_utils import (
    ENV_VARS_TRUE_VALUES,  # 导入环境变量真值集合
    _flax_version,  # 导入 Flax 版本
    _jax_version,  # 导入 JAX 版本
    _onnxruntime_version,  # 导入 ONNX 运行时版本
    _torch_version,  # 导入 PyTorch 版本
    is_flax_available,  # 导入检查 Flax 是否可用的函数
    is_onnx_available,  # 导入检查 ONNX 是否可用的函数
    is_torch_available,  # 导入检查 PyTorch 是否可用的函数
)
from .logging import get_logger  # 从 logging 模块导入获取日志记录器的函数

logger = get_logger(__name__)  # 获取当前模块的日志记录器实例

MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"  # 设置模型卡片模板文件的路径
SESSION_ID = uuid4().hex  # 生成一个唯一的会话 ID

def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:  # 定义一个格式化用户代理字符串的函数
    """
    Formats a user-agent string with basic info about a request.  # 函数说明,格式化用户代理字符串
    """
    ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"  # 构建基本用户代理字符串
    if HF_HUB_DISABLE_TELEMETRY or HF_HUB_OFFLINE:  # 检查是否禁用遥测或处于离线状态
        return ua + "; telemetry/off"  # 返回禁用遥测的用户代理字符串
    if is_torch_available():  # 检查 PyTorch 是否可用
        ua += f"; torch/{_torch_version}"  # 将 PyTorch 版本信息添加到用户代理字符串
    if is_flax_available():  # 检查 Flax 是否可用
        ua += f"; jax/{_jax_version}"  # 将 JAX 版本信息添加到用户代理字符串
        ua += f"; flax/{_flax_version}"  # 将 Flax 版本信息添加到用户代理字符串
    if is_onnx_available():  # 检查 ONNX 是否可用
        ua += f"; onnxruntime/{_onnxruntime_version}"  # 将 ONNX 运行时版本信息添加到用户代理字符串
    # CI will set this value to True  # CI 会将此值设置为 True
    if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:  # 检查环境变量是否指示在 CI 中运行
        ua += "; is_ci/true"  # 如果是 CI,添加相关信息到用户代理字符串
    if isinstance(user_agent, dict):  # 检查用户代理是否为字典类型
        ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())  # 将字典项格式化为字符串并添加到用户代理
    elif isinstance(user_agent, str):  # 检查用户代理是否为字符串类型
        ua += "; " + user_agent  # 直接添加用户代理字符串
    return ua  # 返回最终的用户代理字符串


def load_or_create_model_card(  # 定义加载或创建模型卡片的函数
    repo_id_or_path: str = None,  # 仓库 ID 或路径,默认为 None
    token: Optional[str] = None,  # 访问令牌,默认为 None
    is_pipeline: bool = False,  # 是否为管道模型,默认为 False
    from_training: bool = False,  # 是否从训练中加载,默认为 False
    # 定义模型描述,类型为可选的字符串,默认值为 None
        model_description: Optional[str] = None,
        # 定义基础模型,类型为字符串,默认值为 None
        base_model: str = None,
        # 定义提示信息,类型为可选的字符串,默认值为 None
        prompt: Optional[str] = None,
        # 定义许可证信息,类型为可选的字符串,默认值为 None
        license: Optional[str] = None,
        # 定义小部件列表,类型为可选的字典列表,默认值为 None
        widget: Optional[List[dict]] = None,
        # 定义推理标志,类型为可选的布尔值,默认值为 None
        inference: Optional[bool] = None,
# 定义一个函数,返回类型为 ModelCard
) -> ModelCard:
    """
    加载或创建模型卡片。

    参数:
        repo_id_or_path (`str`):
            仓库 ID(例如 "runwayml/stable-diffusion-v1-5")或查找模型卡片的本地路径。
        token (`str`, *可选*):
            认证令牌。默认为存储的令牌。详细信息见 https://huggingface.co/settings/token。
        is_pipeline (`bool`):
            布尔值,指示是否为 [`DiffusionPipeline`] 添加标签。
        from_training: (`bool`): 布尔标志,表示模型卡片是否是从训练脚本创建的。
        model_description (`str`, *可选*): 要添加到模型卡片的模型描述。在从训练脚本使用 `load_or_create_model_card` 时有用。
        base_model (`str`): 基础模型标识符(例如 "stabilityai/stable-diffusion-xl-base-1.0")。对类似 DreamBooth 的训练有用。
        prompt (`str`, *可选*): 用于训练的提示。对类似 DreamBooth 的训练有用。
        license: (`str`, *可选*): 输出工件的许可证。在从训练脚本使用 `load_or_create_model_card` 时有用。
        widget (`List[dict]`, *可选*): 附带画廊模板的部件。
        inference: (`bool`, *可选*): 是否开启推理部件。在从训练脚本使用 `load_or_create_model_card` 时有用。
    """
    # 检查是否安装了 Jinja 模板引擎
    if not is_jinja_available():
        # 如果未安装,抛出一个值错误,并提供安装建议
        raise ValueError(
            "Modelcard 渲染基于 Jinja 模板。"
            " 请确保在使用 `load_or_create_model_card` 之前安装了 `jinja`."
            " 要安装它,请运行 `pip install Jinja2`."
        )

    try:
        # 检查远程仓库中是否存在模型卡片
        model_card = ModelCard.load(repo_id_or_path, token=token)
    except (EntryNotFoundError, RepositoryNotFoundError):
        # 如果模型卡片不存在,则根据模板创建一个模型卡片
        if from_training:
            # 从模板创建模型卡片,并使用卡片数据作为 YAML 块
            model_card = ModelCard.from_template(
                card_data=ModelCardData(  # 卡片元数据对象
                    license=license,
                    library_name="diffusers",  # 指定库名
                    inference=inference,  # 指定推理设置
                    base_model=base_model,  # 指定基础模型
                    instance_prompt=prompt,  # 指定实例提示
                    widget=widget,  # 指定部件
                ),
                template_path=MODEL_CARD_TEMPLATE_PATH,  # 模板路径
                model_description=model_description,  # 模型描述
            )
        else:
            # 创建一个空的模型卡片数据对象
            card_data = ModelCardData()
            # 根据 is_pipeline 变量确定组件类型
            component = "pipeline" if is_pipeline else "model"
            # 如果没有提供模型描述,则生成默认描述
            if model_description is None:
                model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
            # 从模板创建模型卡片
            model_card = ModelCard.from_template(card_data, model_description=model_description)
    # 返回模型卡片的内容
        return model_card
# 定义一个函数,用于填充模型卡片的库名称和可选标签
def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard:
    # 如果模型卡片的库名称为空,则设置为 "diffusers"
    if model_card.data.library_name is None:
        model_card.data.library_name = "diffusers"

    # 如果标签不为空
    if tags is not None:
        # 如果标签是字符串,则转换为列表
        if isinstance(tags, str):
            tags = [tags]
        # 如果模型卡片的标签为空,则初始化为空列表
        if model_card.data.tags is None:
            model_card.data.tags = []
        # 遍历所有标签,将它们添加到模型卡片的标签中
        for tag in tags:
            model_card.data.tags.append(tag)

    # 返回更新后的模型卡片
    return model_card


# 定义一个函数,从已解析的文件名中提取提交哈希
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
    # 提取提交哈希,优先使用提供的提交哈希
    if resolved_file is None or commit_hash is not None:
        return commit_hash
    # 将解析后的文件路径转换为 POSIX 格式
    resolved_file = str(Path(resolved_file).as_posix())
    # 在文件路径中搜索提交哈希的模式
    search = re.search(r"snapshots/([^/]+)/", resolved_file)
    # 如果未找到模式,则返回 None
    if search is None:
        return None
    # 从搜索结果中提取提交哈希
    commit_hash = search.groups()[0]
    # 如果提交哈希符合规定格式,则返回它,否则返回 None
    return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None


# 定义旧的默认缓存路径,可能需要迁移
# 该逻辑大体来源于 `transformers`,并有如下不同之处:
# - Diffusers 不使用自定义环境变量来指定缓存路径。
# - 无需迁移缓存格式,只需将文件移动到新位置。
hf_cache_home = os.path.expanduser(
    # 获取环境变量 HF_HOME,默认路径为 ~/.cache/huggingface
    os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
# 定义旧的 diffusers 缓存路径
old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")


# 定义一个函数,用于移动缓存目录
def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
    # 如果新缓存目录为空,则设置为 HF_HUB_CACHE
    if new_cache_dir is None:
        new_cache_dir = HF_HUB_CACHE
    # 如果旧缓存目录为空,则使用旧的 diffusers 缓存路径
    if old_cache_dir is None:
        old_cache_dir = old_diffusers_cache

    # 扩展用户目录路径
    old_cache_dir = Path(old_cache_dir).expanduser()
    new_cache_dir = Path(new_cache_dir).expanduser()
    # 遍历旧缓存目录中的所有 blob 文件
    for old_blob_path in old_cache_dir.glob("**/blobs/*"):
        # 如果路径是文件且不是符号链接
        if old_blob_path.is_file() and not old_blob_path.is_symlink():
            # 计算新 blob 文件的路径
            new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
            # 创建新路径的父目录
            new_blob_path.parent.mkdir(parents=True, exist_ok=True)
            # 替换旧的 blob 文件为新的 blob 文件
            os.replace(old_blob_path, new_blob_path)
            # 尝试在旧路径和新路径之间创建符号链接
            try:
                os.symlink(new_blob_path, old_blob_path)
            except OSError:
                # 如果无法创建符号链接,发出警告
                logger.warning(
                    "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded."
                )
    # 现在,old_cache_dir 包含指向新缓存的符号链接(仍然可以使用)


# 定义缓存版本文件的路径
cache_version_file = os.path.join(HF_HUB_CACHE, "version_diffusers_cache.txt")
# 如果缓存版本文件不存在,则设置缓存版本为 0
if not os.path.isfile(cache_version_file):
    cache_version = 0
else:
    # 打开文件以读取缓存版本
    with open(cache_version_file) as f:
        try:
            # 尝试将读取内容转换为整数
            cache_version = int(f.read())
        except ValueError:
            # 如果转换失败,则设置缓存版本为 0
            cache_version = 0

# 如果缓存版本小于 1
if cache_version < 1:
    # 检查旧的缓存目录是否存在且非空
        old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0
        # 如果旧缓存不为空,则记录警告信息
        if old_cache_is_not_empty:
            logger.warning(
                "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your "
                "existing cached models. This is a one-time operation, you can interrupt it or run it "
                "later by calling `diffusers.utils.hub_utils.move_cache()`."
            )
            # 尝试移动缓存
            try:
                move_cache()
            # 捕获任何异常并处理
            except Exception as e:
                # 获取异常的追踪信息并格式化为字符串
                trace = "\n".join(traceback.format_tb(e.__traceback__))
                # 记录错误信息,建议用户在 GitHub 提交问题
                logger.error(
                    f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
                    "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole "
                    "message and we will do our best to help."
                )
# 检查缓存版本是否小于1
if cache_version < 1:
    # 尝试创建缓存目录
    try:
        os.makedirs(HF_HUB_CACHE, exist_ok=True)  # 创建目录,如果已存在则不报错
        # 打开缓存版本文件以写入版本号
        with open(cache_version_file, "w") as f:
            f.write("1")  # 写入版本号1
    except Exception:  # 捕获异常
        # 记录警告信息,提示用户可能存在的问题
        logger.warning(
            f"There was a problem when trying to write in your cache folder ({HF_HUB_CACHE}). Please, ensure "
            "the directory exists and can be written to."
        )

# 定义函数以添加变体到权重名称
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
    # 如果变体不为 None
    if variant is not None:
        # 按 '.' 分割权重名称
        splits = weights_name.split(".")
        # 确定分割索引
        split_index = -2 if weights_name.endswith(".index.json") else -1
        # 更新权重名称的分割部分,插入变体
        splits = splits[:-split_index] + [variant] + splits[-split_index:]
        # 重新连接分割部分为完整的权重名称
        weights_name = ".".join(splits)

    # 返回更新后的权重名称
    return weights_name

# 装饰器用于验证 HF Hub 的参数
@validate_hf_hub_args
def _get_model_file(
    pretrained_model_name_or_path: Union[str, Path],  # 预训练模型的名称或路径
    *,
    weights_name: str,  # 权重文件的名称
    subfolder: Optional[str] = None,  # 子文件夹,默认为 None
    cache_dir: Optional[str] = None,  # 缓存目录,默认为 None
    force_download: bool = False,  # 强制下载标志,默认为 False
    proxies: Optional[Dict] = None,  # 代理设置,默认为 None
    local_files_only: bool = False,  # 仅使用本地文件的标志,默认为 False
    token: Optional[str] = None,  # 访问令牌,默认为 None
    user_agent: Optional[Union[Dict, str]] = None,  # 用户代理设置,默认为 None
    revision: Optional[str] = None,  # 修订版本,默认为 None
    commit_hash: Optional[str] = None,  # 提交哈希值,默认为 None
):
    # 将预训练模型路径转换为字符串
    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
    # 如果路径指向一个文件
    if os.path.isfile(pretrained_model_name_or_path):
        return pretrained_model_name_or_path  # 直接返回文件路径
    # 如果路径指向一个目录
    elif os.path.isdir(pretrained_model_name_or_path):
        # 检查目录中是否存在权重文件
        if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
            # 从 PyTorch 检查点加载模型文件
            model_file = os.path.join(pretrained_model_name_or_path, weights_name)
            return model_file  # 返回模型文件路径
        # 如果有子文件夹且子文件夹中存在权重文件
        elif subfolder is not None and os.path.isfile(
            os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
        ):
            model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
            return model_file  # 返回子文件夹中的模型文件路径
        else:
            # 抛出环境错误,指示未找到权重文件
            raise EnvironmentError(
                f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
            )

# 检查本地是否存在分片文件的函数
def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames):
    # 构造分片文件的路径
    shards_path = os.path.join(local_dir, subfolder)
    # 获取所有分片文件的完整路径
    shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
    # 遍历每个分片文件
    for shard_file in shard_filenames:
        # 检查分片文件是否存在
        if not os.path.exists(shard_file):
            # 如果不存在,抛出错误提示
            raise ValueError(
                f"{shards_path} does not appear to have a file named {shard_file} which is "
                "required according to the checkpoint index."
            )

# 获取检查点分片文件的函数定义
def _get_checkpoint_shard_files(
    pretrained_model_name_or_path,  # 预训练模型的名称或路径
    index_filename,  # 索引文件名
    cache_dir=None,  # 缓存目录,默认为 None
    proxies=None,  # 代理设置,默认为 None
    # 设置是否仅使用本地文件,默认为 False
    local_files_only=False,
    # 设置访问令牌,默认为 None,表示不使用令牌
    token=None,
    # 设置用户代理字符串,默认为 None
    user_agent=None,
    # 设置修订版号,默认为 None
    revision=None,
    # 设置子文件夹路径,默认为空字符串
    subfolder="",
):
    """
    对于给定的模型:

    - 如果 `pretrained_model_name_or_path` 是 Hub 上的模型 ID,则下载并缓存所有分片的检查点
    - 返回所有分片的路径列表,以及一些元数据。

    有关每个参数的描述,请参见 [`PreTrainedModel.from_pretrained`]。 `index_filename` 是索引的完整路径
    (如果 `pretrained_model_name_or_path` 是 Hub 上的模型 ID,则下载并缓存)。
    """
    # 检查索引文件是否存在,如果不存在则抛出错误
    if not os.path.isfile(index_filename):
        raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")

    # 打开索引文件并读取内容,解析为 JSON 格式
    with open(index_filename, "r") as f:
        index = json.loads(f.read())

    # 获取权重映射中的所有原始分片文件名,并去重后排序
    original_shard_filenames = sorted(set(index["weight_map"].values()))
    # 获取分片元数据
    sharded_metadata = index["metadata"]
    # 将所有检查点键的列表添加到元数据中
    sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
    # 复制权重映射到元数据中
    sharded_metadata["weight_map"] = index["weight_map"].copy()
    # 构建分片的路径
    shards_path = os.path.join(pretrained_model_name_or_path, subfolder)

    # 首先处理本地文件夹
    if os.path.isdir(pretrained_model_name_or_path):
        # 检查本地是否存在所需的分片
        _check_if_shards_exist_locally(
            pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
        )
        # 返回分片路径和分片元数据
        return shards_path, sharded_metadata

    # 此时 pretrained_model_name_or_path 是 Hub 上的模型标识符
    # 设置允许的文件模式为原始分片文件名
    allow_patterns = original_shard_filenames
    # 如果提供了子文件夹,则更新允许的文件模式
    if subfolder is not None:
        allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]

    # 定义需要忽略的文件模式
    ignore_patterns = ["*.json", "*.md"]
    # 如果不是仅使用本地文件
    if not local_files_only:
        # `model_info` 调用必须受到上述条件的保护
        model_files_info = model_info(pretrained_model_name_or_path, revision=revision)
        # 遍历原始分片文件名
        for shard_file in original_shard_filenames:
            # 检查当前分片文件是否在模型文件信息中存在
            shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
            # 如果分片文件不存在,则抛出环境错误
            if not shard_file_present:
                raise EnvironmentError(
                    f"{shards_path} 不存在名为 {shard_file} 的文件,这是根据检查点索引所需的。"
                )

        try:
            # 从 URL 加载
            cached_folder = snapshot_download(
                pretrained_model_name_or_path,  # 要下载的模型路径
                cache_dir=cache_dir,  # 缓存目录
                proxies=proxies,  # 代理设置
                local_files_only=local_files_only,  # 是否仅使用本地文件
                token=token,  # 授权令牌
                revision=revision,  # 版本信息
                allow_patterns=allow_patterns,  # 允许的文件模式
                ignore_patterns=ignore_patterns,  # 忽略的文件模式
                user_agent=user_agent,  # 用户代理信息
            )
            # 如果指定了子文件夹,则更新缓存文件夹路径
            if subfolder is not None:
                cached_folder = os.path.join(cached_folder, subfolder)

        # 已经在获取索引时处理了 RepositoryNotFoundError 和 RevisionNotFoundError,
        # 所以这里不需要捕获它们。也处理了 EntryNotFoundError。
        except HTTPError as e:
            # 如果无法连接到指定的端点,则抛出环境错误
            raise EnvironmentError(
                f"我们无法连接到 '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' 来加载 {pretrained_model_name_or_path}。请检查您的互联网连接后重试。"
            ) from e

    # 如果 `local_files_only=True`,则 `cached_folder` 可能不包含所有分片文件
    elif local_files_only:
        # 检查本地是否存在所有分片
        _check_if_shards_exist_locally(
            local_dir=cache_dir,  # 本地目录
            subfolder=subfolder,  # 子文件夹
            original_shard_filenames=original_shard_filenames  # 原始分片文件名列表
        )
        # 如果指定了子文件夹,则更新缓存文件夹路径
        if subfolder is not None:
            cached_folder = os.path.join(cached_folder, subfolder)

    # 返回缓存文件夹和分片元数据
    return cached_folder, sharded_metadata
# 定义一个混合类,用于将模型、调度器或管道推送到 Hugging Face Hub
class PushToHubMixin:
    """
    A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
    """

    # 定义一个私有方法,用于上传指定文件夹中的所有文件
    def _upload_folder(
        self,
        working_dir: Union[str, os.PathLike],  # 工作目录,包含待上传的文件
        repo_id: str,                           # 目标仓库的 ID
        token: Optional[str] = None,           # 可选的认证令牌
        commit_message: Optional[str] = None,  # 可选的提交信息
        create_pr: bool = False,                # 是否创建拉取请求
    ):
        """
        Uploads all files in `working_dir` to `repo_id`.
        """
        # 如果未提供提交信息,则根据类名生成默认提交信息
        if commit_message is None:
            if "Model" in self.__class__.__name__:
                commit_message = "Upload model"  # 如果是模型类,设置默认信息
            elif "Scheduler" in self.__class__.__name__:
                commit_message = "Upload scheduler"  # 如果是调度器类,设置默认信息
            else:
                commit_message = f"Upload {self.__class__.__name__}"  # 否则,使用类名作为提交信息

        # 记录上传文件的日志信息
        logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
        # 调用 upload_folder 函数上传文件,并返回其结果
        return upload_folder(
            repo_id=repo_id,                    # 目标仓库 ID
            folder_path=working_dir,            # 待上传的文件夹路径
            token=token,                        # 认证令牌
            commit_message=commit_message,      # 提交信息
            create_pr=create_pr                 # 是否创建拉取请求
        )

    # 定义一个公共方法,用于将文件推送到 Hugging Face Hub
    def push_to_hub(
        self,
        repo_id: str,                           # 目标仓库的 ID
        commit_message: Optional[str] = None,  # 可选的提交信息
        private: Optional[bool] = None,        # 可选,是否将仓库设置为私有
        token: Optional[str] = None,           # 可选的认证令牌
        create_pr: bool = False,                # 是否创建拉取请求
        safe_serialization: bool = True,        # 是否安全序列化
        variant: Optional[str] = None,          # 可选的变体参数
    ) -> str:  # 定义函数返回值类型为字符串
        """
        Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub.  # 函数文档字符串,说明功能

        Parameters:  # 参数说明部分
            repo_id (`str`):  # 仓库 ID,类型为字符串
                The name of the repository you want to push your model, scheduler, or pipeline files to. It should
                contain your organization name when pushing to an organization. `repo_id` can also be a path to a local
                directory.  # 描述 repo_id 的用途和格式
            commit_message (`str`, *optional*):  # 可选参数,类型为字符串
                Message to commit while pushing. Default to `"Upload {object}".`  # 提交消息的默认值
            private (`bool`, *optional*):  # 可选参数,类型为布尔值
                Whether or not the repository created should be private.  # 是否创建私有仓库
            token (`str`, *optional*):  # 可选参数,类型为字符串
                The token to use as HTTP bearer authorization for remote files. The token generated when running
                `huggingface-cli login` (stored in `~/.huggingface`).  # 说明 token 的用途
            create_pr (`bool`, *optional*, defaults to `False`):  # 可选参数,类型为布尔值,默认值为 False
                Whether or not to create a PR with the uploaded files or directly commit.  # 是否创建 PR
            safe_serialization (`bool`, *optional*, defaults to `True`):  # 可选参数,类型为布尔值,默认值为 True
                Whether or not to convert the model weights to the `safetensors` format.  # 是否使用安全序列化格式
            variant (`str`, *optional*):  # 可选参数,类型为字符串
                If specified, weights are saved in the format `pytorch_model.<variant>.bin`.  # 权重保存格式

        Examples:  # 示例说明部分

        ```py
        from diffusers import UNet2DConditionModel  # 从 diffusers 导入 UNet2DConditionModel

        unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet")  # 从预训练模型加载 UNet

        # Push the `unet` to your namespace with the name "my-finetuned-unet".  # 推送到个人命名空间
        unet.push_to_hub("my-finetuned-unet")  # 将 unet 推送到指定名称的仓库

        # Push the `unet` to an organization with the name "my-finetuned-unet".  # 推送到组织
        unet.push_to_hub("your-org/my-finetuned-unet")  # 将 unet 推送到指定组织的仓库
        ```
        """  # 结束文档字符串
        repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id  # 创建仓库并获取仓库 ID

        # Create a new empty model card and eventually tag it  # 创建新的模型卡片并可能添加标签
        model_card = load_or_create_model_card(repo_id, token=token)  # 加载或创建模型卡片
        model_card = populate_model_card(model_card)  # 填充模型卡片信息

        # Save all files.  # 保存所有文件
        save_kwargs = {"safe_serialization": safe_serialization}  # 设置保存文件的参数
        if "Scheduler" not in self.__class__.__name__:  # 检查当前类名是否包含 "Scheduler"
            save_kwargs.update({"variant": variant})  # 如果不包含,则添加 variant 参数

        with tempfile.TemporaryDirectory() as tmpdir:  # 创建临时目录
            self.save_pretrained(tmpdir, **save_kwargs)  # 将模型保存到临时目录

            # Update model card if needed:  # 如果需要,更新模型卡片
            model_card.save(os.path.join(tmpdir, "README.md"))  # 将模型卡片保存为 README.md 文件

            return self._upload_folder(  # 上传临时目录中的文件
                tmpdir,  # 临时目录路径
                repo_id,  # 仓库 ID
                token=token,  # 认证 token
                commit_message=commit_message,  # 提交消息
                create_pr=create_pr,  # 是否创建 PR
            )  # 返回上传结果

.\diffusers\utils\import_utils.py

# 版权信息,标明此文件的版权所有者及相关许可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache License, Version 2.0 许可,使用该文件需遵循该许可
# 许可的使用条件;可以在以下地址获取许可
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,按照许可分发的软件以 "现状" 方式提供,
# 不附带任何形式的保证或条件,明示或暗示。
# 详见许可中对权限和限制的具体规定
"""
导入工具:与导入和懒初始化相关的工具函数
"""

# 导入模块,提供动态导入和模块相关功能
import importlib.util
# 导入操作符模块,方便使用比较操作符
import operator as op
# 导入操作系统模块,提供与操作系统交互的功能
import os
# 导入系统模块,提供对 Python 解释器的访问
import sys
# 从 collections 模块导入有序字典
from collections import OrderedDict
# 从 itertools 模块导入链式迭代工具
from itertools import chain
# 导入模块类型
from types import ModuleType
# 导入类型注释工具
from typing import Any, Union

# 从 huggingface_hub.utils 导入检查 Jinja 是否可用的工具
from huggingface_hub.utils import is_jinja_available  # noqa: F401
# 导入版本控制工具
from packaging import version
# 从 packaging.version 导入版本和解析函数
from packaging.version import Version, parse

# 导入当前目录下的 logging 模块
from . import logging

# 根据 Python 版本选择合适的 importlib_metadata 模块导入方式
if sys.version_info < (3, 8):
    # 如果 Python 版本低于 3.8,导入 importlib_metadata
    import importlib_metadata
else:
    # 如果 Python 版本为 3.8 或更高,导入 importlib.metadata
    import importlib.metadata as importlib_metadata

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定义环境变量的真值集合
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
# 定义环境变量的真值和自动值集合
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})

# 从环境变量中获取是否使用 TensorFlow,默认为 "AUTO"
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
# 从环境变量中获取是否使用 PyTorch,默认为 "AUTO"
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
# 从环境变量中获取是否使用 JAX,默认为 "AUTO"
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
# 从环境变量中获取是否使用 SafeTensors,默认为 "AUTO"
USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
# 从环境变量中获取是否进行慢导入,默认为 "FALSE"
DIFFUSERS_SLOW_IMPORT = os.environ.get("DIFFUSERS_SLOW_IMPORT", "FALSE").upper()
# 将慢导入的环境变量值转换为布尔值
DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES

# 定义操作符与对应函数的映射
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

# 初始化 PyTorch 版本为 "N/A"
_torch_version = "N/A"
# 检查是否启用 PyTorch,并且 TensorFlow 未被启用
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
    # 尝试查找 PyTorch 模块是否可用
    _torch_available = importlib.util.find_spec("torch") is not None
    if _torch_available:
        try:
            # 获取 PyTorch 的版本信息
            _torch_version = importlib_metadata.version("torch")
            # 记录可用的 PyTorch 版本
            logger.info(f"PyTorch version {_torch_version} available.")
        except importlib_metadata.PackageNotFoundError:
            # 如果 PyTorch 模块未找到,将其标记为不可用
            _torch_available = False
else:
    # 如果 USE_TORCH 被设置,则禁用 PyTorch
    logger.info("Disabling PyTorch because USE_TORCH is set")
    _torch_available = False

# 检查 PyTorch XLA 是否可用
_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
if _torch_xla_available:
    try:
        # 获取 PyTorch XLA 的版本信息
        _torch_xla_version = importlib_metadata.version("torch_xla")
        # 记录可用的 PyTorch XLA 版本
        logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
    except ImportError:
        # 如果 PyTorch XLA 导入失败,将其标记为不可用
        _torch_xla_available = False

# 检查 torch_npu 是否可用
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
if _torch_npu_available:
    # 尝试获取 "torch_npu" 包的版本信息
        try:
            # 使用 importlib_metadata 获取 "torch_npu" 的版本
            _torch_npu_version = importlib_metadata.version("torch_npu")
            # 记录可用的 "torch_npu" 版本信息到日志
            logger.info(f"torch_npu version {_torch_npu_version} available.")
        # 捕获导入错误,表示 "torch_npu" 包不可用
        except ImportError:
            # 设置标志,表示 "torch_npu" 不可用
            _torch_npu_available = False
# 初始化 JAX 版本为 "N/A"
_jax_version = "N/A"
# 初始化 Flax 版本为 "N/A"
_flax_version = "N/A"
# 检查 USE_JAX 是否在环境变量的真值或自动值列表中
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
    # 检查 JAX 和 Flax 是否可用
    _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
    # 如果 Flax 可用
    if _flax_available:
        try:
            # 获取 JAX 的版本
            _jax_version = importlib_metadata.version("jax")
            # 获取 Flax 的版本
            _flax_version = importlib_metadata.version("flax")
            # 记录 JAX 和 Flax 的版本信息
            logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
        except importlib_metadata.PackageNotFoundError:
            # 如果找不到包,则设置 Flax 不可用
            _flax_available = False
else:
    # 如果 USE_JAX 不在环境变量的真值或自动值列表中,设置 Flax 不可用
    _flax_available = False

# 检查 USE_SAFETENSORS 是否在环境变量的真值或自动值列表中
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
    # 检查 safetensors 是否可用
    _safetensors_available = importlib.util.find_spec("safetensors") is not None
    # 如果 safetensors 可用
    if _safetensors_available:
        try:
            # 获取 safetensors 的版本
            _safetensors_version = importlib_metadata.version("safetensors")
            # 记录 safetensors 的版本信息
            logger.info(f"Safetensors version {_safetensors_version} available.")
        except importlib_metadata.PackageNotFoundError:
            # 如果找不到包,则设置 safetensors 不可用
            _safetensors_available = False
else:
    # 如果 USE_SAFETENSORS 不在环境变量的真值或自动值列表中,记录信息并设置 safetensors 不可用
    logger.info("Disabling Safetensors because USE_TF is set")
    _safetensors_available = False

# 检查 transformers 是否可用
_transformers_available = importlib.util.find_spec("transformers") is not None
try:
    # 获取 transformers 的版本
    _transformers_version = importlib_metadata.version("transformers")
    # 记录 transformers 的版本信息
    logger.debug(f"Successfully imported transformers version {_transformers_version}")
except importlib_metadata.PackageNotFoundError:
    # 如果找不到包,则设置 transformers 不可用
    _transformers_available = False

# 检查 inflect 是否可用
_inflect_available = importlib.util.find_spec("inflect") is not None
try:
    # 获取 inflect 的版本
    _inflect_version = importlib_metadata.version("inflect")
    # 记录 inflect 的版本信息
    logger.debug(f"Successfully imported inflect version {_inflect_version}")
except importlib_metadata.PackageNotFoundError:
    # 如果找不到包,则设置 inflect 不可用
    _inflect_available = False

# 初始化 onnxruntime 版本为 "N/A"
_onnxruntime_version = "N/A"
# 检查 onnxruntime 是否可用
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
# 如果 onnxruntime 可用
if _onnx_available:
    # 可能的 onnxruntime 包候选列表
    candidates = (
        "onnxruntime",
        "onnxruntime-gpu",
        "ort_nightly_gpu",
        "onnxruntime-directml",
        "onnxruntime-openvino",
        "ort_nightly_directml",
        "onnxruntime-rocm",
        "onnxruntime-training",
    )
    # 初始化 onnxruntime 版本为 None
    _onnxruntime_version = None
    # 对于元数据,我们需要查找 onnxruntime 和 onnxruntime-gpu
    for pkg in candidates:
        try:
            # 获取当前候选包的版本
            _onnxruntime_version = importlib_metadata.version(pkg)
            break  # 如果找到版本,跳出循环
        except importlib_metadata.PackageNotFoundError:
            pass  # 如果找不到包,继续下一个候选
    # 如果找到了 onnxruntime 版本,则设置可用状态
    _onnx_available = _onnxruntime_version is not None
    # 如果 onnxruntime 可用,记录其版本信息
    if _onnx_available:
        logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
# (sayakpaul): importlib.util.find_spec("opencv-python") 返回 None 即使它已经被安装。
# 判断 OpenCV 是否可用,若未找到则返回 False
# _opencv_available = importlib.util.find_spec("opencv-python") is not None
try:
    # 定义可能的 OpenCV 包候选名称
    candidates = (
        "opencv-python",
        "opencv-contrib-python",
        "opencv-python-headless",
        "opencv-contrib-python-headless",
    )
    # 初始化 OpenCV 版本变量
    _opencv_version = None
    # 遍历候选包名称
    for pkg in candidates:
        try:
            # 尝试获取当前包的版本
            _opencv_version = importlib_metadata.version(pkg)
            # 成功获取版本后退出循环
            break
        except importlib_metadata.PackageNotFoundError:
            # 如果包未找到,则继续尝试下一个候选包
            pass
    # 检查是否成功获取到 OpenCV 版本
    _opencv_available = _opencv_version is not None
    # 如果 OpenCV 可用,记录调试信息
    if _opencv_available:
        logger.debug(f"Successfully imported cv2 version {_opencv_version}")
except importlib_metadata.PackageNotFoundError:
    # 如果没有找到 OpenCV 包,标记为不可用
    _opencv_available = False

# 判断 SciPy 是否可用,若未找到则返回 False
_scipy_available = importlib.util.find_spec("scipy") is not None
try:
    # 尝试获取 SciPy 包的版本
    _scipy_version = importlib_metadata.version("scipy")
    # 如果成功,记录调试信息
    logger.debug(f"Successfully imported scipy version {_scipy_version}")
except importlib_metadata.PackageNotFoundError:
    # 如果包未找到,则标记为不可用
    _scipy_available = False

# 判断 librosa 是否可用,若未找到则返回 False
_librosa_available = importlib.util.find_spec("librosa") is not None
try:
    # 尝试获取 librosa 包的版本
    _librosa_version = importlib_metadata.version("librosa")
    # 如果成功,记录调试信息
    logger.debug(f"Successfully imported librosa version {_librosa_version}")
except importlib_metadata.PackageNotFoundError:
    # 如果包未找到,则标记为不可用
    _librosa_available = False

# 判断 accelerate 是否可用,若未找到则返回 False
_accelerate_available = importlib.util.find_spec("accelerate") is not None
try:
    # 尝试获取 accelerate 包的版本
    _accelerate_version = importlib_metadata.version("accelerate")
    # 如果成功,记录调试信息
    logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
except importlib_metadata.PackageNotFoundError:
    # 如果包未找到,则标记为不可用
    _accelerate_available = False

# 判断 xformers 是否可用,若未找到则返回 False
_xformers_available = importlib.util.find_spec("xformers") is not None
try:
    # 尝试获取 xformers 包的版本
    _xformers_version = importlib_metadata.version("xformers")
    # 如果 torch 可用,获取其版本并进行版本比较
    if _torch_available:
        _torch_version = importlib_metadata.version("torch")
        # 如果 torch 版本小于 1.12,抛出错误
        if version.Version(_torch_version) < version.Version("1.12"):
            raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12")

    # 如果成功,记录调试信息
    logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
    # 如果包未找到,则标记为不可用
    _xformers_available = False

# 判断 k_diffusion 是否可用,若未找到则返回 False
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
try:
    # 尝试获取 k_diffusion 包的版本
    _k_diffusion_version = importlib_metadata.version("k_diffusion")
    # 如果成功,记录调试信息
    logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
except importlib_metadata.PackageNotFoundError:
    # 如果包未找到,则标记为不可用
    _k_diffusion_available = False

# 判断 note_seq 是否可用,若未找到则返回 False
_note_seq_available = importlib.util.find_spec("note_seq") is not None
try:
    # 尝试获取 note_seq 包的版本
    _note_seq_version = importlib_metadata.version("note_seq")
    # 如果成功,记录调试信息
    logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
except importlib_metadata.PackageNotFoundError:
    # 如果包未找到,则标记为不可用
    _note_seq_available = False

# 判断 wandb 是否可用,若未找到则返回 False
_wandb_available = importlib.util.find_spec("wandb") is not None
try:
    # 尝试获取 wandb 包的版本
    _wandb_version = importlib_metadata.version("wandb")
    # 记录调试信息,指示成功导入 wandb 的版本
    logger.debug(f"Successfully imported wandb version {_wandb_version }")
# 捕获导入错误,表示没有找到 'wandb' 包
except importlib_metadata.PackageNotFoundError:
    # 设置 'wandb' 不可用标志为 False
    _wandb_available = False

# 检查 'tensorboard' 包是否可用
_tensorboard_available = importlib.util.find_spec("tensorboard")
try:
    # 获取 'tensorboard' 的版本信息
    _tensorboard_version = importlib_metadata.version("tensorboard")
    # 记录成功导入 'tensorboard' 的版本
    logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
except importlib_metadata.PackageNotFoundError:
    # 设置 'tensorboard' 不可用标志为 False
    _tensorboard_available = False

# 检查 'compel' 包是否可用
_compel_available = importlib.util.find_spec("compel")
try:
    # 获取 'compel' 的版本信息
    _compel_version = importlib_metadata.version("compel")
    # 记录成功导入 'compel' 的版本
    logger.debug(f"Successfully imported compel version {_compel_version}")
except importlib_metadata.PackageNotFoundError:
    # 设置 'compel' 不可用标志为 False
    _compel_available = False

# 检查 'ftfy' 包是否可用
_ftfy_available = importlib.util.find_spec("ftfy") is not None
try:
    # 获取 'ftfy' 的版本信息
    _ftfy_version = importlib_metadata.version("ftfy")
    # 记录成功导入 'ftfy' 的版本
    logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
except importlib_metadata.PackageNotFoundError:
    # 设置 'ftfy' 不可用标志为 False
    _ftfy_available = False

# 检查 'bs4' 包是否可用
_bs4_available = importlib.util.find_spec("bs4") is not None
try:
    # importlib metadata 以不同名称获取
    _bs4_version = importlib_metadata.version("beautifulsoup4")
    # 记录成功导入 'beautifulsoup4' 的版本
    logger.debug(f"Successfully imported ftfy version {_bs4_version}")
except importlib_metadata.PackageNotFoundError:
    # 设置 'bs4' 不可用标志为 False
    _bs4_available = False

# 检查 'torchsde' 包是否可用
_torchsde_available = importlib.util.find_spec("torchsde") is not None
try:
    # 获取 'torchsde' 的版本信息
    _torchsde_version = importlib_metadata.version("torchsde")
    # 记录成功导入 'torchsde' 的版本
    logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
except importlib_metadata.PackageNotFoundError:
    # 设置 'torchsde' 不可用标志为 False
    _torchsde_available = False

# 检查 'imwatermark' 包是否可用
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
try:
    # 获取 'invisible-watermark' 的版本信息
    _invisible_watermark_version = importlib_metadata.version("invisible-watermark")
    # 记录成功导入 'invisible-watermark' 的版本
    logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}")
except importlib_metadata.PackageNotFoundError:
    # 设置 'invisible-watermark' 不可用标志为 False
    _invisible_watermark_available = False

# 检查 'peft' 包是否可用
_peft_available = importlib.util.find_spec("peft") is not None
try:
    # 获取 'peft' 的版本信息
    _peft_version = importlib_metadata.version("peft")
    # 记录成功导入 'peft' 的版本
    logger.debug(f"Successfully imported peft version {_peft_version}")
except importlib_metadata.PackageNotFoundError:
    # 设置 'peft' 不可用标志为 False
    _peft_available = False

# 检查 'torchvision' 包是否可用
_torchvision_available = importlib.util.find_spec("torchvision") is not None
try:
    # 获取 'torchvision' 的版本信息
    _torchvision_version = importlib_metadata.version("torchvision")
    # 记录成功导入 'torchvision' 的版本
    logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
except importlib_metadata.PackageNotFoundError:
    # 设置 'torchvision' 不可用标志为 False
    _torchvision_available = False

# 检查 'sentencepiece' 包是否可用
_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None
try:
    # 获取 'sentencepiece' 的版本信息
    _sentencepiece_version = importlib_metadata.version("sentencepiece")
    # 记录成功导入 'sentencepiece' 的版本
    logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}")
except importlib_metadata.PackageNotFoundError:
    # 设置 'sentencepiece' 不可用标志为 False
    _sentencepiece_available = False

# 检查 'matplotlib' 包是否可用
_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
try:
    # 导入 matplotlib 的版本信息
        _matplotlib_version = importlib_metadata.version("matplotlib")
        # 记录成功导入 matplotlib 的版本信息到调试日志
        logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
# 捕获导入库时的包未找到异常
except importlib_metadata.PackageNotFoundError:
    # 如果未找到 matplotlib,设置其可用状态为 False
    _matplotlib_available = False

# 检查 "timm" 库是否可用,返回结果为布尔值
_timm_available = importlib.util.find_spec("timm") is not None
# 如果 "timm" 库可用
if _timm_available:
    try:
        # 获取 "timm" 库的版本信息
        _timm_version = importlib_metadata.version("timm")
        # 记录可用的 "timm" 版本信息
        logger.info(f"Timm version {_timm_version} available.")
    # 捕获导入库时的包未找到异常
    except importlib_metadata.PackageNotFoundError:
        # 如果未找到 "timm",设置其可用状态为 False
        _timm_available = False

# 定义函数以返回 "timm" 库的可用状态
def is_timm_available():
    return _timm_available

# 检查 "bitsandbytes" 库是否可用,返回结果为布尔值
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
try:
    # 获取 "bitsandbytes" 库的版本信息
    _bitsandbytes_version = importlib_metadata.version("bitsandbytes")
    # 记录成功导入的 "bitsandbytes" 版本信息
    logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
# 捕获导入库时的包未找到异常
except importlib_metadata.PackageNotFoundError:
    # 如果未找到 "bitsandbytes",设置其可用状态为 False
    _bitsandbytes_available = False

# 检查当前是否在 Google Colab 环境中
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)

# 检查 "imageio" 库是否可用,返回结果为布尔值
_imageio_available = importlib.util.find_spec("imageio") is not None
# 如果 "imageio" 库可用
if _imageio_available:
    try:
        # 获取 "imageio" 库的版本信息
        _imageio_version = importlib_metadata.version("imageio")
        # 记录成功导入的 "imageio" 版本信息
        logger.debug(f"Successfully imported imageio version {_imageio_version}")
    # 捕获导入库时的包未找到异常
    except importlib_metadata.PackageNotFoundError:
        # 如果未找到 "imageio",设置其可用状态为 False
        _imageio_available = False

# 定义函数以返回 "torch" 库的可用状态
def is_torch_available():
    return _torch_available

# 定义函数以返回 "torch_xla" 库的可用状态
def is_torch_xla_available():
    return _torch_xla_available

# 定义函数以返回 "torch_npu" 库的可用状态
def is_torch_npu_available():
    return _torch_npu_available

# 定义函数以返回 "flax" 库的可用状态
def is_flax_available():
    return _flax_available

# 定义函数以返回 "transformers" 库的可用状态
def is_transformers_available():
    return _transformers_available

# 定义函数以返回 "inflect" 库的可用状态
def is_inflect_available():
    return _inflect_available

# 定义函数以返回 "unidecode" 库的可用状态
def is_unidecode_available():
    return _unidecode_available

# 定义函数以返回 "onnx" 库的可用状态
def is_onnx_available():
    return _onnx_available

# 定义函数以返回 "opencv" 库的可用状态
def is_opencv_available():
    return _opencv_available

# 定义函数以返回 "scipy" 库的可用状态
def is_scipy_available():
    return _scipy_available

# 定义函数以返回 "librosa" 库的可用状态
def is_librosa_available():
    return _librosa_available

# 定义函数以返回 "xformers" 库的可用状态
def is_xformers_available():
    return _xformers_available

# 定义函数以返回 "accelerate" 库的可用状态
def is_accelerate_available():
    return _accelerate_available

# 定义函数以返回 "k_diffusion" 库的可用状态
def is_k_diffusion_available():
    return _k_diffusion_available

# 定义函数以返回 "note_seq" 库的可用状态
def is_note_seq_available():
    return _note_seq_available

# 定义函数以返回 "wandb" 库的可用状态
def is_wandb_available():
    return _wandb_available

# 定义函数以返回 "tensorboard" 库的可用状态
def is_tensorboard_available():
    return _tensorboard_available

# 定义函数以返回 "compel" 库的可用状态
def is_compel_available():
    return _compel_available

# 定义函数以返回 "ftfy" 库的可用状态
def is_ftfy_available():
    return _ftfy_available

# 定义函数以返回 "bs4" 库的可用状态
def is_bs4_available():
    return _bs4_available

# 定义函数以返回 "torchsde" 库的可用状态
def is_torchsde_available():
    return _torchsde_available

# 定义函数以返回 "invisible_watermark" 库的可用状态
def is_invisible_watermark_available():
    return _invisible_watermark_available

# 定义函数以返回 "peft" 库的可用状态
def is_peft_available():
    return _peft_available

# 定义函数以返回 "torchvision" 库的可用状态
def is_torchvision_available():
    return _torchvision_available

# 定义函数以返回 "matplotlib" 库的可用状态
def is_matplotlib_available():
    return _matplotlib_available

# 定义函数以返回 "safetensors" 库的可用状态
def is_safetensors_available():
    return _safetensors_available

# 定义函数以返回 "bitsandbytes" 库的可用状态
def is_bitsandbytes_available():
    return _bitsandbytes_available

# 定义函数以返回是否在 Google Colab 环境中
def is_google_colab():
    return _is_google_colab

# 定义函数以返回 "sentencepiece" 库的可用状态
def is_sentencepiece_available():
    # 返回句子分割工具的可用性状态
        return _sentencepiece_available
# 定义一个函数来检查 imageio 库是否可用
def is_imageio_available():
    # 返回 _imageio_available 的值,表示 imageio 库的可用性
    return _imageio_available


# docstyle-ignore
# 定义一个字符串,提示用户缺少 FLAX 库并提供安装链接
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 inflect 库并提供安装命令
INFLECT_IMPORT_ERROR = """
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
inflect`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 PyTorch 库并提供安装链接
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 onnxruntime 库并提供安装命令
ONNX_IMPORT_ERROR = """
{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
install onnxruntime`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 OpenCV 库并提供安装命令
OPENCV_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip
install opencv-python`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 scipy 库并提供安装命令
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
scipy`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 librosa 库并提供安装链接
LIBROSA_IMPORT_ERROR = """
{0} requires the librosa library but it was not found in your environment.  Checkout the instructions on the
installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment.
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 transformers 库并提供安装命令
TRANSFORMERS_IMPORT_ERROR = """
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
install transformers`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 unidecode 库并提供安装命令
UNIDECODE_IMPORT_ERROR = """
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
Unidecode`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 k-diffusion 库并提供安装命令
K_DIFFUSION_IMPORT_ERROR = """
{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
install k-diffusion`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 note-seq 库并提供安装命令
NOTE_SEQ_IMPORT_ERROR = """
{0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip
install note-seq`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 wandb 库并提供安装命令
WANDB_IMPORT_ERROR = """
{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip
install wandb`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 tensorboard 库并提供安装命令
TENSORBOARD_IMPORT_ERROR = """
{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip
install tensorboard`
"""


# docstyle-ignore
# 定义一个字符串,提示用户缺少 compel 库并提供安装命令
COMPEL_IMPORT_ERROR = """
{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
"""

# docstyle-ignore
# 定义一个字符串,提示用户缺少 Beautiful Soup 库并提供安装命令
BS4_IMPORT_ERROR = """
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 ftfy 库
FTFY_IMPORT_ERROR = """
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""

# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 torchsde 库
TORCHSDE_IMPORT_ERROR = """
{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde`
"""

# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 invisible-watermark 库
INVISIBLE_WATERMARK_IMPORT_ERROR = """
{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0`
"""

# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 peft 库
PEFT_IMPORT_ERROR = """
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft`
"""

# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 safetensors 库
SAFETENSORS_IMPORT_ERROR = """
{0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors`
"""

# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 sentencepiece 库
SENTENCEPIECE_IMPORT_ERROR = """
{0} requires the sentencepiece library but it was not found in your environment. You can install it with pip: `pip install sentencepiece`
"""

# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 bitsandbytes 库
BITSANDBYTES_IMPORT_ERROR = """
{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
"""

# docstyle-ignore
# 定义一个错误消息模板,提示用户缺少 imageio 库和 ffmpeg
IMAGEIO_IMPORT_ERROR = """
{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
"""

# 定义一个有序字典,用于映射后端
BACKENDS_MAPPING = OrderedDict(
    # 创建一个包含库名称及其可用性检查和导入错误的元组列表
        [
            # BS4 库的可用性检查及对应的导入错误信息
            ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
            # Flax 库的可用性检查及对应的导入错误信息
            ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
            # Inflect 库的可用性检查及对应的导入错误信息
            ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
            # ONNX 库的可用性检查及对应的导入错误信息
            ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
            # OpenCV 库的可用性检查及对应的导入错误信息
            ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)),
            # SciPy 库的可用性检查及对应的导入错误信息
            ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
            # PyTorch 库的可用性检查及对应的导入错误信息
            ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
            # Transformers 库的可用性检查及对应的导入错误信息
            ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
            # Unidecode 库的可用性检查及对应的导入错误信息
            ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
            # Librosa 库的可用性检查及对应的导入错误信息
            ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
            # K-Diffusion 库的可用性检查及对应的导入错误信息
            ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
            # Note Seq 库的可用性检查及对应的导入错误信息
            ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)),
            # WandB 库的可用性检查及对应的导入错误信息
            ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
            # TensorBoard 库的可用性检查及对应的导入错误信息
            ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
            # Compel 库的可用性检查及对应的导入错误信息
            ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)),
            # FTFY 库的可用性检查及对应的导入错误信息
            ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
            # TorchSDE 库的可用性检查及对应的导入错误信息
            ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
            # Invisible Watermark 库的可用性检查及对应的导入错误信息
            ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
            # PEFT 库的可用性检查及对应的导入错误信息
            ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
            # SafeTensors 库的可用性检查及对应的导入错误信息
            ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
            # BitsAndBytes 库的可用性检查及对应的导入错误信息
            ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
            # SentencePiece 库的可用性检查及对应的导入错误信息
            ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
            # ImageIO 库的可用性检查及对应的导入错误信息
            ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
        ]
# 定义一个装饰器函数,检查所需后端是否可用
def requires_backends(obj, backends):
    # 如果后端参数不是列表或元组,则将其转换为列表
    if not isinstance(backends, (list, tuple)):
        backends = [backends]

    # 获取对象的名称,如果对象有 __name__ 属性则使用该属性,否则使用类名
    name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
    # 根据后端列表生成检查器元组
    checks = (BACKENDS_MAPPING[backend] for backend in backends)
    # 遍历检查器,生成未通过检查的错误信息列表
    failed = [msg.format(name) for available, msg in checks if not available()]
    # 如果有检查失败,则抛出导入错误,显示所有错误信息
    if failed:
        raise ImportError("".join(failed))

    # 如果对象名称在特定管道列表中,并且 transformers 版本小于 4.25.0,则抛出错误
    if name in [
        "VersatileDiffusionTextToImagePipeline",
        "VersatileDiffusionPipeline",
        "VersatileDiffusionDualGuidedPipeline",
        "StableDiffusionImageVariationPipeline",
        "UnCLIPPipeline",
    ] and is_transformers_version("<", "4.25.0"):
        raise ImportError(
            f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install"
            " --upgrade transformers \n```py"
        )

    # 如果对象名称在另一个特定管道列表中,并且 transformers 版本小于 4.26.0,则抛出错误
    if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
        "<", "4.26.0"
    ):
        raise ImportError(
            f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
            " --upgrade transformers \n```py"
        )


# 定义一个元类,用于生成虚拟对象
class DummyObject(type):
    """
    Dummy 对象的元类。任何继承自它的类在用户尝试访问其任何方法时都会返回由
    `requires_backend` 生成的 ImportError。
    """

    # 重写 __getattr__ 方法,处理属性访问
    def __getattr__(cls, key):
        # 如果属性以 "_" 开头且不是特定的两个属性,则返回默认实现
        if key.startswith("_") and key not in ["_load_connected_pipes", "_is_onnx"]:
            return super().__getattr__(cls, key)
        # 检查类是否满足后端要求
        requires_backends(cls, cls._backends)


# 该函数用于比较库版本与要求版本的关系,来源于指定的 GitHub 链接
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
    """
    参数:
    比较库版本与某个要求之间的关系,使用给定的操作符。
        library_or_version (`str` 或 `packaging.version.Version`):
            要检查的库名或版本。
        operation (`str`):
            操作符的字符串表示,如 `">"` 或 `"<="`。
        requirement_version (`str`):
            要与库版本比较的版本
    """
    # 检查操作符是否在预定义的操作符字典中
    if operation not in STR_OPERATION_TO_FUNC.keys():
        raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
    # 获取对应的操作函数
    operation = STR_OPERATION_TO_FUNC[operation]
    # 如果传入的库或版本是字符串,则获取其版本信息
    if isinstance(library_or_version, str):
        library_or_version = parse(importlib_metadata.version(library_or_version))
    # 使用操作函数比较版本,并返回结果
    return operation(library_or_version, parse(requirement_version))


# 该函数用于检查 PyTorch 版本与给定要求的关系,来源于指定的 GitHub 链接
def is_torch_version(operation: str, version: str):
    """
    参数:
    # 比较当前 PyTorch 版本与给定参考版本及操作符
    # operation (`str`): 操作符的字符串表示,例如 `">"` 或 `"<="`
    # version (`str`): PyTorch 的字符串版本
    """
    # 返回解析后的当前 PyTorch 版本与指定版本的比较结果
    return compare_versions(parse(_torch_version), operation, version)
# 定义一个函数,用于比较当前 Transformers 版本与给定版本之间的关系
def is_transformers_version(operation: str, version: str):
    """
    Args:
    Compares the current Transformers version to a given reference with an operation.
        operation (`str`):
            A string representation of an operator, such as `">"` or `"<="`
        version (`str`):
            A version string
    """
    # 检查 Transformers 是否可用,如果不可用则返回 False
    if not _transformers_available:
        return False
    # 将当前 Transformers 版本与给定版本进行比较,并返回比较结果
    return compare_versions(parse(_transformers_version), operation, version)


# 定义一个函数,用于比较当前 Accelerate 版本与给定版本之间的关系
def is_accelerate_version(operation: str, version: str):
    """
    Args:
    Compares the current Accelerate version to a given reference with an operation.
        operation (`str`):
            A string representation of an operator, such as `">"` or `"<="`
        version (`str`):
            A version string
    """
    # 检查 Accelerate 是否可用,如果不可用则返回 False
    if not _accelerate_available:
        return False
    # 将当前 Accelerate 版本与给定版本进行比较,并返回比较结果
    return compare_versions(parse(_accelerate_version), operation, version)


# 定义一个函数,用于比较当前 PEFT 版本与给定版本之间的关系
def is_peft_version(operation: str, version: str):
    """
    Args:
    Compares the current PEFT version to a given reference with an operation.
        operation (`str`):
            A string representation of an operator, such as `">"` or `"<="`
        version (`str`):
            A version string
    """
    # 检查 PEFT 版本是否可用,如果不可用则返回 False
    if not _peft_version:
        return False
    # 将当前 PEFT 版本与给定版本进行比较,并返回比较结果
    return compare_versions(parse(_peft_version), operation, version)


# 定义一个函数,用于比较当前 k-diffusion 版本与给定版本之间的关系
def is_k_diffusion_version(operation: str, version: str):
    """
    Args:
    Compares the current k-diffusion version to a given reference with an operation.
        operation (`str`):
            A string representation of an operator, such as `">"` or `"<="`
        version (`str`):
            A version string
    """
    # 检查 k-diffusion 是否可用,如果不可用则返回 False
    if not _k_diffusion_available:
        return False
    # 将当前 k-diffusion 版本与给定版本进行比较,并返回比较结果
    return compare_versions(parse(_k_diffusion_version), operation, version)


# 定义一个函数,用于从指定模块中提取对象
def get_objects_from_module(module):
    """
    Args:
    Returns a dict of object names and values in a module, while skipping private/internal objects
        module (ModuleType):
            Module to extract the objects from.

    Returns:
        dict: Dictionary of object names and corresponding values
    """

    # 创建一个空字典,用于存储对象名称及其对应的值
    objects = {}
    # 遍历模块中的所有对象名称
    for name in dir(module):
        # 跳过以 "_" 开头的私有对象
        if name.startswith("_"):
            continue
        # 获取对象并将名称和值存入字典
        objects[name] = getattr(module, name)

    # 返回包含对象名称及值的字典
    return objects


# 定义一个异常类,用于表示可选依赖未在环境中找到的错误
class OptionalDependencyNotAvailable(BaseException):
    """An error indicating that an optional dependency of Diffusers was not found in the environment."""


# 定义一个懒加载模块类,只有在请求对象时才会执行相关导入
class _LazyModule(ModuleType):
    """
    Module class that surfaces all objects but only performs associated imports when the objects are requested.
    """

    # 受到 optuna.integration._IntegrationModule 的启发
    # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
    # 初始化方法,接收模块的名称、文件、导入结构及可选参数
        def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
            # 调用父类初始化方法
            super().__init__(name)
            # 存储导入结构中模块的集合
            self._modules = set(import_structure.keys())
            # 存储类与模块的映射字典
            self._class_to_module = {}
            # 遍历导入结构,将每个类与其对应模块关联
            for key, values in import_structure.items():
                for value in values:
                    self._class_to_module[value] = key
            # 为 IDE 的自动补全准备的属性
            self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
            # 记录模块文件的路径
            self.__file__ = module_file
            # 记录模块的规格
            self.__spec__ = module_spec
            # 存储模块路径
            self.__path__ = [os.path.dirname(module_file)]
            # 存储额外对象,如果没有则初始化为空字典
            self._objects = {} if extra_objects is None else extra_objects
            # 存储模块名称
            self._name = name
            # 存储导入结构
            self._import_structure = import_structure
    
        # 为 IDE 的自动补全准备的方法
        def __dir__(self):
            # 获取当前对象的所有属性
            result = super().__dir__()
            # 将未包含在当前属性中的__all__元素添加到结果中
            for attr in self.__all__:
                if attr not in result:
                    result.append(attr)
            # 返回包含所有属性的列表
            return result
    
        def __getattr__(self, name: str) -> Any:
            # 如果对象中包含该名称,返回对应对象
            if name in self._objects:
                return self._objects[name]
            # 如果名称在模块中,获取模块
            if name in self._modules:
                value = self._get_module(name)
            # 如果名称在类到模块的映射中,获取对应模块的属性
            elif name in self._class_to_module.keys():
                module = self._get_module(self._class_to_module[name])
                value = getattr(module, name)
            # 如果名称不存在,抛出属性错误
            else:
                raise AttributeError(f"module {self.__name__} has no attribute {name}")
    
            # 将获取到的值设置为对象的属性
            setattr(self, name, value)
            # 返回值
            return value
    
        def _get_module(self, module_name: str):
            # 尝试导入指定模块
            try:
                return importlib.import_module("." + module_name, self.__name__)
            # 捕获异常,抛出运行时错误
            except Exception as e:
                raise RuntimeError(
                    f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
                    f" traceback):\n{e}"
                ) from e
    
        def __reduce__(self):
            # 返回对象的序列化信息
            return (self.__class__, (self._name, self.__file__, self._import_structure))

.\diffusers\utils\loading_utils.py

# 导入操作系统模块
import os
# 导入临时文件模块
import tempfile
# 导入类型相关的类型提示
from typing import Callable, List, Optional, Union

# 导入 PIL 库中的图像模块
import PIL.Image
# 导入 PIL 库中的图像处理模块
import PIL.ImageOps
# 导入请求库
import requests

# 从本地模块中导入一些实用工具
from .import_utils import BACKENDS_MAPPING, is_imageio_available


# 定义加载图像的函数
def load_image(
    # 接受字符串或 PIL 图像作为输入
    image: Union[str, PIL.Image.Image], 
    # 可选的转换方法,用于加载后处理图像
    convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
) -> PIL.Image.Image:
    """
    加载给定的 `image` 为 PIL 图像。

    参数:
        image (`str` 或 `PIL.Image.Image`):
            要转换为 PIL 图像格式的图像。
        convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
            加载后应用于图像的转换方法。如果为 `None`,则图像将被转换为 "RGB"。

    返回:
        `PIL.Image.Image`:
            一个 PIL 图像。
    """
    # 检查 image 是否为字符串类型
    if isinstance(image, str):
        # 如果字符串以 http 或 https 开头,则认为是 URL
        if image.startswith("http://") or image.startswith("https://"):
            # 通过请求获取图像并打开为 PIL 图像
            image = PIL.Image.open(requests.get(image, stream=True).raw)
        # 检查给定路径是否为有效文件
        elif os.path.isfile(image):
            # 打开本地文件为 PIL 图像
            image = PIL.Image.open(image)
        else:
            # 如果路径无效,抛出错误
            raise ValueError(
                f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
            )
    # 检查 image 是否为 PIL 图像对象
    elif isinstance(image, PIL.Image.Image):
        # 如果是 PIL 图像,保持不变
        image = image
    else:
        # 如果格式不正确,抛出错误
        raise ValueError(
            "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
        )

    # 应用 EXIF 转换,调整图像方向
    image = PIL.ImageOps.exif_transpose(image)

    # 如果提供了转换方法,则应用该方法
    if convert_method is not None:
        image = convert_method(image)
    else:
        # 否则将图像转换为 RGB 格式
        image = image.convert("RGB")

    # 返回处理后的图像
    return image


# 定义加载视频的函数
def load_video(
    # 接受视频的字符串路径或 URL
    video: str,
    # 可选的转换方法,用于加载后处理图像列表
    convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
) -> List[PIL.Image.Image]:
    """
    加载给定的 `video` 为 PIL 图像列表。

    参数:
        video (`str`):
            视频的 URL 或路径,转换为 PIL 图像列表。
        convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
            加载后应用于视频的转换方法。如果为 `None`,则图像将被转换为 "RGB"。

    返回:
        `List[PIL.Image.Image]`:
            视频作为 PIL 图像列表。
    """
    # 检查视频是否为 URL
    is_url = video.startswith("http://") or video.startswith("https://")
    # 检查视频路径是否为有效文件
    is_file = os.path.isfile(video)
    # 标记是否创建了临时文件
    was_tempfile_created = False

    # 如果既不是 URL 也不是文件,抛出错误
    if not (is_url or is_file):
        raise ValueError(
            f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
        )

    # 如果是 URL,获取视频数据
    if is_url:
        video_data = requests.get(video, stream=True).raw
        # 获取视频的文件后缀,如果没有则默认使用 .mp4
        suffix = os.path.splitext(video)[1] or ".mp4"
        # 创建一个带后缀的临时文件
        video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
        # 标记已创建临时文件
        was_tempfile_created = True
        # 将视频数据写入临时文件
        with open(video_path, "wb") as f:
            f.write(video_data.read())

        # 更新视频变量为临时文件路径
        video = video_path
    # 初始化一个空列表,用于存储 PIL 图像
        pil_images = []
        # 检查视频文件是否为 GIF 格式
        if video.endswith(".gif"):
            # 打开 GIF 文件
            gif = PIL.Image.open(video)
            try:
                # 持续读取 GIF 的每一帧
                while True:
                    # 将当前帧的副本添加到列表中
                    pil_images.append(gif.copy())
                    # 移动到下一帧
                    gif.seek(gif.tell() + 1)
            # 捕捉到 EOFError 时停止循环
            except EOFError:
                pass
    
        else:
            # 检查 imageio 库是否可用
            if is_imageio_available():
                # 导入 imageio 库
                import imageio
            else:
                # 如果不可用,抛出导入错误
                raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video"))
    
            try:
                # 尝试获取 ffmpeg 执行文件
                imageio.plugins.ffmpeg.get_exe()
            except AttributeError:
                # 如果未找到 ffmpeg,抛出属性错误
                raise AttributeError(
                    "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg"
                )
    
            # 使用 imageio 创建视频读取器
            with imageio.get_reader(video) as reader:
                # 读取所有帧
                for frame in reader:
                    # 将每一帧转换为 PIL 图像并添加到列表中
                    pil_images.append(PIL.Image.fromarray(frame))
    
        # 如果创建了临时文件,删除它
        if was_tempfile_created:
            os.remove(video_path)
    
        # 如果提供了转换方法,应用该方法到 PIL 图像列表
        if convert_method is not None:
            pil_images = convert_method(pil_images)
    
        # 返回 PIL 图像列表
        return pil_images