新闻  |   论坛  |   博客  |   在线研讨会
在Transformer时代重塑RNN,RWKV将非Transformer架构扩展到数百亿参数(2)
机器之心 | 2023-05-24 20:58:50    阅读:212   发布文章

架构细节

RWKV 架构由一系列堆叠的残差块组成,每个残差块又由具有循环结构的时间混合和通道混合子块组成。


循环被表示为当前输入和前一个时间步的输入之间的线性插值(研究者称这种技术为时移混合或 token shift,如下图 3 所示),该插值可以针对输入嵌入的每个线性投影进行独立调整(比如时间混合中的 R、K 和 V,通道混合中的 R 和 K),并作为公式 14 中形式化的 WKV 的时变更新。


图片

图片

类 Transformer 的并行化

RWKV 可以在时间并行模式下进行高效地并行化,让人联想到 Transformer。单个层中一个 batch 序列的时间复杂度为 O (BTd^2 ),它主要由矩阵乘法 W_□,  □ ∈ {r, k, v, o}(假设 B 个序列、T 个最大 token 和 d 个通道)。同时更新注意力分数 wkv_t 需要串行扫描,并且复杂度为 O (BTd)。


类 RNN 的序列解码


在循环网络中,将状态 t 时的输出用作状态 t+1 时的输入很常见。这在语言模型的自回归解码推理中尤为明显,要求每一个 token 在馈入下一步之前必须进行计算,从而使 RWKV 可以利用类 RNN 结构(即时序模式)。在这种情况下,RWKV 可以方便地循环用于推理解码,从而利用每个输出 token 仅依赖于最新状态的优势。


然后 RWKV 充当 RNN ****,在序列长度方面保持恒定速度和内存占用,从而更高效地处理更长的序列。相比之下,自注意力通常需要 KV 缓存相对于序列长度呈线性增长,这会导致效率下降,并随序列长度增加消耗更多内存和时间。


软件实现

RWKV 最初使用 PyTorch 深度学习库和自定义 CUDA 内核(它用于 WKV 计算)来实现。尽管 RWKV 是一个通用循环网络,但其当前的实现主要集中在语言建模任务(RWKV-LM)。该模型架构包含了一个嵌入层,为此研究者遵循第 4.7 节中的设置,并按照第 4.6 节中的原则依次应用几个相同的残差块,具体如上图 2 和 3 所示。


梯度稳定性和层堆叠

RWKV 架构被设计为 Transformer 和 RNN 的融合,与传统的 RNN 相比,Transformers 具有稳定梯度和更深层次架构的优势,同时推理效率高。


RWKV 模型具有用于更新类似注意力分数的单步过程,其中包括一个依赖于时间的 softmax 操作,该操作有助于数值稳定性并防止梯度消失(有关严格证明,请参见附录 F)。直观地说,此操作可确保梯度沿最相关的路径传播。Layer normalization (Ba et al., 2016) 是架构的另一个关键方面,它通过稳定梯度、解决梯度消失和爆炸问题来增强深度神经网络的训练动态。


利用时间结构进行时序数据处理


RWKV 通过三种机制的组合来捕获和传播时序信息:循环、时间衰减和 token shift。


RWKV 时间混合块中的循环是模型捕获序列元素之间复杂关系和随时间传播局部信息的能力的基础。


时间衰减机制(等式 14 中的 e^−w 和 e^u)保持了对序列元素之间位置关系的敏感性。通过逐渐减少以往信息随时间的影响,该模型保留了时间局部性和进展感,这对于时序处理至关重要。


token shift 或 time-shift 混合或(图 3 中的对角线箭头),也有助于模型适应时序数据。通过在当前输入和前一个时间步输入之间进行线性插值,模型自然地聚合和门控输入通道中的信息。


实验结果


实验的重点是回答以下问题:


  • RQ1:在参数数量和训练 token 数量相等的情况下,RWKV 与二次 transformer 架构相比具有竞争力吗?

  • RQ2:增加参数数量时,RWKV 是否仍然具有与二次 transformer 架构相竞争的能力?

  • RQ3:当 RWKV 模型被训练用于开源二次 transformer 无法高效处理的上下文长度时,增加 RWKV 的参数是否能够获得更好的语言建模损失?


首先是回答 RQ1 和 RQ2 问题,从图 4 可以看出,在六个基准测试中(Winogrande、PIQA、ARC-C、ARC-E、LAMBADA 和 SciQ),RWKV 与开源二次复杂度 transformer 模型 Pythia、OPT 和 BLOOM 具有相当的竞争力。RWKV 甚至在四个任务(PIQA、OBQA、ARC-E 和 COPA)中胜过了 Pythia 和 GPT-Neo。


图片


对于 RQ3,图 5 显示,增加上下文长度会导致 Pile 上的测试损失降低,这表明 RWKV 能够有效利用较长的上下文信息。


图片


*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。

参与讨论
登录后参与讨论
推荐文章
最近访客