CogView3---CogView-3Plus-微调代码源码解析-一-

龙哥盟 / 2024-11-11 / 原文

CogView3 & CogView-3Plus 微调代码源码解析(一)

Raise valuable PR / 提出有价值的PR

Caution / 注意事项:

Users should keep the following points in mind when submitting PRs:

  1. Ensure that your code meets the requirements in the specification.
  2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.

用户在提交PR时候应该注意以下几点:

  1. 确保您的代码符合 规范 中的要求。
  2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。

不应该提出的PR / PRs that should not be proposed

If a developer proposes a PR about any of the following, it may be closed or Rejected.

  1. those that don't describe improvement options.
  2. multiple issues of different types combined in one PR.
  3. The proposed PR is highly duplicative of already existing PRs.

如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。

  1. 没有说明改进方案的。
  2. 多个不同类型的问题合并在一个PR中的。
  3. 提出的PR与已经存在的PR高度重复的。

检查您的PR

.\cogview3-finetune\inference\cli_demo.py

# 该脚本演示如何使用 Hugging Face 的 `diffusers` 管道生成图像
"""
This script demonstrates how to generate an image using the CogView3-Plus-3B model with the Hugging Face `diffusers` pipeline.
It showcases memory-efficient techniques like model offloading, VAE slicing, and tiling to reduce memory consumption during inference.
The prompt describes an image to be generated by the model, and the final image is saved to disk.

Running the Script:
To run the script, use the following command with appropriate arguments:


python cli_demo.py --prompt "A beautiful sunset over a mountain" --width 1024 --height 1024


Additional options are available to specify the model path, guidance scale, number of inference steps, image generation type, and output paths.
"""
from diffusers import CogView3PlusPipeline  # 从 diffusers 库导入 CogView3PlusPipeline 类
import torch  # 导入 PyTorch 库以进行张量操作
import argparse  # 导入 argparse 库以解析命令行参数


def generate_image(  # 定义生成图像的函数
    prompt, model_path, guidance_scale, num_images_per_prompt, num_inference_steps, width, height, output_path, dtype
):
    # 使用指定精度加载预训练模型
    pipe = CogView3PlusPipeline.from_pretrained(model_path, torch_dtype=dtype)

    # 启用 CPU 卸载,以便在层未被使用时释放 GPU 内存
    pipe.enable_model_cpu_offload()

    # 启用 VAE 切片和拼接以优化内存使用
    pipe.vae.enable_slicing()
    pipe.vae.enable_tiling()

    # 根据提示生成图像
    image = pipe(
        prompt=prompt,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
    ).images[0]  # 获取生成的第一幅图像

    # 将生成的图像保存到本地文件系统
    image.save(output_path)

    # 打印保存路径
    print(f"Image saved to {output_path}")


if __name__ == "__main__":  # 检查是否为主程序
    parser = argparse.ArgumentParser(description="Generate an image using the CogView3-Plus-3B model.")  # 创建参数解析器

    # 定义用于提示、模型路径等的参数
    parser.add_argument("--prompt", type=str, required=True, help="The text description for generating the image.")  # 添加提示参数
    parser.add_argument(
        "--model_path", type=str, default="THUDM/CogView3-Plus-3B", help="Path to the pre-trained model."  # 添加模型路径参数
    )
    parser.add_argument(
        "--guidance_scale", type=float, default=7.0, help="The guidance scale for classifier-free guidance."  # 添加引导比例参数
    )
    parser.add_argument(
        "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt."  # 添加每个提示生成图像数量的参数
    )
    parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of denoising steps for inference.")  # 添加推理步骤数量的参数
    parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.")  # 添加生成图像宽度的参数
    parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.")  # 添加生成图像高度的参数
    parser.add_argument("--output_path", type=str, default="cogview3.png", help="Path to save the generated image.")  # 添加输出路径参数
    parser.add_argument("--dtype", type=str, default="bfloat16", help="Precision type (float16 or bfloat16).")  # 添加数据类型参数
    # 解析命令行参数
    args = parser.parse_args()

    # 将 dtype 参数转换为 PyTorch 的数据类型
    dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16

    # 调用生成图像的函数
    generate_image(
        # 传入用户输入的提示文本
        prompt=args.prompt,
        # 传入模型的路径
        model_path=args.model_path,
        # 传入引导尺度
        guidance_scale=args.guidance_scale,
        # 传入每个提示生成的图像数量
        num_images_per_prompt=args.num_images_per_prompt,
        # 传入推理步骤的数量
        num_inference_steps=args.num_inference_steps,
        # 传入生成图像的宽度
        width=args.width,
        # 传入生成图像的高度
        height=args.height,
        # 传入输出图像的路径
        output_path=args.output_path,
        # 传入转换后的数据类型
        dtype=dtype,
    )

.\cogview3-finetune\inference\gradio_web_demo.py

# 主文件用于 Gradio 网络演示,使用 CogView3-Plus-3B 模型生成图像
"""
THis is the main file for the gradio web demo. It uses the CogView3-Plus-3B model to generate images gradio web demo.
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.

Usage:
    OpenAI_API_KEY=your_openai_api_key OpenAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
"""

# 导入必要的库
import os  # 用于处理操作系统功能,如环境变量
import re  # 用于正则表达式处理字符串
import threading  # 用于多线程操作
import time  # 用于时间相关操作
from datetime import datetime, timedelta  # 用于日期和时间处理

import gradio as gr  # 导入 Gradio 库以创建用户界面
import random  # 用于生成随机数
from diffusers import CogView3PlusPipeline  # 导入用于图像生成的管道
import torch  # 导入 PyTorch 库以处理深度学习模型
from openai import OpenAI  # 导入 OpenAI 库以使用 API

import gc  # 导入垃圾回收模块

# 检查是否可以使用 GPU,设置设备类型
device = "cuda" if torch.cuda.is_available() else "cpu"

# 加载预训练的 CogView3-Plus-3B 模型并将其移到指定设备
pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16).to(device)

# 创建用于临时文件的目录,如果已经存在则不报错
os.makedirs("./gradio_tmp", exist_ok=True)


# 定义函数以清理字符串
def clean_string(s):
    # 将字符串中的换行符替换为空格
    s = s.replace("\n", " ")
    # 去掉字符串开头和结尾的空白
    s = s.strip()
    # 用单个空格替换两个或更多的空白
    s = re.sub(r"\s{2,}", " ", s)
    # 返回清理后的字符串
    return s


# 定义函数以转换提示词
def convert_prompt(
    prompt: str,
    retry_times: int = 5,
) -> str:
    # 检查环境变量是否设置了 OPENAI_API_KEY
    if not os.environ.get("OPENAI_API_KEY"):
        # 如果未设置,直接返回原始提示词
        return prompt
    # 创建 OpenAI 客户端实例
    client = OpenAI()
    # 定义系统指令,指导图像描述生成
    system_instruction = """
    You are part of a team of bots that creates images . You work with an assistant bot that will draw anything you say. 
    For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an image of a forest morning , as described. 
    You will be prompted by people looking to create detailed , amazing images. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive. 
    There are a few rules to follow : 
    - Prompt should always be written in English, regardless of the input language. Please provide the prompts in English.
    - You will only ever output a single image description per user request.
    - Image descriptions must be detailed and specific, including keyword categories such as subject, medium, style, additional details, color, and lighting. 
    - When generating descriptions, focus on portraying the visual elements rather than delving into abstract psychological and emotional aspects. Provide clear and concise details that vividly depict the scene and its composition, capturing the tangible elements that make up the setting.
    - Do not provide the process and explanation, just return the modified English description . Image descriptions must be between 100-200 words. Extra words will be ignored. 
    """

    # 去除提示词两端的空白
    text = prompt.strip()
    # 返回原始提示词(此处未对提示词进行修改)
    return prompt


# 定义函数以删除旧文件
def delete_old_files():
    # 无限循环,用于持续检查和清理文件
        while True:
            # 获取当前的日期和时间
            now = datetime.now()
            # 计算截止时间,5分钟前的时间点
            cutoff = now - timedelta(minutes=5)
            # 定义需要检查的目录列表
            directories = ["./gradio_tmp"]
    
            # 遍历每个目录
            for directory in directories:
                # 列出目录中的所有文件
                for filename in os.listdir(directory):
                    # 生成文件的完整路径
                    file_path = os.path.join(directory, filename)
                    # 检查路径是否为文件
                    if os.path.isfile(file_path):
                        # 获取文件的最后修改时间
                        file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
                        # 如果文件的修改时间早于截止时间,则删除该文件
                        if file_mtime < cutoff:
                            os.remove(file_path)
            # 暂停600秒(10分钟),然后继续循环
            time.sleep(600)
# 创建并启动一个后台线程用于删除旧文件
threading.Thread(target=delete_old_files, daemon=True).start()


# 定义推断函数,接受多个参数
def infer(
    # 输入提示词
    prompt,
    # 随机种子
    seed,
    # 是否随机化种子标志
    randomize_seed,
    # 图像宽度
    width,
    # 图像高度
    height,
    # 引导缩放参数
    guidance_scale,
    # 推断步骤数量
    num_inference_steps,
    # 进度条对象,跟踪进度
    progress=gr.Progress(track_tqdm=True),
):
    # 垃圾回收,释放内存
    gc.collect()
    # 清空 CUDA 的缓存
    torch.cuda.empty_cache()
    # 收集 CUDA 进程间通信的资源
    torch.cuda.ipc_collect()
    
    # 如果需要随机化种子
    if randomize_seed:
        # 生成一个新的随机种子
        seed = random.randint(0, 65536)

    # 使用管道进行推断,生成图像
    image = pipe(
        prompt=prompt,  # 输入提示词
        guidance_scale=guidance_scale,  # 指导缩放
        num_images_per_prompt=1,  # 每个提示生成一张图像
        num_inference_steps=num_inference_steps,  # 指定推断步骤
        width=width,  # 图像宽度
        height=height,  # 图像高度
        generator=torch.Generator().manual_seed(seed),  # 使用手动设置的种子生成器
    ).images[0]  # 获取生成的第一张图像
    # 返回生成的图像和种子
    return image, seed


# 示例提示词列表
examples = [
    # 描述一辆复古粉色敞篷车的场景
    "A vintage pink convertible with glossy chrome finishes and whitewall tires sits parked on an open road, surrounded by a field of wildflowers under a clear blue sky. The car's body is a delicate pastel pink, complementing the vibrant greens and colors of the meadow. Its interior boasts cream leather seats and a polished wooden dashboard, evoking a sense of classic elegance. The sun casts a soft light on the vehicle, highlighting its curves and shiny surfaces, creating a picture of nostalgia mixed with dreamy escapism.",
    # 描述一只黑色拉布拉多犬在草地上的场景
    "A noble black Labrador retriever sits serenely in a sunlit meadow, its glossy coat absorbing the golden rays of a late afternoon sun. The dog's intelligent eyes sparkle with a mixture of curiosity and loyalty, as it gazes off into the distance where the meadow meets a line of tall, slender birch trees. The dog's posture is regal, yet approachable, with its tongue playfully hanging out to the side slightly, suggesting a friendly disposition. The idyllic setting is filled with the vibrant greens of lush grass and the soft colors of wildflowers speckled throughout, creating a peaceful harmony between the dog and its natural surroundings.",
    # 描述一只红色犬在秋季森林中的场景
    "A vibrant red-colored dog of medium build stands attentively in an autumn forest setting. Its fur is a deep, rich red, reminiscent of autumn leaves, contrasting with its bright, intelligent eyes, a clear sky blue. The dog's ears perk up, and its tail wags slightly as it looks off into the distance, its posture suggesting alertness and curiosity. Golden sunlight filters through the canopy of russet and gold leaves above, casting dappled light onto the forest floor and the glossy coat of the canine, creating a serene and heartwarming scene.",
]

# CSS 样式定义
css = """
#col-container {
    margin: 0 auto;  # 设置外边距为 0,使容器居中
    max-width: 640px;  # 设置容器最大宽度为 640 像素
}
"""

# 使用 Gradio 创建块式界面
with gr.Blocks(css=css) as demo:
    # 设置触发器,当运行按钮点击或提示提交时调用推断函数
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,  # 指定调用的函数
        inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],  # 输入参数
        outputs=[result, seed],  # 输出结果和种子
    )

# 启动 Gradio 应用并处理请求
demo.queue().launch()

.\cogview3-finetune\prompt_optimize.py

# 导入正则表达式模块
import re
# 导入命令行参数解析模块
import argparse
# 从 OpenAI 库导入 OpenAI 类
from openai import OpenAI
# 导入追踪模块以便于调试
import traceback


# 定义一个函数,用于清理字符串
def clean_string(s):
    # 将字符串中的换行符替换为空格
    s = s.replace("\n", " ")
    # 去除字符串前后的空白字符
    s = s.strip()
    # 使用正则表达式替换多个空格为一个空格
    s = re.sub(r"\s{2,}", " ", s)
    # 返回清理后的字符串
    return s


