People have tweaked the Transformer architecture enough in 7 years that we’re apparently now calling the current recipe “Transformer++”.

The changes between Vaswani et al 2017 and Radford et al 2019 are well-known (see Figure 3 here): removing the encoder, clipping gradients, not using dropout, and changing the activation function to something weighted or gated. So I focus on post-GPT-2 tweaks.

I also focus on architecture changes rather than data changes (curation or generation) or training changes (infrastructure and hyperparameters) or post-training or inference optimisations (e.g. Various kinds of data parallelism and model parallelism across multiple devices arose largely after 2019). I won’t get into multimodal architectures.

Inclusion criterion: three strong open-source architectures from 2024 using the tweak (LLaMA, Gemma, Qwen2.5, DeepSeek-V2, Hunyuan).

(The term “Transformer++” was coined in this sense by Gu and Dao 2023, who also noted several of the main tweaks.)



The Transformer++

Let the “Transformer++” be a Transformer with

  • A fused attention implementation (the scaled dot-product backend -> FlashAttention). Subquadratic memory complexity in input sequence length. Practically: can double GPU utilization and so halve training time. Also enables longer contexts and speeds up inference on long context input.
  • Rotary position embedding (sinusoidal -> learned APE -> RoPE)
  • Removing attention’s redundant key heads and value heads (vanilla MHA -> MQA -> GQA)
  • Regularized / preconditioned optimizer (Adam -> AdamW -> SOAP)
  • Normalise before each layer (post LayerNorm -> pre LayerNorm)
  • When doing layer normalization: just rescale, don’t centre (LayerNorm -> RMSNorm)
  • Divine activation function for the MLP (GeLU -> … -> SwiGLU or GeGLU)
  • Tied embeddings. An oldie but goodie.
  • Fix logit drift (query/key normalization)
  • Fixing that one softmax off-by-one (fixed in some places around 2021)



Less well-established tweaks

  • Sparsification. I could mention the turn to sparse Mixtures of Experts. But this turn was more of a cost-saving thing. LLaMA is still dense though.
  • BPE -> Tiktoken / SentencePiece. Basically the same, but about 25% better compression.
  • Quantization. FP16 to BF16 to int8.
  • No bias on QKV projection or layernorm. Thus not sure about putting the biases back into attention, but various people use it.
  • Sliding Window Attention e.g. Rolling Buffer Cache
  • Cross-Layer Attention shrinks the KV cache
  • WARP
  • Regularizing outputs (“soft-capping logits”)



Top open architectures by tweak

Component Tweak LLaMA 3 Gemma 2 Qwen2.5 DeepSeek-V2 Hunyuan-Large
             
Attention Attention kernel FlashAttention-2? Eager attention FlashAttention-2 FlashAttention-2 FlashAttention-2
Attention Sliding window attention No? Local-Global SWA both No? No?
Attention Removing KV heads GQA GQA GQA MLA GQA
Attention Cross-Layer Attention No No No No CLA
Attention prefill KV cache Yes ? ? No? ?
Attention low-rank KV cache compression No No No Yes No
Attention Biases in QKV projection No? ? QKV biases ? No
Attention QK Normalization No? No? ? No? No
             
Block sequence Parallel layers No No? No? No? No?
             
Embedding Position encoding RoPE RoPE RoPE decoupled RoPE DynamicNTKRope (*6)
Embedding Tied embeddings “Shared” (*1) Tied Tied (*4) ? Tied (*7)
             
Optimizer Regularized / preconditioned AdamW AdamW(*3) ? AdamW AdamW
             
Activation normalization post or pre layernorm pre both ? pre ?
Activation normalization Don’t center RMSNorm(*2) RMSNorm RMSNorm RMSNorm RMSNorm
             
Output normalization Soft-capped logits No? Soft-capped logits No? No? No?
             
Activation function Gated linear unit SwiGLU GeGLU SwiGLU SwiGLU SwiGLU (*5)
             
Sparsification Sparse? Dense Dense? Dense MoE MoE
             
Weights quantization BF16 training Yes No, FP32 Yes Yes Yes
Weights quantization 8-bit post-training In one version No No No In one version


[1] In the 3.2 models anyway
[2] Llama 2 uses RMSNorm anyway [3] They recommend AdamW for fine-tuning, unsure for training
[4] Only the smaller models
[5] Code says “silu”
[6] “Credits to the Reddit users /u/bloc97 and /u/emozilla”
[7] https://huggingface.co/tencent/Tencent-Hunyuan-Large/blob/main/Hunyuan-A52B-Pretrain/modeling_hunyuan.py#L1419



Caveats

  • The above ignores the much more important changes since 2017 to data “collection” (curation and synthesis), cluster infrastructure, post-training, and scaffolding.
  • The public tokenizers still use byte-pair encoding
  • Some models have absurdly high embedding-parameter counts. This is unlikely to be a performance optimisation. Instead we conjecture this is a tradeoff to allow underreporting the Transformer-parameter count and so enter a lesser model class (“7B”).
  • A lot of this doesn’t improve absolute performance that much, but it does make it a lot cheaper to run.
  • And this is just the public architecture. And probably there are some public methods which we haven’t realised are improvements yet.



I thank Kushal Thaman for helpful comments.



See also

  • https://arxiv.org/html/2410.16682v1
  • https://openreview.net/forum?id=d8w0pmvXbZ



Bibliography