Implementing LLaMA3 in 100 Lines of Pure Jax
(saurabhalone.com)
In this post, we'll implement llama3 from scratch using pure jax in just 100 lines of code. Why jax? Because I think it has good aesthetics. Also jax looks like a NumPy wrapper but it has some cool features like xla; a linear algebra accelerator, jit, vmap, pmap etc., which makes your training go brr brr.
In this post, we'll implement llama3 from scratch using pure jax in just 100 lines of code. Why jax? Because I think it has good aesthetics. Also jax looks like a NumPy wrapper but it has some cool features like xla; a linear algebra accelerator, jit, vmap, pmap etc., which makes your training go brr brr.