Key Features:
Modular, readable, and extensible codebase Implementations of GPT-2 and LLaMA 3 in pure JAX/Flax Accelerated training with XLA + Optax Google Colab support (TPU-ready) Hugging Face dataset integration Upcoming support for fine-tuning, Mistral, and DeepSeek-R This is primarily an educational resource, but it's written with performance in mind and can be adapted for more serious use. Contributions are welcome — whether you’re improving performance, adding new models, or experimenting with different attention mechanisms.