In-Context Learning and Contrastive Learning

3 minute read

Published:

Paper reading: In-context Learning with Transformer Is Really Equivalent to a Contrastive Learning Pattern.

阅读了一篇arXiv2023.10的新文章 [1],文章建立了In-Context Learning (ICL) 以及对比学习之间的关系。之前的结果大多是建立了ICL对于线性Attention机制下与梯度下降的关系,但本文的框架更加general,可以适用于非线性的Attention,并且证明其与某个特定的对比学习损失的梯度下降的联系。

Linear Attention and GD

给定一个权重矩阵 $W$, 模型的预测为 $\hat y = W x$, 对应损失函数为 $\mathcal{L}(\hat y)$. 对 $W$ 进行梯度下降得到

\[\begin{align*} \hat W = W_0 + \Delta W = W_0 - \frac{\eta}{n} \sum_{i=1}^n e_i \otimes x_i. \end{align*}\]

其中 $e_i = (\partial \mathcal{L} / \partial \hat y_i )$. 给定数据点 $x_{\rm test}$, 上述过程在 $\hat y_{\rm test}$ 上产生的变化可以用线性Attention机制实现,

\[\begin{align*} \hat y_{\rm test} &= W x_{\rm test} \\ &= W_0 x_{\rm test} - \frac{\eta}{n} \left(\sum_{i=1}^n e_i \otimes x_i \right) x_{\rm test} \\ &= W_0 x_{\rm test} + P V M K^\top q. \end{align*}\]

其中

\[\begin{align*} P= -\frac{\eta}{n} I_n, \quad M = \begin{pmatrix} I_n & 0 \\ 0 & 0 \end{pmatrix}, \quad q = x_{\rm test}, \quad K = (x_1, \cdots, x_n , x_{\rm test}), \quad V = (e_1,\cdots, e_n, e_{\rm test}). \end{align*}\]

其中掩码矩阵 $M$ 的目的是使得最后一个样本的Attention仅仅注意到前面的 $n$ 个样本。

上式可以理解为ICL使用prompt作为训练集,进行了一步梯度下降,结果作用于测试样本上。

下面我们根据 [2] 中的构造,说明特定选取的 $V$ 可以满足对应的损失为最小二乘损失。

对于 $\mathcal{L}(\hat y_i ; y_i) = \frac{1}{2}( \hat y_i - y_i)^2$ ,我们知道在初始权重 $W_0$ 处对应的梯度信号为,

\[\begin{align*} e_i = - \frac{\eta}{n} ( W_0 x_i - y_i). \end{align*}\]

考虑输入的tocken为 $(x_1;y_1), \cdots , (x_n;y_n), (x_{\rm test};W_0 x_{\rm test})$. 通过构造

\[\begin{align*} W_V = \begin{pmatrix} O & O \\ W_0 &- I \end{pmatrix}, \quad W_K = W_Q = \begin{pmatrix} I & 0 \\ 0 & 0 \end{pmatrix}. \end{align*}\]

我们可以知道

\[\begin{align*} W_V \begin{pmatrix} x_i \\ y_i \end{pmatrix} = \begin{pmatrix} 0 \\ W_0 x_i - y_i \end{pmatrix}, \quad W_K \begin{pmatrix} x_i \\ y_i \end{pmatrix} = W_Q \begin{pmatrix} x_i \\ y_i \end{pmatrix} = \begin{pmatrix} x_i \\ 0 \end{pmatrix}. \end{align*}\]

考虑经过如下单层Linear Attention后的网络

\[\begin{align*} \hat H = H+ P W_V H M H^\top W_K^\top W_Q H, \quad H = ( h_1,\cdots h_n,h_{\rm test}), ~~h_{i} = \begin{pmatrix} x_i \\ y_i \end{pmatrix}. \end{align*}\]

可以发现其输出的最后一个token的最后一维恰好是之前提到的梯度步(GD)的结果。

\[\begin{align*} \hat y_{\rm test } = W_0 x_{\rm test} - \frac{\eta}{n} \left(\sum_{i=1}^n e_i \otimes x_i \right) x_{\rm test}. \end{align*}\]

上述分析构成了 [2] 的主要结果.

Non-Linear Attention and Contrastive Learning

在本节,我们参考 [1], 引入核函数以推广结论至非线性的Attention机制,并且讨论ICL以及对比学习(CL)的联系。

根据高斯核函数的性质以及softmax算子的定义,我们知道存在非线性变换 $\phi(x)$ 使得

\[\begin{align*} V {\rm softmax}(K^\top Q / \sqrt{d_{\rm out}}) = V D^{-1} \phi(K)^\top \phi(Q). \end{align*}\]

其中 $D$ 为对角阵其对角元对应于 softmax 中的归一化常数,$\phi(K) = (\phi(k_1), \cdots , \phi(k_n))$ 且 $\phi(Q)$ 同理。

在非线性情况下,给定一个权重矩阵 $W$, 模型的预测为 $\hat y = W \phi(x)$ .

定义 $e_i = \partial \mathcal{L} / \partial \hat y_i$. 类似于线性Attention的推导,我们知道

\[\begin{align*} \hat y_{\rm test} &= W \phi(x_{\rm test}) \\ &= W_0 \phi(x_{\rm test}) - \frac{\eta}{n} \left(\sum_{i=1}^n e_i \otimes \phi(x_i) \right) \phi(x_{\rm test}) \\ &= W_0 x_{\rm test} + P V M {\rm softmax} \left(K^\top q / \sqrt{d_{\rm out}} \right). \end{align*}\]

其中

\[\begin{align*} P= -\frac{\eta}{n} D^{-1}, \quad M = \begin{pmatrix} I_n & 0 \\ 0 & 0 \end{pmatrix}, \quad q = \phi(x_{\rm test}), \quad K = (x_1, \cdots, x_n , x_{\rm test}), \quad V = (e_1,\cdots, e_n, e_{\rm test}). \end{align*}\]

最后,注意到给定 $V$ 根据 $e_i$ 的定义我们可以知道损失函数的形式为

\[\begin{align*} \mathcal{L}(\hat y_i;y_i) &= \langle e_i, \hat y_i \rangle \\ \mathcal{L}(W;y_i) &= \langle W_V h_i, W \phi(x_i) \rangle = \langle W_V h_i, W \phi(W_K h_i) \rangle. \end{align*}\]

可以发现这个损失函数正好为一个相似度的形式,用CL的语言可以解释为使用给定的矩阵 $W_V,W_K$ 对token $h$ 进行一个线性变换作为数据增强,然后将 $W \phi(W_K h_i)$ 作为一个网络的输出,最小化该输出与 $W_V h_i$ 的相似度。

上述的分析即建立了ICL以及对比学习(CL) 之间的联系。该观点于 [1] 首次提出。

Reference

[1] Ren, Ruifeng, and Yong Liu. “In-context Learning with Transformer Is Really Equivalent to a Contrastive Learning Pattern.” arXiv preprint arXiv:2310.13220 (2023).

[2] Von Oswald, Johannes, et al. “Transformers learn in-context by gradient descent.” In ICML, 2023.