input : Matrices A ∈ R n × d \boldsymbol{A}\in\mathbb{R}^{n \times d} A ∈ R n × d , B ∈ R n × m \boldsymbol{B}\in\mathbb{R}^{n \times m} B ∈ R n × m . Number k k k of subsamples. Probabilities p 1 , … , p n p_1,\ldots,p_n p 1 , … , p n
output : Sketched matrix C ∈ R d × m \boldsymbol{C}\in\mathbb{R}^{d \times m} C ∈ R d × m
Sample indices s 1 , … , s k ∈ [ n ] s_1,\ldots,s_k\in[n] s 1 , … , s k ∈ [ n ] iid wrt p 1 , … , p n p_1,\ldots,p_n p 1 , … , p n
Build the sample-and-rescale matrix S ∈ R k × n \boldsymbol{S}\in\mathbb{R}^{k \times n} S ∈ R k × n :
Row t t t of S \boldsymbol{S} S has form [ 0 0 ⋯ 0 1 k p s t 0 ⋯ 0 ] \begin{bmatrix}0&0&\cdots&0&\frac{1}{\sqrt{k p_{s_t}}}&0&\cdots&0\end{bmatrix} [ 0 0 ⋯ 0 k p s t 1 0 ⋯ 0 ] , where index s t s_t s t is the nonzero entry.
Return C = ( S A ) ⊺ ( S B ) \boldsymbol{C} = (\boldsymbol{S}\boldsymbol{A})^\intercal(\boldsymbol{S}\boldsymbol{B}) C = ( S A ) ⊺ ( S B )
Since ( S A ) ⊺ ∈ R d × k (\boldsymbol{S}\boldsymbol{A})^\intercal \in \mathbb{R}^{d \times k} ( S A ) ⊺ ∈ R d × k and S B ∈ R k × m \boldsymbol{S}\boldsymbol{B}\in\mathbb{R}^{k \times m} S B ∈ R k × m , we can compute ( S A ) ⊺ ( S B ) (\boldsymbol{S}\boldsymbol{A})^\intercal(\boldsymbol{S}\boldsymbol{B}) ( S A ) ⊺ ( S B ) in O ( k d m ) O(kdm) O ( k d m ) time instead of O ( n d m ) O(ndm) O ( n d m ) time, just using naive matrix multiplication. We show the following:
Fix
ε > 0 \varepsilon>0 ε > 0 and
δ ∈ ( 0 , 1 ) \delta\in(0,1) δ ∈ ( 0 , 1 ) . Let
C \boldsymbol{C} C be the resulting of fast matrix multiplication with
k ≥ 1 ε 2 δ k \geq \frac{1}{\varepsilon^2\delta} k ≥ ε 2 δ 1 and
p ℓ = ∥ a ℓ ∥ 2 2 ∥ A ∥ F 2 p_\ell = \frac{\|\mathbf{a}_\ell\|_2^2}{\|\boldsymbol{A}\|_F^2} p ℓ = ∥ A ∥ F 2 ∥ a ℓ ∥ 2 2 , where
a ℓ \mathbf{a}_\ell a ℓ is the
ℓ t h \ell^{th} ℓ t h row of
A \boldsymbol{A} A . Then with probability
1 − δ 1-\delta 1 − δ ,
∥ C − A ⊺ B ∥ F ≤ ε ∥ A ∥ F ∥ B ∥ F \|\boldsymbol{C}-\boldsymbol{A}^\intercal\boldsymbol{B}\|_F \leq \varepsilon\|\boldsymbol{A}\|_F \|\boldsymbol{B}\|_F ∥ C − A ⊺ B ∥ F ≤ ε ∥ A ∥ F ∥ B ∥ F
Notably, we are not hiding any constants when we say k ≥ 1 ε 2 δ k \geq \frac1{\varepsilon^2 \delta} k ≥ ε 2 δ 1 . We prove the results in two steps. First, we show a result for arbitrary sampling probabilities, then we prove Theorem 1 . Also, there's a lot of indexing in this analysis, so to be clean we consistently denote t ∈ [ k ] t \in [k] t ∈ [ k ] , ℓ ∈ [ n ] \ell \in [n] ℓ ∈ [ n ] , i ∈ [ d ] i \in [d] i ∈ [ d ] , and j ∈ [ m ] j \in [m] j ∈ [ m ] .
For any sampling probabilities
p 1 , … , p d p_1,\ldots,p_d p 1 , … , p d , we have
E [ ∥ C − A ⊺ B ∥ F 2 ] ≤ 1 k ∑ ℓ = 1 n 1 p ℓ ∥ a ℓ ∥ 2 2 ∥ b ℓ ∥ 2 2 \mathbb{E}[\|\boldsymbol{C}-\boldsymbol{A}^\intercal\boldsymbol{B}\|_F^2] \leq \frac1k \sum_{\ell=1}^n \frac1{p_\ell} \|\mathbf{a}_\ell\|_2^2 \|\mathbf{b}_\ell\|_2^2 E [ ∥ C − A ⊺ B ∥ F 2 ] ≤ k 1 ℓ = 1 ∑ n p ℓ 1 ∥ a ℓ ∥ 2 2 ∥ b ℓ ∥ 2 2 where a ℓ \mathbf{a}_\ell a ℓ and b ℓ \mathbf{b}_\ell b ℓ are the ℓ t h \ell^{th} ℓ t h rows of A \boldsymbol{A} A and B \boldsymbol{B} B respectively.
Proof
Proof. Let
R t = 1 k p s t a s t b s t ⊺ \boldsymbol{R}_t = \frac1{k p_{s_t}} \mathbf{a}_{s_t} \mathbf{b}_{s_t}^\intercal R t = k p s t 1 a s t b s t ⊺ , so that we have
C = ∑ t = 1 k R t \boldsymbol{C} = \sum_{t=1}^k \boldsymbol{R}_t C = ∑ t = 1 k R t :
C = ( S A ) ⊺ ( S B ) = ∑ t = 1 k 1 k p s t a s t ⋅ 1 k p s t b s t ⊺ = ∑ t = 1 k R t \begin{aligned} \boldsymbol{C} &= (\boldsymbol{S}\boldsymbol{A})^\intercal(\boldsymbol{S}\boldsymbol{B}) \\ &= \sum_{t=1}^k \frac1{\sqrt{k p_{s_t}}} \mathbf{a}_{s_t} \cdot \frac1{\sqrt{k p_{s_t}}} \mathbf{b}_{s_t}^\intercal \\ &= \sum_{t=1}^k \boldsymbol{R}_t \end{aligned} C = ( S A ) ⊺ ( S B ) = t = 1 ∑ k k p s t 1 a s t ⋅ k p s t 1 b s t ⊺ = t = 1 ∑ k R t In particular, we see that
E [ R t ] = ∑ ℓ = 1 n p ℓ 1 k p ℓ a ℓ b ℓ ⊺ = 1 k A ⊺ B \mathbb{E}[\boldsymbol{R}_t] = \sum_{\ell=1}^n p_\ell \frac{1}{kp_\ell} \mathbf{a}_\ell\mathbf{b}_\ell^\intercal = \frac1k \boldsymbol{A}^\intercal\boldsymbol{B} E [ R t ] = ∑ ℓ = 1 n p ℓ k p ℓ 1 a ℓ b ℓ ⊺ = k 1 A ⊺ B , which in turn implies
E [ C ] = A ⊺ B \mathbb{E}[\boldsymbol{C}] = \boldsymbol{A}^\intercal\boldsymbol{B} E [ C ] = A ⊺ B . We then can expand and simplify by independence, linearity of variance and by the bound
Var [ x ] ≤ E [ x 2 ] \text{Var}[x] \leq \mathbb{E}[x^2] Var [ x ] ≤ E [ x 2 ] :
E [ ∥ C − A ⊺ B ∥ F 2 ] = ∑ i = 1 d ∑ j = 1 m E [ ( [ C − A ⊺ B ] i , j ) 2 ] = ∑ i = 1 d ∑ j = 1 m E [ ( ∑ t = 1 k [ R t ] i , j − E [ R t ] i , j ) 2 ] = ∑ i = 1 d ∑ j = 1 m Var [ ∑ t = 1 k [ R t ] i , j ] = k ∑ i = 1 d ∑ j = 1 m Var [ [ R 1 ] i , j ] ≤ k ∑ i = 1 d ∑ j = 1 m E [ ( [ R 1 ] i , j ) 2 ] ≤ k E [ ∥ R 1 ∥ F 2 ] \begin{aligned} \mathbb{E}[\|\boldsymbol{C}-\boldsymbol{A}^\intercal\boldsymbol{B}\|_F^2] &= \sum_{i=1}^d\sum_{j=1}^m \mathbb{E}\left[ ([\boldsymbol{C}-\boldsymbol{A}^\intercal\boldsymbol{B}]_{i,j})^2 \right] \\ &= \sum_{i=1}^d\sum_{j=1}^m \mathbb{E}\left[ \left(\textstyle{\sum_{t=1}^k [\boldsymbol{R}_t]_{i,j} - \mathbb{E}[\boldsymbol{R}_t]_{i,j}}\right)^2 \right] \\ &= \sum_{i=1}^d\sum_{j=1}^m \text{Var}\left[ \textstyle{\sum_{t=1}^k [\boldsymbol{R}_t]_{i,j}} \right] \\ &= k \sum_{i=1}^d\sum_{j=1}^m \text{Var}\left[ [\boldsymbol{R}_1]_{i,j} \right] \\ &\leq k \sum_{i=1}^d\sum_{j=1}^m \mathbb{E}\left[ \left([\boldsymbol{R}_1]_{i,j}\right)^2 \right] \\ &\leq k \mathbb{E}\left[ \|\boldsymbol{R}_1\|_F^2 \right] \end{aligned} E [ ∥ C − A ⊺ B ∥ F 2 ] = i = 1 ∑ d j = 1 ∑ m E [ ([ C − A ⊺ B ] i , j ) 2 ] = i = 1 ∑ d j = 1 ∑ m E [ ( ∑ t = 1 k [ R t ] i , j − E [ R t ] i , j ) 2 ] = i = 1 ∑ d j = 1 ∑ m Var [ ∑ t = 1 k [ R t ] i , j ] = k i = 1 ∑ d j = 1 ∑ m Var [ [ R 1 ] i , j ] ≤ k i = 1 ∑ d j = 1 ∑ m E [ ( [ R 1 ] i , j ) 2 ] ≤ k E [ ∥ R 1 ∥ F 2 ] Since
R 1 \boldsymbol{R}_1 R 1 is rank-one, it is simple to compute its Frobenius norm:
∥ R 1 ∥ F 2 = tr ( R 1 ⊺ R 1 ) = 1 k 2 p s t 2 tr ( ( a s t b s t ⊺ ) ⊺ ( a s t b s t ⊺ ) ) = 1 k 2 p s t 2 tr ( b s t a s t ⊺ a s t b s t ⊺ ) = ∥ a s t ∥ 2 2 ∥ b s t ∥ 2 2 k 2 p s t 2 \begin{aligned} \|\boldsymbol{R}_1\|_F^2 &= \text{tr}(\boldsymbol{R}_1^\intercal\boldsymbol{R}_1) \\ &= \frac1{k^2 p_{s_t}^2} \text{tr}((\mathbf{a}_{s_t} \mathbf{b}_{s_t}^\intercal)^\intercal(\mathbf{a}_{s_t}\mathbf{b}_{s_t}^\intercal)) \\ &= \frac1{k^2 p_{s_t}^2} \text{tr}(\mathbf{b}_{s_t}\mathbf{a}_{s_t}^\intercal\mathbf{a}_{s_t}\mathbf{b}_{s_t}^\intercal) \\ &= \frac{\|\mathbf{a}_{s_t}\|_2^2 \|\mathbf{b}_{s_t}\|_2^2}{k^2 p_{s_t}^2} \end{aligned} ∥ R 1 ∥ F 2 = tr ( R 1 ⊺ R 1 ) = k 2 p s t 2 1 tr (( a s t b s t ⊺ ) ⊺ ( a s t b s t ⊺ )) = k 2 p s t 2 1 tr ( b s t a s t ⊺ a s t b s t ⊺ ) = k 2 p s t 2 ∥ a s t ∥ 2 2 ∥ b s t ∥ 2 2 And overall, we conclude that
E [ ∥ C − A ⊺ B ∥ F 2 ] ≤ k E [ ∥ R 1 ∥ F 2 ] = k ⋅ ∑ ℓ = 1 n p ℓ ∥ a ℓ ∥ 2 2 ∥ b ℓ ∥ 2 2 k 2 p ℓ 2 = 1 k ⋅ ∑ ℓ = 1 n ∥ a ℓ ∥ 2 2 ∥ b ℓ ∥ 2 2 p ℓ \begin{aligned} \mathbb{E}[\|\boldsymbol{C}-\boldsymbol{A}^\intercal\boldsymbol{B}\|_F^2] &\leq k \mathbb{E}\left[ \|\boldsymbol{R}_1\|_F^2 \right] \\ &= k \cdot \sum_{\ell=1}^n p_\ell \frac{\|\mathbf{a}_\ell\|_2^2 \|\mathbf{b}_\ell\|_2^2}{k^2 p_\ell^2} \\ &= \frac1k \cdot \sum_{\ell=1}^n \frac{\|\mathbf{a}_\ell\|_2^2 \|\mathbf{b}_\ell\|_2^2}{p_\ell} \end{aligned} E [ ∥ C − A ⊺ B ∥ F 2 ] ≤ k E [ ∥ R 1 ∥ F 2 ] = k ⋅ ℓ = 1 ∑ n p ℓ k 2 p ℓ 2 ∥ a ℓ ∥ 2 2 ∥ b ℓ ∥ 2 2 = k 1 ⋅ ℓ = 1 ∑ n p ℓ ∥ a ℓ ∥ 2 2 ∥ b ℓ ∥ 2 2 Having completed this core technical claim, Theorem 1 follows by a short corollary from just plugging in the chosen sampling probabilities. In fact, we prove something slightly broader:
Let
τ ~ 1 , … , τ ~ n \tilde\tau_1,\ldots,\tilde\tau_n τ ~ 1 , … , τ ~ n be numbers such that
τ ~ ℓ ≥ ∥ a ℓ ∥ 2 2 \tilde\tau_\ell \geq \|\mathbf{a}_\ell\|_2^2 τ ~ ℓ ≥ ∥ a ℓ ∥ 2 2 for all
ℓ ∈ [ n ] \ell\in[n] ℓ ∈ [ n ] , and let
T : = ∑ ℓ = 1 n τ ~ ℓ T \;{\vcentcolon=}\; \sum_{\ell=1}^n \tilde\tau_\ell T : = ∑ ℓ = 1 n τ ~ ℓ . Then let
p ℓ : = τ ~ ℓ T p_\ell \;{\vcentcolon=}\; \frac{\tilde\tau_\ell}{T} p ℓ : = T τ ~ ℓ and run
Theorem 1 . Then, so long as
k ≥ 1 ε 2 δ ⋅ T ∥ A ∥ F 2 k \geq \frac{1}{\varepsilon^2 \delta} \cdot \frac{T}{\|\boldsymbol{A}\|_F^2} k ≥ ε 2 δ 1 ⋅ ∥ A ∥ F 2 T , with probability
1 − δ 1-\delta 1 − δ we get
∥ C − A ⊺ B ∥ F ≤ ε ∥ A ∥ F ∥ B ∥ F \|\boldsymbol{C}-\boldsymbol{A}^\intercal\boldsymbol{B}\|_F \leq \varepsilon\|\boldsymbol{A}\|_F \|\boldsymbol{B}\|_F ∥ C − A ⊺ B ∥ F ≤ ε ∥ A ∥ F ∥ B ∥ F
Proof
Proof. We first bound the expected error from
Lemma 1 , where we get
∥ C − A ⊺ B ∥ F 2 ≤ 1 k ⋅ ∑ ℓ = 1 n ∥ a ℓ ∥ 2 2 ∥ b ℓ ∥ 2 2 p ℓ = T k ⋅ ∑ ℓ = 1 n ∥ a ℓ ∥ 2 2 τ ~ ℓ ∥ b ℓ ∥ 2 2 ≤ T k ⋅ ∑ ℓ = 1 n ∥ b ℓ ∥ 2 2 = T k ∥ B ∥ F 2 \begin{aligned} \|\boldsymbol{C}-\boldsymbol{A}^\intercal\boldsymbol{B}\|_F^2 &\leq \frac1k \cdot \sum_{\ell=1}^n \frac{\|\mathbf{a}_\ell\|_2^2 \|\mathbf{b}_\ell\|_2^2}{p_\ell} \\ &= \frac{T}{k} \cdot \sum_{\ell=1}^n \frac{\|\mathbf{a}_\ell\|_2^2}{\tilde\tau_\ell} \|\mathbf{b}_\ell\|_2^2 \\ &\leq \frac{T}{k} \cdot \sum_{\ell=1}^n \|\mathbf{b}_\ell\|_2^2 \\ &= \frac{T}{k} \|\boldsymbol{B}\|_F^2 \end{aligned} ∥ C − A ⊺ B ∥ F 2 ≤ k 1 ⋅ ℓ = 1 ∑ n p ℓ ∥ a ℓ ∥ 2 2 ∥ b ℓ ∥ 2 2 = k T ⋅ ℓ = 1 ∑ n τ ~ ℓ ∥ a ℓ ∥ 2 2 ∥ b ℓ ∥ 2 2 ≤ k T ⋅ ℓ = 1 ∑ n ∥ b ℓ ∥ 2 2 = k T ∥ B ∥ F 2 We apply Markov's inequality, which tells us that
Pr [ ∥ C − A ⊺ B ∥ 2 > ε 2 ∥ A ∥ F 2 ∥ B ∥ F 2 ] ≤ ∥ C − A ⊺ B ∥ F 2 ε 2 ∥ A ∥ F 2 ∥ B ∥ F 2 ≤ 1 k T ∥ B ∥ F 2 ε 2 ∥ A ∥ F 2 ∥ B ∥ F 2 = T k ε 2 ∥ A ∥ F 2 \begin{aligned} \Pr[\|\boldsymbol{C}-\boldsymbol{A}^\intercal\boldsymbol{B}\|^2 > \varepsilon^2\|\boldsymbol{A}\|_F^2\|\boldsymbol{B}\|_F^2] &\leq \frac{\|\boldsymbol{C}-\boldsymbol{A}^\intercal\boldsymbol{B}\|_F^2}{\varepsilon^2\|\boldsymbol{A}\|_F^2\|\boldsymbol{B}\|_F^2} \\ &\leq \frac{\frac 1k T\|\boldsymbol{B}\|_F^2}{\varepsilon^2\|\boldsymbol{A}\|_F^2\|\boldsymbol{B}\|_F^2} \\ &= \frac{T}{k\varepsilon^2\|\boldsymbol{A}\|_F^2} \end{aligned} Pr [ ∥ C − A ⊺ B ∥ 2 > ε 2 ∥ A ∥ F 2 ∥ B ∥ F 2 ] ≤ ε 2 ∥ A ∥ F 2 ∥ B ∥ F 2 ∥ C − A ⊺ B ∥ F 2 ≤ ε 2 ∥ A ∥ F 2 ∥ B ∥ F 2 k 1 T ∥ B ∥ F 2 = k ε 2 ∥ A ∥ F 2 T Which is at most
δ \delta δ when
k > 1 δ ε 2 ⋅ T ∥ A ∥ F 2 k > \frac{1}{\delta\varepsilon^2} \cdot \frac{T}{\|\boldsymbol{A}\|_F^2} k > δ ε 2 1 ⋅ ∥ A ∥ F 2 T .
Note that when we compute the norms exactly, so that τ ~ ℓ = ∥ a ℓ ∥ 2 2 \tilde\tau_\ell = \|\mathbf{a}_\ell\|_2^2 τ ~ ℓ = ∥ a ℓ ∥ 2 2 for all ℓ \ell ℓ , we get T = ∑ ℓ τ ~ ℓ = ∥ A ∥ F 2 T = \sum_\ell \tilde\tau_\ell = \|\boldsymbol{A}\|_F^2 T = ∑ ℓ τ ~ ℓ = ∥ A ∥ F 2 , which recovers Theorem 1 exactly.
The analysis here is a blend of Drineas, Mahoney (2018) and Nelson (2015) , as well as some personal notes by Christopher Musco. Note that randomized matrix multiplication often does not use the exact sampling probabilities discussed here, and the references below discuss a variety of slightly different schemes.
Here's some relevant papers:
Drineas, Kannan, Mahoney (2006) is (afaik) the original paper on the topic. Table 1 of this paper compares a great variety of sampling schemes.
Drineas, Mahoney (2018) is a book with a section that this page partially copies.
Nelson (2015) are lecture notes that this page partially copies.
Avron et al. (2019) generalizes this to approximate linear operator multiplication in Claim 45, the "Approximate Operator Application".
Let me know if anything is missing