Time-Space Lower Bound for Parity Learning

10 minute read

Published:

Paper Reading: [FOCS 2015] A time-space lower bound for a large class of learning problems.

本文考虑Time-Space tradeoff的问题,本文最关键的结论是,考虑维度为 $n$ 的Parity learning 问题,如果内存不超过 $n^2$, 那么需要指数多的时间才能求解该问题。

Problem Setup

首先我们介绍Parity learning问题,这是理论计算机中的一个经典的问题,在密码学等领域也具有重要应用。

其问题定义如下,首先从随机采样一个secret $x \in {-1,1 }^n$ 隐藏起来。Leaner希望通过每次随机选取 $a \in { -1,1}^n$, 然后观测到 $b_i = \langle a, x \rangle_{\rm XOR}$,希望学习到正确的secret $x$. 其中 $b$ 表示 $a$ 与 x$ 中符号相同的元素的个数的奇偶性。

本文将问题建模成更general的一个矩阵学习问题,给定矩阵 $M :A \times X \rightarrow {-1,1 }$, 矩阵的大小为 $n \times n$, 其中矩阵的每个元素表示 $M(a,x) = \langle a, x\rangle_{\rm XOR}$. 可以证明对于这个矩阵成立 $\Vert M \Vert \le 2^{n/2}$.

推广的矩阵学习问题是,首先从集合 $X$ 中随机采样一个secret 隐藏起来。Leaner希望通过每次随机选取 $a \in A$, 然后观测到 $b = M(a,x)$, 希望学习到正确的secret $x$.

Branching Program

文章考虑内存有界的确定程序,这样的程序可以由如下的有向图建模。

图中的每个结点表示内存的状态,图表示状态的转移。该程序的长度为 $m$, 宽度为 $d$, 分为 $m+1$ 层。程序的长度表示时间的转移,宽度表示内存的多少。如果一个程序有 $r$ 个bit 运行时间为 $m$, 对应的长度为 $n$, 宽度为 $2^r$. 第0层只有一个结点,称为初始结点,表示程序的开始状态。图中出度为0的结点称为叶节点,最后一层均为叶结点,但是在程序的中间也可以由叶节点,叶结点表示程序的终止状态,每个叶结点都有一个输出 $\tilde x(v)$ 表示该结点输出对于secret的预测值。每个非叶结点都存在着 $2 \vert A \vert$ 条边,表示在该状态的时候如果给定样本 $(a,b) \in A \times {-1,1}$ 程序应该如何转移到下一个状态。如果随机选取 $x$ 后,每一步随机选取 $a$, 程序沿着定义的边不断进行转移直到叶结点,如果程序输出了正确的 $x$, 那么程序运行成功。

Main Theorem

文章的主要定理叙述如下。

令 $X,A$ 为有限集合。令 $n = \log \vert X \vert$. 矩阵 $M :A \times X \rightarrow {-1,1 }$ 满足 $\Vert M \Vert \le 2^{2 \gamma n}$.

存在常数 $c,\epsilon$, 使得任意长度为 $2^{\epsilon n}$ 宽度为 $2^{c n^2}$ 的程序对于上述矩阵问题输出正确答案的概率不超过 $O(2^{-\epsilon n})$.

注意到对于Parity leaning问题,我们有 $\gamma = \frac{1}{2}$.

Proof

下面沿着原论文的思路进行证明。

首先对于两个函数 $f:X \rightarrow \mathbb{R}$, $g:X \rightarrow \mathbb{R}$, 我们定义其内积 $\langle f, g \rangle = \mathbb{E}_x[ f(x) g(x) ]$.

由此内积我们也可以诱导出对应的范数。

Analysis of the Truncated Program

对于任意一个程序,我们定义 $P_{x \mid v}$ 为程序运行到结点$v$ 的时候 $x$ 的分布,这包含着程序对于secret的信息。

首先我们需要如下的一些定义:我们称结点 $v$ 是显著的,如果 $\Vert P_{x \mid v} \Vert > 2^{\delta n} 2^{-n}$, 其中 $\delta \in (0,1)$ 是一个待定参数.

