https://github.com/atomicarchitects/fusionfail

Profile showing 3 layers of NequIP using e3nn-jax

https://github.com/atomicarchitects/fusionfail

Science Score: 10.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org, nature.com
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (7.6%) to scientific vocabulary
Last synced: 9 months ago · JSON representation

Repository

Profile showing 3 layers of NequIP using e3nn-jax

Basic Info
  • Host: GitHub
  • Owner: atomicarchitects
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 1.89 MB
Statistics
  • Stars: 0
  • Watchers: 4
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created about 2 years ago · Last pushed about 2 years ago
Metadata Files
Readme

README.md

Fusion Fail

nequip_profile

What are we looking at ?

The function train_step corresponds to a forward and backward pass through a 3 layered NequIP model implemented using e3nn-jax acting on a simple Tetris dataset. Thanks @ameya98 @mariogeiger for the code !

What's happening ?

Here's a brief summary of the under the hood story:

  • XLA is unable to pattern match or generate a small subset of fused kernels for the compuatation (See arxiv:2301.13062 to understand how XLA works). Instead its left with around ~300 kernels (half of which are cuBLAS/CUTLASS calls) that it needs to execute at runtime (small chunks below Thunk:#hlo_op in the TSL row)

  • This makes the compiler fall back to CUDAGraphs which batches the execution of these kernels. However, the execution graph needs to be updated with new inputs at runtime (~30% runtime overhead before Graph 7 is launched on the GPU). This overhead (notice the CUDA API row) increases with the size of the computation graph.

What's the alternative ?

Ideally, the compiler/human should be giving us one forward and one backward fused kernel for our computation (See FlashAttention).

Packages

bash pip install requirements.txt

To reproduce the profile shown above install NVIDIA Nsight Systems and run the following command (borrowed from JAX-Toolbox)

bash nsys profile --capture-range=cudaProfilerApi --cuda-graph-trace=node --capture-range-end=stop -o nequip_profile_disable_cudagraph -f true python train.py

TODO

  • [ ] Add a MLP-equivalent to show what non-CUDAGraph fusion should look like
  • More profiling:
    • [ ] Add TensorProduct, TensorProductLinear and TensorProductLinearGate
    • [ ] Allegro-JAX and MACE-JAX

Owner

  • Name: The Atomic Architects
  • Login: atomicarchitects
  • Kind: organization
  • Location: United States of America

Research Group of Prof. Tess Smidt

GitHub Events

Total
Last Year

Dependencies

requirements.txt pypi
  • PyYAML ==6.0.1
  • Pygments ==2.17.2
  • absl-py ==2.1.0
  • attrs ==23.2.0
  • chex ==0.1.86
  • e3nn-jax ==0.20.6
  • etils ==1.8.0
  • flax ==0.8.2
  • fsspec ==2024.3.1
  • importlib_resources ==6.4.0
  • jax ==0.4.25
  • jaxlib ==0.4.25
  • jraph ==0.0.6.dev0
  • markdown-it-py ==3.0.0
  • mdurl ==0.1.2
  • ml-dtypes ==0.3.2
  • mpmath ==1.3.0
  • msgpack ==1.0.8
  • nest-asyncio ==1.6.0
  • numpy ==1.26.4
  • nvidia-cublas-cu12 ==12.4.2.65
  • nvidia-cuda-cupti-cu12 ==12.4.99
  • nvidia-cuda-nvcc-cu12 ==12.4.99
  • nvidia-cuda-runtime-cu12 ==12.4.99
  • nvidia-cudnn-cu12 ==9.0.0.312
  • nvidia-cufft-cu12 ==11.2.0.44
  • nvidia-cusolver-cu12 ==11.6.0.99
  • nvidia-cusparse-cu12 ==12.3.0.142
  • nvidia-nccl-cu12 ==2.20.5
  • nvidia-nvjitlink-cu12 ==12.4.99
  • nvtx ==0.2.10
  • opt-einsum ==3.3.0
  • optax ==0.2.2
  • orbax-checkpoint ==0.5.7
  • protobuf ==5.26.0
  • rich ==13.7.1
  • scipy ==1.12.0
  • sympy ==1.12
  • tensorstore ==0.1.56
  • toolz ==0.12.1
  • tqdm ==4.66.2
  • typing_extensions ==4.10.0
  • zipp ==3.18.1