Fix Point Theorem, 即不动点定理,是泛函分析中的基本工具,被广泛应用于非线性函数的分析。
在介绍不动点定理之前,我们先介绍两个概念
首先是不动点的概念。
Definition
对于函数 f : R n → R n f:\mathbb{R}^n\to\mathbb{R}^n f : R n → R n , 如果一个点 x ∗ ∈ R n x^*\in\mathbb{R}^n x ∗ ∈ R n 满足
f ( x ∗ ) = x ∗ f(x^*)=x^* f ( x ∗ ) = x ∗
则我们称 x ∗ x^* x ∗ 是函数 f f f 的不动点。
接下来是 contraction mapping 的概念
Definition
对于函数 f : R n → R n f:\mathbb{R}^n\to\mathbb{R}^n f : R n → R n , 如果存在 γ ∈ ( 0 , 1 ) \gamma\in(0,1) γ ∈ ( 0 , 1 ) 满足
∥ f ( x 1 ) − f ( x 2 ) ∥ ≤ γ ∥ x 1 − x 2 ∥ , ∀ x 1 , x 2 ∈ R n \|f(x_1)-f(x_2)\| \leq \gamma \|x_1-x_2\|,\forall\ x_1,x_2\in\mathbb{R}^n ∥ f ( x 1 ) − f ( x 2 ) ∥ ≤ γ ∥ x 1 − x 2 ∥ , ∀ x 1 , x 2 ∈ R n
则我们称 f f f 是一个 contraction mapping. 这里 ∥ ⋅ ∥ \|\cdot\| ∥ ⋅ ∥ 是一个 matrix norm.
接下来,我们介绍不动点定理
Theorem
给定 equation x = f ( x ) x=f(x) x = f ( x ) , 其中 f : R n → R n f:\mathbb{R}^n\to\mathbb{R}^n f : R n → R n , 如果 f f f 是一个 contraction mapping, 则 f f f 具有如下性质
Existence: 存在 fixed point x ∗ ∈ R n x^*\in\mathbb{R}^n x ∗ ∈ R n 满足 f ( x ∗ ) = x ∗ f(x^*)=x^* f ( x ∗ ) = x ∗ .
Uniqueness: fixed point x ∗ x^* x ∗ 唯一。
Algorithm: 对任意 x 0 ∈ R n x^0\in\mathbb{R}^n x 0 ∈ R n , 使用迭代算法 x k + 1 = f ( x k ) x_{k+1}=f(x_k) x k + 1 = f ( x k ) 产生的序列 { x k } k = 0 ∞ \{x_k\}_{k=0}^{\infty} { x k } k = 0 ∞ 收敛到 fixed point x ∗ x* x ∗ , 且收敛速度为指数级。
证明需要用到柯西列的概念。
Definition
一个序列 x 1 , x 2 , … x_1,x_2,\dots x 1 , x 2 , … 被称为柯西列 (Cauchy sequence) 当且仅当对任意 ϵ > 0 \epsilon>0 ϵ > 0 , 都存在 N > 0 N>0 N > 0 , 使得
∥ x m − x n ∥ < ϵ , ∀ m , n > N \|x_m-x_n\| <\epsilon,\forall m, n>N ∥ x m − x n ∥ < ϵ , ∀ m , n > N
柯西列的一个重要性质为柯西列一定是收敛列。
证明
我们首先证明由 x k = f ( x k = 1 ) x_k=f(x_{k=1}) x k = f ( x k = 1 ) 产生的序列 { x k } k = 1 ∞ \{x_k\}_{k=1}^{\infty} { x k } k = 1 ∞ 是收敛的,我们通过证明序列 { x k } k = 1 ∞ \{x_k\}_{k=1}^{\infty} { x k } k = 1 ∞ 是一个柯西列来证明这一点。
注意到 f f f 是一个 contraction mapping, 因此
∥ x k + 1 − x k ∥ = ∥ f ( x k ) − f ( x k − 1 ) ∥ ≤ γ ∥ x k − x k − 1 ∥ \|x_{k+1}-x_k\| = \|f(x_k)-f(x_{k-1})\|\leq \gamma \|x_k-x_{k-1}\| ∥ x k + 1 − x k ∥ = ∥ f ( x k ) − f ( x k − 1 ) ∥ ≤ γ ∥ x k − x k − 1 ∥
迭代下去,我们就得到
∥ x k + 1 − x k ∥ ≤ γ ∥ x k − x k − 1 ∥ ≤ ⋯ ≤ γ k ∥ x 1 − x 0 ∥ \|x_{k+1}-x_k\|\leq \gamma \|x_k-x_{k-1}\|\leq\cdots\leq \gamma^k \|x_1-x_{0}\| ∥ x k + 1 − x k ∥ ≤ γ ∥ x k − x k − 1 ∥ ≤ ⋯ ≤ γ k ∥ x 1 − x 0 ∥
现在我们证明序列 { x k } k = 1 ∞ \{x_k\}_{k=1}^{\infty} { x k } k = 1 ∞ 是一个柯西列:
∥ x m − x n ∥ = ∥ x m − x m − 1 + x m − 1 − ⋯ − x n + 1 + x n + 1 − x n ∥ ≤ ∑ i = n m − 1 ∥ x i + 1 − x i ∥ ≤ ∑ i = n m − 1 γ i ∥ x 1 − x 0 ∥ ≤ γ n 1 − γ ∥ x 1 − x 0 ∥ . \begin{aligned}
\|x_m-x_n\| &= \|x_m-x_{m-1}+x_{m-1}-\cdots-x_{n+1}+x_{n+1}-x_n\|\\
&\leq \sum_{i=n}^{m-1}\|x_{i+1}-x_i\|\\
&\leq \sum_{i=n}^{m-1}\gamma^i \|x_{1}-x_0\|\\
&\leq \frac{\gamma^n}{1-\gamma}\|x_{1}-x_0\|.
\end{aligned} ∥ x m − x n ∥ = ∥ x m − x m − 1 + x m − 1 − ⋯ − x n + 1 + x n + 1 − x n ∥ ≤ i = n ∑ m − 1 ∥ x i + 1 − x i ∥ ≤ i = n ∑ m − 1 γ i ∥ x 1 − x 0 ∥ ≤ 1 − γ γ n ∥ x 1 − x 0 ∥.
从而,序列 { x k } k = 1 ∞ \{x_k\}_{k=1}^{\infty} { x k } k = 1 ∞ 是一个柯西列, 因此也是一个收敛列。
接下来,我们证明 x ∗ = lim k → ∞ x k x^*=\lim_{k\to\infty}x_k x ∗ = lim k → ∞ x k 是 f ( x ) f(x) f ( x ) 的不动点,注意到
∥ f ( x k ) − x k ∥ = ∥ x k + 1 − x k ∥ ≤ γ k ∥ x 1 − x 0 ∥ → 0 , k → ∞ \|f(x_k)-x_k\| = \|x_{k+1}-x_k\|\leq \gamma^k\|x_1-x_0\| \to 0, k\to\infty ∥ f ( x k ) − x k ∥ = ∥ x k + 1 − x k ∥ ≤ γ k ∥ x 1 − x 0 ∥ → 0 , k → ∞
我们有 lim k → ∞ f ( x k ) = lim k → ∞ x k \lim_{k\to\infty}f(x_k)=\lim_{k\to\infty}x_k lim k → ∞ f ( x k ) = lim k → ∞ x k , 由于 contraction mapping 一定是连续的,因此我们就可以得到 f ( x ∗ ) = x ∗ f(x^*)=x^* f ( x ∗ ) = x ∗ .
然后,我们证明不动点唯一。假设还存在一个另外一个不动点 x ′ ≠ x ∗ x'\neq x^* x ′ = x ∗ 满足 f ( x ′ ) = x ′ f(x')=x' f ( x ′ ) = x ′ , 那么
∥ x ′ − x ∗ ∥ = ∥ f ( x ′ ) − f ( x ′ ) ∥ ≤ γ ∥ x ′ − x ∗ ∥ \|x'-x^*\| = \|f(x')-f(x')\| \leq \gamma \|x'-x^*\| ∥ x ′ − x ∗ ∥ = ∥ f ( x ′ ) − f ( x ′ ) ∥ ≤ γ ∥ x ′ − x ∗ ∥
由于 γ ∈ ( 0 , 1 ) \gamma\in(0,1) γ ∈ ( 0 , 1 ) , 因此上述等式当且仅当 ∥ x ′ − x ∗ ∥ = 0 \|x'-x^*\|=0 ∥ x ′ − x ∗ ∥ = 0 , 这与前面假设矛盾,因而不动点是唯一的
最后,我们证明 x k + 1 = f ( x k ) x_{k+1}=f(x_k) x k + 1 = f ( x k ) 这个算法的收敛速度为指数级,注意到
∥ x ∗ − x n ∥ = lim m → ∞ ] ∥ x m − x n ∥ ≤ γ n 1 − γ ∥ x 1 − x 0 ∥ \|x^*-x_n\| = \lim_{m\to\infty}]\|x_m-x_n\| \leq \frac{\gamma^n}{1-\gamma}\|x_1-x_0\| ∥ x ∗ − x n ∥ = m → ∞ lim ] ∥ x m − x n ∥ ≤ 1 − γ γ n ∥ x 1 − x 0 ∥
因为 γ < 1 \gamma <1 γ < 1 , 因此收敛速度为指数级
最大似然估计,即MLE (maximum likelihood estimation), 是一个估计参数分布的方法,其核心思想是:模型的参数,应该让观察样本出现的概率最大。
假设我们有一个参数分布 p ( x ∣ θ ) p(x\mid \theta) p ( x ∣ θ ) , 其中 θ \theta θ 是参数,如正态分布中的均值和方差。我们从p ( x ∣ θ ) p(x\mid \theta) p ( x ∣ θ ) 进行采样得到 i . i . d . i.i.d. i . i . d . 的数据 X = { x 1 , … , x n } X=\{x_1,\dots,x_n\} X = { x 1 , … , x n } .
似然函数 (likelihood function) 定义为给定数据 X X X 的联合分布,即:
L ( θ ∣ X ) = P ( X ∣ θ ) \mathcal{L}(\theta\mid X) = P(X\mid \theta) L ( θ ∣ X ) = P ( X ∣ θ )
由于 X = { x 1 , … , x n } X=\{x_1,\dots,x_n\} X = { x 1 , … , x n } 是 i . i . d . i.i.d. i . i . d . , 因此,我们可以将上式改写为:
L ( θ ∣ X ) = ∏ i = 1 n p ( x i ∣ θ ) \mathcal{L}(\theta\mid X) = \prod_{i=1}^n p(x_i\mid \theta) L ( θ ∣ X ) = i = 1 ∏ n p ( x i ∣ θ )
这样我们的优化目标就是
θ M L E ∗ = arg max θ L ( θ ∣ X ) = arg max θ ∏ i = 1 n p ( x i ∣ θ ) = arg max θ log ∏ i = 1 n p ( x i ∣ θ ) = arg max θ ∑ i = 1 n log p ( x i ∣ θ ) \begin{aligned}
\theta_{MLE}^* &= \arg\max_{\theta} \mathcal{L}(\theta\mid X)\\
&= \arg\max_{\theta} \prod_{i=1}^n p(x_i\mid \theta)\\
&= \arg\max_{\theta} \log\prod_{i=1}^n p(x_i\mid \theta)\\
&=\arg\max_{\theta} \sum_{i=1}^n \log p(x_i\mid \theta)\\
\end{aligned} θ M L E ∗ = arg θ max L ( θ ∣ X ) = arg θ max i = 1 ∏ n p ( x i ∣ θ ) = arg θ max log i = 1 ∏ n p ( x i ∣ θ ) = arg θ max i = 1 ∑ n log p ( x i ∣ θ )
即
θ M L E ∗ = arg max θ ∑ i = 1 n log p ( x i ∣ θ ) \theta_{MLE}^* = \arg\max_{\theta} \sum_{i=1}^n \log p(x_i\mid \theta) θ M L E ∗ = arg θ max i = 1 ∑ n log p ( x i ∣ θ )
KL divergence 用于衡量概率分布 Q ( x ) Q(x) Q ( x ) 到概率分布 P ( x ) P(x) P ( x ) 的不同程度,我们可以将其理解为:如果我们用 Q ( x ) Q(x) Q ( x ) 来替换 P ( x ) P(x) P ( x ) , 会有多大的信息损失?
连续概率分布的KL divergence的定义如下
D K L ( P ∣ ∣ Q ) = ∫ P ( x ) log ( P ( x ) Q ( x ) ) d x D_{KL}(P\mid\mid Q) =\int P(x)\log\left(\frac{P(x)}{Q(x)}\right)dx D K L ( P ∣∣ Q ) = ∫ P ( x ) log ( Q ( x ) P ( x ) ) d x
离散概率分布的KL divergence定义如下
D K L ( P ∣ ∣ Q ) = ∑ x P ( x ) log ( P ( x ) Q ( x ) ) D_{KL}(P\mid\mid Q) = \sum_{x} P(x)\log\left(\frac{P(x)}{Q(x)}\right) D K L ( P ∣∣ Q ) = x ∑ P ( x ) log ( Q ( x ) P ( x ) )
KL divergence有两个关键性质:
非负性:D K L ( P ∣ ∣ Q ) ≥ 0 D_{KL}(P\mid\mid Q)\geq0 D K L ( P ∣∣ Q ) ≥ 0 , 且 D K L ( P ∣ ∣ Q ) = 0 D_{KL}(P\mid\mid Q)=0 D K L ( P ∣∣ Q ) = 0 当且仅当 P ( x ) = Q ( x ) P(x)=Q(x) P ( x ) = Q ( x ) 对任意 x x x 成立
非对称性: 一般情况下,D K L ( P ∣ ∣ Q ) ≠ D K L ( Q ∣ ∣ P ) D_{KL}(P\mid\mid Q)\neq D_{KL}(Q\mid\mid P) D K L ( P ∣∣ Q ) = D K L ( Q ∣∣ P ) .
我们假设 p d a t a ( x ) p_{data}(x) p d a t a ( x ) 是数据X X X 的真实分布, 我们现在需要找到合适的参数 θ \theta θ 以及其对应的分布 p ( x ∣ θ ) p(x\mid \theta) p ( x ∣ θ ) 来近似 p d a t a ( x ) p_{data}(x) p d a t a ( x ) , 此时我们可以用KL Divergence作为我们的目标函数,即
θ K L = arg min θ D K L ( p d a t a ( x ) ∣ ∣ p ( x ∣ θ ) ) \theta_{KL} = \arg\min_{\theta}D_{KL}(p_{data}(x)\mid\mid p(x\mid \theta)) θ K L = arg θ min D K L ( p d a t a ( x ) ∣∣ p ( x ∣ θ ))
我们将上面的式子进行展开得到
θ K L ∗ = arg min θ D K L ( p d a t a ( x ) ∣ ∣ p ( x ∣ θ ) ) = arg min θ ∫ p d a t a ( x ) p d a t a ( x ) p ( x ∣ θ ) d x = arg min θ ∫ p d a t a ( x ) log p d a t a ( x ) d x − ∫ p d a t a ( x ) log p ( x ∣ θ ) d x = arg min θ − ∫ p d a t a ( x ) log p ( x ∣ θ ) d x = arg max θ ∫ p d a t a ( x ) log p ( x ∣ θ ) d x \begin{aligned}
\theta_{KL}^* &= \arg\min_{\theta}D_{KL}(p_{data}(x)\mid\mid p(x\mid \theta))\\
&= \arg\min_{\theta} \int p_{data}(x)\frac{p_{data}(x)}{p(x\mid \theta)} dx\\
&= \arg\min_{\theta}\int p_{data}(x)\log p_{data}(x) dx - \int p_{data}(x)\log p(x\mid \theta)dx \\
&= \arg\min_{\theta} - \int p_{data}(x)\log p(x\mid \theta)dx \\
&= \arg\max_{\theta} \int p_{data}(x)\log p(x\mid \theta)dx
\end{aligned} θ K L ∗ = arg θ min D K L ( p d a t a ( x ) ∣∣ p ( x ∣ θ )) = arg θ min ∫ p d a t a ( x ) p ( x ∣ θ ) p d a t a ( x ) d x = arg θ min ∫ p d a t a ( x ) log p d a t a ( x ) d x − ∫ p d a t a ( x ) log p ( x ∣ θ ) d x = arg θ min − ∫ p d a t a ( x ) log p ( x ∣ θ ) d x = arg θ max ∫ p d a t a ( x ) log p ( x ∣ θ ) d x
实际上,真实的数据分布 p d a t a ( x ) p_{data}(x) p d a t a ( x ) 是未知的,我们只有从 p d a t a ( x ) p_{data}(x) p d a t a ( x ) 采样得到的一批数据 X = { x 1 , … , x n } ∼ p d a t a ( x ) X=\{x_1,\dots,x_n\}\sim p_{data}(x) X = { x 1 , … , x n } ∼ p d a t a ( x ) .
基于大数定律,我们有
1 n ∑ i = 1 n log p ( θ i ∣ θ ) = E x ∼ p d a t a [ log p ( x ∣ θ ) ] = ∫ p d a t a ( x ) log p ( x ∣ θ ) d x , n → ∞ \frac{1}{n}\sum_{i=1}^n\log p(\theta_i\mid \theta)=\mathbb{E}_{x\sim p_{data}}[\log p(x\mid \theta)] = \int p_{data}(x)\log p(x\mid \theta)dx, n\to \infty n 1 i = 1 ∑ n log p ( θ i ∣ θ ) = E x ∼ p d a t a [ log p ( x ∣ θ )] = ∫ p d a t a ( x ) log p ( x ∣ θ ) d x , n → ∞
这样,最大似然估计就与最小化KL divergence构建起了联系:
θ M L E ∗ = arg max θ ∑ i = 1 n log p ( x i ∣ θ ) = arg max θ ∫ p d a t a ( x ) log p ( x ∣ θ ) d x = θ K L ∗ , n → ∞ . \begin{aligned}
\theta_{MLE}^*&=\arg\max_{\theta} \sum_{i=1}^n \log p(x_i\mid \theta)\\
&= \arg\max_{\theta} \int p_{data}(x)\log p(x\mid \theta)dx\\
&= \theta_{KL}^*, n\to\infty.
\end{aligned} θ M L E ∗ = arg θ max i = 1 ∑ n log p ( x i ∣ θ ) = arg θ max ∫ p d a t a ( x ) log p ( x ∣ θ ) d x = θ K L ∗ , n → ∞.
也就是说,当采样样本足够多的时候,最大似然估计和最小KL divergence是等价的。
A simple note to understand Sigmoid Loss in SigLip 1 . Supported by DeepSeek2
Suppose we want to solve the binary classification problem, with label y ∈ { 0 , 1 } y\in\{0, 1\} y ∈ { 0 , 1 } , a common option is to use binary cross entropy loss:
L ( x , y ) = − [ y log ( σ ( z ) ) + ( 1 − y ) log ( 1 − σ ( z ) ) ] \mathcal{L}(x, y) = -[y\log (\sigma(z)) + (1-y)\log (1-\sigma(z))] L ( x , y ) = − [ y log ( σ ( z )) + ( 1 − y ) log ( 1 − σ ( z ))]
where z = f θ ( x ) z=f_\theta(x) z = f θ ( x ) is the logits predicted by our model f θ f_\theta f θ , and σ \sigma σ is the sigmoid function:
σ ( z ) : = 1 1 + e − z \sigma(z) := \frac{1}{1 + e^{-z}} σ ( z ) := 1 + e − z 1
Let σ ( ⋅ ) \sigma(\cdot) σ ( ⋅ ) be the sigmoid function, then we have:
σ ( − z ) = 1 1 + e z = e − z 1 + e − z = 1 − 1 1 + e − z = 1 − σ ( z ) \sigma(-z) = \frac{1}{1 + e^{z}} = \frac{e^{-z}}{1 + e^{-z}} = 1 - \frac{1}{1 + e^{-z}} = 1- \sigma(z) σ ( − z ) = 1 + e z 1 = 1 + e − z e − z = 1 − 1 + e − z 1 = 1 − σ ( z )
Now we substitute σ ( − z ) = 1 − σ ( z ) \sigma(-z)=1-\sigma(z) σ ( − z ) = 1 − σ ( z ) into the loss function, we obtain:
L ( x , y ) = − [ y log ( σ ( z ) ) + ( 1 − y ) log ( σ ( − z ) ) ] \mathcal{L}(x, y) = -[y\log (\sigma(z)) + (1-y)\log (\sigma(-z))] L ( x , y ) = − [ y log ( σ ( z )) + ( 1 − y ) log ( σ ( − z ))]
Note that y ∈ { 0 , 1 } y\in\{0, 1\} y ∈ { 0 , 1 } thus for each instance, there are two cases:
If y = 0 y=0 y = 0 , then L ( x , y ) = − log ( σ ( − z ) ) \mathcal{L}(x, y) =-\log (\sigma(-z)) L ( x , y ) = − log ( σ ( − z ))
If y = 1 y=1 y = 1 , then L ( x , y ) = − log ( σ ( z ) ) \mathcal{L}(x, y) =-\log (\sigma(z)) L ( x , y ) = − log ( σ ( z ))
Now we want to use a unified expression to express these two cases. Note that this requires fitting a curve that passes two points ( 0 , − 1 ) (0, -1) ( 0 , − 1 ) and ( 1 , 1 ) (1, 1) ( 1 , 1 ) . The simplest curve is a straight line y = 2 x − 1 y=2x-1 y = 2 x − 1 . So, we can further simplify the loss expression into:
L ( x , y ) = − log [ σ ( ( 2 y − 1 ) z ) ] \mathcal{L}(x, y) = -\log\left[\sigma((2y-1)z)\right] L ( x , y ) = − log [ σ (( 2 y − 1 ) z ) ]
Now we recall the sigmoid loss in SigLip:
L ( { x , y } i = 1 N ) = − 1 N ∑ i = 1 N ∑ j = 1 N log 1 1 + exp [ z i j ( − t x i ⋅ y j + b ) ] \mathcal{L}(\{\bm{x}, \bm{y}\}_{i=1}^N)=-\frac{1}{N}\sum_{i=1}^N\sum_{j=1}^N\log \frac{1}{1+\exp\left[z_{ij}(-t\bm{x}_i\cdot \bm{y_j}+b)\right]} L ({ x , y } i = 1 N ) = − N 1 ∑ i = 1 N ∑ j = 1 N log 1 + e x p [ z ij ( − t x i ⋅ y j + b ) ] 1
where t , b t, b t , b are learnable parameters, and z i j = 1 z_{ij}=1 z ij = 1 if i = j i=j i = j and z i j = − 1 z_{ij}=-1 z ij = − 1 otherwise.
To understand Sigmoid loss, notice that z i j = 2 I i = j − 1 z_{ij}=2\mathbb{I}_{i=j}-1 z ij = 2 I i = j − 1 , which exactly matches the form we derived earlier.
More stable: avoids log 0 \log 0 log 0 .
More efficient: Compute Sigmoid once.
More Precise: one line of code without condition checking.
softmax 函数用于将 K K K 个实数转换为一个 K K K 维概率分布。其具体做法是先对所有元素指数化,即求 e x e^x e x , 然后每个元素除以所有指数的和。即
s o f t m a x : R K → ( 0 , 1 ) K s o f t m a x ( z ) = [ e z 1 ∑ j = 1 K e z j , … , e z K ∑ j = 1 K e z j ] \begin{aligned}
\mathrm{softmax}:\mathbb{R}^K&\to (0,1)^K\\
\mathrm{softmax}(\mathbf{z}) &=\left[\frac{e^{z_1}}{\sum_{j=1}^Ke^{z_j}},\dots,\frac{e^{z_K}}{\sum_{j=1}^Ke^{z_j}}\right]
\end{aligned} softmax : R K softmax ( z ) → ( 0 , 1 ) K = [ ∑ j = 1 K e z j e z 1 , … , ∑ j = 1 K e z j e z K ]
softmax 的第一个性质是 shift invariance, 即
s o f t m a x ( z + c ) = s o f t m a x ( z ) \mathrm{softmax}(\mathbf{z}+c) = \mathrm{softmax}(\mathbf{z}) softmax ( z + c ) = softmax ( z )
证明比较容易:
s o f t m a x ( z + c ) i = e z i + c ∑ j = 1 K e z j + c = e c e z i e c ∑ j = 1 K e z j = e z i ∑ j = 1 K e z j = s o f t m a x ( z ) i , i = 1 , … , K \mathrm{softmax}(\mathbf{z}+c)_i = \frac{e^{z_i+c}}{\sum_{j=1}^Ke^{z_j+c}} = \frac{e^ce^{z_i}}{e^c\sum_{j=1}^Ke^{z_j}} = \frac{e^{z_i}}{\sum_{j=1}^Ke^{z_j}}=\mathrm{softmax}(\mathbf{z})_i,\ i=1,\dots,K softmax ( z + c ) i = ∑ j = 1 K e z j + c e z i + c = e c ∑ j = 1 K e z j e c e z i = ∑ j = 1 K e z j e z i = softmax ( z ) i , i = 1 , … , K
向量输入下 Softmax 函数的 Jacobian 矩阵推导
设输入为向量 z = [ z 1 , z 2 , … , z d ] ⊤ ∈ R d \mathbf{z} = [z_1, z_2, \dots, z_d]^\top \in \mathbb{R}^d z = [ z 1 , z 2 , … , z d ] ⊤ ∈ R d ,Softmax 函数的输出为向量 a = [ a 1 , a 2 , … , a d ] ⊤ ∈ R d \mathbf{a} = [a_1, a_2, \dots, a_d]^\top \in \mathbb{R}^d a = [ a 1 , a 2 , … , a d ] ⊤ ∈ R d ,其中每个元素定义为:
a j = softmax ( z ) j = e z j ∑ k = 1 d e z k a_j = \text{softmax}(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^d e^{z_k}} a j = softmax ( z ) j = ∑ k = 1 d e z k e z j
记分母(归一化因子)为 S = ∑ k = 1 d e z k S = \sum_{k=1}^d e^{z_k} S = ∑ k = 1 d e z k ,则 a j = e z j / S a_j = e^{z_j}/S a j = e z j / S .
我们分两种情况计算 ∂ a j ∂ z k \frac{\partial a_j}{\partial z_k} ∂ z k ∂ a j :
当 j = k j = k j = k 时, 此时求 a j a_j a j 对自身输入 z j z_j z j 的偏导数:
∂ a j ∂ z j = ∂ ∂ z j ( e z j S ) = e z j S − e z j e z j S 2 = e z j S ( 1 − e z j S ) = a j ( 1 − a j ) \frac{\partial a_j}{\partial z_j} = \frac{\partial}{\partial z_j} \left( \frac{e^{z_j}}{S} \right) = \frac{e^{z_j}S-e^{z_j}e^{z_j}}{S^2}=\frac{e^{z_j}}{S}\left(1-\frac{e^{z_j}}{S}\right)=a_j(1-a_j) ∂ z j ∂ a j = ∂ z j ∂ ( S e z j ) = S 2 e z j S − e z j e z j = S e z j ( 1 − S e z j ) = a j ( 1 − a j )
当 j ≠ k j \neq k j = k 时, 此时求 a j a_j a j 对输入 z k z_k z k 的偏导数有:
∂ a j ∂ z k = ∂ ∂ z k ( e z j S ) = 0 ⋅ S − e z j e z k S 2 = − e z j e z k S = − a j a j \frac{\partial a_j}{\partial z_k} = \frac{\partial}{\partial z_k} \left( \frac{e^{z_j}}{S} \right) = \frac{0\cdot S-e^{z_j}e^{z_k}}{S^2}=-\frac{e^{z_j}e^{z_k}}{S}=-a_ja_j ∂ z k ∂ a j = ∂ z k ∂ ( S e z j ) = S 2 0 ⋅ S − e z j e z k = − S e z j e z k = − a j a j
综合以上两种情况,Jacobian 矩阵 J \mathbf{J} J 可表示为:
J = diag ( a ) − a a ⊤ \mathbf{J} = \text{diag}(\mathbf{a}) - \mathbf{a} \mathbf{a}^\top J = diag ( a ) − a a ⊤
softmax 是 argmax 的 smooth approximation, 所以实际上 softmax 指的是 “soft argmax”. 为了证明这一点,我们首先定义如下函数
s o f t m a x ( z ; τ ) = s o f t m a x ( z / τ ) = [ e z 1 / τ ∑ j = 1 K e z j / τ , … , e z K / τ ∑ j = 1 K e z j / τ ] \mathrm{softmax}(\mathbf{z};\tau) =\mathrm{softmax}(\mathbf{z}/\tau)=\left[\frac{e^{z_1/\tau}}{\sum_{j=1}^Ke^{z_j/\tau}},\dots,\frac{e^{z_K/\tau}}{\sum_{j=1}^Ke^{z_j/\tau}}\right] softmax ( z ; τ ) = softmax ( z / τ ) = [ ∑ j = 1 K e z j / τ e z 1 / τ , … , ∑ j = 1 K e z j / τ e z K / τ ]
易知, s o f t m a x ( z ) = s o f t m a x ( z ; 1 ) \mathrm{softmax}(\mathbf{z})=\mathrm{softmax}(\mathbf{z};1) softmax ( z ) = softmax ( z ; 1 ) . 并且,s o f t m a x \mathrm{softmax} softmax 还是一个光滑函数
我们定义 smooth approximation 为
Definition
如果 lim τ → 0 + s o f t m a x ( z ; τ ) = 1 arg max ( z ) \lim_{\tau\to0^+}\mathrm{softmax}(\mathbf{z};\tau)=\mathbb{1}_{\arg\max(\mathbf{z})} lim τ → 0 + softmax ( z ; τ ) = 1 a r g m a x ( z ) , 则我们说 s o f t m a x ( ⋅ ; τ ) \mathrm{softmax}(\cdot;\tau) softmax ( ⋅ ; τ ) 是 arg max \arg\max arg max 的光滑近似,特别地,s o f t m a x ( ⋅ ) \mathrm{softmax}(\cdot) softmax ( ⋅ ) 是 arg max \arg\max arg max 的光滑近似。
这里 arg max ( z ) = arg max k z k \arg\max(\mathbf{z})=\arg\max_k z_k arg max ( z ) = arg max k z k 是最大值的索引, 1 ∈ { 0 , 1 } K \mathbb{1}\in\{0,1\}^K 1 ∈ { 0 , 1 } K 是示性函数 (indicator function), 即 1 arg max ( z ) [ i ] = 1 \mathbb{1}_{\arg\max(\mathbf{z})}[i]=1 1 a r g m a x ( z ) [ i ] = 1 当且仅当 z i = max j z j z_i=\max_jz_j z i = max j z j .
我们下面来进行证明。我们不妨假设最大值唯一,其 index 为 m m m , 即 z m = max i z i z_m = \max_i z_i z m = max i z i . 由前面的性质,我们有:
s o f t m a x ( z ; τ ) = s o f t m a x ( z − z m ; τ ) = [ e ( z 1 − z m ) / τ ∑ j = 1 K e ( z j − z m ) / τ , … , e ( z K − z m ) / τ ∑ j = 1 K e ( z j − z m ) / τ ] \mathrm{softmax}(\mathbf{z};\tau) = \mathrm{softmax}(\mathbf{z}-z_m;\tau) =\left[\frac{e^{(z_1-z_m)/\tau}}{\sum_{j=1}^Ke^{(z_j-z_m)/\tau}},\dots,\frac{e^{(z_K-z_m)/\tau}}{\sum_{j=1}^Ke^{(z_j-z_m)/\tau}}\right] softmax ( z ; τ ) = softmax ( z − z m ; τ ) = [ ∑ j = 1 K e ( z j − z m ) / τ e ( z 1 − z m ) / τ , … , ∑ j = 1 K e ( z j − z m ) / τ e ( z K − z m ) / τ ]
此时,我们有
lim τ → 0 + s o f t m a x ( z ; τ ) i = { 1 , if i = m 0 , otherwise \lim_{\tau\to0^+}\mathrm{softmax}(\mathbf{z};\tau)_i = \begin{cases}
1, &\text{if }i = m\\
0, &\text{otherwise}
\end{cases} τ → 0 + lim softmax ( z ; τ ) i = { 1 , 0 , if i = m otherwise
当最大值不唯一的时候,我们记 I = { i ∈ [ K ] ∣ z i = max j z j } \mathcal{I} = \{i\in[K]\mid z_i=\max_j z_j\} I = { i ∈ [ K ] ∣ z i = max j z j } , 与上面方法类似,最终 s o f t m a x ( ⋅ ; τ ) \mathrm{softmax}(\cdot;\tau) softmax ( ⋅ ; τ ) 的结果为
lim τ → 0 + s o f t m a x ( z ; τ ) i = { 1 / ∣ I ∣ , if i ∈ I 0 , otherwise \lim_{\tau\to0^+}\mathrm{softmax}(\mathbf{z};\tau)_i = \begin{cases}
1/|\mathcal{I}|, &\text{if }i \in \mathcal{I}\\
0, &\text{otherwise}
\end{cases} τ → 0 + lim softmax ( z ; τ ) i = { 1/∣ I ∣ , 0 , if i ∈ I otherwise
因此,我们就证明了 softmax 是 argmax 函数的 smooth approximation.
我们前面介绍了 s o f t m a x ( z ; τ ) \mathrm{softmax}(\mathbf{z};\tau) softmax ( z ; τ ) 函数,这里的 τ \tau τ 实际上被称为温度 (temperature), 它控制了输入的 variance, T T T 越大,输入的 variance 越低,输出就倾向于均匀分布,而 T T T 越小,则说明输入的 variance 越高,输出就倾向于 one-hot 分布。
我们前面已经证明了后者,现在我们来证明一下前者,证明思路也很简单,T → + ∞ T\to+\infty T → + ∞ 时,e x / T → 1 e^{x/T}\to 1 e x / T → 1 , 因而
lim τ → + ∞ s o f t m a x ( z ; τ ) i = 1 K , i = 1 , … , K \lim_{\tau\to+\infty}\mathrm{softmax}(\mathbf{z};\tau)_i =\frac1K,\ i=1,\dots,K τ → + ∞ lim softmax ( z ; τ ) i = K 1 , i = 1 , … , K
下面是可视化的代码以及结果
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline
def softmax (x):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
num_elements = 15
indices = np.arange(num_elements)
logits = np.linspace( - 3.5 , 3.5 , num_elements)
scales = [ 0.01 , 0.1 , 1.0 , 5.0 , 10.0 , 100.0 ]
plt.figure( figsize = ( 10 , 6 ))
for s in scales:
probs = softmax(logits * s)
x_smooth = np.linspace(indices.min(), indices.max(), 300 )
spl = make_interp_spline(indices, probs, k = 3 )
y_smooth = np.clip(spl(x_smooth), 0 , None ) # Clip to ensure no negative artifacts
plt.plot(x_smooth, y_smooth, label = f 'Scale = { s } ' , linewidth = 2 )
uniform_prob = 1.0 / len (indices)
plt.axhline( y = uniform_prob, color = 'black' , linestyle = ':' , alpha = 0.6 , label = f 'Uniform distribution' )
plt.xticks(indices)
plt.xlabel( 'Logit Index' , fontsize = 12 )
plt.ylabel( 'Softmax Probability' , fontsize = 12 )
plt.title( 'Impact of Variance Scaling on Softmax Distribution' , fontsize = 14 )
plt.legend( title = "Variance Scale" )
plt.grid( True , linestyle = '--' , alpha = 0.5 )
plt.tight_layout()
plt.show()
可以看到,当 variance 比较小的时候,输出的分布接近于均匀分布,而 variance 越大,输出的分布越接近 One-hot 分布。
在 attention 的计算过程中,我们也有 softmax 函数,为了在 softmax 过程中避免 variance 的影响,现在会在计算 softmax 之前加入 normalization layer 来提前进行归一化。见 QK-norm .
由于 e x e^x e x 在实际计算时,非常容易溢出,因此在实现的时候,我们往往会考虑其数值稳定性。实际上,现在的 softmax 函数基本由 logsumexp 实现,logsumexp 函数定义如下
l o g s u m e x p ( z ) = log ( ∑ i = 1 K e z i ) \mathrm{logsumexp}(\mathbf{z}) = \log \left(\sum_{i=1}^K e^{z_i}\right) logsumexp ( z ) = log ( i = 1 ∑ K e z i )
softmax 函数与 logsumexp 函数的关系如下
s o f t m a x ( z ) = exp log ( e z ∑ j = 1 K e z j ) = exp ( z − log ( ∑ i = 1 K e z i ) ) = exp ( z − l o g s u m e x p ( z ) ) \begin{aligned}
\mathrm{softmax}(\mathbf{z}) &=\exp\log\left(\frac{e^{\mathbf{z}}}{\sum_{j=1}^Ke^{z_j}}\right)\\
&= \exp\left(\mathbf{z} - \log\left(\sum_{i=1}^K e^{z_i}\right)\right)\\
&= \exp(\mathbf{z} - \mathrm{logsumexp}(\mathbf{z}))
\end{aligned} softmax ( z ) = exp log ( ∑ j = 1 K e z j e z ) = exp ( z − log ( i = 1 ∑ K e z i ) ) = exp ( z − logsumexp ( z ))
考虑前面提到的 e x e^x e x 数值溢出的问题,我们的输入会先经过 shift, 减掉最大值。此时我们有
s o f t m a x ( z ) = s o f t m a x ( z − c ) = exp ( ( z − c ) − l o g s u m e x p ( z − c ) ) \mathrm{softmax}(\mathbf{z}) = \mathrm{softmax}(\mathbf{z}-c) = \exp((\mathbf{z}-c) - \mathrm{logsumexp}(\mathbf{z}-c)) softmax ( z ) = softmax ( z − c ) = exp (( z − c ) − logsumexp ( z − c ))
这里我们使用了前面推导出来的 shift invariance 性质。对应的代码实现如下:
def softmax (x: torch.Tensor, dim: int = - 1 ) -> torch.Tensor:
x = x - x.max( dim = dim, keepdim = True ).values
log_sum_exp = torch.log(torch.sum(torch.exp(x), dim = dim, keepdim = True ))
return torch.exp(x - log_sum_exp)
TODO
注意到我们在计算 softmax 时,需要加载 z \mathbf{z} z 的全部信息,如果 z \mathbf{z} z 非常大的话,会产生频繁的内存读写进而影响整体效率。因此 flash attention 中提出了 online softmax 算法来减少内存访问开销。
其具体做法是假设我们的输入被分为若干个 block, 即 z = [ z 1 ; … , z n ] ∈ R K \mathbf{z}=[\mathbf{z}^1;\dots,\mathbf{z}^n]\in\mathbb{R}^K z = [ z 1 ; … , z n ] ∈ R K , 这里 z i ∈ R K / n \mathbf{z}^i\in\mathbb{R}^{K/n} z i ∈ R K / n (K m o d n = 0 K\mod n=0 K mod n = 0 ).
对于 z ∈ R K \mathbf{z}\in\mathbb{R}^K z ∈ R K , flash attention 定义如下结果
m ( z ) = max i z i , f ( z ) = [ e z 1 − m ( z ) , … , e z K − m ( z ) ] , ℓ ( z ) = ∑ i f ( z ) i , s o f t m a x ( z ) = f ( z ) ℓ ( z ) m(\mathbf{z}) = \max_i z_i,\ f(\mathbf{z}) = [e^{z_1-m(\mathbf{z})},\dots,e^{z_K-m(\mathbf{z})}], \ \ell(\mathbf{z})=\sum_if(z)_i, \ \mathrm{softmax}(\mathbf{z}) = \frac{f(\mathbf{z})}{\ell(\mathbf{z})} m ( z ) = i max z i , f ( z ) = [ e z 1 − m ( z ) , … , e z K − m ( z ) ] , ℓ ( z ) = i ∑ f ( z ) i , softmax ( z ) = ℓ ( z ) f ( z )
对于 z = [ z 1 ; … , z n ] ∈ R K \mathbf{z}=[\mathbf{z}^1;\dots,\mathbf{z}^n]\in\mathbb{R}^K z = [ z 1 ; … , z n ] ∈ R K , 我们现在的计算方式为
m i ( z ) = max ( [ z 1 ; … ; z i ] ) = max ( m i − 1 ( z ) , m ( z i ) ) ℓ i ( z ) = ∑ j = 1 i f ( z j ) = exp ( m i − 1 ( z ) − m i ( z ) ) ℓ ( z i − 1 ) + exp ( z i − m i ( z ) ) \begin{aligned}
m_i(\mathbf{z}) &= \max([\mathbf{z}^1;\dots;\mathbf{z}^i]) = \max(m_{i-1}(\mathbf{z}),m(\mathbf{z}^i))\\
\ell_i(\mathbf{z}) &= \sum_{j=1}^if(\mathbf{z}^j) = \exp(m_{i-1}(\mathbf{z}) - m_i(\mathbf{z}))\ell(\mathbf{z}^{i-1}) + \exp(\mathbf{z}^i-m_i(\mathbf{z}))
\end{aligned} m i ( z ) ℓ i ( z ) = max ([ z 1 ; … ; z i ]) = max ( m i − 1 ( z ) , m ( z i )) = j = 1 ∑ i f ( z j ) = exp ( m i − 1 ( z ) − m i ( z )) ℓ ( z i − 1 ) + exp ( z i − m i ( z ))
因此,如果我们额外记录 m ( x ) m(x) m ( x ) 以及 ℓ ( x ) \ell(x) ℓ ( x ) 这两个量,那么我们可以每次仅计算 softmax 的一个 block. 计算完毕之后,m i ( z ) m_i(\mathbf{z}) m i ( z ) 和 ℓ i ( z ) \ell_i(\mathbf{z}) ℓ i ( z ) 就分别代表了 global max 和 global denominator.
我们回顾了机器学习中 softmax function 的基本定义与性质