Differentiable Permutation Layer
本文最后更新于:May 12, 2025 pm
Introduction
Early in 2024, when I was still working on computer vision (CV), Mamba had just been introduced, and I saw many attempts to apply it to images. However, models like Mamba or RNN/SSM inherently carry causal and positional inductive biases. If we process images by splitting them into patches and feeding them sequentially (as in ViT), it imposes a top-left to bottom-right reading order on the image—a prior that doesn’t naturally exist in images. Many papers at the time were exploring better scanning methods.
A natural question arises:
Does an optimal scanning order exist? And if so, can the model learn it automatically?
Method
For sequence modeling, the input is typically a tensor . Ignoring the batch dimension, an image can be represented as a matrix , where each row vector corresponds to the embedding of a patch. After passing through a patch embedding layer, we obtain an ordered sequence , where the order is determined by the position of each row in the matrix.
To find an optimal scanning order, we need to find an elementary row transformation matrix such that:
Thus, we need a function that can predict this transformation matrix. Suppose we have a function that outputs a “importance score” for each patch. Using Tensor.sort()
, we can obtain a sorted index , where each element indicates the desired position of the corresponding patch. One-hot encoding this index yields a matrix , which can serve as the row transformation matrix.
Example
It’s easy to verify that .
However, a major issue is that Tensor.sort()
is non-differentiable—gradients cannot flow through it. Is there a fully differentiable way to obtain the transformation matrix?
Differentiable Approximation
Consider a simple input where . Applying gives importance scores . After sorting, suppose we get , corresponding to and:
Note that:
The squared terms () indicate the positions of "1"s in the elementary matrix.
A naive way to approximate is:
Here, subtracting zeros out the squared terms, and scaling by temperature (possibly with absolute values for cross-terms) followed by softmax approximates the elementary matrix. However, optimization might be tricky due to softmax.
Alternatively, if we treat as containing learnable information, we can use the stop-gradient trick:
Let be the one-hot encoded matrix from Tensor.sort()
. Then:
where sg
stops gradients.
Conclusion
This was a rough idea I came up with to let models learn their own scanning order. I originally planned to experiment further, but later shifted focus to LLMs and text-to-video work, leaving Mamba/SSM behind. While preliminary, I wanted to document it—maybe someone else will find it useful.