MiniMax представила MSA — sparse attention, который режет квадратичную сложность до константы
MiniMax представила MSA (MiniMax Sparse Attention) — разреженный метод внимания, построенный поверх Grouped Query Attention (GQA). Основная проблема, которую решает MSA, — квадратичная стоимость softmax-внимания при длинных контекстах. Исследователи протестировали его внутри 109B-параметрической Mixture-of-Experts-модели, обученной на мультимодальных данных. Они также открыли исходный код инференс-ядра и выпустили продакшн-модель MiniMax-M3.
MSA делит внимание на две ветви: Index Branch и Main Branch. Index Branch решает, какие блоки ключ-значение должен читать каждый запрос. Main Branch выполняет точное softmax-внимание только по выбранным блокам. Размер блока по умолчанию — 128 токенов, каждый запрос и группа GQA держат k = 16 блоков, что фиксирует бюджет на 2048 токенов на запрос. В отличие от плотного GQA, где сложность растёт линейно с длиной контекста, MSA остаётся константной.
Обучение MSA нетривиально: Top-k-выборка недифференцируема, поэтому используется KL-функция потерь, которая согласует распределение Index Branch с паттерном внимания Main Branch. Для стабилизации применяются три механизма: Gradient Detach (останавливает градиенты для Index Branch), Indexer Warmup (полное внимание на первых итерациях) и принудительный локальный блок. MSA поддерживает два режима обучения: с нуля (MSA-PT) и конвертацию плотного GQA-чека (MSA-CPT) с дообучением на 400 млрд токенов.
Теоретическая разреженность не даёт ускорения без оптимизированных GPU-ядер. MSA включает два кастомных ядра: exp-free Top-k (в 5,1 раза быстрее torch.topk при контексте 128K и k = 16) и второе ядро, которое обходит стандартные реализации. Исходный код и веса модели доступны на GitHub.