我们定义一个程序的截断路径,其运行到一个显著结点就会停止。

当程序运行到一个非显著结点,如果对于secret $x$ 我们有 $P_{x \mid v}(x)> 2^{2 (\delta+\epsilon) } 2^{-n}$ , 那么我们也让程序终止,这种情况称为 $x \in {\rm Sig}(v)$

或者在一个非显著结点的时候,下一步采样 $a$, 但是 $\vert M P_{x \mid v}(a) \vert > 2^{(\delta+ \gamma + \epsilon)} 2^{-n}$, 那么我们也让程序终止, 该情况称为 $ a \in {\rm Bad}(v)$.

这样截断后的程序与原程序并不相同,但是我们可以证明,这与原程序不同的概率为 $O(2^{-\epsilon n})$, 因此我们可以仅分析截断后的程序。

第一种情况比较复杂,我们将其分析留在后面。我们先分析第二和第三种情况,

对于情况二,我们知道对于任意的非显著结点 $v$, 有

\[\begin{align*} \mathbb{E}_{x \sim P_{x \mid v}} [ P_{x \mid v} (x) ] = \sum_{x \in X} P_{x \mid v}(x)^2 = 2^n \mathbb{E} [ P_{x \mid v} (x)^2] \le 2^{2 \delta n} 2^{-n}. \end{align*}\]

那么根据Markov不等式,有

\[\begin{align*} P[ x \in {\rm Sig}(v) \mid v] = P( P_{x \mid v}(x) > 2^{2 \epsilon n } 2^{2 \delta n} 2^{-n} \mid v) \le 2^{-2 \epsilon n}. \end{align*}\]

假设程序运行前随机采样了 $x, a_1,\cdots,a_m$,这样决定了程序的路径为 $v_1,\cdots,v_m’$ $(m’ \le m)$ .

这样我们可以得到由于第二种情况程序终止的概率为

\[\begin{align*} P( \text{Terminate by Reason II} ) &\le \sum_{x , a_1,\cdots,a_m} P(x,a_1,\cdots,a_m) \sum_{i=1}^{m '} I( x \in {\rm Sig}(v_i) \mid v_i) \\ &\le \sum_{a_1,\cdots,a_m} P(a_1,\cdots,a_m) \sum_{i=1}^{m'} P( x \in {\rm Sig}(v_i) \mid v_i) \\ &\le \sum_{a_1,\cdots,a_m} P(a_1,\cdots,a_m) 2^{-\epsilon n } = 2^{-\epsilon n}, \end{align*}\]

其中第一个不等式用到了联合界,第二个不等式用到了 $m = 2^{\epsilon n}$。

类似地,由于$M$ 是一个有界矩阵,由于第三种情况终止的也可以被如下控制住, 对于任意的非显著结点 $v$,

\[\begin{align*} \mathbb{E}_a [ \vert M P_{x \mid v}(a) \vert^2] = \Vert M P_{x \mid v} \Vert^2 \le \Vert M \Vert^2 \Vert P_{x \mid v} \Vert^2 \le 2^{2 \gamma n} 2^{2 \delta n} 2^{-2n}. \end{align*}\]

根据Markov不等式,

\[\begin{align*} P(a \in {\rm Bad}(v) \mid v) &= P( \vert M P_{x \mid v} (a) \vert > 2^{(\delta+\gamma+\epsilon) n } 2^{-n} ) \le 2^{-2 \epsilon n }. \end{align*}\]

这样我们可以得到由于第二种情况程序终止的概率为

\[\begin{align*} P( \text{Terminate by Reason III} ) &\le \sum_{x , a_1,\cdots,a_m} P(x,a_1,\cdots,a_m) \sum_{i=1}^m I( a_i \in {\rm Sig}(v_i) \mid v_i) \\ &= \sum_x P(x) \sum_{i=1}^m P(a_i \in {\rm Sig}(v_i) \mid v_i, a_1,\cdots,a_{i-1}) \\ &\le \sum_{x , a_1,\cdots,a_m} P(x,a_1,\cdots,a_m) 2^{-\epsilon n } = 2^{-\epsilon n}, \end{align*}\]

