Transformer计算公式
LLM inference workflow
Generative Inference. A typical LLM generative inference task consists of two stages: i) the prefill stage which takes a prompt sequence to generate the key-value cache (KV cache) for each transformer layer of the LLM; and ii) the decoding stage which utilizes and updates the KV cache to generate tokens step-by-step, where the current token generation depends on previously generated tokens.
prefill phase
Then, the cached key,value can be computed by:
\(\mathrm{x}_K^i=\mathrm{x}^i \cdot \mathrm{w}_K^i ; \quad \mathrm{x}_V^i=\mathrm{x}^i \cdot \mathrm{w}_V^i\)
The rest of the computation in the i-th layer is:
\(\begin{gathered}\mathrm{x}_Q^i=\mathrm{x}^i \cdot \mathrm{w}_Q^i \\ \mathrm{x}_{\text {Out }}^i=f_{\text {Softmax }}\left(\frac{\mathrm{x}_Q^i \mathrm{x}_K^i}{\sqrt{h}}\right) \cdot \mathrm{x}_V^i \cdot \mathrm{w}_O^i+\mathrm{x}^i \\ \mathrm{x}^{i+1}=f_{\text {relu }}\left(\mathrm{x}_{\text {Out }}^i \cdot \mathrm{w}_1\right) \cdot \mathrm{w}_2+\mathrm{x}_{\text {Out }}^i\end{gathered}\)
decode phase
During the decode phase, given \(\mathbf{t}^i \in \mathcal{R}^{b \times 1 \times h_1}\) as the embedding of the current generated token in the \(i\)-th layer, the inference computation needs to i) update the KV cache:
and ii) compute the output of the current layer: