Momentum-Based Variance Reduction in Non-Convex SGD
Published:
Paper Reading: Momentum-Based Variance Reduction in Non-Convex SGD.
考虑如下的经典的随机优化问题
\[\begin{align*} \min_x \left\{f(x) \triangleq F(x;\xi) \right\}. \end{align*}\]假设成立如下的平均光滑假设:
\[\begin{align*} E_\xi \Vert \nabla F(x;\xi) - \nabla F(y; \xi) \Vert^2 \le L \Vert x - y \Vert^2, \quad \forall x,y. \end{align*}\]文章提出了一个简单的基于动量的方差缩减的方法,可以达到最优的 $\mathcal{O}(\epsilon^{-3})$ 的收敛率。
Algorithm and Analysis
算法对标准的带动量的SGD的动量项进行一个修正,迭代格式为:
\[\begin{align*} v_t =& (1-\alpha) v_{t-1} + \alpha \nabla f(x_t;\xi_t) + (1-\alpha) ( \nabla f(x_t; \xi_t) - \nabla f(x_{t-1}; \xi_t) ) \\ x_{t+1} = & x_t - \eta v_t. \end{align*}\]首先是经典的光滑性假设下的下降引理, 若步长满足 $\eta \le 1/(2L)$, 则有
\[\begin{align*} f(x_{t+1}) \le & f(x_t) - \eta v_t^\top \nabla f(x_t) + \frac{\eta^2 L}{2} \Vert v_t \Vert^2 \\ =& f(x_t) - \frac{\eta(1- \eta L)}{2} \Vert v_t \Vert^2 - \frac{\eta}{2} \Vert \nabla f(x_t) \Vert^2 + \frac{\eta}{2} \Vert v_t - \nabla f(x) \Vert^2 \\ \le& f(x_t) - \frac{\eta}{2} \Vert \nabla f(x_t) \Vert^2 - \frac{\eta}{4} \Vert v_t \Vert^2 + \frac{\eta}{2} \Vert v_t - \nabla f(x) \Vert^2 \end{align*}\]我们进一步控制最后一项的期望,根据等式
\[\begin{align*} v_t - \nabla f(x_t) =& (1-\alpha) v_{t-1} + \alpha \nabla f(x_t;\xi_t) - \nabla f(x_t) \\ &+ (1-\alpha) ( \nabla f(x_t; \xi_t) - \nabla f(x_{t-1}; \xi_t) ) \\ =& (1-\alpha) (v_{t-1} - \nabla f(x_{t-1})) - \alpha (\nabla f(x_t) - \nabla f(x_t;\xi_t)) \\ & +(1-\alpha ) (\nabla f(x_{t-1}) - \nabla f(x_t) + \nabla f(x_t; \xi_t) - \nabla f(x_{t-1}; \xi_t) ) \end{align*}\]取期望,并且利用梯度的无偏性 $E[\nabla f(x_t; \xi_t)]= \nabla f(x_t)$,可以得到
\[\begin{align*} E \Vert v_t - \nabla f(x_t) \Vert^2 =& (1-\alpha)^2 E \Vert v_{t-1} - \nabla f(x_{t-1}) \Vert^2 + \alpha^2 E \Vert \nabla f(x_t)- \nabla f(x_t;\xi_t) \Vert^2 \\ &+ (1-\alpha)^2 E \Vert \nabla f(x_{t-1}) - \nabla f(x_t) + \nabla f(x_t; \xi_t) - \nabla f(x_{t-1}; \xi_t) \Vert^2. \end{align*}\]进一步利用梯度的方差小于等于 $\sigma^2$, 以及光滑性假设,可以得到
\[\begin{align*} E \Vert v_t - \nabla f(x_t) \Vert^2 \le& (1-\alpha)^2 E \Vert v_{t-1} - \nabla f(x_{t-1}) \Vert^2 + \alpha^2 \sigma^2 + 2(1-\alpha)^2 L^2 \Vert x_t - x_{t-1} \Vert^2. \end{align*}\]展开递推式,得到
\[\begin{align*} E \Vert v_t - \nabla f(x_t) \Vert^2 \le \frac{(1-\alpha)^{2t} \sigma^2}{b_0} + \frac{\alpha \sigma^2}{2-\alpha} + 2 (1-\alpha)^2 L^2 \eta^2 \sum_{k=0}^{t-1} (1-\alpha)^{2(t-1-k)} \Vert v_k \Vert^2. \end{align*}\]求和
\[\begin{align*} \sum_{t=0}^{T-1} E \Vert v_t - \nabla f(x_t) \Vert^2 \le \frac{\sigma^2}{b_0\alpha(2-\alpha)} + \frac{\alpha \sigma^2 T}{2-\alpha}+ \frac{2 (1-\alpha)^2 L^2 \eta^2}{2 \alpha - \alpha^2} \sum_{t=0}^{T-1} \Vert v_t \Vert^2, \end{align*}\]其中 $b_0$为初始的batch size.
结合方差的界以及下降引理,我们指导选择 $\eta \le \sqrt{\alpha}/ (4 L)$ 以及 $\alpha \le 1/2$, 就可以使用下降引理中的 $\Vert v_t \Vert^2$ 项抵消掉方差中对应项的上升,得到
\[\begin{align*} \frac{1}{T} \Vert \nabla f(x_t) \Vert^2 \le \frac{2\Delta}{\eta T} + \frac{2\sigma^2}{\alpha b_0 T } +2 \alpha \sigma^2 = \mathcal{O}(T^{-2/3}), \end{align*}\]其中 $\Delta = f(x_0) - \inf_x f(x)$, 并且选择 $\alpha = 1/T^{2/3}$, $b_0 = T^{1/3}$.
换句话说,找到一个 $\epsilon$-稳定点的复杂度为 $\mathcal{O}(\epsilon^{-3})$.