其中第一个不等式用到了联合界,第二个不等式用到了 $m = 2^{\epsilon n}$。

假设我们也可以证明程序由于第一种情况终止的概率为 $O(2^{-\epsilon n})$, 那么我们知道以 $O(2^{-\epsilon n})$ 的概率截断后的程序和原程序完全相同。所以我们可以仅分析截断后的程序,但由于所有显著结点已经被截断,程序只可能终止于非显著结点,那么由于非显著结点不能包含足够的关于secret的信息,我们可以得到程序高概率情况下不能得到正确的输出。分析如下:

对于任意的非显著结点 $v$ , 根据其定义,有

\[\begin{align*} \mathbb{E}_x[P_{x \mid v} (x)^2] \le 2^{2 \delta n} 2^{-2 n }. \end{align*}\]

那么对于任意的 $x’ \in X$, 有

\[\begin{align*} P(x = x' \mid v) = P_{x \mid v}(x') \le \left(\sum_{x} P_{x \mid v} (x)^2 \right)^{1/2} \le 2^{\delta n} 2^{-n/2}. \end{align*}\]

假设程序运行前随机采样了 $x, a_1,\cdots,a_m$,这样决定了程序的路径为 $v_1,\cdots,v_m’$ $(m’ \le m)$.

这样一来程序得到准确输出的概率为