# 定义一个函数,用于增强提示内容
def upsample_prompt(
        prompt: str,
        api_key: str,
        url: str,
        model: str
) -> str:
    # 创建 OpenAI 客户端实例
    client = OpenAI(api_key=api_key, base_url=url)
    # 定义系统指令,说明 bot 的工作职责和行为准则
    system_instruction = """
    You are part of a team of bots that creates images . You work with an assistant bot that will draw anything you say. 
    For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an image of a forest morning , as described. 
    You will be prompted by people looking to create detailed , amazing images. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive. 
    There are a few rules to follow : 
    - Prompt should always be written in English, regardless of the input language. Please provide the prompts in English.
    - You will only ever output a single image description per user request.
    - Image descriptions must be detailed and specific, including keyword categories such as subject, medium, style, additional details, color, and lighting. 
    - When generating descriptions, focus on portraying the visual elements rather than delving into abstract psychological and emotional aspects. Provide clear and concise details that vividly depict the scene and its composition, capturing the tangible elements that make up the setting.
    - Do not provide the process and explanation, just return the modified English description . Image descriptions must be between 100-200 words. Extra words will be ignored. 
    """
    # 去除提示文本前后的空白字符
    text = prompt.strip()
    # 捕获并打印异常信息(当前代码块没有执行内容,可能是个错误)
    except Exception as e:
        traceback.print_exc()
    # 返回原始提示(此处逻辑似乎有问题,应返回增强后的内容)
    return prompt


# 主程序入口
if __name__ == "__main__":
    # 创建参数解析器实例
    parser = argparse.ArgumentParser()
    # 添加 API 密钥参数
    parser.add_argument("--api_key", type=str, help="api key")
    # 添加提示内容参数
    parser.add_argument("--prompt", type=str, help="Prompt to upsample")
    # 添加基础 URL 参数,设置默认值
    parser.add_argument(
        "--base_url",
        type=str,
        default="https://open.bigmodel.cn/api/paas/v4",
        help="base url"
    )
    # 添加模型参数,设置默认值
    parser.add_argument(
        "--model",
        type=str,
        default="glm-4-plus",
        help="LLM using for upsampling"
    )
    # 解析命令行参数
    args = parser.parse_args()

    # 获取 API 密钥
    api_key = args.api_key
    # 获取提示内容
    prompt = args.prompt

    # 调用 upsample_prompt 函数进行增强提示
    prompt_enhanced = upsample_prompt(
        prompt=prompt,
        api_key=api_key,
        url=args.base_url,
        model=args.model
    )
    # 打印增强后的提示内容
    print(prompt_enhanced)

CogView3 & CogView-3Plus

Read this in English

在 🤗 Huggingface Space 在线体验 CogView3-Plus-3B 模型

📚 查看 论文

👋 加入我们的 微信

📍 前往 清言 和 API平台 体验更大规模的商业版视频生成模型。

项目更新

  • 🔥🔥 py/10/13: 我们适配和开源了 diffusers 版本的 CogView-3Plus-3B
    模型。你可以前往在线体验。
  • 🔥 py/9/29: 我们已经开源了 CogView3 以及 CogView-3Plus-3BCogView3 是一个基于级联扩散的文本生成图像系统,采用了接力扩散框架。
    CogView-3Plus 是一系列新开发的基 Diffusion Transformer 的文本生成图像模型。

模型介绍

CogView-3-Plus 在 CogView3(ECCV'24) 的基础上引入了最新的 DiT 框架,以实现整体性能的进一步提升。CogView-3-Plus 采用了
Zero-SNR
扩散噪声调度,并引入了文本-图像联合注意力机制。与常用的 MMDiT 结构相比,它在保持模型基本能力的同时,有效降低了训练和推理成本。CogView-3Plus
使用潜在维度为 16 的 VAE。

下表显示了我们目前提供的文本生成图像模型列表及其基础信息。

模型名称 CogView3-Base-3B CogView3-Base-3B-distill CogView3-Plus-3B
模型描述 CogView3 的基础阶段和接力阶段模型,支持 512x512 文本生成图像以及 2x 超分辨率生成。 CogView3 的蒸馏版本,分别在两个阶段采样 4 和 1 步(或 8 和 2 步)。 DIT 版本的图像生成模型 ,支持从 512 到 2048 范围内的图像生成。
分辨率 512 * 512 512 <= H, W <= 2048
H * W <= 2^{21}
H, W \mod 32 = 0
推理精度 FP16(推荐), BF16, FP32 BF16*(推荐), FP16, FP32
显存占用 (bs = 4) 17G 64G 30G(2048 * 2048)
20G(1024 * 1024)
提示词语言 English*
提示词长度上限 225 Tokens 224 Tokens
下载链接 (SAT) SAT
下载链接 (Diffusers) 未适配 🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel

数据解释

  • 所有推理测试均在单卡A100上运行,批量大小为4。并使用PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True以节约显存。
  • 模型仅支持英语输入,其他语言可以通过大模型润色时翻译为英语。
  • 本次测试环境均使用SAT框架测试,众多优化点还未完善,我们会联合社区一起制作diffusers库版本的模型。diffusers
    仓库支持后,将会使用diffusers 测试。预计将于 2024 年 11 月发布。

快速开始

提示词优化

虽然 CogView3 系列模型都是通过长篇合成图像描述进行训练的,但我们强烈建议在文本生成图像之前,基于大语言模型(LLMs)进行提示词的重写操作,这将大大提高生成质量。

我们提供了一个 示例脚本。我们建议您运行这个脚本,以实现对提示词对润色

python prompt_optimize.py --api_key "智谱AI API Key" --prompt {你的提示词} --base_url "https://open.bigmodel.cn/api/paas/v4" --model "glm-4-plus"

推理模型(Diffusers)

首先,确保从源代码安装diffusers库。

pip install git+https://github.com/huggingface/diffusers.git

接着,运行以下代码:

from diffusers import CogView3PlusPipeline
import torch

pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.float16).to("cuda")

# Open it for reduce GPU memory usage
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background."
image = pipe(
    prompt=prompt,
    guidance_scale=7.0,
    num_images_per_prompt=1,
    num_inference_steps=50,
    width=1024,
    height=1024,
).images[0]

image.save("cogview3.png")

更多推理代码,请关注inference,该文件夹还包含一个Gradio封装的简单WEBUI代码。

推理模型 (SAT)

请查看 sat 手把手教程实现模型推理。

开源计划

由于项目处于初步阶段,我们正在制作以下内容:

CogView3(ECCV'24)

官方论文仓库:CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion

CogView3 是一种新颖的文本生成图像系统,采用了接力扩散的方式,将生成高分辨率图像的过程分解为多个阶段。通过接力的超分辨率过程,对低分辨率生成结果添加高斯噪声,并从这些带噪声的图像开始扩散。我们的结果显示,CogView3
的表现优于 SDXL,获胜率达到 77.0%。此外,通过对扩散模型的逐步蒸馏,CogView3 能够在推理时间仅为 SDXL 1/10 的情况下,生成可比的结果。

CogView3 示例
CogView3 流程

人类评估的对比结果:

CogView3 evaluation

引用

🌟 如果您发现我们的工作有所帮助,欢迎引用我们的文章,留下宝贵的stars

@article{zheng2024cogview3,
  title={Cogview3: Finer and faster text-to-image generation via relay diffusion},
  author={Zheng, Wendi and Teng, Jiayan and Yang, Zhuoyi and Wang, Weihan and Chen, Jidong and Gu, Xiaotao and Dong, Yuxiao and Ding, Ming and Tang, Jie},
  journal={arXiv preprint arXiv:2403.05121},
  year={2024}
}

我们欢迎您的贡献,您可以点击这里查看更多信息。

模型协议

该代码库基于 Apache 2.0 License 协议发布。

CogView3-Base、CogView3-Relay 和 CogView3-Plus 模型(包括 UNet 模块、Transformers 模块和 VAE
模块)基于 Apache 2.0 License 协议发布。

Contribution Guide

There may still be many incomplete aspects in this project.

We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
and are willing to submit a PR and share it with the community, upon review, we
will acknowledge your contribution on the project homepage.

Model Algorithms

  • Support for model quantization inference (Int4 quantization project)
  • Optimization of model fine-tuning data loading (replacing the existing decord tool)

Model Engineering

  • diffusers version of the model implementation
  • Model fine-tuning examples / Best prompt practices
  • Inference adaptation on different devices (e.g., MLX framework)
  • Any tools related to the model

Code Standards

Good code style is an art. We have prepared a pyproject.toml configuration file for the project to standardize code
style. You can organize the code according to the following specifications:

  1. Install the ruff tool
pip install ruff

Then, run the ruff tool

ruff check tools sat inference

Check the code style. If there are issues, you can automatically fix them using the ruff format command.

ruff format tools sat inference

Once your code meets the standard, there should be no errors.

Naming Conventions

  1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
  2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.

贡献指南

本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。

模型工程

  • diffusers 版本的模型实现
  • 模型微调示例 / 最佳提示词实践
  • 不同设备上的推理适配(MLX等框架)
  • 任何模型周边工具

代码规范

良好的代码风格是一种艺术,我们已经为项目准备好了pyproject.toml配置文件,用于规范代码风格。您可以按照以下规范梳理代码:

  1. 安装ruff工具
pip install ruff

接着,运行ruff工具

ruff check tools sat inference

检查代码风格,如果有问题,您可以通过ruff format .命令自动修复。

ruff format tools sat inference

如果您的代码符合规范,应该不会出现任何的错误。

命名规范

  • 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
  • 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。

扫码关注公众号,加入「 CogView 交流群」

Scan the QR code to follow the official account and join the "CogView Discussion Group"

.\cogview3-finetune\sat\arguments.py

# 导入所需的库
import argparse  # 处理命令行参数
import os  # 与操作系统交互的功能
import torch  # 深度学习框架
import json  # JSON 数据处理
import warnings  # 发出警告的功能
import omegaconf  # 处理配置文件的库
from omegaconf import OmegaConf  # 导入 OmegaConf 用于配置文件
from sat.helpers import print_rank0  # 导入打印函数,用于仅在主进程中输出信息
from sat import mpu  # 导入分布式计算相关的功能
from sat.arguments import set_random_seed  # 设置随机种子的功能
from sat.arguments import add_training_args, add_evaluation_args, add_data_args  # 导入参数添加功能

# 定义函数以添加模型配置参数
def add_model_config_args(parser):
    """Model arguments"""  # 函数说明:模型参数

    # 创建一个新的参数组,命名为 "model"
    group = parser.add_argument_group("model", "model configuration")
    # 添加基本配置参数,接收多个字符串作为输入
    group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
    # 添加模型并行大小参数,默认为 1,仅供专家使用
    group.add_argument(
        "--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
    )
    # 添加强制预训练标志
    group.add_argument("--force-pretrain", action="store_true")
    # 添加设备参数,默认为 -1
    group.add_argument("--device", type=int, default=-1)

    # 返回修改后的解析器
    return parser

# 定义函数以添加采样配置参数
def add_sampling_config_args(parser):
    """Sampling configurations"""  # 函数说明:采样配置

    # 创建一个新的参数组,命名为 "sampling"
    group = parser.add_argument_group("sampling", "Sampling Configurations")
    # 添加输入目录参数,默认为 None
    group.add_argument("--input-dir", type=str, default=None)
    # 添加输出目录参数,默认为 "samples"
    group.add_argument("--output-dir", type=str, default="samples")
    # 添加输入类型参数,默认为 "cli"
    group.add_argument("--input-type", type=str, default="cli")
    # 添加中继模型参数,默认为 False
    group.add_argument("--relay-model", type=bool, default=False)
    # 添加输入文件参数,默认为 "input.txt"
    group.add_argument("--input-file", type=str, default="input.txt")
    # 添加采样图像大小参数,默认为 1024
    group.add_argument("--sampling-image-size", type=int, default=1024)
    # 添加采样潜在维度参数,默认为 4
    group.add_argument("--sampling-latent-dim", type=int, default=4)
    # 添加采样 F 参数,默认为 8
    group.add_argument("--sampling-f", type=int, default=8)
    # 添加采样图像宽度参数,默认为 None
    group.add_argument("--sampling-image-size-x", type=int, default=None)
    # 添加采样图像高度参数,默认为 None
    group.add_argument("--sampling-image-size-y", type=int, default=None)
    # 添加 SDEdit 标志
    group.add_argument("--sdedit", action="store_true")
    # 添加 IP2P 标志
    group.add_argument("--ip2p", action="store_true")
    # 添加网格列数参数,默认为 1
    group.add_argument("--grid-num-columns", type=int, default=1)
    # 添加强制推理标志
    group.add_argument("--force-inference", action="store_true")

    # 返回修改后的解析器
    return parser

# 定义函数以添加额外配置参数
def add_additional_config_args(parser):
    # 创建一个新的参数组,命名为 "additional"
    group = parser.add_argument_group("additional", "Additional Configurations")
    # 添加多方面训练标志
    group.add_argument("--multiaspect-training", action="store_true")
    # 添加多方面形状参数,接收多个整数
    group.add_argument("--multiaspect-shapes", nargs="+", default=None, type=int)

    # 返回修改后的解析器
    return parser

# 定义函数以获取所有参数
def get_args(args_list=None, parser=None):
    """Parse all the args."""  # 函数说明:解析所有参数
    # 如果未提供解析器,则创建一个新的 ArgumentParser 实例
    if parser is None:
        parser = argparse.ArgumentParser(description="sat")
    else:
        # 确保提供的解析器是 ArgumentParser 的实例
        assert isinstance(parser, argparse.ArgumentParser)
    # 添加模型配置参数
    parser = add_model_config_args(parser)
    # 添加采样配置参数
    parser = add_sampling_config_args(parser)
    # 添加训练参数
    parser = add_training_args(parser)
    # 添加评估参数
    parser = add_evaluation_args(parser)
    # 添加数据参数
    parser = add_data_args(parser)
    # 添加额外配置参数
    parser = add_additional_config_args(parser)

    # 导入 DeepSpeed 库
    import deepspeed
    # 包含 DeepSpeed 配置参数
    parser = deepspeed.add_config_arguments(parser)

    # 解析提供的参数列表
    args = parser.parse_args(args_list)
    # 处理配置并转换为参数
    args = process_config_to_args(args)

    # 如果没有指定训练数据,则发出警告
    if not args.train_data:
        print_rank0("No training data specified", level="WARNING")
    # 确保 train_iters 和 epochs 仅有一个被设置
        assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
        # 如果两个参数都没有设置
        if args.train_iters is None and args.epochs is None:
            # 默认设置为 10000 次迭代
            args.train_iters = 10000  # default 10k iters
            # 打印警告信息,使用默认的迭代次数
            print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
    
        # 检查 CUDA 是否可用
        args.cuda = torch.cuda.is_available()
    
        # 从环境变量获取当前进程的排名
        args.rank = int(os.getenv("RANK", "0"))
        # 从环境变量获取世界大小
        args.world_size = int(os.getenv("WORLD_SIZE", "1"))
        # 如果本地排名未设置
        if args.local_rank is None:
            # 从环境变量获取本地排名
            args.local_rank = int(os.getenv("LOCAL_RANK", "0"))  # torchrun
    
        # 如果设备未手动设置
        if args.device == -1:  # not set manually
            # 如果没有可用的 CUDA 设备
            if torch.cuda.device_count() == 0:
                # 使用 CPU 作为设备
                args.device = "cpu"
            # 如果本地排名已设置
            elif args.local_rank is not None:
                # 将本地排名设置为设备
                args.device = args.local_rank
            else:
                # 使用当前排名与 CUDA 设备数量的余数作为设备
                args.device = args.rank % torch.cuda.device_count()
    
        # 本地排名在 DeepSpeed 中应与设备一致
        if args.local_rank != args.device and args.mode != "inference":
            # 抛出不一致错误
            raise ValueError(
                "LOCAL_RANK (default 0) and args.device inconsistent. "
                "This can only happens in inference mode. "
                "Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. "
            )
    
        # args.model_parallel_size = min(args.model_parallel_size, args.world_size)
        # 如果当前进程为 0
        if args.rank == 0:
            # 打印世界大小
            print_rank0("using world size: {}".format(args.world_size))
        # if args.vocab_size > 0:
        #     _adjust_vocab_size(args)
    
        # 如果训练数据权重已设置
        if args.train_data_weights is not None:
            # 确保权重和训练数据的长度一致
            assert len(args.train_data_weights) == len(args.train_data)
    
        # 如果模式不是推理,则进行训练
        if args.mode != "inference":  # training with deepspeed
            # 启用 DeepSpeed
            args.deepspeed = True
            # 如果未指定 DeepSpeed 配置
            if args.deepspeed_config is None:  # not specified
                # 生成 DeepSpeed 配置路径
                deepspeed_config_path = os.path.join(
                    os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json"
                )
                # 打开并加载 DeepSpeed 配置文件
                with open(deepspeed_config_path) as file:
                    args.deepspeed_config = json.load(file)
                # 标记为覆盖 DeepSpeed 配置
                override_deepspeed_config = True
            else:
                # 不覆盖 DeepSpeed 配置
                override_deepspeed_config = False
    
        # 确保不能同时指定 fp16 和 bf16
        assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16."
    
        # 如果 zero_stage 大于 0 并且 fp16 和 bf16 均未设置
        if args.zero_stage > 0 and not args.fp16 and not args.bf16:
            # 自动设置 fp16 为 True
            print_rank0("Automatically set fp16=True to use ZeRO.")
            args.fp16 = True
            # 设置 bf16 为 False
            args.bf16 = False
    # 检查是否启用 DeepSpeed
        if args.deepspeed:
            # 检查是否启用检查点激活
            if args.checkpoint_activations:
                # 启用 DeepSpeed 激活检查点
                args.deepspeed_activation_checkpointing = True
            else:
                # 禁用 DeepSpeed 激活检查点
                args.deepspeed_activation_checkpointing = False
            # 检查是否指定了 DeepSpeed 配置
            if args.deepspeed_config is not None:
                # 将 DeepSpeed 配置赋值
                deepspeed_config = args.deepspeed_config
                # 注释掉的代码,读取 JSON 格式的 DeepSpeed 配置
                # with open(args.deepspeed_config) as file:
                #     deepspeed_config = json.load(file)
    
            # 如果覆盖 DeepSpeed 配置
            if override_deepspeed_config:  # not specify deepspeed_config, use args
                # 检查是否启用 FP16 精度
                if args.fp16:
                    deepspeed_config["fp16"]["enabled"] = True
                # 检查是否启用 BF16 精度
                elif args.bf16:
                    deepspeed_config["bf16"]["enabled"] = True
                    deepspeed_config["fp16"]["enabled"] = False
                else:
                    # 禁用 FP16 精度
                    deepspeed_config["fp16"]["enabled"] = False
                # 设置每个 GPU 的微批大小
                deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size
                # 设置梯度累积步数
                deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
                # 获取优化器参数配置
                optimizer_params_config = deepspeed_config["optimizer"]["params"]
                # 设置学习率
                optimizer_params_config["lr"] = args.lr
                # 设置权重衰减
                optimizer_params_config["weight_decay"] = args.weight_decay
            else:  # override args with values in deepspeed_config
                # 如果当前进程为主进程,输出提示信息
                if args.rank == 0:
                    print_rank0("Will override arguments with manually specified deepspeed_config!")
                # 检查 FP16 配置并更新参数
                if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
                    args.fp16 = True
                else:
                    args.fp16 = False
                # 检查 BF16 配置并更新参数
                if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]:
                    args.bf16 = True
                else:
                    args.bf16 = False
                # 更新每个 GPU 的微批大小
                if "train_micro_batch_size_per_gpu" in deepspeed_config:
                    args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
                # 更新梯度累积步数,如果没有则设为 None
                if "gradient_accumulation_steps" in deepspeed_config:
                    args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
                else:
                    args.gradient_accumulation_steps = None
                # 更新优化器参数
                if "optimizer" in deepspeed_config:
                    optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
                    args.lr = optimizer_params_config.get("lr", args.lr)
                    args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
            # 将 DeepSpeed 配置存储到 args 中
            args.deepspeed_config = deepspeed_config
    
        # 注释掉的代码,处理 sandwich 层归一化(在 v0.3 中移除)
        # if args.sandwich_ln: # removed in v0.3
        #     args.layernorm_order = 'sandwich'
    
        # 初始化分布式环境和随机种子
        initialize_distributed(args)
        # 设置种子,增加当前进程的排名
        args.seed = args.seed + torch.distributed.get_rank()
        # 设置随机种子
        set_random_seed(args.seed)
        # 返回更新后的参数
        return args
# 初始化分布式训练的设置
def initialize_distributed(args):
    """Initialize torch.distributed."""
    # 检查分布式训练是否已初始化
    if torch.distributed.is_initialized():
        # 检查模型并行是否已初始化
        if mpu.model_parallel_is_initialized():
            # 检查模型并行大小是否与之前的配置一致
            if args.model_parallel_size != mpu.get_model_parallel_world_size():
                raise ValueError(
                    "model_parallel_size is inconsistent with prior configuration."
                    "We currently do not support changing model_parallel_size."
                )
            return False
        else:
            # 如果模型并行大小大于1且未通过SAT初始化分布式
            if args.model_parallel_size > 1:
                warnings.warn(
                    "model_parallel_size > 1 but torch.distributed is not initialized via SAT."
                    "Please carefully make sure the correctness on your own."
                )
            # 初始化模型并行
            mpu.initialize_model_parallel(args.model_parallel_size)
        return True
    # 将设备的自动分配移至arguments.py
    if args.device == "cpu":
        pass
    else:
        # 设置当前CUDA设备
        torch.cuda.set_device(args.device)
    # 设置初始化方法
    init_method = "tcp://"
    # 获取主节点IP,默认为localhost
    args.master_ip = os.getenv("MASTER_ADDR", "localhost")

    # 如果世界规模为1,获取一个可用的端口
    if args.world_size == 1:
        from sat.helpers import get_free_port

        default_master_port = str(get_free_port())
    else:
        # 否则使用默认端口6000
        default_master_port = "6000"
    # 获取主节点端口,优先使用环境变量
    args.master_port = os.getenv("MASTER_PORT", default_master_port)
    # 构建初始化方法字符串
    init_method += args.master_ip + ":" + args.master_port
    # 初始化进程组
    torch.distributed.init_process_group(
        backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
    )

    # 设置模型并行和数据并行的通信器
    # mpu.initialize_model_parallel(args.model_parallel_size)
    mpu.initialize_model_parallel(1)
    # 可选的DeepSpeed激活检查点功能
    if args.deepspeed:
        import deepspeed

        # 初始化DeepSpeed分布式设置
        deepspeed.init_distributed(
            dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
        )
        # # 配置检查点,即使未使用也似乎没有负面影响
        # deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
    else:
        # 在仅模型模式下,不想初始化DeepSpeed,但仍需初始化rng追踪器,以便在丢弃时保存种子
        try:
            import deepspeed
            from deepspeed.runtime.activation_checkpointing.checkpointing import (
                _CUDA_RNG_STATE_TRACKER,
                _MODEL_PARALLEL_RNG_TRACKER_NAME,
            )

            # 默认种子为1,添加到RNG状态追踪器
            _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1)  # default seed 1
        except Exception as e:
            from sat.helpers import print_rank0

            # 打印调试级别的错误信息
            print_rank0(str(e), level="DEBUG")

    # 返回初始化成功
    return True


# 从--base中提取参数
def process_config_to_args(args):
    """Fetch args from only --base"""
    # 从给定的基本配置路径加载每个配置文件,并将它们组成一个列表
        configs = [OmegaConf.load(cfg) for cfg in args.base]
        # 合并所有加载的配置,形成一个单一的配置对象
        config = OmegaConf.merge(*configs)
    
        # 从合并后的配置中提取 "args" 部分,若不存在则创建一个空的 OmegaConf 对象
        args_config = config.pop("args", OmegaConf.create())
        # 遍历 args_config 中的每个键
        for key in args_config:
            # 检查值是否为字典或列表配置,若是则转换为普通对象
            if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
                arg = OmegaConf.to_object(args_config[key])
            else:
                # 否则直接获取其值
                arg = args_config[key]
            # 如果 args 中有该键,则设置其属性为对应的值
            if hasattr(args, key):
                setattr(args, key, arg)
    
        # 检查配置中是否包含 "model" 键
        if "model" in config:
            # 从配置中提取 "model" 部分,若不存在则创建一个空的 OmegaConf 对象
            model_config = config.pop("model", OmegaConf.create())
            # 将提取的模型配置赋值给 args 的 model_config 属性
            args.model_config = model_config
        # 检查配置中是否包含 "deepspeed" 键
        if "deepspeed" in config:
            # 从配置中提取 "deepspeed" 部分,若不存在则创建一个空的 OmegaConf 对象
            deepspeed_config = config.pop("deepspeed", OmegaConf.create())
            # 将提取的深度学习加速配置转换为对象并赋值给 args 的 deepspeed_config 属性
            args.deepspeed_config = OmegaConf.to_object(deepspeed_config)
        # 检查配置中是否包含 "data" 键
        if "data" in config:
            # 从配置中提取 "data" 部分,若不存在则创建一个空的 OmegaConf 对象
            data_config = config.pop("data", OmegaConf.create())
            # 将提取的数据配置赋值给 args 的 data_config 属性
            args.data_config = data_config
    
        # 返回更新后的 args 对象
        return args

.\cogview3-finetune\sat\diffusion.py

# 导入数学库以进行数学运算
import math
# 从 typing 模块导入类型提示相关的类
from typing import Any, Dict, List, Tuple, Union

# 导入 PyTorch 库及其 nn 模块
import torch
from torch import nn
# 导入 PyTorch 的功能模块
import torch.nn.functional as F

# 从 sgm.modules 导入未指定条件的配置
from sgm.modules import UNCONDITIONAL_CONFIG
# 从 sgm.modules.diffusionmodules.wrappers 导入 OPENAIUNETWRAPPER 类
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
# 从 sgm.util 导入默认值获取、字符串到对象的转换、配置实例化函数
from sgm.util import default, get_obj_from_str, instantiate_from_config


