https://github.com/atomicarchitects/fusionfail
Profile showing 3 layers of NequIP using e3nn-jax
Science Score: 10.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
○codemeta.json file
-
○.zenodo.json file
-
○DOI references
-
✓Academic publication links
Links to: arxiv.org, nature.com -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (7.6%) to scientific vocabulary
Repository
Profile showing 3 layers of NequIP using e3nn-jax
Basic Info
Statistics
- Stars: 0
- Watchers: 4
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
Fusion Fail

What are we looking at ?
The function train_step corresponds to a forward and backward pass through a 3 layered NequIP model implemented using e3nn-jax acting on a simple Tetris dataset. Thanks @ameya98 @mariogeiger for the code !
What's happening ?
Here's a brief summary of the under the hood story:
XLA is unable to pattern match or generate a small subset of fused kernels for the compuatation (See arxiv:2301.13062 to understand how XLA works). Instead its left with around ~300 kernels (half of which are cuBLAS/CUTLASS calls) that it needs to execute at runtime (small chunks below
Thunk:#hlo_opin theTSLrow)This makes the compiler fall back to CUDAGraphs which batches the execution of these kernels. However, the execution graph needs to be updated with new inputs at runtime (~30% runtime overhead before
Graph 7is launched on the GPU). This overhead (notice theCUDA APIrow) increases with the size of the computation graph.
What's the alternative ?
Ideally, the compiler/human should be giving us one forward and one backward fused kernel for our computation (See FlashAttention).
Packages
bash
pip install requirements.txt
To reproduce the profile shown above install NVIDIA Nsight Systems and run the following command (borrowed from JAX-Toolbox)
bash
nsys profile --capture-range=cudaProfilerApi --cuda-graph-trace=node --capture-range-end=stop -o nequip_profile_disable_cudagraph -f true python train.py
TODO
- [ ] Add a MLP-equivalent to show what non-CUDAGraph fusion should look like
- More profiling:
- [ ] Add
TensorProduct,TensorProductLinearandTensorProductLinearGate - [ ] Allegro-JAX and MACE-JAX
- [ ] Add
Owner
- Name: The Atomic Architects
- Login: atomicarchitects
- Kind: organization
- Location: United States of America
- Website: https://atomicarchitects.github.io/
- Twitter: AtomArchitects
- Repositories: 2
- Profile: https://github.com/atomicarchitects
Research Group of Prof. Tess Smidt
GitHub Events
Total
Last Year
Dependencies
- PyYAML ==6.0.1
- Pygments ==2.17.2
- absl-py ==2.1.0
- attrs ==23.2.0
- chex ==0.1.86
- e3nn-jax ==0.20.6
- etils ==1.8.0
- flax ==0.8.2
- fsspec ==2024.3.1
- importlib_resources ==6.4.0
- jax ==0.4.25
- jaxlib ==0.4.25
- jraph ==0.0.6.dev0
- markdown-it-py ==3.0.0
- mdurl ==0.1.2
- ml-dtypes ==0.3.2
- mpmath ==1.3.0
- msgpack ==1.0.8
- nest-asyncio ==1.6.0
- numpy ==1.26.4
- nvidia-cublas-cu12 ==12.4.2.65
- nvidia-cuda-cupti-cu12 ==12.4.99
- nvidia-cuda-nvcc-cu12 ==12.4.99
- nvidia-cuda-runtime-cu12 ==12.4.99
- nvidia-cudnn-cu12 ==9.0.0.312
- nvidia-cufft-cu12 ==11.2.0.44
- nvidia-cusolver-cu12 ==11.6.0.99
- nvidia-cusparse-cu12 ==12.3.0.142
- nvidia-nccl-cu12 ==2.20.5
- nvidia-nvjitlink-cu12 ==12.4.99
- nvtx ==0.2.10
- opt-einsum ==3.3.0
- optax ==0.2.2
- orbax-checkpoint ==0.5.7
- protobuf ==5.26.0
- rich ==13.7.1
- scipy ==1.12.0
- sympy ==1.12
- tensorstore ==0.1.56
- toolz ==0.12.1
- tqdm ==4.66.2
- typing_extensions ==4.10.0
- zipp ==3.18.1