Fix Point Theorem

不动点定理

Author

Published

2026-03-09 17:16:02+08:00

Fix Point Theorem, 即不动点定理,是泛函分析中的基本工具,被广泛应用于非线性函数的分析。

在介绍不动点定理之前,我们先介绍两个概念

首先是不动点的概念。

Definition

对于函数 f:RnRnf:\mathbb{R}^n\to\mathbb{R}^n, 如果一个点 xRnx^*\in\mathbb{R}^n 满足

f(x)=xf(x^*)=x^*

则我们称 xx^* 是函数 ff 的不动点。

接下来是 contraction mapping 的概念

Definition

对于函数 f:RnRnf:\mathbb{R}^n\to\mathbb{R}^n, 如果存在 γ(0,1)\gamma\in(0,1) 满足

f(x1)f(x2)γx1x2, x1,x2Rn\|f(x_1)-f(x_2)\| \leq \gamma \|x_1-x_2\|,\forall\ x_1,x_2\in\mathbb{R}^n

则我们称 ff 是一个 contraction mapping. 这里 \|\cdot\| 是一个 matrix norm.

接下来,我们介绍不动点定理

Theorem

给定 equation x=f(x)x=f(x), 其中 f:RnRnf:\mathbb{R}^n\to\mathbb{R}^n, 如果 ff 是一个 contraction mapping, 则 ff 具有如下性质

  1. Existence: 存在 fixed point xRnx^*\in\mathbb{R}^n 满足 f(x)=xf(x^*)=x^*.
  2. Uniqueness: fixed point xx^* 唯一。
  3. Algorithm: 对任意 x0Rnx^0\in\mathbb{R}^n, 使用迭代算法 xk+1=f(xk)x_{k+1}=f(x_k) 产生的序列 {xk}k=0\{x_k\}_{k=0}^{\infty} 收敛到 fixed point xx*, 且收敛速度为指数级。

证明需要用到柯西列的概念。

Definition

一个序列 x1,x2,x_1,x_2,\dots 被称为柯西列 (Cauchy sequence) 当且仅当对任意 ϵ>0\epsilon>0, 都存在 N>0N>0, 使得

xmxn<ϵ,m,n>N\|x_m-x_n\| <\epsilon,\forall m, n>N

柯西列的一个重要性质为柯西列一定是收敛列。

证明

我们首先证明由 xk=f(xk=1)x_k=f(x_{k=1}) 产生的序列 {xk}k=1\{x_k\}_{k=1}^{\infty} 是收敛的,我们通过证明序列 {xk}k=1\{x_k\}_{k=1}^{\infty} 是一个柯西列来证明这一点。

注意到 ff 是一个 contraction mapping, 因此

xk+1xk=f(xk)f(xk1)γxkxk1\|x_{k+1}-x_k\| = \|f(x_k)-f(x_{k-1})\|\leq \gamma \|x_k-x_{k-1}\|

迭代下去,我们就得到

xk+1xkγxkxk1γkx1x0\|x_{k+1}-x_k\|\leq \gamma \|x_k-x_{k-1}\|\leq\cdots\leq \gamma^k \|x_1-x_{0}\|

现在我们证明序列 {xk}k=1\{x_k\}_{k=1}^{\infty} 是一个柯西列:

xmxn=xmxm1+xm1xn+1+xn+1xni=nm1xi+1xii=nm1γix1x0γn1γx1x0.\begin{aligned} \|x_m-x_n\| &= \|x_m-x_{m-1}+x_{m-1}-\cdots-x_{n+1}+x_{n+1}-x_n\|\\ &\leq \sum_{i=n}^{m-1}\|x_{i+1}-x_i\|\\ &\leq \sum_{i=n}^{m-1}\gamma^i \|x_{1}-x_0\|\\ &\leq \frac{\gamma^n}{1-\gamma}\|x_{1}-x_0\|. \end{aligned}

从而,序列 {xk}k=1\{x_k\}_{k=1}^{\infty} 是一个柯西列, 因此也是一个收敛列。