# 定义 SATDiffusionEngine 类,继承自 nn.Module
class SATDiffusionEngine(nn.Module):
    # 使用装饰器禁用梯度计算
    @torch.no_grad()
    # 定义解码第一阶段的方法
    def decode_first_stage(self, z):
        # 根据缩放因子调整 z 的值
        z = 1.0 / self.scale_factor * z
        # 获取每次解码的样本数量,使用默认值处理
        n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])

        # 计算所需的轮数以解码所有样本
        n_rounds = math.ceil(z.shape[0] / n_samples)
        # 创建一个空列表以存储输出
        all_out = []
        # 在自动混合精度的上下文中运行
        with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
            # 遍历每一轮
            for n in range(n_rounds):
                # 解码当前样本批次
                out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples])
                # 将输出添加到输出列表中
                all_out.append(out)
        # 将所有输出在第0维拼接
        out = torch.cat(all_out, dim=0)
        # 返回拼接后的输出
        return out

    # 使用装饰器禁用梯度计算
    @torch.no_grad()
    # 定义编码第一阶段的方法
    def encode_first_stage(self, x):
        # 获取每次编码的样本数量,使用默认值处理
        n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
        # 计算所需的轮数以编码所有样本
        n_rounds = math.ceil(x.shape[0] / n_samples)
        # 创建一个空列表以存储输出
        all_out = []
        # 在自动混合精度的上下文中运行
        with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
            # 遍历每一轮
            for n in range(n_rounds):
                # 编码当前样本批次
                out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples])
                # 将输出添加到输出列表中
                all_out.append(out)
        # 将所有输出在第0维拼接
        z = torch.cat(all_out, dim=0)
        # 根据缩放因子调整 z 的值
        z = self.scale_factor * z
        # 返回编码后的结果
        return z

    # 定义前向传播的方法
    def forward(self, x, batch, **kwargs):
        # 计算损失
        loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
        # 计算损失的均值
        loss_mean = loss.mean()
        # 创建一个字典以存储损失
        loss_dict = {"loss": loss_mean}
        # 返回损失均值和损失字典
        return loss_mean, loss_dict

    # 定义共享步骤的方法
    def shared_step(self, batch: Dict) -> Any:
        # 从批次中获取输入
        x = self.get_input(batch)
        # 检查学习率缩放因子是否为 None
        if self.lr_scale is not None:
            # 对输入进行下采样
            lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
            # 对下采样后的输入进行上采样
            lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
            # 编码下采样后的输入
            lr_z = self.encode_first_stage(lr_x)
            # 将编码结果存入批次
            batch["lr_input"] = lr_z
        # 编码原始输入
        x = self.encode_first_stage(x)
        # batch["global_step"] = self.global_step  # 这行被注释掉,可能是为了调试或保留未来的扩展
        # 计算损失和损失字典
        loss, loss_dict = self(x, batch)
        # 返回损失和损失字典
        return loss, loss_dict

    # 使用装饰器禁用梯度计算
    @torch.no_grad()
    # 定义采样的方法,包含条件和可选参数
    def sample(
        self,
        cond: Dict,
        uc: Union[Dict, None] = None,
        batch_size: int = 16,
        shape: Union[None, Tuple, List] = None,
        target_size=None,
        **kwargs,
    ):
        # 生成形状为 (batch_size, *shape) 的随机正态分布张量,并转换为 float32 类型,移动到指定设备
        randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)

        # 检查是否提供 target_size
        if target_size is not None:
            # 定义 denoiser 函数,包含 target_size 参数
            denoiser = lambda input, sigma, c, **additional_model_inputs: self.denoiser(
                self.model, input, sigma, c, target_size=target_size, **additional_model_inputs
            )
        else:
            # 定义 denoiser 函数,不包含 target_size 参数
            denoiser = lambda input, sigma, c, **additional_model_inputs: self.denoiser(
                self.model, input, sigma, c, **additional_model_inputs
            )
        # 使用采样器生成样本
        samples = self.sampler(denoiser, randn, cond, uc=uc)
        # 检查样本是否为列表
        if isinstance(samples, list):
            # 遍历样本列表,将每个样本转换为指定数据类型
            for i in range(len(samples)):
                samples[i] = samples[i].to(self.dtype)
        else:
            # 将样本转换为指定数据类型
            samples = samples.to(self.dtype)
        # 返回生成的样本
        return samples

    @torch.no_grad()  # 禁用梯度计算以节省内存和提高速度
    def sample_relay(
        self,
        image: torch.Tensor,  # 输入图像张量
        cond: Dict,  # 条件字典
        uc: Union[Dict, None] = None,  # 可选的未条件字典
        batch_size: int = 16,  # 批量大小,默认为16
        shape: Union[None, Tuple, List] = None,  # 输出形状,可选
        **kwargs,  # 其他可选参数
    ):
        # 生成形状为 (batch_size, *shape) 的随机正态分布张量,转换为指定数据类型,移动到指定设备
        randn = torch.randn(batch_size, *shape).to(self.dtype).to(self.device)
        # 定义 denoiser 函数
        denoiser = lambda input, sigma, c, **additional_model_inputs: self.denoiser(
            self.model, input, sigma, c, **additional_model_inputs
        )
        # 使用采样器生成样本
        samples = self.sampler(denoiser, image, randn, cond, uc=uc)
        # 检查样本是否为列表
        if isinstance(samples, list):
            # 遍历样本列表,将每个样本转换为指定数据类型
            for i in range(len(samples)):
                samples[i] = samples[i].to(self.dtype)
        else:
            # 将样本转换为指定数据类型
            samples = samples.to(self.dtype)
        # 返回生成的样本
        return samples

SAT CogView3 & CogView-3-Plus

Read this in Chinese

This folder contains the inference code using the SAT weights, as well as fine-tuning code for SAT weights.

The code is the framework used by the team during model training. There are few comments, so it requires careful study.

Step-by-step guide to running the model

1. Environment setup

Ensure you have installed the dependencies required by this folder:

pip install -r requirements.txt

2. Download model weights

The following links are for different model weights:

CogView-3-Plus-3B

  • transformer: https://cloud.tsinghua.edu.cn/d/f913eabd3f3b4e28857c
  • vae: https://cloud.tsinghua.edu.cn/d/af4cc066ce8a4cf2ab79

CogView-3-Base-3B

  • transformer:

    • cogview3-base: https://cloud.tsinghua.edu.cn/d/242b66daf4424fa99bf0
    • cogview3-base-distill-4step: https://cloud.tsinghua.edu.cn/d/d10032a94db647f5aa0e
    • cogview3-base-distill-8step: https://cloud.tsinghua.edu.cn/d/1598d4fe4ebf4afcb6ae

    These three versions are interchangeable. Choose the one that suits your needs and run it with the corresponding configuration file.

  • vae: https://cloud.tsinghua.edu.cn/d/c8b9497fc5124d71818a/

CogView-3-Base-3B-Relay

  • transformer:

    • cogview3-relay: https://cloud.tsinghua.edu.cn/d/134951acced949c1a9e1/
    • cogview3-relay-distill-2step: https://cloud.tsinghua.edu.cn/d/6a902976fcb94ac48402
    • cogview3-relay-distill-1step: https://cloud.tsinghua.edu.cn/d/4d50ec092c64418f8418/

    These three versions are interchangeable. Choose the one that suits your needs and run it with the corresponding configuration file.

  • vae: Same as CogView-3-Base-3B

Next, arrange the model files into the following format:

.cogview3-plus-3b
├── transformer
│   ├── 1
│   │   └── mp_rank_00_model_states.pt
│   └── latest
└── vae
    └── imagekl_ch16.pt

Clone the T5 model. This model is not used for training or fine-tuning but is necessary. You can download the T5 model separately, but it must be in safetensors format, not bin format (otherwise an error may occur).

Since we have uploaded the T5 model in safetensors format in CogVideoX, a simple way is to clone the model from the CogVideoX-2B model and move it to the corresponding folder.

git clone https://huggingface.co/THUDM/CogVideoX-2b.git
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git
mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl

With this setup, you will have a safetensor format T5 file, ensuring no errors during Deepspeed fine-tuning.

├── added_tokens.json
├── config.json
├── model-00001-of-00002.safetensors
├── model-00002-of-00002.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── spiece.model
└── tokenizer_config.json

0 directories, 8 files

3. Modify the files in configs.

Here is an example using CogView3-Base, with explanations for some of the parameters:

args:
  mode: inference
  relay_model: False # Set to True when using CogView-3-Relay
  load: "cogview3_base/transformer" # Path to the transformer folder
  batch_size: 8 # Number of images per inference
  grid_num_columns: 2 # Number of columns in grid.png output
  input_type: txt # Input can be from command line or TXT file
  input_file: configs/test.txt # Not needed for command line input
  fp16: True # Set to bf16 for CogView-3-Plus inference
  # bf16: True
  sampling_image_size: 512 # Fixed size, supports 512x512 resolution images
  # For CogView-3-Plus, use the following:
  # sampling_image_size_x: 1024 (width)
  # sampling_image_size_y: 1024 (height)

  output_dir: "outputs/cogview3_base-512x512"
  # This section is for CogView-3-Relay. Set the input_dir to the folder with base model generated images.
  # input_dir: "outputs/cogview3_base-512x512" 
  deepspeed_config: { }

model:
  conditioner_config:
  target: sgm.modules.GeneralConditioner
  params:
    emb_models:
      - is_trainable: False
        input_key: txt
        target: sgm.modules.encoders.modules.FrozenT5Embedder
        params:
          model_dir: "google/t5-v1_1-xxl" # Path to T5 safetensors
          max_length: 225 # Maximum prompt length

  first_stage_config:
    target: sgm.models.autoencoder.AutoencodingEngine
    params:
      ckpt_path: "cogview3_base/vae/imagekl_ch16.pt" # Path to VAE PT file
      monitor: val/rec_loss

4. Running the model

Different models require different code for inference. Here are the inference commands for each model:

CogView-3Plus

python sample_dit.py --base configs/cogview3_plus.yaml

CogView-3-Base

  • Original model
python sample_unet.py --base configs/cogview3_base.yaml
  • Distilled model
python sample_unet.py --base configs/cogview3_base_distill_4step.yaml

CogView-3-Relay

  • Original model
python sample_unet.py --base configs/cogview3_relay.yaml
  • Distilled model
python sample_unet.py --base configs/cogview3_relay_distill_1step.yaml 

The output image format will be a folder. The folder name will consist of the sequence number and the first 15 characters of the prompt, containing multiple images. The number of images is based on the batch parameter. The structure should look like this:

.
├── 000000000.png
├── 000000001.png
├── 000000002.png
├── 000000003.png
├── 000000004.png
├── 000000005.png
├── 000000006.png
├── 000000007.png
└── grid.png

1 directory, 9 files

In this example, the batch size is 8, so there are 8 images along with one grid.png.

SAT CogView3 && CogView-3-Plus

本文件夹包含了使用 SAT 权重的推理代码,以及 SAT 权重的微调代码。

该代码是团队训练模型时使用的框架。注释较少,需要认真研究。

手把手带你运行模型

1. 环境安装

确保你已经正确安装本文件夹中的要求的依赖

pip install -r requirements.txt

2. 下载模型权重

以下链接为各个模型权重:

CogView-3-Plus-3B

  • transformer: https://cloud.tsinghua.edu.cn/d/f913eabd3f3b4e28857c
  • vae: https://cloud.tsinghua.edu.cn/d/af4cc066ce8a4cf2ab79

CogView-3-Base-3B

  • transformer:

    • cogview3-base: https://cloud.tsinghua.edu.cn/d/242b66daf4424fa99bf0
    • cogview3-base-distill-4step: https://cloud.tsinghua.edu.cn/d/d10032a94db647f5aa0e
    • cogview3-base-distill-8step: https://cloud.tsinghua.edu.cn/d/1598d4fe4ebf4afcb6ae

    以上三个版本为替换关系,选择适合自己的版本和对应的配置文件进行运行

  • vae: https://cloud.tsinghua.edu.cn/d/c8b9497fc5124d71818a/

CogView-3-Base-3B-Relay

  • transformer:

    • cogview3-relay: https://cloud.tsinghua.edu.cn/d/134951acced949c1a9e1/
    • cogview3-relay-distill-2step: https://cloud.tsinghua.edu.cn/d/6a902976fcb94ac48402
    • cogview3-relay-distill-1step: https://cloud.tsinghua.edu.cn/d/4d50ec092c64418f8418/

    以上三个版本为替换关系,选择适合自己的版本和对应的配置文件进行运行

  • vae: 与 CogView-3-Base-3B 相同

接着,你需要将模型文件排版成如下格式:

.cogview3-plus-3b
├── transformer
│   ├── 1
│   │   └── mp_rank_00_model_states.pt
│   └── latest
└── vae
    └── imagekl_ch16.pt

克隆 T5 模型,该模型不用做训练和微调,但是必须使用。这里,您可以单独下载T5模型,必须是safetensors类型,不能是bin
类型(否则可能出现错误)。

由于我们在CogVideoX中上传过 safetensors 格式的T5模型,一个简单的办法是从CogVideX-2B模型中克隆模型,然后将其移动到对应的文件夹中。

