https://github.com/adonath/nanogptx

🚀✨ A nanoGPT implementation in pure JAX and some infrastructure you can use as template for your own small scale LLM research projects 🚀✨

https://github.com/adonath/nanogptx

Science Score: 26.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (15.5%) to scientific vocabulary
Last synced: 9 months ago · JSON representation

Repository

🚀✨ A nanoGPT implementation in pure JAX and some infrastructure you can use as template for your own small scale LLM research projects 🚀✨

Basic Info
  • Host: GitHub
  • Owner: adonath
  • License: bsd-3-clause
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 626 KB
Statistics
  • Stars: 0
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created over 1 year ago · Last pushed 10 months ago
Metadata Files
Readme License

README.md

NanoGTPX: A "nanoGPT" Implementation in Pure JAX

Release Build status codecov Commit activity License

Banner

Purpose of this Repository

The purpose of this repository is mostly documenting my own learning progress on recent developments in AI. I first wanted to learn more transformers and the process of training LLMs and at the same time I wanted to learn about JAX. Give these goals a reasonable project was to re-implement nanoGPT in pure JAX. In the process I have found that I typically ended up with much cleaner code, compared to PyTorch. So I decided to split the code base up into smaller reusable and more modular parts and release it. Now it can be used for educational purposes, or as a clean and hackable starting point for small scale experiments on modified architectures, training strategies or experiments in interpretability. I think cooking a new experiment needs to start from a clean lab, so happy cooking!

Note: if you need minimal production grade implementations of LLMs you might rather want to check out official JAX LLM examples or for large scale experiments and training checkout Levanter, which is based on Equinox.

Getting started

This repositiry comes with mutiple pre-defined environments in a pixi.toml file. This makes it very covenient to run the model in CPU, GPU and even TPU (soon...) environments. To get started, you first install pixi using:

bash curl -fsSL https://pixi.sh/install.sh | sh

And then proceed with one of the options:

(a) Training a Small Model on "Tiny Shakespeare" and CPU

To train a small transformer model with character level encoding on the "Tiny Shakespeare" dataset you can use:

bash pixi run download --dataset shakespeare pixi run --environment prepare prepare --dataset shakespeare --encoding char --shard-size 1000000 --shards-val 1 pixi run --environment cpu train train-shakespeare-char pixi run --environment cpu sample --init-from resume --sampler.max-new-tokens 500 --sampler.num-samples 5 The workflow always consists of those four steps. The training should finish in <2 minutes on a M1 type machine. All the sub-commands have a --help option which shows you the available configuration options.

(b) Training a GPT2 124m Model on Fineweb10b and multiple GPUs

To train a GPT2 124m model on the Fineweb10b dataset on two GPUs you can use for example: bash pixi run download --dataset fineweb_10b pixi run --environment prepare prepare --dataset fineweb_10b pixi run --environment gpu train train-fineweb-10b --sharding.devices cuda:0,cuda:1 --loading.sharding.devices cuda:0,cuda:1 pixi run --envrionment gpu sample --init-from resume --sampler.max-new-tokens 500 --sampler.num-samples 5 nanogptx supports a simple SPMD (single program multiple data) distribution strategy, meaning groups of batches are evaluated in parallel on the configured devices.

NanoGPTX Features

Here are some of the features of the nanogptx implementation:

  • 🗄️ Hierarchical configuration: I do like hierarchical configuration as long as it is not too deep (3-4 levels max). The confiuration system is based on TOML and dataclasses, combined with JAX pytree operations and dacite for serialization and deserialization. The configuration requires setting defaults and on I/O it does type coercion to the default type, to catch errors early.
  • 💻 CLI: after careful consideration I have decied to support a CLI via tyro. The code overhead is minimal, as all the configuration is in dataclasses anyway. And tyro offers a nice default interface as well as nested commands.
  • Abstract evaluation and lazy initialization: I think it is useful to not fully instantiate a model on creation, but rather instantiate an abstract description of the array shapes, dtypes and shardings. This allows for an abstract evaluation which catches shape, dtype and sharding errors early without using any flops.
  • 🗃️ Minimal provenance: The implementation supports minimal provenance of model configs, datasets and training. This includes logging of which batch is trained on, saving configs in model files and verifying data hashes.
  • 🐍 Support for Pixi enviromments: this repository includes a pixi.toml with pre-defined environments for many scenarios such as CPU, CPU and even TPU. I have also tried to support MPS, via jax-metal, but ran into multiple issus with missing support for operations.
  • 💯 Sharding strategies: Currently only SPMD is supported, other strategies might follow...
  • 📚 Data preprocessing pipeline: A minimal function based pre-processing pipeline for tokenization and custom document cleaning / pre-processing.
  • 📇 Logging: just as the original nanoGPT this project uses WandB for logging. I have considered alternatives (especially local solutions), but found other solutions introduced more complexity with fewer features.

How to work with this Repository

This repository can be used as template for your own small to mid-scale research and educational projects. You can explore different training strategies, modified architectures etc. As everything is in pure JAX, you can modify any small component in the model, without the need of implementing whole new layers. nanogptx still provides the whole scalable infrastructure.

Profiling

There is a dedicated option and environment for profiling available. The approch follows the programmatic capture section in the JAX docs. An you can enable it using:

bash pixi run --environment cpu-profile train train-shakespeare-char --profile pixi run --environment gpu-profile train train-fineweb-10b --sharding.devices cuda:0,cuda:1 --loading.sharding.devices cuda:0,cuda:1 --profile

⚠️ Note: the profiling requires tensorflow to be available in the same environment. This potentially leaves you with a different version of JAX, that is used for the profiling:

Adding a new Dataset

If you would like to add a new dataset follow these steps:

  • Add a new entry to the DatasetEnum in src/utils.py with a short identifier of your dataset
  • Add the download urls in src/download.py, decompressing / unzip / untar should also happen at this step
  • Add a custom read function in src/prepare.py as needed.

Adding a new Model

TODO:

Adding a new Config

TODO:

Acknowledgements

Thanks to @fcrespo82 for the names list from the Ubuntu Name generator.

Owner

  • Name: Axel Donath
  • Login: adonath
  • Kind: user
  • Location: Cambridge, MA
  • Company: Center for Astrophysics | Havard & Smithonian

I'm a Postdoc researcher at Center for Astrophysics. I work on statistical methods for analysis of low counts astronomical data.

GitHub Events

Total
  • Push event: 2
Last Year
  • Push event: 2

Dependencies

.github/workflows/ci.yml actions
  • actions/checkout v4 composite