To actually read how they did it, here is there model page: https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k
Approach:
- meta-llama/Meta-Llama-3-8B-Instruct as the base
- NTK-aware interpolation [1] to initialize an optimal schedule for RoPE theta, followed by empirical RoPE theta optimization
- Progressive training on increasing context lengths, similar to Large World Model [2] (See details below)
Infra
We build on top of the EasyContext Blockwise RingAttention library [3] to scalably and efficiently train on contexts up to 1048k tokens on Crusoe Energy high performance L40S cluster.
Notably, we layered parallelism on top of Ring Attention with a custom network topology to better leverage large GPU clusters in the face of network bottlenecks from passing many KV blocks between devices. This gave us a 33x speedup in model training (compare 524k and 1048k to 65k and 262k in the table below).
Data
For training data, we generate long contexts by augmenting SlimPajama. We also fine-tune on a chat dataset based on UltraChat [4], following a similar recipe for data augmentation to [2].
It is llama3-8B so it is not out of question but I am not sure how much memory you would need to really go to 1M context window. They use ring attention to achieve high context window, which I am unfamiliar with but that seems to lower greatly the memory requirements.