接下来,我们证明 x=limkxkx^*=\lim_{k\to\infty}x_kf(x)f(x) 的不动点,注意到

f(xk)xk=xk+1xkγkx1x00,k\|f(x_k)-x_k\| = \|x_{k+1}-x_k\|\leq \gamma^k\|x_1-x_0\| \to 0, k\to\infty

我们有 limkf(xk)=limkxk\lim_{k\to\infty}f(x_k)=\lim_{k\to\infty}x_k , 由于 contraction mapping 一定是连续的,因此我们就可以得到 f(x)=xf(x^*)=x^*.

然后,我们证明不动点唯一。假设还存在一个另外一个不动点 xxx'\neq x^* 满足 f(x)=xf(x')=x', 那么

xx=f(x)f(x)γxx\|x'-x^*\| = \|f(x')-f(x')\| \leq \gamma \|x'-x^*\|

由于 γ(0,1)\gamma\in(0,1), 因此上述等式当且仅当 xx=0\|x'-x^*\|=0, 这与前面假设矛盾,因而不动点是唯一的

最后,我们证明 xk+1=f(xk)x_{k+1}=f(x_k) 这个算法的收敛速度为指数级,注意到

xxn=limm]xmxnγn1γx1x0\|x^*-x_n\| = \lim_{m\to\infty}]\|x_m-x_n\| \leq \frac{\gamma^n}{1-\gamma}\|x_1-x_0\|

因为 γ<1\gamma <1, 因此收敛速度为指数级

MLE

最大似然估计,即MLE (maximum likelihood estimation), 是一个估计参数分布的方法,其核心思想是:模型的参数,应该让观察样本出现的概率最大。

假设我们有一个参数分布 p(xθ)p(x\mid \theta), 其中 θ\theta 是参数,如正态分布中的均值和方差。我们从p(xθ)p(x\mid \theta)进行采样得到 i.i.d.i.i.d. 的数据 X={x1,,xn}X=\{x_1,\dots,x_n\}.

似然函数 (likelihood function) 定义为给定数据 XX 的联合分布,即:

L(θX)=P(Xθ)\mathcal{L}(\theta\mid X) = P(X\mid \theta)

由于 X={x1,,xn}X=\{x_1,\dots,x_n\}i.i.d.i.i.d., 因此,我们可以将上式改写为:

L(θX)=i=1np(xiθ)\mathcal{L}(\theta\mid X) = \prod_{i=1}^n p(x_i\mid \theta)

这样我们的优化目标就是

θMLE=argmaxθL(θX)=argmaxθi=1np(xiθ)=argmaxθlogi=1np(xiθ)=argmaxθi=1nlogp(xiθ)\begin{aligned} \theta_{MLE}^* &= \arg\max_{\theta} \mathcal{L}(\theta\mid X)\\ &= \arg\max_{\theta} \prod_{i=1}^n p(x_i\mid \theta)\\ &= \arg\max_{\theta} \log\prod_{i=1}^n p(x_i\mid \theta)\\ &=\arg\max_{\theta} \sum_{i=1}^n \log p(x_i\mid \theta)\\ \end{aligned}

θMLE=argmaxθi=1nlogp(xiθ)\theta_{MLE}^* = \arg\max_{\theta} \sum_{i=1}^n \log p(x_i\mid \theta)

KL divergence

KL divergence 用于衡量概率分布 Q(x)Q(x) 到概率分布 P(x)P(x) 的不同程度,我们可以将其理解为:如果我们用 Q(x)Q(x)来替换 P(x)P(x), 会有多大的信息损失?

连续概率分布的KL divergence的定义如下

DKL(PQ)=P(x)log(P(x)Q(x))dxD_{KL}(P\mid\mid Q) =\int P(x)\log\left(\frac{P(x)}{Q(x)}\right)dx

离散概率分布的KL divergence定义如下

DKL(PQ)=xP(x)log(P(x)Q(x))D_{KL}(P\mid\mid Q) = \sum_{x} P(x)\log\left(\frac{P(x)}{Q(x)}\right)