git clone https://huggingface.co/THUDM/CogVideoX-2b.git #从huggingface下载模型
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #从modelscope下载模型
mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl

通过上述方案,你将会得到一个 safetensor 格式的T5文件,确保在 Deepspeed微调过程中读入的时候不会报错。

├── added_tokens.json
├── config.json
├── model-00001-of-00002.safetensors
├── model-00002-of-00002.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── spiece.model
└── tokenizer_config.json

0 directories, 8 files

3. 修改configs中的文件。

这里以CogView3-Base为例,提供部分参数的讲解和介绍:

args:
  mode: inference
  relay_model: False # 当模型类型为 CogView-3-Relay 时,需要将该参数设置为 True
  load: "cogview3_base/transformer" # 这里填写到transformer文件夹
  batch_size: 8 # 每次推理图像数
  grid_num_columns: 2 # 推理结束后,每个提示词文件夹下会有 grid.png 图片,该数字代表列数。
  input_type: txt # 可以选择命令行输入,或者TXT文件输入
  input_file: configs/test.txt # 如果使用命令行,不需要这个参数
  fp16: True # CogView-3-Plus 模型 需要更换为 bf16 推理
  # bf16: True
  sampling_image_size: 512 # 固定大小,支持512 * 512 分辨率图像
  # CogView-3-Plus 模型可以使用以下两个参数。
  # sampling_image_size_x: 1024 宽 
  # sampling_image_size_y: 1024 高

  output_dir: "outputs/cogview3_base-512x512"
  # # 这个部分是给 CogView-3-Relay 模型使用的,需要将该参数设置为推理模型的输入文件夹,提示词建议与 base 模型生成图片时的提示词的一致。
  # input_dir: "outputs/cogview3_base-512x512" 
  deepspeed_config: { }

model:
  conditioner_config:
  target: sgm.modules.GeneralConditioner
  params:
    emb_models:
      - is_trainable: False
        input_key: txt
        target: sgm.modules.encoders.modules.FrozenT5Embedder
        params:
          model_dir: "google/t5-v1_1-xxl" # T5 safetensors的绝对路径
          max_length: 225 # 支持输入的提示词的最大长度

  first_stage_config:
    target: sgm.models.autoencoder.AutoencodingEngine
    params:
      ckpt_path: "cogview3_base/vae/imagekl_ch16.pt" # VAE PT文件绝对路径
      monitor: val/rec_loss

4. 推理模型

由于不同的模型需要使用的代码不一样,在这里,我们列出了不同模型的推理代码:

CogView-3Plus

  python sample_dit.py --base configs/cogview3_plus.yaml

CogView-3-Base

  • 原始模型
python sample_unet.py --base configs/cogview3_base.yaml
  • 蒸馏版本模型
python sample_unet.py --base configs/cogview3_base_distill_4step.yaml

CogView-3-Relay

  • 原始模型
python sample_unet.py --base configs/cogview3_relay.yaml
  • 蒸馏版本模型
python sample_unet.py --base configs/cogview3_relay_distill_1step.yaml 

输出图片格式为文件夹,其中,文件夹的名字为生成的序号加提示词的前15个字母,文件夹中包含多张图片,具体数量以 batch 参数为准。
其结构应该如下:

.
├── 000000000.png
├── 000000001.png
├── 000000002.png
├── 000000003.png
├── 000000004.png
├── 000000005.png
├── 000000006.png
├── 000000007.png
└── grid.png

1 directory, 9 files

上述例子中,batch 为8。因此,有8张图像并带有一张grid.png的图像。

.\cogview3-finetune\sat\sample_dit.py

# 导入操作系统模块,用于与操作系统交互
import os
# 导入数学模块,提供数学函数和常量
import math
# 导入命令行参数解析模块
import argparse
# 导入类型提示模块
from typing import List, Union
# 导入进度条模块,用于显示进度
from tqdm import tqdm
# 导入 OmegaConf 的 ListConfig 类,用于处理配置
from omegaconf import ListConfig
# 导入图像处理库 PIL
from PIL import Image

# 导入 PyTorch 库
import torch
# 导入 NumPy 库
import numpy as np
# 从 einops 导入 rearrange 和 repeat 函数,用于处理张量
from einops import rearrange, repeat
# 从 torchvision 导入 make_grid 函数,用于生成图像网格
from torchvision.utils import make_grid

# 从自定义模型模块导入获取模型的函数
from sat.model.base_model import get_model
# 从自定义训练模块导入加载检查点的函数
from sat.training.model_io import load_checkpoint

# 从 diffusion 模块导入 SATDiffusionEngine 类
from diffusion import SATDiffusionEngine
# 从 arguments 模块导入获取命令行参数的函数
from arguments import get_args


# 定义从命令行读取输入的生成器函数
def read_from_cli():
    cnt = 0  # 初始化计数器
    try:
        # 无限循环,等待用户输入
        while True:
            # 提示用户输入英文文本,按 Ctrl-D 退出
            x = input("Please input English text (Ctrl-D quit): ")
            # 去除输入文本的前后空格并生成一个元组
            yield x.strip(), cnt
            cnt += 1  # 计数器递增
    except EOFError as e:
        pass  # 捕获文件结束错误,结束循环


# 定义从文件中读取输入的生成器函数
def read_from_file(p, rank=0, world_size=1):
    # 以只读模式打开文件
    with open(p, "r") as fin:
        cnt = -1  # 初始化计数器
        # 遍历文件中的每一行
        for l in fin:
            cnt += 1  # 计数器递增
            # 如果当前计数不是该进程的排名,则跳过
            if cnt % world_size != rank:
                continue
            # 去除行首尾空白并生成一个元组
            yield l.strip(), cnt


# 定义从调节器中获取唯一嵌入器键的函数
def get_unique_embedder_keys_from_conditioner(conditioner):
    # 从嵌入器中提取输入键,去重并转换为列表
    return list(set([x.input_key for x in conditioner.embedders]))


# 定义获取批次的函数
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
    batch = {}  # 初始化批次字典
    batch_uc = {}  # 初始化无条件批次字典
    # 遍历给定的键列表
    for key in keys:
        # 如果键是 "txt",处理相关的文本数据
        if key == "txt":
            # 通过重复提示文本构建 batch 中的 "txt" 数据
            batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
            # 通过重复负面提示文本构建 batch_uc 中的 "txt" 数据
            batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
        # 如果键是 "original_size_as_tuple",处理原始图像大小
        elif key == "original_size_as_tuple":
            # 将原始高度和宽度转换为张量并在设备上重复 N 次
            batch["original_size_as_tuple"] = (
                torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
            )
        # 如果键是 "crop_coords_top_left",处理裁剪坐标
        elif key == "crop_coords_top_left":
            # 将裁剪坐标转换为张量并在设备上重复 N 次
            batch["crop_coords_top_left"] = (
                torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
            )
        # 如果键是 "aesthetic_score",处理美学评分
        elif key == "aesthetic_score":
            # 将美学评分转换为张量并在设备上重复 N 次
            batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
            # 将负面美学评分转换为张量并在 batch_uc 中重复 N 次
            batch_uc["aesthetic_score"] = (
                torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
            )

        # 如果键是 "target_size_as_tuple",处理目标大小
        elif key == "target_size_as_tuple":
            # 将目标高度和宽度转换为张量并在设备上重复 N 次
            batch["target_size_as_tuple"] = (
                torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
            )
        # 如果键是 "fps",处理帧率
        elif key == "fps":
            # 将帧率转换为张量并在设备上重复 math.prod(N) 次
            batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
        # 如果键是 "fps_id",处理帧率 ID
        elif key == "fps_id":
            # 将帧率 ID 转换为张量并在设备上重复 math.prod(N) 次
            batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
        # 如果键是 "motion_bucket_id",处理运动桶 ID
        elif key == "motion_bucket_id":
            # 将运动桶 ID 转换为张量并在设备上重复 math.prod(N) 次
            batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
        # 如果键是 "pool_image",处理图像数据
        elif key == "pool_image":
            # 使用 repeat 函数处理图像数据并在设备上转换数据类型为半精度
            batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
        # 如果键是 "cond_aug",处理条件增强
        elif key == "cond_aug":
            # 将条件增强转换为张量并在 CUDA 设备上重复
            batch[key] = repeat(
                torch.tensor([value_dict["cond_aug"]]).to("cuda"),
                "1 -> b",
                b=math.prod(N),
            )
        # 如果键是 "cond_frames",处理条件帧
        elif key == "cond_frames":
            # 使用 repeat 函数处理条件帧数据
            batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
        # 如果键是 "cond_frames_without_noise",处理无噪声条件帧
        elif key == "cond_frames_without_noise":
            # 使用 repeat 函数处理无噪声条件帧数据
            batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
        # 如果键是 "cfg_scale",处理配置缩放
        elif key == "cfg_scale":
            # 将配置缩放值转换为张量并在设备上重复 math.prod(N) 次
            batch[key] = torch.tensor([value_dict["cfg_scale"]]).to(device).repeat(math.prod(N))
        # 处理其他键,将其值直接赋给 batch
        else:
            batch[key] = value_dict[key]

    # 如果 T 不为 None,添加视频帧数量到 batch 中
    if T is not None:
        batch["num_video_frames"] = T

    # 遍历 batch 中的所有键
    for key in batch.keys():
        # 如果键不在 batch_uc 中且对应值是张量,则进行克隆
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    # 返回 batch 和 batch_uc
    return batch, batch_uc
# 定义一个将样本保存到本地的函数
def perform_save_locally(save_path, samples, grid, only_save_grid=False):
    # 创建保存路径,如果已存在则不报错
    os.makedirs(save_path, exist_ok=True)

    # 如果不只保存网格图像
    if not only_save_grid:
        # 遍历样本列表,获取索引和样本
        for i, sample in enumerate(samples):
            # 将样本转换为 RGB 格式并缩放到 255 范围
            sample = 255.0 * rearrange(sample.numpy(), "c h w -> h w c")
            # 将样本保存为 PNG 图像,命名为索引格式
            Image.fromarray(sample.astype(np.uint8)).save(os.path.join(save_path, f"{i:09}.png"))

    # 如果网格不为空
    if grid is not None:
        # 将网格转换为 RGB 格式并缩放到 255 范围
        grid = 255.0 * rearrange(grid.numpy(), "c h w -> h w c")
        # 将网格保存为 PNG 图像
        Image.fromarray(grid.astype(np.uint8)).save(os.path.join(save_path, f"grid.png"))