\[\begin{align*} P(\text{Program Sucess}) &= \sum_{x , a_1,\cdots,a_m} P(x,a_1,\cdots,a_m) I( x = \tilde{x}(v_{m'}) \mid v_{m'} ) \\ &= \sum_{a_1,\cdots,a_m} P(a_1,\cdots,a_m) P(x = \tilde x(v_{m'}) \mid v_{m'}) \\ &\le 2^{\delta n} 2^{-n/2} \le 2^{-\epsilon n}, \end{align*}\]

只需要选择 $\epsilon$ 相对于 $\delta$ 足够小。

在下一个部分,我们将证明本文的核心,也即 第一种情况终止的概率为 $O(2^{-\epsilon n})$。

这就完成了整篇文章的证明。

Upper Bound the Probability of Reaching a Significant Vertex is Low

在本小节里面,我们完全本文的最终的部分,也即证明程序运行到一个显著结点的概率为 $O(2^{-\epsilon n})$, 换句话说,如果运行的时间不够多,那么程序将不能学习到关于 secret $x$ 的信息。

给定程序,对于程序中的每一个显著的结点 $s$, 我们定义Progress函数

\[\begin{align*} Z_i &= \sum_{v \in L_i} P(v) \langle P_{x \mid v}, P_{x \mid s} \rangle^{\beta n } \\ Z_{i}' &= \sum_{e \in \Gamma_i} P(e) \langle P_{x \mid e}, P_{x \mid e} \rangle^{\beta n}. \end{align*}\]

其中 $L_i$ 表示程序的第 $i$ 层中所有结点的集合, $\Gamma_i$ 表示从第$i-1$ 层到第$i$ 层的所有边的集合。

我们希望证明 $Z_i$ 相对于 $Z_{i-1}$ 只会有很小的增量,我们将借助 $Z_i’$ 作为桥梁,首先我们证明 $Z_i \le Z_i’ $.

该结论可以直接由凸性得到, 令 $\Gamma_{\rm in}(v)$ 为 $\Gamma_i$ 中所有指向 $v$ 的边的集合,有

\[\begin{align*} \sum_{e \in \Gamma_{\rm in}(v)} P(e) = P(v). \end{align*}\]

根据全概率公式,有

\[\begin{align*} P_{x \mid v}(x') = \sum_{e \in \Gamma_{\rm in}(v)} \frac{P(e)}{P(v)} P_{x \mid e} (x') \end{align*}\]

根据progress的定义以及Jesen不等式,有

\[\begin{align*} \langle P_{x \mid v}, P_{x \mid s} \rangle^{\beta n} \le \sum_{e \in \Gamma_{\rm in}(v)} \frac{P(e)}{P(v)} \langle P_{x \mid e}, P_{x \mid s } \rangle^{\beta n} \end{align*}\]

对集合 $L_i$ 中的所有结点求和,得到

\[\begin{align*} Z_i \le \sum_{v \in L_i} P(v)\sum_{e \in \Gamma_{\rm in}(v)} \frac{P(e)}{P(v)} \langle P_{x \mid e}, P_{x \mid s } \rangle^{\beta n} = Z_i' \end{align*}\]

这就得到了 $Z_i \le Z_i’ $. 因此关键在于证明 $Z_i’$ 相对于 $Z_{i-1}$ 仅有很小的增量。

为了证明所需要的结论,我们需要研究条件分布 $P_{x \mid e}$ 以及 $P_{x \mid v}$ 之间的关系。

回顾我们定义 $P_{x \mid e}$ 为截断程序经过边 $e = (a,b)$ 的概率,我们知道如果程序不终止, 也即 $x’ \notin {\rm Sig}(v)$ 并且 $M (a,x’) = b$, 那么

\[\begin{align*} P_{x \mid e}(x') = P_{x \mid v} (x') c_e^{-1} \end{align*}\]

其中归一化常数 $c_e$ 为程序不终止的概率, 由于我们知道 $a \notin {\rm Bad}(v)$, 有

\[\begin{align*} P(M(a,x) \ne b \mid v) &\le \frac{1}{2} (1 + \vert P( M(a,x)=1 \mid v) - P( M(a,x)=-1 \mid v) \vert ) \\ &= \frac{1}{2} (1 + \vert M P_{x \mid v} (a) \vert) \\ &\le \frac{1}{2} (1+ 2^{(\delta+ \gamma+ \epsilon)n} 2^{-n}) \le \frac{1}{2} + 2^{-2 \epsilon n}, \end{align*}\]

最后一个不等式只需要选取 $\epsilon$ 在给定 $\delta,\gamma$ 的前提下足够小即可。

结合我们之前所证明的 $P(x \in {\rm Sig}(v) \mid v) \le 2^{-2\epsilon n}$ 的结论我们知道

\[\begin{align*} c_e \ge \frac{1}{2} - 2 \cdot 2^{-2\epsilon n}. \end{align*}\]

根据上述结论,我们可以控制范数 $\Vert P_{x \mid e} \Vert$,

只要选择 $n$ 足够大,成立 $c_e > \frac{1}{4}$, 由于截断路径中经过的节点都为非显著结点,那么

\[\begin{align*} \Vert P_{x \mid e} \Vert \le c_e^{-1} \Vert P_{x \mid} \Vert \le 4 \cdot 2^{\delta n} 2^{-n}. \end{align*}\]

下面作为简单的推论,令 $\Gamma_{\rm in}(s)$ 为指向结点 $s$ 的所有边的集合,我们知道

\[\begin{align*} \sum_{e \in \Gamma_{\rm in}(s)} P(e) = P(s). \end{align*}\]

根据全概率公式,有

\[\begin{align*} P_{x \mid s}(x') = \sum_{e \in \Gamma_{\rm in}(s)} \frac{P(e)}{P(s)} P_{x \mid e} (x') \end{align*}\]

根据Jesen不等式,有

\[\begin{align*} \Vert P_{x \mid s} \Vert^2 \le \sum_{e \in \Gamma_{\rm in}(v)} \frac{P(e)}{P(v)} \Vert P_{x \mid e} \Vert^2 \le (4 \cdot 2^{\delta n} 2^{-n})^2. \end{align*}\]

注意到上述分析实际上与之前证明的 $Z_i \le Z_i’ $ 完全类似。

有了上述的准备工作,我们开始证明 $Z_i’$ 相对于 $Z_{i-1}$ 仅有很小的增量。

令 $\Gamma_{\rm out}(v)$ 为从 $v$ 流出的所有边的集合,我们考虑如下量作为桥梁:

\[\begin{align*} Y_i' = \sum_{e \in \Gamma_{\rm out}(v)} \frac{P(e)}{P(v)} \langle P_{x \mid e} , P_{x \mid s} \rangle^{\beta n}. \end{align*}\]

我们定义

\[\begin{align*} f(x') = P_{x \mid v} (x') I( x' \notin {\rm Sig}(v)) P_{x \mid s} (x'). \end{align*}\]

对于上述定义,回顾 ${\rm Sig}(v)$ 的定义, 我们知道

\[\begin{align*} \Vert f \Vert \le 2^{2(\delta+ \epsilon) n} 2^{-n} \Vert P_{x \mid s} (x') \Vert \le 2^{(3 \delta+ 2 \epsilon) n + 2} 2^{-2n}. \end{align*}\]

那么我们知道

\[\begin{align*} P_{x \mid e}(x') = c_e^{-1} f(x') I(M(a,x') = b), \end{align*}\]

其中 $c_e$ 为之前计算过的归一化常数。那么,

\[\begin{align*} \langle P_{x \mid e}, P_{x \mid s} \rangle &= \mathbb{E}_{x' \in X} [P_{x \mid e}(x') P_{x \mid s} (x')] \\ &= c_{e}^{-1} 2^{-n} \sum_{x' : M(a,x') = b} f(x') \\ &= c_{e}^{-1} 2^{-n} \frac{F + b (Mf)(a) }{2} \\ &\le 2^{-n} (1+ 2^{-2 \epsilon n+2}) (F + \vert (Mf)(a) \vert), \end{align*}\]

其中我们定义 $F = \sum_{x’ \in X} f(x’)$. 下面我们分别考虑两种情况,

第一种情况是 $F \le 2^{-n}$, 此时选择足够大的 $n$ , 我们有

\[\begin{align*} \langle P_{x \mid e}, P_{x \mid s} \rangle \le 4 \cdot 2^{-2n}. \end{align*}\]

回顾 $Y_i’$ 的定义, 这意味着 $Y_i’ \le ( 2^{-2n +2})^{\beta n}$.

第二种情况的计算较为繁琐,此时 $F > 2^{-n}$. 我们根据原文定义如下的辅助函数

\[\begin{align*} t(a) = \left( \frac{(Mf)(a)}{F} \right)^2. \end{align*}\]

代入上述的定义和刚刚得到的不等式,有

\[\begin{align*} \langle P_{x \mid e}, P_{x \mid s} \rangle^{\beta n} &\le (2^{-n} F)^{\beta n } \left(1 + \sqrt{t(a)} \right)^{\beta n} (1+ 2^{-2 \epsilon n+2})^{\beta n} \\ &\le \langle P_{x \mid v}, P_{x \mid s} \rangle^{ \beta n} \left(1 + \sqrt{t(a)} \right)^{\beta n} (1+ 2^{-2 \epsilon n+2})^{\beta n} \end{align*}\]

考虑到 $Y_i’$ 的定义,求和后得到

\[\begin{align*} Y_i' \le \langle P_{x \mid v}, P_{x \mid s} \rangle^{ \beta n} \mathbb{E}_{a \in A} \left(1 + \sqrt{t(a)} \right)^{\beta n} (1+ 2^{-2 \epsilon n+2})^{\beta n} \end{align*}\]

上式最难分析的在于中间项。

根据得到的 $F$ , $f$ 的条件,我们首先知道中间项的期望并不太大:

\[\begin{align*} \mathbb{E}_{a \in A} [t(a)] \le \frac{ \Vert M \Vert^2 \Vert f \Vert^2}{F^2} \le 2^{- (2 - 2 \gamma - 6 \delta - 4 \epsilon) n + 4}. \end{align*}\]

我们自然希望使用Jesen不等式,我们研究函数 $g(t) = (1+\sqrt{t})^{\beta n}$, 可以发现其在区间 $0 \le t \le \frac{1}{(\beta n-2)^2}$ 上是凹函数,因此我们将求和分为两个部分, 其中第一个区间内的 $a$ 我们使用Jesen不等式,第二个区间内的 $a$ 我们采用Markov不等式,得到

\[\begin{align*} &\quad \mathbb{E}_{a \in A} \left(1 + \sqrt{t(a)} \right)^{\beta n} \\ &\le \left( 1 + \sqrt{\mathbb{E}_{a \in A} [t(a)]} \right)^{\beta n} + 2^{\beta n} P\left( t(a) > \frac{1}{(\beta n-2)^2} \right) \\ &\le \left( 1 + \frac{\beta n}{2}\mathbb{E}_{a \in A} [t(a)] \right) + 2^{\beta n} (\beta n -2)^2 \mathbb{E}_{a \in A} [t(a)] \end{align*}\]

代入关于 $\mathbb{E} [t(a)]$ 的上界,并且注意到我们可以选取 $\beta$ 足够小,之后 $\epsilon$ 足够小, $n$ 足够大,可以满足

\[\begin{align*} Y_i' &\le \langle P_{x \mid v}, P_{x \mid s} \rangle^{ \beta n} (1+ 2^{-2 \epsilon n+2})^{\beta n+1} \\ &\le \langle P_{x \mid v}, P_{x \mid s} \rangle^{ \beta n} (1+ 2^{-1.9 \epsilon n}) \end{align*}\]

综合我们关于 $F$ 的两种情况的讨论我们得到了,

\[\begin{align*} Y_i' \le \langle P_{x \mid v}, P_{x \mid s} \rangle^{ \beta n} (1+ 2^{-1.9 \epsilon n}) + ( 2^{-2n +2})^{\beta n} \end{align*}\]

通过 $Y_i’$ 为桥梁,我们可以得到

\[\begin{align*} Z_i' &= \sum_{e \in \Gamma_i} P(e) \langle P_{x \mid e}, P_{x \mid e} \rangle^{\beta n} \\ &= \sum_{v \in L_{i-1}} P(v) \sum_{e \in \Gamma_{\rm out}(v)} Y_i' \\ &\le \sum_{v \in L_{i-1}} P(v) \left( \langle P_{x \mid v}, P_{x \mid s} \rangle^{ \beta n} (1+ 2^{-1.9 \epsilon n}) + ( 2^{-2n +2})^{\beta n} \right) \\ &\le (1+ 2^{-1.9 \epsilon n}) Z_{i-1} + ( 2^{-2n +2})^{\beta n}. \end{align*}\]

求解上述的递归式,首先注意到按照定义 $Z_0 = (2^{-2n})^{\beta n}$. 根据等比数列求和公式,有

\[\begin{align*} Z_i \le \frac{(1 + 2^{-1.9\epsilon n)^{m}}}{2^{-1.9 \epsilon n}} ( 2^{-2n +2})^{\beta n} \le 2^{-2 \beta n^2} 2^{2 (\beta + \epsilon) n}. \end{align*}\]

最终我们证明上面意味着 $P(s)$ 的概率很小,假设 $s$ 在 第$i$ 层,那么根据progress的上界,我们知道

\[\begin{align*} Z_i \ge P(s) \langle P_{x \mid s} , P_{x \mid s} \rangle^{\beta n} \ge P(s) 2^{2 \delta \beta n^2} 2^{-2 \beta n^2}. \end{align*}\]

这也就是说对于任意一个显著的结点 $s$ ,其到达概率都很小:

\[\begin{align*} P(s) \le 2^{2 (\beta + \epsilon) n} 2^{-2 \delta \beta n^2}. \end{align*}\]

考虑到整个程序中至多只有 $2^{\epsilon n } 2^{c n^2}$ 个结点,我们对于所有的显著的结点取联合界,得到

\[\begin{align*} P (\text{Reach a significant vertex}) \le 2^{\epsilon n } 2^{c n^2} 2^{2 (\beta + \epsilon) n} 2^{-2 \delta \beta n^2} \end{align*}\]

观察上式的主项,我们发现只要 $c \le 2\delta \beta$ 条件满足,该概率一定小于 $2^{-\epsilon n}$.

这就证明了所需要的结论,从而完成本文主要定理的证明。