KL divergence有两个关键性质:

  1. 非负性:DKL(PQ)0D_{KL}(P\mid\mid Q)\geq0, 且 DKL(PQ)=0D_{KL}(P\mid\mid Q)=0 当且仅当 P(x)=Q(x)P(x)=Q(x) 对任意 xx成立
  2. 非对称性: 一般情况下,DKL(PQ)DKL(QP)D_{KL}(P\mid\mid Q)\neq D_{KL}(Q\mid\mid P).

MLE和KL Divergence的等价性

我们假设 pdata(x)p_{data}(x) 是数据XX的真实分布, 我们现在需要找到合适的参数 θ\theta 以及其对应的分布 p(xθ)p(x\mid \theta) 来近似 pdata(x)p_{data}(x), 此时我们可以用KL Divergence作为我们的目标函数,即

θKL=argminθDKL(pdata(x)p(xθ))\theta_{KL} = \arg\min_{\theta}D_{KL}(p_{data}(x)\mid\mid p(x\mid \theta))

我们将上面的式子进行展开得到

θKL=argminθDKL(pdata(x)p(xθ))=argminθpdata(x)pdata(x)p(xθ)dx=argminθpdata(x)logpdata(x)dxpdata(x)logp(xθ)dx=argminθpdata(x)logp(xθ)dx=argmaxθpdata(x)logp(xθ)dx\begin{aligned} \theta_{KL}^* &= \arg\min_{\theta}D_{KL}(p_{data}(x)\mid\mid p(x\mid \theta))\\ &= \arg\min_{\theta} \int p_{data}(x)\frac{p_{data}(x)}{p(x\mid \theta)} dx\\ &= \arg\min_{\theta}\int p_{data}(x)\log p_{data}(x) dx - \int p_{data}(x)\log p(x\mid \theta)dx \\ &= \arg\min_{\theta} - \int p_{data}(x)\log p(x\mid \theta)dx \\ &= \arg\max_{\theta} \int p_{data}(x)\log p(x\mid \theta)dx \end{aligned}

实际上,真实的数据分布 pdata(x)p_{data}(x) 是未知的,我们只有从 pdata(x)p_{data}(x) 采样得到的一批数据 X={x1,,xn}pdata(x)X=\{x_1,\dots,x_n\}\sim p_{data}(x).

基于大数定律,我们有

1ni=1nlogp(θiθ)=Expdata[logp(xθ)]=pdata(x)logp(xθ)dx,n\frac{1}{n}\sum_{i=1}^n\log p(\theta_i\mid \theta)=\mathbb{E}_{x\sim p_{data}}[\log p(x\mid \theta)] = \int p_{data}(x)\log p(x\mid \theta)dx, n\to \infty

这样,最大似然估计就与最小化KL divergence构建起了联系:

θMLE=argmaxθi=1nlogp(xiθ)=argmaxθpdata(x)logp(xθ)dx=θKL,n.\begin{aligned} \theta_{MLE}^*&=\arg\max_{\theta} \sum_{i=1}^n \log p(x_i\mid \theta)\\ &= \arg\max_{\theta} \int p_{data}(x)\log p(x\mid \theta)dx\\ &= \theta_{KL}^*, n\to\infty. \end{aligned}

也就是说,当采样样本足够多的时候,最大似然估计和最小KL divergence是等价的。

Introduction

A simple note to understand Sigmoid Loss in SigLip 1. Supported by DeepSeek2

Binary cross entropy loss

Suppose we want to solve the binary classification problem, with label y{0,1}y\in\{0, 1\}, a common option is to use binary cross entropy loss:

L(x,y)=[ylog(σ(z))+(1y)log(1σ(z))]\mathcal{L}(x, y) = -[y\log (\sigma(z)) + (1-y)\log (1-\sigma(z))]

where z=fθ(x)z=f_\theta(x) is the logits predicted by our model fθf_\theta, and σ\sigma is the sigmoid function:

σ(z):=11+ez\sigma(z) := \frac{1}{1 + e^{-z}}

Let σ()\sigma(\cdot) be the sigmoid function, then we have:

σ(z)=11+ez=ez1+ez=111+ez=1σ(z)\sigma(-z) = \frac{1}{1 + e^{z}} = \frac{e^{-z}}{1 + e^{-z}} = 1 - \frac{1}{1 + e^{-z}} = 1- \sigma(z)

Now we substitute σ(z)=1σ(z)\sigma(-z)=1-\sigma(z) into the loss function, we obtain:

L(x,y)=[ylog(σ(z))+(1y)log(σ(z))]\mathcal{L}(x, y) = -[y\log (\sigma(z)) + (1-y)\log (\sigma(-z))]

Note that y{0,1}y\in\{0, 1\} thus for each instance, there are two cases:

Now we want to use a unified expression to express these two cases. Note that this requires fitting a curve that passes two points (0,1)(0, -1) and (1,1)(1, 1). The simplest curve is a straight line y=2x1y=2x-1. So, we can further simplify the loss expression into:

L(x,y)=log[σ((2y1)z)]\mathcal{L}(x, y) = -\log\left[\sigma((2y-1)z)\right]

Sigmoid Loss in SigLip

Now we recall the sigmoid loss in SigLip:

L({x,y}i=1N)=1Ni=1Nj=1Nlog11+exp[zij(txiyj+b)]\mathcal{L}(\{\bm{x}, \bm{y}\}_{i=1}^N)=-\frac{1}{N}\sum_{i=1}^N\sum_{j=1}^N\log \frac{1}{1+\exp\left[z_{ij}(-t\bm{x}_i\cdot \bm{y_j}+b)\right]}

where t,bt, b are learnable parameters, and zij=1z_{ij}=1 if i=ji=j and zij=1z_{ij}=-1 otherwise.

To understand Sigmoid loss, notice that zij=2Ii=j1z_{ij}=2\mathbb{I}_{i=j}-1, which exactly matches the form we derived earlier.