# 定义一个主函数用于采样
def sampling_main(args, model_cls):
    # 如果模型类是类型,则获取模型实例
    if isinstance(model_cls, type):
        model = get_model(args, model_cls)
    else:
        model = model_cls

    # 加载模型的检查点
    load_checkpoint(model, args)
    # 设置模型为评估模式
    model.eval()

    # 根据输入类型读取数据
    if args.input_type == "cli":
        data_iter = read_from_cli()
    elif args.input_type == "txt":
        # 获取当前进程的排名和总进程数
        rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()
        # 从文件读取数据
        data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
    else:
        # 如果输入类型未实现,则抛出异常
        raise NotImplementedError

    # 获取采样图像的尺寸
    image_size_x = args.sampling_image_size_x
    image_size_y = args.sampling_image_size_y
    # 组合成图像大小元组
    image_size = (image_size_x, image_size_y)
    # 获取潜在维度和采样参数
    latent_dim = args.sampling_latent_dim
    f = args.sampling_f

    # 检查图像尺寸是否在有效范围内
    assert (
        image_size_x >= 512 and image_size_y >= 512 and image_size_x <= 2048 and image_size_y <= 2048
    ), "Image size should be between 512 and 2048"
    # 检查图像尺寸是否为 32 的倍数
    assert image_size_x % 32 == 0 and image_size_y % 32 == 0, "Image size should be divisible by 32"

    # 获取模型的采样函数
    sample_func = model.sample

    # 定义图像的高、宽、通道数和采样参数
    H, W, C, F = image_size_x, image_size_y, latent_dim, f
    # 定义样本数量
    num_samples = [args.batch_size]
    # 定义强制使用的嵌入类型
    force_uc_zero_embeddings = ["txt"]
    # 禁用梯度计算,节省内存和加快计算速度
    with torch.no_grad():
        # 遍历数据迭代器中的文本和计数
        for text, cnt in tqdm(data_iter):
            # 创建一个字典来存储生成图像所需的参数
            value_dict = {
                # 提供的提示文本
                "prompt": text,
                # 负提示文本为空
                "negative_prompt": "",
                # 原始图像尺寸以元组形式存储
                "original_size_as_tuple": image_size,
                # 目标图像尺寸以元组形式存储
                "target_size_as_tuple": image_size,
                # 原始图像高度
                "orig_height": image_size_x,
                # 原始图像宽度
                "orig_width": image_size_y,
                # 目标图像高度
                "target_height": image_size_x,
                # 目标图像宽度
                "target_width": image_size_y,
                # 裁剪区域的上边界
                "crop_coords_top": 0,
                # 裁剪区域的左边界
                "crop_coords_left": 0,
            }

            # 获取批量数据和无条件批量数据
            batch, batch_uc = get_batch(
                # 从条件器中获取唯一的嵌入键
                get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
            )

            # 获取无条件条件的上下文
            c, uc = model.conditioner.get_unconditional_conditioning(
                batch,
                # 传递无条件批量
                batch_uc=batch_uc,
                # 是否强制将无条件嵌入设置为零
                force_uc_zero_embeddings=force_uc_zero_embeddings,
            )

            # 遍历条件上下文
            for k in c:
                # 如果不是交叉注意力
                if not k == "crossattn":
                    # 将每个上下文和无条件上下文映射到 CUDA
                    c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))

            # 生成样本
            samples_z = sample_func(
                c,
                uc=uc,
                # 批量大小
                batch_size=args.batch_size,
                # 目标形状
                shape=(C, H // F, W // F),
                # 目标图像尺寸
                target_size=[image_size],
            )

            # 解码生成的样本
            samples_x = model.decode_first_stage(samples_z).to(torch.float32)
            # 将样本标准化到 [0, 1] 范围内,并移到 CPU
            samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
            # 获取样本的批量大小
            batch_size = samples.shape[0]
            # 断言确保批量大小能被列数整除
            assert (batch_size // args.grid_num_columns) * args.grid_num_columns == batch_size

            # 如果批量大小为 1,则不生成网格
            if args.batch_size == 1:
                grid = None
            else:
                # 生成样本的网格
                grid = make_grid(samples, nrow=args.grid_num_columns)

            # 生成保存路径
            save_path = os.path.join(args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:20])
            # 在本地保存样本和网格
            perform_save_locally(save_path, samples, grid)
# 当脚本作为主程序运行时执行以下代码
if __name__ == "__main__":
    # 创建一个解析命令行参数的解析器,且不自动添加帮助信息
    py_parser = argparse.ArgumentParser(add_help=False)
    # 解析已知参数和剩余参数
    known, args_list = py_parser.parse_known_args()

    # 调用 get_args 函数处理剩余参数,返回结果
    args = get_args(args_list)
    # 将已知参数和处理后的参数合并为一个命名空间对象
    args = argparse.Namespace(**vars(args), **vars(known))

    # 调用 sampling_main 函数,传入参数和模型类
    sampling_main(args, model_cls=SATDiffusionEngine)

.\cogview3-finetune\sat\sample_unet.py

# 导入操作系统模块
import os
# 导入数学模块
import math
# 导入命令行参数解析模块
import argparse
# 导入进度条模块
from tqdm import tqdm
# 导入列表和联合类型的类型注解
from typing import List, Union
# 从 OmegaConf 导入列表配置
from omegaconf import ListConfig
# 导入图像处理库
from PIL import Image

# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的功能模块
import torch.nn.functional as functional
# 导入 NumPy 库
import numpy as np
# 从 einops 导入重新排列和重复函数
from einops import rearrange, repeat
# 从 torchvision 导入生成网格的工具
from torchvision.utils import make_grid
# 导入 torchvision 的变换模块
import torchvision.transforms as TT

# 从自定义模型模块导入获取模型的函数
from sat.model.base_model import get_model
# 从自定义训练模块导入加载检查点的函数
from sat.training.model_io import load_checkpoint

# 导入扩散模型引擎
from diffusion import SATDiffusionEngine
# 导入命令行参数获取函数
from arguments import get_args


# 定义从命令行读取输入的生成器函数
def read_from_cli():
    # 初始化计数器
    cnt = 0
    # 尝试读取输入
    try:
        while True:
            # 提示用户输入英文文本,直到 Ctrl-D 结束
            x = input("Please input English text (Ctrl-D quit): ")
            # 去掉输入字符串的前后空白,并生成 (输入字符串, 计数) 元组
            yield x.strip(), cnt
            # 计数器加一
            cnt += 1
    # 捕获 EOFError 异常,表示输入结束
    except EOFError as e:
        pass


# 定义从文件读取输入的生成器函数
def read_from_file(p, rank=0, world_size=1):
    # 打开指定路径的文件,读取模式
    with open(p, "r") as fin:
        # 初始化计数器
        cnt = -1
        # 遍历文件中的每一行
        for l in fin:
            # 计数器加一
            cnt += 1
            # 如果当前计数不符合当前进程的 rank,则跳过该行
            if cnt % world_size != rank:
                continue
            # 去掉行末空白并生成 (行内容, 计数) 元组
            yield l.strip(), cnt


# 定义从条件器中获取唯一嵌入键的函数
def get_unique_embedder_keys_from_conditioner(conditioner):
    # 从条件器的嵌入器中提取输入键,去重并转换为列表
    return list(set([x.input_key for x in conditioner.embedders]))


# 定义获取批次数据的函数
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
    # 初始化批次字典
    batch = {}
    # 初始化无条件批次字典
    batch_uc = {}
    # 遍历指定的键
        for key in keys:
            # 如果键是 "txt",则处理相关数据
            if key == "txt":
                # 重复 prompt 值,生成指定大小的数组,并转换为列表
                batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
                # 重复 negative_prompt 值,生成指定大小的数组,并转换为列表
                batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
            # 如果键是 "original_size_as_tuple",则处理原始尺寸
            elif key == "original_size_as_tuple":
                # 创建一个张量,包含原始高度和宽度,并在设备上重复
                batch["original_size_as_tuple"] = (
                    torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
                )
            # 如果键是 "crop_coords_top_left",则处理裁剪坐标
            elif key == "crop_coords_top_left":
                # 创建一个张量,包含裁剪的顶部和左侧坐标,并在设备上重复
                batch["crop_coords_top_left"] = (
                    torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
                )
            # 如果键是 "aesthetic_score",则处理美学评分
            elif key == "aesthetic_score":
                # 创建一个张量,包含美学评分,并在设备上重复
                batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
                # 创建一个张量,包含负面美学评分,并在设备上重复
                batch_uc["aesthetic_score"] = (
                    torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
                )
            # 如果键是 "target_size_as_tuple",则处理目标尺寸
            elif key == "target_size_as_tuple":
                # 创建一个张量,包含目标高度和宽度,并在设备上重复
                batch["target_size_as_tuple"] = (
                    torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
                )
            # 如果键是 "fps",则处理帧率
            elif key == "fps":
                # 创建一个张量,包含帧率值,并在设备上重复
                batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
            # 如果键是 "fps_id",则处理帧率ID
            elif key == "fps_id":
                # 创建一个张量,包含帧率ID值,并在设备上重复
                batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
            # 如果键是 "motion_bucket_id",则处理运动桶ID
            elif key == "motion_bucket_id":
                # 创建一个张量,包含运动桶ID值,并在设备上重复
                batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
            # 如果键是 "pool_image",则处理池图像
            elif key == "pool_image":
                # 重复池图像值,转换维度,并在设备上设置数据类型
                batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
            # 如果键是 "cond_aug",则处理条件增强
            elif key == "cond_aug":
                # 创建一个张量,包含条件增强值,并在设备上重复
                batch[key] = repeat(
                    torch.tensor([value_dict["cond_aug"]]).to("cuda"),
                    "1 -> b",
                    b=math.prod(N),
                )
            # 如果键是 "cond_frames",则处理条件帧
            elif key == "cond_frames":
                # 重复条件帧值,转换维度
                batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
            # 如果键是 "cond_frames_without_noise",则处理无噪声条件帧
            elif key == "cond_frames_without_noise":
                # 重复无噪声条件帧值,转换维度
                batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
            # 如果键是 "cfg_scale",则处理配置缩放
            elif key == "cfg_scale":
                # 创建一个张量,包含配置缩放值,并在设备上重复
                batch[key] = torch.tensor([value_dict["cfg_scale"]]).to(device).repeat(math.prod(N))
            # 如果键不在上述条件中,则直接从 value_dict 复制值
            else:
                batch[key] = value_dict[key]
    
        # 如果 T 不是 None,设置视频帧数量
        if T is not None:
            batch["num_video_frames"] = T
    
        # 遍历 batch 中的键
        for key in batch.keys():
            # 如果键不在 batch_uc 中,并且对应的值是张量
            if key not in batch_uc and isinstance(batch[key], torch.Tensor):
                # 克隆张量并存储到 batch_uc 中
                batch_uc[key] = torch.clone(batch[key])
        # 返回处理后的 batch 和 batch_uc
        return batch, batch_uc
# 定义一个将样本保存到本地的函数,参数包括保存路径、样本、网格和是否只保存网格的标志
def perform_save_locally(save_path, samples, grid, only_save_grid=False):
    # 创建保存路径的目录,如果已存在则不报错
    os.makedirs(save_path, exist_ok=True)

    # 如果不只保存网格
    if not only_save_grid:
        # 遍历样本及其索引
        for i, sample in enumerate(samples):
            # 将样本从张量格式转为图片格式,并进行归一化处理
            sample = 255.0 * rearrange(sample.numpy(), "c h w -> h w c")
            # 将处理后的样本保存为 PNG 格式,文件名以索引命名,前面填充零
            Image.fromarray(sample.astype(np.uint8)).save(os.path.join(save_path, f"{i:09}.png"))

    # 如果网格不为 None
    if grid is not None:
        # 将网格从张量格式转为图片格式,并进行归一化处理
        grid = 255.0 * rearrange(grid.numpy(), "c h w -> h w c")
        # 将处理后的网格保存为 PNG 格式,文件名为 grid.png
        Image.fromarray(grid.astype(np.uint8)).save(os.path.join(save_path, f"grid.png"))


# 定义一个主采样函数,参数包括输入参数和模型类
def sampling_main(args, model_cls):
    # 判断 model_cls 是否为类类型
    if isinstance(model_cls, type):
        # 获取模型实例
        model = get_model(args, model_cls)
    else:
        # 如果不是类,直接赋值为模型
        model = model_cls

    # 加载模型的检查点
    load_checkpoint(model, args)
    # 将模型设置为评估模式
    model.eval()

    # 根据输入类型读取数据
    if args.input_type == "cli":
        # 从命令行读取数据
        data_iter = read_from_cli()
    elif args.input_type == "txt":
        # 获取当前进程的排名和总进程数
        rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()
        # 从文件中读取数据
        data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
    else:
        # 如果输入类型不在支持的范围,抛出未实现错误
        raise NotImplementedError

    # 获取采样图像的尺寸
    image_size = args.sampling_image_size
    input_sample_dirs = None
    # 如果启用中继模型
    if args.relay_model is True:
        # 使用中继采样函数
        sample_func = model.sample_relay
        # 设置图像的高度、宽度和通道数
        H, W, C, F = image_size, image_size, 4, 8
        # 确保输入目录不为 None
        assert args.input_dir is not None
        # 列出输入样本目录
        input_sample_dirs = os.listdir(args.input_dir)
        # 排序目录名称,并提取排名和名称
        input_sample_dirs_and_rank = sorted([(int(name.split("_")[0]), name) for name in input_sample_dirs])
        # 重新构建完整的输入样本目录路径
        input_sample_dirs = [os.path.join(args.input_dir, name) for _, name in input_sample_dirs_and_rank]
    else:
        # 使用常规采样函数
        sample_func = model.sample
        # 获取潜在维度和采样频率
        latent_dim = args.sampling_latent_dim
        f = args.sampling_f
        # 设置图像的高度、宽度、通道数和帧数
        H, W, C, F = image_size, image_size, latent_dim, f
    # 设置样本数量为批量大小
    num_samples = [args.batch_size]
    # 强制将特定嵌入维度设为零
    force_uc_zero_embeddings = ["txt"]
    # 禁用梯度计算,以节省内存和提高计算速度
        with torch.no_grad():
            # 遍历数据迭代器中的文本和计数
            for text, cnt in tqdm(data_iter):
                # 创建一个字典,存储与当前文本相关的参数
                value_dict = {
                    "prompt": text,  # 当前的提示文本
                    "negative_prompt": "",  # 负提示文本,初始为空
                    "original_size_as_tuple": (image_size, image_size),  # 原始图像尺寸
                    "target_size_as_tuple": (image_size, image_size),  # 目标图像尺寸
                    "orig_height": image_size,  # 原始高度
                    "orig_width": image_size,  # 原始宽度
                    "target_height": image_size,  # 目标高度
                    "target_width": image_size,  # 目标宽度
                    "crop_coords_top": 0,  # 裁剪坐标顶部
                    "crop_coords_left": 0,  # 裁剪坐标左侧
                }
    
                # 获取当前批次和无条件的嵌入
                batch, batch_uc = get_batch(
                    get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
                )
                # 获取无条件的条件
                c, uc = model.conditioner.get_unconditional_conditioning(
                    batch,
                    batch_uc=batch_uc,
                    force_uc_zero_embeddings=force_uc_zero_embeddings,
                )
    
                # 遍历条件字典,处理每个条件
                for k in c:
                    if not k == "crossattn":  # 如果键不是 "crossattn"
                        # 将条件和无条件条件的相应部分移动到 CUDA
                        c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
                # 如果参数 relay_model 为真
                if args.relay_model is True:
                    # 获取输入样本目录
                    input_sample_dir = input_sample_dirs[cnt]
                    images = []  # 存储图像的列表
                    # 遍历批次大小
                    for i in range(args.batch_size):
                        # 构建图像文件路径
                        filepath = os.path.join(input_sample_dir, f"{i:09}.png")
                        # 打开图像并转换为 RGB 格式
                        image = Image.open(filepath).convert("RGB")
                        # 将图像转换为张量并标准化
                        image = TT.ToTensor()(image) * 2 - 1
                        # 将处理后的图像添加到列表
                        images.append(image[None, ...])
                    # 将图像列表合并为一个张量
                    images = torch.cat(images, dim=0)
                    # 将图像上采样
                    images = functional.interpolate(images, scale_factor=2, mode="bilinear", align_corners=False)
                    # 转换图像为半精度并移动到 CUDA
                    images = images.to(torch.float16).cuda()
                    # 编码第一阶段的图像
                    images = model.encode_first_stage(images)
                    # 进行采样
                    samples_z = sample_func(images, c, uc=uc, batch_size=args.batch_size, shape=(C, H // F, W // F))
                else:
                    # 直接进行采样
                    samples_z = sample_func(c, uc=uc, batch_size=args.batch_size, shape=(C, H // F, W // F))
                # 解码第一阶段的样本
                samples_x = model.decode_first_stage(samples_z).to(torch.float32)
                # 将样本归一化并转移到 CPU
                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
                # 获取批次大小
                batch_size = samples.shape[0]
                # 确保批次大小能够被列数整除
                assert (batch_size // args.grid_num_columns) * args.grid_num_columns == batch_size
    
                # 如果批次大小为 1,网格设为 None
                if args.batch_size == 1:
                    grid = None
                else:
                    # 创建网格,将样本放置在网格中
                    grid = make_grid(samples, nrow=args.grid_num_columns)
    
                # 构建保存路径
                save_path = os.path.join(args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:20])
                # 执行本地保存样本和网格
                perform_save_locally(save_path, samples, grid)
# 当脚本作为主程序运行时执行以下代码
if __name__ == "__main__":
    # 创建一个命令行参数解析器,不自动添加帮助信息
    py_parser = argparse.ArgumentParser(add_help=False)
    # 解析已知参数和位置参数,返回已知参数和剩余参数列表
    known, args_list = py_parser.parse_known_args()

    # 从剩余参数列表中获取自定义参数
    args = get_args(args_list)
    # 将已知参数和自定义参数合并为一个命名空间对象
    args = argparse.Namespace(**vars(args), **vars(known))
    # 调用主函数,传入参数和模型类
    sampling_main(args, model_cls=SATDiffusionEngine)

.\cogview3-finetune\sat\sgm\models\autoencoder.py

# 导入标准库和第三方库
import logging  # 用于记录日志信息
import math  # 提供数学函数
import re  # 提供正则表达式支持
from abc import abstractmethod  # 用于定义抽象方法
from contextlib import contextmanager  # 提供上下文管理器功能
from typing import Any, Dict, Tuple, Union  # 提供类型提示支持

import pytorch_lightning as pl  # 引入 PyTorch Lightning 框架
import torch  # 引入 PyTorch 库
from omegaconf import ListConfig  # 用于处理配置文件
from packaging import version  # 用于版本比较
from safetensors.torch import load_file as load_safetensors  # 用于加载 safetensors 格式文件

from ..modules.diffusionmodules.model import Decoder, Encoder  # 导入解码器和编码器
from ..modules.distributions.distributions import DiagonalGaussianDistribution  # 导入对角高斯分布
from ..modules.ema import LitEma  # 导入指数移动平均类
from ..util import default, get_obj_from_str, instantiate_from_config  # 导入实用工具函数


class AbstractAutoencoder(pl.LightningModule):
    """
    这是所有自编码器的基类,包括图像自编码器、带判别器的图像自编码器、
    unCLIP 模型等。因此,它是相当通用的,特定功能
    (例如判别器训练、编码、解码)必须在子类中实现。
    """

    def __init__(
        self,
        ema_decay: Union[None, float] = None,  # 指定 EMA 衰减值
        monitor: Union[None, str] = None,  # 用于监控的指标名称
        input_key: str = "jpg",  # 输入数据的键名,默认为 "jpg"
        ckpt_path: Union[None, str] = None,  # 检查点文件路径
        ignore_keys: Union[Tuple, list, ListConfig] = (),  # 需要忽略的键
    ):
        super().__init__()  # 调用父类构造函数
        self.input_key = input_key  # 保存输入数据的键名
        self.use_ema = ema_decay is not None  # 判断是否使用 EMA
        if monitor is not None:  # 如果提供监控指标
            self.monitor = monitor  # 保存监控指标

        if self.use_ema:  # 如果使用 EMA
            self.model_ema = LitEma(self, decay=ema_decay)  # 初始化 EMA 对象
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")  # 打印 EMA 缓冲区的数量

        if ckpt_path is not None:  # 如果提供检查点路径
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)  # 从检查点初始化模型

        if version.parse(torch.__version__) >= version.parse("2.0.0"):  # 如果 PyTorch 版本 >= 2.0.0
            self.automatic_optimization = False  # 禁用自动优化

    def init_from_ckpt(
        self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()  # 初始化检查点
    ) -> None:
        if path.endswith("ckpt"):  # 如果路径以 "ckpt" 结尾
            sd = torch.load(path, map_location="cpu")["state_dict"]  # 加载检查点的状态字典
        elif path.endswith("safetensors"):  # 如果路径以 "safetensors" 结尾
            sd = load_safetensors(path)  # 加载 safetensors 文件
        else:  # 如果路径不符合以上两种格式
            raise NotImplementedError  # 抛出未实现异常

        keys = list(sd.keys())  # 获取状态字典的所有键
        for k in keys:  # 遍历每个键
            for ik in ignore_keys:  # 遍历忽略的键
                if re.match(ik, k):  # 如果键匹配忽略模式
                    print("Deleting key {} from state_dict.".format(k))  # 打印被删除的键
                    del sd[k]  # 从状态字典中删除该键
        missing, unexpected = self.load_state_dict(sd, strict=False)  # 加载状态字典,允许非严格匹配
        # print(
        #     f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
        # )  # 打印恢复信息
        # if len(missing) > 0:  # 如果有缺失的键
        #     print(f"Missing Keys: {missing}")  # 打印缺失的键
        # if len(unexpected) > 0:  # 如果有意外的键
        #     print(f"Unexpected Keys: {unexpected}")  # 打印意外的键
    # 应用检查点,参数可以是 None、路径字符串或字典
    def apply_ckpt(self, ckpt: Union[None, str, dict]):
        # 如果检查点为 None,直接返回
        if ckpt is None:
            return
        # 如果检查点是字符串,将其转换为字典格式
        if isinstance(ckpt, str):
            ckpt = {
                # 指定检查点引擎的目标
                "target": "sgm.modules.checkpoint.CheckpointEngine",
                # 指定检查点路径
                "params": {"ckpt_path": ckpt},
            }
        # 根据配置实例化检查点引擎
        engine = instantiate_from_config(ckpt)
        # 调用引擎并传入当前对象
        engine(self)
    
    # 抽象方法,获取输入数据,参数为 batch,返回类型为 Any
    @abstractmethod
    def get_input(self, batch) -> Any:
        # 抛出未实现错误
        raise NotImplementedError()
    
    # 训练批次结束时的回调函数
    def on_train_batch_end(self, *args, **kwargs):
        # 用于 EMA(Exponential Moving Average)计算
        if self.use_ema:
            # 调用 EMA 模型
            self.model_ema(self)
    
    # 上下文管理器,处理 EMA 权重的切换
    @contextmanager
    def ema_scope(self, context=None):
        # 如果使用 EMA
        if self.use_ema:
            # 存储当前参数
            self.model_ema.store(self.parameters())
            # 将 EMA 权重复制到当前模型
            self.model_ema.copy_to(self)
            # 如果提供上下文,打印切换消息
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            # 进入上下文
            yield None
        finally:
            # 离开上下文时恢复参数
            if self.use_ema:
                self.model_ema.restore(self.parameters())
                # 如果提供上下文,打印恢复消息
                if context is not None:
                    print(f"{context}: Restored training weights")
    
    # 抽象方法,进行编码,返回类型为 torch.Tensor
    @abstractmethod
    def encode(self, *args, **kwargs) -> torch.Tensor:
        # 抛出未实现错误
        raise NotImplementedError("encode()-method of abstract base class called")
    
    # 抽象方法,进行解码,返回类型为 torch.Tensor
    @abstractmethod
    def decode(self, *args, **kwargs) -> torch.Tensor:
        # 抛出未实现错误
        raise NotImplementedError("decode()-method of abstract base class called")
    
    # 从配置实例化优化器
    def instantiate_optimizer_from_config(self, params, lr, cfg):
        # 打印正在加载的优化器信息
        print(f"loading >>> {cfg['target']} <<< optimizer from config")
        # 根据配置获取优化器对象,并初始化
        return get_obj_from_str(cfg["target"])(
            params, lr=lr, **cfg.get("params", dict())
        )
    
    # 配置优化器,返回类型为 Any
    def configure_optimizers(self) -> Any:
        # 抛出未实现错误
        raise NotImplementedError()
# 自动编码器引擎的基类,供所有图像自动编码器使用,如 VQGAN 或 AutoencoderKL
class AutoencodingEngine(AbstractAutoencoder):
    """
    Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
    (we also restore them explicitly as special cases for legacy reasons).
    Regularizations such as KL or VQ are moved to the regularizer class.
    """

    # 初始化方法,配置编码器、解码器、损失函数等
    def __init__(
        self,
        *args,
        encoder_config: Dict,  # 编码器配置字典
        decoder_config: Dict,  # 解码器配置字典
        loss_config: Dict,  # 损失函数配置字典
        regularizer_config: Dict,  # 正则化器配置字典
        optimizer_config: Union[Dict, None] = None,  # 优化器配置字典(可选)
        lr_g_factor: float = 1.0,  # 学习率缩放因子
        ckpt_path=None,  # 检查点路径(可选)
        ignore_keys=[],  # 忽略的键列表
        **kwargs,  # 额外的关键字参数
    ):
        super().__init__(*args, **kwargs)  # 调用父类初始化方法
        # todo: add options to freeze encoder/decoder
        # 实例化编码器
        self.encoder = instantiate_from_config(encoder_config)
        # 实例化解码器
        self.decoder = instantiate_from_config(decoder_config)
        # 实例化损失函数
        self.loss = instantiate_from_config(loss_config)
        # 实例化正则化器
        self.regularization = instantiate_from_config(regularizer_config)
        # 设置优化器配置,默认为 Adam
        self.optimizer_config = default(
            optimizer_config, {"target": "torch.optim.Adam"}
        )
        # 设置学习率缩放因子
        self.lr_g_factor = lr_g_factor
        # 如果检查点路径不为空,初始化从检查点
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

    # 从检查点加载模型状态
    def init_from_ckpt(self, path, ignore_keys=list()):
        # 根据文件扩展名加载状态字典
        if path.endswith("ckpt") or path.endswith("pt"):
            sd = torch.load(path, map_location="cpu")['state_dict']  # 加载 PyTorch 检查点
        elif path.endswith("safetensors"):
            sd = load_safetensors(path)  # 加载 safetensors 格式
        else:
            raise NotImplementedError  # 未实现的文件格式处理
        keys = list(sd.keys())  # 获取状态字典中的所有键
        for k in keys:
            for ik in ignore_keys:
                # 如果键以忽略键开头,删除该键
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        # 加载状态字典,返回缺失和意外的键
        missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
        print("Missing keys: ", missing_keys)  # 打印缺失的键
        print("Unexpected keys: ", unexpected_keys)  # 打印意外的键
        print(f"Restored from {path}")  # 打印恢复信息

    # 获取输入数据
    def get_input(self, batch: Dict) -> torch.Tensor:
        # 假设统一数据格式,数据加载器返回一个字典。
        # 图像张量应缩放至 -1 ... 1 并采用通道优先格式(如 bchw 而非 bhwc)
        return batch[self.input_key]  # 返回指定的输入键对应的张量

    # 获取自动编码器的所有参数
    def get_autoencoder_params(self) -> list:
        # 收集编码器、解码器、正则化器和损失函数的可训练参数
        params = (
            list(self.encoder.parameters())
            + list(self.decoder.parameters())
            + list(self.regularization.get_trainable_parameters())
            + list(self.loss.get_trainable_autoencoder_parameters())
        )
        return params  # 返回所有参数的列表

    # 获取鉴别器的参数
    def get_discriminator_params(self) -> list:
        # 获取损失函数中的可训练参数,例如鉴别器
        params = list(self.loss.get_trainable_parameters())  
        return params  # 返回鉴别器参数的列表

    # 获取解码器的最后一层
    def get_last_layer(self):
        return self.decoder.get_last_layer()  # 返回解码器的最后一层

    # 编码输入张量
    def encode(
        self,
        x: torch.Tensor,  # 输入张量
        return_reg_log: bool = False,  # 是否返回正则化日志
        unregularized: bool = False,  # 是否使用未正则化的编码
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
        # 使用编码器对输入 x 进行编码,生成潜在表示 z
        z = self.encoder(x)
        # 如果 unregularized 为真,返回 z 和空字典
        if unregularized:
            return z, dict()
        # 对 z 进行正则化处理,得到正则化后的 z 和正则化日志 reg_log
        z, reg_log = self.regularization(z)
        # 如果需要返回正则化日志,返回 z 和 reg_log
        if return_reg_log:
            return z, reg_log
        # 否则只返回 z
        return z

    def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
        # 使用解码器对潜在表示 z 进行解码,生成输出 x
        x = self.decoder(z, **kwargs)
        # 返回解码后的输出
        return x

    def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # 执行编码过程,获取潜在表示 z 和正则化日志 reg_log
        z, reg_log = self.encode(x, return_reg_log=True)
        # 对潜在表示 z 进行解码,获取重建的输入 dec
        dec = self.decode(z)
        # 返回潜在表示 z,重建输入 dec 和正则化日志 reg_log
        return z, dec, reg_log

    def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
        # 获取当前 batch 的输入数据 x
        x = self.get_input(batch)
        # 执行前向传播,获取潜在表示 z,重建输入 xrec 和正则化日志
        z, xrec, regularization_log = self(x)

        # 判断优化器的索引,进行不同的训练步骤
        if optimizer_idx == 0:
            # 处理自编码器的损失
            aeloss, log_dict_ae = self.loss(
                # 计算自编码器损失
                regularization_log,
                x,
                xrec,
                optimizer_idx,
                self.global_step,
                last_layer=self.get_last_layer(),
                split="train",
            )

            # 记录自编码器的日志
            self.log_dict(
                log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
            )
            # 返回自编码器的损失
            return aeloss

        if optimizer_idx == 1:
            # 处理判别器的损失
            discloss, log_dict_disc = self.loss(
                # 计算判别器损失
                regularization_log,
                x,
                xrec,
                optimizer_idx,
                self.global_step,
                last_layer=self.get_last_layer(),
                split="train",
            )
            # 记录判别器的日志
            self.log_dict(
                log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
            )
            # 返回判别器的损失
            return discloss

    def validation_step(self, batch, batch_idx) -> Dict:
        # 执行验证步骤,获取日志字典
        log_dict = self._validation_step(batch, batch_idx)
        # 在 EMA 范围内执行验证步骤,获取 EMA 日志
        with self.ema_scope():
            log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
            # 更新日志字典,合并 EMA 日志
            log_dict.update(log_dict_ema)
        # 返回更新后的日志字典
        return log_dict

    def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
        # 获取当前 batch 的输入数据 x
        x = self.get_input(batch)

        # 执行前向传播,获取潜在表示 z,重建输入 xrec 和正则化日志
        z, xrec, regularization_log = self(x)
        # 计算自编码器的损失
        aeloss, log_dict_ae = self.loss(
            regularization_log,
            x,
            xrec,
            0,
            self.global_step,
            last_layer=self.get_last_layer(),
            split="val" + postfix,
        )

        # 计算判别器的损失
        discloss, log_dict_disc = self.loss(
            regularization_log,
            x,
            xrec,
            1,
            self.global_step,
            last_layer=self.get_last_layer(),
            split="val" + postfix,
        )
        # 记录重建损失
        self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
        # 更新自编码器日志字典,合并判别器日志
        log_dict_ae.update(log_dict_disc)
        # 记录合并后的日志字典
        self.log_dict(log_dict_ae)
        # 返回自编码器日志字典
        return log_dict_ae
    # 配置优化器,返回优化器和其他信息
    def configure_optimizers(self) -> Any:
        # 获取自动编码器的参数
        ae_params = self.get_autoencoder_params()
        # 获取鉴别器的参数
        disc_params = self.get_discriminator_params()

        # 从配置中实例化自动编码器的优化器,使用学习率的默认值
        opt_ae = self.instantiate_optimizer_from_config(
            ae_params,
            default(self.lr_g_factor, 1.0) * self.learning_rate,
            self.optimizer_config,
        )
        # 从配置中实例化鉴别器的优化器
        opt_disc = self.instantiate_optimizer_from_config(
            disc_params, self.learning_rate, self.optimizer_config
        )

        # 返回自动编码器和鉴别器的优化器,以及一个空列表
        return [opt_ae, opt_disc], []

    # 禁止梯度计算,避免在日志记录时影响计算图
    @torch.no_grad()
    def log_images(self, batch: Dict, **kwargs) -> Dict:
        # 初始化一个字典用于存储日志
        log = dict()
        # 从批次中获取输入数据
        x = self.get_input(batch)
        # 通过模型进行前向传播,获取重构结果
        _, xrec, _ = self(x)
        # 将输入数据和重构结果存入日志字典
        log["inputs"] = x
        log["reconstructions"] = xrec
        # 在 EMA(指数移动平均)作用域内进行操作
        with self.ema_scope():
            # 获取 EMA 重构结果
            _, xrec_ema, _ = self(x)
            # 将 EMA 重构结果存入日志字典
            log["reconstructions_ema"] = xrec_ema
        # 返回日志字典
        return log
# 定义 AutoencodingEngineLegacy 类,继承自 AutoencodingEngine
class AutoencodingEngineLegacy(AutoencodingEngine):
    # 初始化方法,接受嵌入维度及其他可选参数
    def __init__(self, embed_dim: int, **kwargs):
        # 从 kwargs 中提取最大批处理大小,如果没有则为 None
        self.max_batch_size = kwargs.pop("max_batch_size", None)
        # 从 kwargs 中提取 ddconfig 配置
        ddconfig = kwargs.pop("ddconfig")
        # 从 kwargs 中提取检查点路径,如果没有则为 None
        ckpt_path = kwargs.pop("ckpt_path", None)
        # 从 kwargs 中提取检查点引擎,如果没有则为 None
        ckpt_engine = kwargs.pop("ckpt_engine", None)
        # 调用父类的初始化方法,配置编码器和解码器
        super().__init__(
            encoder_config={
                "target": "sgm.modules.diffusionmodules.model.Encoder",  # 编码器目标
                "params": ddconfig,  # 编码器参数
            },
            decoder_config={
                "target": "sgm.modules.diffusionmodules.model.Decoder",  # 解码器目标
                "params": ddconfig,  # 解码器参数
            },
            **kwargs,
        )
        # 定义量化卷积层,输入通道和输出通道根据配置计算
        self.quant_conv = torch.nn.Conv2d(
            (1 + ddconfig["double_z"]) * ddconfig["z_channels"],  # 输入通道数
            (1 + ddconfig["double_z"]) * embed_dim,  # 输出通道数
            1,  # 卷积核大小
        )
        # 定义后量化卷积层,将嵌入维度映射回 z_channels
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)  
        # 保存嵌入维度
        self.embed_dim = embed_dim

        # 应用检查点设置,使用默认方法获取路径和引擎
        self.apply_ckpt(default(ckpt_path, ckpt_engine))

    # 获取自动编码器的参数
    def get_autoencoder_params(self) -> list:
        # 调用父类方法获取参数
        params = super().get_autoencoder_params()
        return params

    # 编码输入张量 x,返回量化后的表示 z
    def encode(
        self, x: torch.Tensor, return_reg_log: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
        # 如果没有最大批处理大小
        if self.max_batch_size is None:
            # 直接编码并量化
            z = self.encoder(x)
            z = self.quant_conv(z)
        else:
            # 获取输入样本数
            N = x.shape[0]
            # 获取批处理大小
            bs = self.max_batch_size
            # 计算总批次数
            n_batches = int(math.ceil(N / bs))
            z = list()  # 初始化存储编码结果的列表
            # 遍历每个批次
            for i_batch in range(n_batches):
                # 对当前批次进行编码
                z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
                z_batch = self.quant_conv(z_batch)  # 量化编码结果
                z.append(z_batch)  # 添加到结果列表中
            # 将所有批次的结果连接成一个张量
            z = torch.cat(z, 0)

        # 应用正则化方法
        z, reg_log = self.regularization(z)
        # 如果需要返回正则化日志
        if return_reg_log:
            return z, reg_log  # 返回量化结果和日志
        return z  # 返回量化结果

    # 解码输入张量 z,返回重构结果
    def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
        # 如果没有最大批处理大小
        if self.max_batch_size is None:
            dec = self.post_quant_conv(z)  # 先经过后量化卷积
            dec = self.decoder(dec, **decoder_kwargs)  # 再经过解码器
        else:
            # 获取输入样本数
            N = z.shape[0]
            # 获取批处理大小
            bs = self.max_batch_size
            # 计算总批次数
            n_batches = int(math.ceil(N / bs))
            dec = list()  # 初始化存储解码结果的列表
            # 遍历每个批次
            for i_batch in range(n_batches):
                # 对当前批次进行后量化处理
                dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
                dec_batch = self.decoder(dec_batch, **decoder_kwargs)  # 解码
                dec.append(dec_batch)  # 添加到结果列表中
            # 将所有批次的结果连接成一个张量
            dec = torch.cat(dec, 0)

        return dec  # 返回解码结果
    

# 定义 AutoencoderKL 类,继承自 AutoencodingEngine
class AutoencoderKL(AutoencodingEngine):
    # 初始化方法,接受嵌入维度和其他可选参数
        def __init__(self, embed_dim: int, **kwargs):
            # 从 kwargs 中提取 ddconfig 配置
            ddconfig = kwargs.pop("ddconfig")
            # 从 kwargs 中提取检查点路径,如果不存在则为 None
            ckpt_path = kwargs.pop("ckpt_path", None)
            # 从 kwargs 中提取忽略的键,默认为空元组
            ignore_keys = kwargs.pop("ignore_keys", ())
            # 调用父类初始化,配置编码器、解码器和正则化器
            super().__init__(
                encoder_config={"target": "torch.nn.Identity"},
                decoder_config={"target": "torch.nn.Identity"},
                regularizer_config={"target": "torch.nn.Identity"},
                loss_config=kwargs.pop("lossconfig"),
                **kwargs,
            )
            # 确保 ddconfig 中的 double_z 为真
            assert ddconfig["double_z"]
            # 初始化编码器,传入 ddconfig 参数
            self.encoder = Encoder(**ddconfig)
            # 初始化解码器,传入 ddconfig 参数
            self.decoder = Decoder(**ddconfig)
            # 创建一个卷积层,将输入的通道数映射到嵌入维度的两倍
            self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
            # 创建一个卷积层,将嵌入维度映射回原始通道数
            self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
            # 存储嵌入维度
            self.embed_dim = embed_dim
    
            # 如果检查点路径不为空,则从该路径初始化模型
            if ckpt_path is not None:
                self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
    
        # 编码方法,将输入数据编码为后验分布
        def encode(self, x):
            # 确保当前模式为推理模式
            assert (
                not self.training
            ), f"{self.__class__.__name__} only supports inference currently"
            # 通过编码器处理输入数据
            h = self.encoder(x)
            # 通过量化卷积层处理编码结果
            moments = self.quant_conv(h)
            # 创建一个对角高斯分布,基于量化结果
            posterior = DiagonalGaussianDistribution(moments)
            # 返回后验分布
            return posterior
    
        # 解码方法,将后验分布样本解码为输出
        def decode(self, z, **decoder_kwargs):
            # 通过反量化卷积层处理输入样本
            z = self.post_quant_conv(z)
            # 通过解码器生成最终输出
            dec = self.decoder(z, **decoder_kwargs)
            # 返回解码结果
            return dec
# 定义一个名为 AutoencoderKLInferenceWrapper 的类,继承自 AutoencoderKL
class AutoencoderKLInferenceWrapper(AutoencoderKL):
    # 定义 encode 方法,接受参数 x
    def encode(self, x):
        # 调用父类的 encode 方法并返回其结果的样本
        return super().encode(x).sample()

# 定义一个名为 IdentityFirstStage 的类,继承自 AbstractAutoencoder
class IdentityFirstStage(AbstractAutoencoder):
    # 定义初始化方法,接受可变数量的参数
    def __init__(self, *args, **kwargs):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

    # 定义 get_input 方法,接受参数 x,返回输入
    def get_input(self, x: Any) -> Any:
        return x

    # 定义 encode 方法,接受参数 x 和其他可变参数
    def encode(self, x: Any, *args, **kwargs) -> Any:
        # 返回输入 x
        return x

    # 定义 decode 方法,接受参数 x 和其他可变参数
    def decode(self, x: Any, *args, **kwargs) -> Any:
        # 返回输入 x
        return x


# 定义一个名为 AutoencoderKLModeOnly 的类,继承自 AutoencodingEngineLegacy
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
    # 定义初始化方法,接受关键字参数
    def __init__(self, **kwargs):
        # 如果 kwargs 中包含 'lossconfig' 键,则将其改名为 'loss_config'
        if "lossconfig" in kwargs:
            kwargs["loss_config"] = kwargs.pop("lossconfig")
        # 调用父类的初始化方法,并传入 regularizer_config 和 kwargs
        super().__init__(
            regularizer_config={
                # 定义目标为 DiagonalGaussianRegularizer
                "target": (
                    "sgm.modules.autoencoding.regularizers"
                    ".DiagonalGaussianRegularizer"
                ),
                # 设置参数 sample 为 False
                "params": {"sample": False},
            },
            **kwargs,
        )