以LLaMA模型为例,计算一下参数量

首先,假设词表大小为 V,模型包含 L 层解码器,中间状态的维度大小为 H',前馈网络层的中间状态维度大小为 H。我们主要关注计算以下几个部分的参数量:

  • 输入嵌入层:首先,输入嵌入层(E \in \mathbb{R}^{V \times H})将词表中的每个单词映射到一个 H 维的向量,因此输入编码层有 VH 个参数。
  • 多头注意力层:传统的注意力机制部分包含查询 (W^Q \in \mathbb{R}^{H \times H})、键 (W^K \in \mathbb{R}^{H \times H}) 和值 (W^V \in \mathbb{R}^{H \times H}) 的线性变换矩阵,每个变换矩阵都包含 H^2 个参数,因此这部分需要 3 \times H^2 个参数。同时还需要一个额外的线性变换将多头注意力机制的输出拼接后映射为最终输出 (W^O \in \mathbb{R}^{H \times H}),这又需要 H^2 个参数。因此,多头注意力层总共需要 4H^2 个参数。
  • 前馈网络层:LLaMA 的前馈网络层由三个线性变换组成,其中有一个非线性激活函数。前两个线性变换 (W^U \in \mathbb{R}^{H \times H'}W^G \in \mathbb{R}^{H \times H'}) 将输入从 H 维映射到 H' 维,需要 2 \times HH' 个参数;最后一个线性变换 (W^D \in \mathbb{R}^{H' \times H}) 将输出从 H' 维映射回 H 维,需要 HH' 个参数。因此,前馈网络层总共需要 3 \times HH' 个参数。
  • 归一化层:每一层解码器都包含两个 RMSNorm 操作,分别用于对多头注意力层和前馈网络层的输入进行归一化处理,共需要 2 \times H 个参数。此外,最后一层的输出也需要进行归一化处理,这又需要额外的 H 个参数。
  • 输出层:最后,LLaMA 的输出层包含一个线性变换 (W^L \in \mathbb{R}^{H \times V}),将解码器的输出映射到词表大小 V 的维度上,使用 softmax 函数归一化后预测下一个单词的概率分布。这个线性变换需要 VH 个参数。

综合上述,累积输入嵌入层、输出层和 L 层解码器每层的多头注意力层、前馈网络层和归一化层,LLaMA 模型的参数量计算公式为:

\text{参数量} = 2VH + H + L \cdot (4H^2 + 3HH' + 2H)

以 LLaMA (7B) 为例计算其参数量,给定 V = 32000L = 32H = 4096H' = 11008,将这些值代入上述公式中:

\text{参数量} = 2 \times 32000 \times 4096 + 4096 + 32 \times (4 \times 4096^2 + 3 \times 4096 \times 11008 + 2 \times 4096) = 6,738,415,616

计算得到的参数量与 LLaMA (7B) 模型的实际参数量基本完全一致。