Why Use Sigmoid Loss?

  1. More stable: avoids log0\log 0.
  2. More efficient: Compute Sigmoid once.
  3. More Precise: one line of code without condition checking.

    Footnotes

    1. SigLip

    2. DeepSeek

    Introduction

    softmax 函数用于将 KK 个实数转换为一个 KK 维概率分布。其具体做法是先对所有元素指数化,即求 exe^x, 然后每个元素除以所有指数的和。即

    softmax:RK(0,1)Ksoftmax(z)=[ez1j=1Kezj,,ezKj=1Kezj]\begin{aligned} \mathrm{softmax}:\mathbb{R}^K&\to (0,1)^K\\ \mathrm{softmax}(\mathbf{z}) &=\left[\frac{e^{z_1}}{\sum_{j=1}^Ke^{z_j}},\dots,\frac{e^{z_K}}{\sum_{j=1}^Ke^{z_j}}\right] \end{aligned}

    Analysis

    Properties

    softmax 的第一个性质是 shift invariance, 即

    softmax(z+c)=softmax(z)\mathrm{softmax}(\mathbf{z}+c) = \mathrm{softmax}(\mathbf{z})

    证明比较容易:

    softmax(z+c)i=ezi+cj=1Kezj+c=eceziecj=1Kezj=ezij=1Kezj=softmax(z)i, i=1,,K\mathrm{softmax}(\mathbf{z}+c)_i = \frac{e^{z_i+c}}{\sum_{j=1}^Ke^{z_j+c}} = \frac{e^ce^{z_i}}{e^c\sum_{j=1}^Ke^{z_j}} = \frac{e^{z_i}}{\sum_{j=1}^Ke^{z_j}}=\mathrm{softmax}(\mathbf{z})_i,\ i=1,\dots,K

    Gradient

    向量输入下 Softmax 函数的 Jacobian 矩阵推导

    设输入为向量 z=[z1,z2,,zd]Rd\mathbf{z} = [z_1, z_2, \dots, z_d]^\top \in \mathbb{R}^d,Softmax 函数的输出为向量 a=[a1,a2,,ad]Rd\mathbf{a} = [a_1, a_2, \dots, a_d]^\top \in \mathbb{R}^d,其中每个元素定义为:

    aj=softmax(z)j=ezjk=1dezka_j = \text{softmax}(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^d e^{z_k}}

    记分母(归一化因子)为 S=k=1dezkS = \sum_{k=1}^d e^{z_k},则 aj=ezj/Sa_j = e^{z_j}/S.

    我们分两种情况计算 ajzk\frac{\partial a_j}{\partial z_k}

    j=kj = k 时, 此时求 aja_j 对自身输入 zjz_j 的偏导数:

    ajzj=zj(ezjS)=ezjSezjezjS2=ezjS(1ezjS)=aj(1aj)\frac{\partial a_j}{\partial z_j} = \frac{\partial}{\partial z_j} \left( \frac{e^{z_j}}{S} \right) = \frac{e^{z_j}S-e^{z_j}e^{z_j}}{S^2}=\frac{e^{z_j}}{S}\left(1-\frac{e^{z_j}}{S}\right)=a_j(1-a_j)

    jkj \neq k 时, 此时求 aja_j 对输入 zkz_k 的偏导数有:

    ajzk=zk(ezjS)=0SezjezkS2=ezjezkS=ajaj\frac{\partial a_j}{\partial z_k} = \frac{\partial}{\partial z_k} \left( \frac{e^{z_j}}{S} \right) = \frac{0\cdot S-e^{z_j}e^{z_k}}{S^2}=-\frac{e^{z_j}e^{z_k}}{S}=-a_ja_j

    综合以上两种情况,Jacobian 矩阵 J\mathbf{J} 可表示为:

    J=diag(a)aa\mathbf{J} = \text{diag}(\mathbf{a}) - \mathbf{a} \mathbf{a}^\top

    Interpretation

    Soft Argmax

    softmax 是 argmax 的 smooth approximation, 所以实际上 softmax 指的是 “soft argmax”. 为了证明这一点,我们首先定义如下函数

    softmax(z;τ)=softmax(z/τ)=[ez1/τj=1Kezj/τ,,ezK/τj=1Kezj/τ]\mathrm{softmax}(\mathbf{z};\tau) =\mathrm{softmax}(\mathbf{z}/\tau)=\left[\frac{e^{z_1/\tau}}{\sum_{j=1}^Ke^{z_j/\tau}},\dots,\frac{e^{z_K/\tau}}{\sum_{j=1}^Ke^{z_j/\tau}}\right]

    易知, softmax(z)=softmax(z;1)\mathrm{softmax}(\mathbf{z})=\mathrm{softmax}(\mathbf{z};1). 并且,softmax\mathrm{softmax} 还是一个光滑函数

    我们定义 smooth approximation 为

    Definition 如果 limτ0+softmax(z;τ)=1argmax(z)\lim_{\tau\to0^+}\mathrm{softmax}(\mathbf{z};\tau)=\mathbb{1}_{\arg\max(\mathbf{z})}, 则我们说 softmax(;τ)\mathrm{softmax}(\cdot;\tau)argmax\arg\max 的光滑近似,特别地,softmax()\mathrm{softmax}(\cdot)argmax\arg\max 的光滑近似。 这里 argmax(z)=argmaxkzk\arg\max(\mathbf{z})=\arg\max_k z_k 是最大值的索引, 1{0,1}K\mathbb{1}\in\{0,1\}^K 是示性函数 (indicator function), 即 1argmax(z)[i]=1\mathbb{1}_{\arg\max(\mathbf{z})}[i]=1 当且仅当 zi=maxjzjz_i=\max_jz_j.

    我们下面来进行证明。我们不妨假设最大值唯一,其 index 为 mm, 即 zm=maxiziz_m = \max_i z_i. 由前面的性质,我们有:

    softmax(z;τ)=softmax(zzm;τ)=[e(z1zm)/τj=1Ke(zjzm)/τ,,e(zKzm)/τj=1Ke(zjzm)/τ]\mathrm{softmax}(\mathbf{z};\tau) = \mathrm{softmax}(\mathbf{z}-z_m;\tau) =\left[\frac{e^{(z_1-z_m)/\tau}}{\sum_{j=1}^Ke^{(z_j-z_m)/\tau}},\dots,\frac{e^{(z_K-z_m)/\tau}}{\sum_{j=1}^Ke^{(z_j-z_m)/\tau}}\right]

    此时,我们有

    limτ0+softmax(z;τ)i={1,if i=m0,otherwise\lim_{\tau\to0^+}\mathrm{softmax}(\mathbf{z};\tau)_i = \begin{cases} 1, &\text{if }i = m\\ 0, &\text{otherwise} \end{cases}

    当最大值不唯一的时候,我们记 I={i[K]zi=maxjzj}\mathcal{I} = \{i\in[K]\mid z_i=\max_j z_j\}, 与上面方法类似,最终 softmax(;τ)\mathrm{softmax}(\cdot;\tau) 的结果为

    limτ0+softmax(z;τ)i={1/I,if iI0,otherwise\lim_{\tau\to0^+}\mathrm{softmax}(\mathbf{z};\tau)_i = \begin{cases} 1/|\mathcal{I}|, &\text{if }i \in \mathcal{I}\\ 0, &\text{otherwise} \end{cases}

    因此,我们就证明了 softmax 是 argmax 函数的 smooth approximation.

    Statistical Mechanics

    Temperature

    我们前面介绍了 softmax(z;τ)\mathrm{softmax}(\mathbf{z};\tau) 函数,这里的 τ\tau 实际上被称为温度 (temperature), 它控制了输入的 variance, TT 越大,输入的 variance 越低,输出就倾向于均匀分布,而 TT 越小,则说明输入的 variance 越高,输出就倾向于 one-hot 分布。

    我们前面已经证明了后者,现在我们来证明一下前者,证明思路也很简单,T+T\to+\infty 时,ex/T1e^{x/T}\to 1, 因而

    limτ+softmax(z;τ)i=1K, i=1,,K\lim_{\tau\to+\infty}\mathrm{softmax}(\mathbf{z};\tau)_i =\frac1K,\ i=1,\dots,K

    下面是可视化的代码以及结果

    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.interpolate import make_interp_spline
    
    def softmax(x):
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum()
    
    num_elements = 15
    indices = np.arange(num_elements)
    logits = np.linspace(-3.5, 3.5, num_elements)
    scales = [0.01, 0.1, 1.0, 5.0, 10.0, 100.0]
    
    plt.figure(figsize=(10, 6))
    
    for s in scales:
        probs = softmax(logits * s)
        
        x_smooth = np.linspace(indices.min(), indices.max(), 300)
        spl = make_interp_spline(indices, probs, k=3)
        y_smooth = np.clip(spl(x_smooth), 0, None) # Clip to ensure no negative artifacts
        
        plt.plot(x_smooth, y_smooth, label=f'Scale = {s}', linewidth=2)
    
    uniform_prob = 1.0 / len(indices)
    plt.axhline(y=uniform_prob, color='black', linestyle=':', alpha=0.6, label=f'Uniform distribution')
    
    plt.xticks(indices)
    plt.xlabel('Logit Index', fontsize=12)
    plt.ylabel('Softmax Probability', fontsize=12)
    plt.title('Impact of Variance Scaling on Softmax Distribution', fontsize=14)
    plt.legend(title="Variance Scale")
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    
    plt.show()
    
    impact of variance on softmax

    可以看到,当 variance 比较小的时候,输出的分布接近于均匀分布,而 variance 越大,输出的分布越接近 One-hot 分布。

    在 attention 的计算过程中,我们也有 softmax 函数,为了在 softmax 过程中避免 variance 的影响,现在会在计算 softmax 之前加入 normalization layer 来提前进行归一化。见 QK-norm.

    Algorithms

    Implementation

    由于 exe^x 在实际计算时,非常容易溢出,因此在实现的时候,我们往往会考虑其数值稳定性。实际上,现在的 softmax 函数基本由 logsumexp 实现,logsumexp 函数定义如下

    logsumexp(z)=log(i=1Kezi)\mathrm{logsumexp}(\mathbf{z}) = \log \left(\sum_{i=1}^K e^{z_i}\right)

    softmax 函数与 logsumexp 函数的关系如下

    softmax(z)=explog(ezj=1Kezj)=exp(zlog(i=1Kezi))=exp(zlogsumexp(z))\begin{aligned} \mathrm{softmax}(\mathbf{z}) &=\exp\log\left(\frac{e^{\mathbf{z}}}{\sum_{j=1}^Ke^{z_j}}\right)\\ &= \exp\left(\mathbf{z} - \log\left(\sum_{i=1}^K e^{z_i}\right)\right)\\ &= \exp(\mathbf{z} - \mathrm{logsumexp}(\mathbf{z})) \end{aligned}

    考虑前面提到的 exe^x 数值溢出的问题,我们的输入会先经过 shift, 减掉最大值。此时我们有

    softmax(z)=softmax(zc)=exp((zc)logsumexp(zc))\mathrm{softmax}(\mathbf{z}) = \mathrm{softmax}(\mathbf{z}-c) = \exp((\mathbf{z}-c) - \mathrm{logsumexp}(\mathbf{z}-c))

    这里我们使用了前面推导出来的 shift invariance 性质。对应的代码实现如下:

    def softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
        x = x - x.max(dim=dim, keepdim=True).values
        log_sum_exp = torch.log(torch.sum(torch.exp(x), dim=dim, keepdim=True))
        return torch.exp(x - log_sum_exp)
    

    Gumbel-softmax Reparametrization Trick

    TODO

    Online Softmax

    注意到我们在计算 softmax 时,需要加载 z\mathbf{z} 的全部信息,如果 z\mathbf{z} 非常大的话,会产生频繁的内存读写进而影响整体效率。因此 flash attention 中提出了 online softmax 算法来减少内存访问开销。

    其具体做法是假设我们的输入被分为若干个 block, 即 z=[z1;,zn]RK\mathbf{z}=[\mathbf{z}^1;\dots,\mathbf{z}^n]\in\mathbb{R}^K, 这里 ziRK/n\mathbf{z}^i\in\mathbb{R}^{K/n} (Kmodn=0K\mod n=0).

    对于 zRK\mathbf{z}\in\mathbb{R}^K, flash attention 定义如下结果

    m(z)=maxizi, f(z)=[ez1m(z),,ezKm(z)], (z)=if(z)i, softmax(z)=f(z)(z)m(\mathbf{z}) = \max_i z_i,\ f(\mathbf{z}) = [e^{z_1-m(\mathbf{z})},\dots,e^{z_K-m(\mathbf{z})}], \ \ell(\mathbf{z})=\sum_if(z)_i, \ \mathrm{softmax}(\mathbf{z}) = \frac{f(\mathbf{z})}{\ell(\mathbf{z})}

    对于 z=[z1;,zn]RK\mathbf{z}=[\mathbf{z}^1;\dots,\mathbf{z}^n]\in\mathbb{R}^K, 我们现在的计算方式为

    mi(z)=max([z1;;zi])=max(mi1(z),m(zi))i(z)=j=1if(zj)=exp(mi1(z)mi(z))(zi1)+exp(zimi(z))\begin{aligned} m_i(\mathbf{z}) &= \max([\mathbf{z}^1;\dots;\mathbf{z}^i]) = \max(m_{i-1}(\mathbf{z}),m(\mathbf{z}^i))\\ \ell_i(\mathbf{z}) &= \sum_{j=1}^if(\mathbf{z}^j) = \exp(m_{i-1}(\mathbf{z}) - m_i(\mathbf{z}))\ell(\mathbf{z}^{i-1}) + \exp(\mathbf{z}^i-m_i(\mathbf{z})) \end{aligned}

    因此,如果我们额外记录 m(x)m(x) 以及 (x)\ell(x) 这两个量,那么我们可以每次仅计算 softmax 的一个 block. 计算完毕之后,mi(z)m_i(\mathbf{z})i(z)\ell_i(\mathbf{z}) 就分别代表了 global max 和 global denominator.

    Conclusion

    我们回顾了机器学习中 softmax function 的基本定义与性质