Differentiable Permutation Layer

本文最后更新于:May 12, 2025 pm

Introduction

Early in 2024, when I was still working on computer vision (CV), Mamba had just been introduced, and I saw many attempts to apply it to images. However, models like Mamba or RNN/SSM inherently carry causal and positional inductive biases. If we process images by splitting them into patches and feeding them sequentially (as in ViT), it imposes a top-left to bottom-right reading order on the image—a prior that doesn’t naturally exist in images. Many papers at the time were exploring better scanning methods.

A natural question arises:

Does an optimal scanning order exist? And if so, can the model learn it automatically?

Method​

For sequence modeling, the input is typically a tensor XRB×T×DX \in \mathbb{R}^{B \times T \times D}. Ignoring the batch dimension, an image can be represented as a matrix XRT×DX \in \mathbb{R}^{T \times D}, where each row vector corresponds to the embedding of a patch. After passing through a patch embedding layer, we obtain an ordered sequence XX, where the order is determined by the position of each row in the matrix.

To find an optimal scanning order, we need to find an elementary row transformation matrix IrowRT×TI_{\text{row}} \in \mathbb{R}^{T \times T} such that:

X=Irow×XX' = I_{\text{row}} \times X

Thus, we need a function fθ()f_{\theta}(\cdot) that can predict this transformation matrix. Suppose we have a function hθ():RT×DRTh_\theta(\cdot): \mathbb{R}^{T \times D} \rightarrow \mathbb{R}^T that outputs a “importance score” for each patch. Using Tensor.sort(), we can obtain a sorted index iRTi \in \mathbb{R}^T, where each element indicates the desired position of the corresponding patch. One-hot encoding this index yields a matrix INT×TI \in \mathbb{N}^{T \times T}, which can serve as the row transformation matrix.

Example​

a=[1,3,2,5,4]sorti=[0,2,1,4,3]one_hotI=[1000000100010000000100010]\begin{aligned} a = [1, 3, 2, 5, 4] &\overset{\text{sort}}{\rightarrow} i = [0, 2, 1, 4, 3] \\ &\overset{\text{one\_hot}}{\rightarrow} I = \begin{bmatrix} 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 & 0 \\ 0 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 1 & 0 \end{bmatrix} \end{aligned}

It’s easy to verify that asorted=I×aTa_{\text{sorted}} = I \times a^T.

However, a major issue is that Tensor.sort() is non-differentiable—gradients cannot flow through it. Is there a fully differentiable way to obtain the transformation matrix?

Differentiable Approximation

Consider a simple input x=[x0,x1,x2]\mathbf{x} = [x_0, x_1, x_2] where xiRDx_i \in \mathbb{R}^D. Applying hθ(x)h_\theta(x) gives importance scores e=[e1,e2,e3]\mathbf{e} = [e_1, e_2, e_3]. After sorting, suppose we get e=[e3,e1,e2]\mathbf{e}' = [e_3, e_1, e_2], corresponding to i=[1,2,0]i = [1, 2, 0] and:

I=[010001100]I = \begin{bmatrix} 0 & 1 & 0 \\ 0 & 0 & 1 \\ 1 & 0 & 0 \end{bmatrix}

Note that:

eT×e=[e1e3e12e1e2e2e3e2e1e22e32e1e3e2e1]\mathbf{e}'^T \times \mathbf{e} = \begin{bmatrix} e_1 e_3 & e_1^2 & e_1 e_2 \\ e_2 e_3 & e_2 e_1 & e_2^2 \\ e_3^2 & e_1 e_3 & e_2 e_1 \end{bmatrix}

The squared terms (ei2e_i^2) indicate the positions of "1"s in the elementary matrix.

A naive way to approximate II is:

I^row=softmax(t(eT×ee2))\hat I_{\text{row}} = \text{softmax}\left(t \cdot (\mathbf{e}'^T \times \mathbf{e} - \mathbf{e}^2)\right)

Here, subtracting e2\mathbf{e}^2 zeros out the squared terms, and scaling by temperature tt (possibly with absolute values for cross-terms) followed by softmax approximates the elementary matrix. However, optimization might be tricky due to softmax.

Alternatively, if we treat eT×e\mathbf{e}'^T \times \mathbf{e} as containing learnable information, we can use the stop-gradient trick:

I^row=exp(eT×ee2)\hat I_{\text{row}} = \exp(\mathbf{e}'^T \times \mathbf{e} - \mathbf{e}^2)

Let II be the one-hot encoded matrix from Tensor.sort(). Then:

X=(I^row+sg(II^row))×XX' = \left(\hat I_{\text{row}} + \text{sg}(I - \hat I_{\text{row}})\right) \times X

where sg stops gradients.

Conclusion​

This was a rough idea I came up with to let models learn their own scanning order. I originally planned to experiment further, but later shifted focus to LLMs and text-to-video work, leaving Mamba/SSM behind. While preliminary, I wanted to document it—maybe someone else will find it useful.


Differentiable Permutation Layer
https://jesseprince.github.io/2025/05/08/research/differentiable_reorder/
Author
林正
Posted on
May 8, 2025
Licensed under