Announcing NNX 0.0.4 (Beyond Pytrees) π
Major update:
- Simplified APIs
- Modules are now regular mutable python classes
- Added support for Module graphs (RIP pytrees)
NNX is now more pythonic, simpler, and still as powerful as Flax.
I'm getting tired of Sam Altman's ego
I'd rather have a future where everybody has their own LLaMa at home running on
@__tinygrad__
's box than being at the mercy of a corporation.
Copilot is AMAZING at creating Matplotlib plots, it's one of the only times I just put a comment and accept multi-line suggestions. Plotting is tedious and has hard to remember APIs, a perfect target for automation.
Happy to announce I'll be joining
@GoogleDeepMind
in the next few weeks π. I'll continue my work on JAX/Flax, but I'll also be working on a cool new thingβ’
Pretty excited about the new horizons! π¬π§
Lambda Networks: SOTA on ImageNet. Again Transformer-like architectures dominating over a new field π. CNNs had a good run π°.
Paper:
Awesome video from
@ykilcher
:
The EU AI Act passes in the European Parliament! 523 voted for, 46 against, 49 abstained. Iβm grateful for being in Strasbourg today, witnessing one of the last hurdles before the first-ever comprehensive AI regulation, the AI Act, becomes an actual law in the EU.
Hey! The JAX team (
@shoyer
) recently released Tree Math π²π’, a small library that lets you operate over pytrees as if they were ndarrays. Under the hood this reduces to a tree_map but it makes certain types of code look very clean.
Hey! Sharing this notebook that showcases how to implement various parallelism strategies like data parallel and model parallel in JAX. It implements a very simple model in pure jax using shard_map and jit with different sharing configurations.
I am shocked π€― Deep Learning in JAX starts to get stupidly simple if Modules are Pytrees.
Turns out jit is aware about the static part of Pytrees and recompiles upon change:
Since there is a lot of fuzz about Mojo, maybe it would be good to take a look a the Codon project:
"Codon is a high-performance Python compiler that compiles Python code to native machine code without any runtime overhead."
JAX code I β€οΈ
#2
In the old days you could code a nice pairwise formula but vectorizing added a lot of unpleasant artifacts π (tiling, broadcasting). Using a double vmap however, you can teach your beautiful function to operate over sets without changing a single line π₯
Hey Twitter! Very excited to announce the JAX Global Meetup π Its a fully online meetup for people across the world passionate about JAX, Deep Learning, and Scientific Computing.
I have the pleasure of being able to co-host it with
@bhutanisanyam1
π
A dev at a friends workplace committed the companies GCP keys on a personal Github repo π€¦
"Hackers" turned on 200+ VMs with 4 T4 GPUs each over the weekend amounting to 25K USD.
Any advice for them? (apart from dont do it again)
Friendly reminder that if you love einops but are tired having to switching from op to op,
@MilesCranmer
created einop which infers which op you need so you don't have to think.
Today, weβre announcing that
@Amazon
will invest up to $4 billion in Anthropic. The agreement is part of a broader collaboration to develop reliable and high-performing foundation models.
Created my first π€
@huggingface
dataset:
The Point Cloud MNIST
Its a toy dataset to mess around with point clouds and play with architectures like DeepSets, Transformers, or more fancy Geometric DL techniques without consuming too much resources.
πPleased to announce Elegy! A Keras-like Deep Learning framework based on Jax + Haiku. Its still in very early stageπ we would love to get feedback so we can explore all the possibilities that Jax brings!
Check my latest video tutorial on the JAX lax scan function that is useful to produce trajectories by autoregressively evaluating time steppers.
We will conveniently repurpose the KS-solver of the last weeks.
𧡠What I like to use this function for
We took ChatGPT offline Monday to fix a bug in an open source library that allowed some users to see titles from other usersβ chat history. Our investigation has also found that 1.2% of ChatGPT Plus users might have had personal data revealed to another user. 1/2
Made this new post about Quantile Regression! Its a very useful technique to estimate uncertainty, very easy to understand / implement. Example in JAX π
Blog:
Notebook:
Repo:
Deep Mind recently came out with Long Range Arena, a benchmark suite to pit Efficient Transformers against each other in Long Range Tasks ππββοΈπββοΈ.
Do we finally have a winner?
Paper:
π£ New 'Transfer Learning' Flax guide!
We added a new guide that shows how to use models from
@huggingface
's transformers library, perform parameter surgery, and freeze parameters with optax or perform fine-tunning with different learning rates.
MaxText is probably the best to train LLMs in JAX/Flax outside of Google.
They provide configs for:
gemma-2b, gemma-7b, gpt3-175b, gpt3-22b, gpt3-52k, gpt3-6b, llama2-70b, llama2-7b, mistral-7b, mixtral-8x7b
Babe wake up, a new JAX library just dropped.
Jeometric is a new GNN library for JAX, based on Flax (you love to see it). It offers both common GNN layers and a data format for the inputs. Check it out!
Exciting News! (for me) π I've just published Jeometric, a new Python library for graph neural networks in JAX!
If you use PyTorch Geometric, it will look familiar!
It's still early stage and I'll be building it in public. Pull requests very welcome β€οΈ
I wrote this (Deep) Quantile Regression tutorial in JAX/Elegy a while back, I had was a lot of fun.
I read you can adapt QR to the conformal theory so maybe I'm onboard with the conformal agenda π . I'll start getting angry at bayesians soon :p
Getting started with Deep Learning in JAX with Treex in 5 tweets
If you are JAX-curious but don't want to stray too far from the Pytorch-way,
Treex is here to save the day π³
π§΅
Announced by Mark Zuckerberg this morning β today we're releasing DINOv2, the first method for training computer vision models that uses self-supervised learning to achieve results matching or exceeding industry standards.
More on this new work β‘οΈ
The more I learn about model parallelism the more mind-blowing π€― jax.Array / pjit become. With relative ease one can try out techniques from ZeRO and other papers.
(picture from Megatron)
I was going to write a blog on parallel JAX until they updated the API a lot. Now it is so shockingly easy to do that it might not be worth it π
Below is the diff between my single GPU training script and one that I just tested on 8xTPUs in data parallel. π₯
π₯Amazed at the simplicity of Jax + Haiku! The following code implements a cumulative accuracy metric. Haiku hooks make implementing Deep Learning code so much easier.
Jax is the future!
gist:
We need a dedicated collection of Toy Datasets for Machine Learning:
1. They can be more interesting than real datasets, specially if designed to be hard for certain algorithms.
2. They are more useful for teaching / learning.
Maybe
@huggingface
/
@kaggle
can help with this?
JAX Global Meetup is back!
Join us this Friday Oct 7,
@_arohan_
will be talking about second order optimizers, deep learning, and JAX.
@borisdayma
and I will be hosting the event.
Event link:
Join the JAX Meetup to notified of all future events!
Hey JAX users, here is a slightly better pattern for splitting your RNG key to what is commonly used. It some nice properties:
- the RNG has longer cycles
- no need to iteratively update `key`
- easier checkpointing
(credit to
@froystig
)
Easily the best tutorial on distributed training I've seen π₯
Uses JAX/Flax and shows how to use the low-level communication primitives to teach the basic concepts, even if the compiler can do most of this for you I think it's super valuable to learn how things work.
JAX success in RL is due to a couple of reasons:
1. environments are being ported to run on device (the jax.numpy API might help here)
2. once ported environments can be be easily parallelized/distributed to run along with the agent + trainer
3. synchronization is easy with
@cgarciae88
Do you think thereβs a technical reason for this thatβs specific to RL, or just that in the absence of tech debt itβs a good way to go?
Of course RL community momentum will play a role increasingly.
jax.jit will soon be able to tell you why a function is recompiling/retracing π
This is a nice quality-of-life update when debugging JAX programs. Props to
@SingularMattrix
and
@yashk2810
for the awesome log messages.
Hey JAXers, I've been searching for an abstraction that could allow us to create a libraries as powerful as Flax but as simple as Equinox.
This is what I've found:
π§΅
Working on NNX's readthedocs page and adding a Quick Start guide.
NNX is an Neural Network library for JAX that aims to aims to be Pythonic / support Object Oriented patterns.
Can we have the power of Flax with the simplicity of Equinox?
Introducing NNX: Neural Networks for JAX
A highly experimental π§ͺ proof of concept framework that provides Pytree Modules with:
* Shared state
* Tractable mutability
* Semantic partitioning (collections)
π¨Exclusive: a report commissioned by the U.S. government says advanced AI could pose an "extinction-level threat to the human species" and calls for urgent, sweeping new regulations
Hey twitter! I am happy to announce that next week I will be joining the
@quansightai
team πI will be working with our partners to solve real world Data Science / ML problems. Additionally I'll be supporting 2 Open Source projects: Pandas and Elegy β₯
A very nice result from Open Assistant! Not on par with ChatGPT (expected) but still very descent for an Open Source effort.
This could be huge for startups and companies that want to fine-tune it on their own data or need to run it offline.