e3nn.c

Pure C implementation of e3nn

https://github.com/teddykoker/e3nn.c

Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 3 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org, zenodo.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (5.9%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

Pure C implementation of e3nn

Basic Info
  • Host: GitHub
  • Owner: teddykoker
  • License: mit
  • Language: C
  • Default Branch: main
  • Homepage:
  • Size: 1.27 MB
Statistics
  • Stars: 18
  • Watchers: 3
  • Forks: 5
  • Open Issues: 4
  • Releases: 0
Created over 1 year ago · Last pushed 12 months ago
Metadata Files
Readme License Citation

README.md

e3nn.c

DOI

Pure C implementation of e3nn. Mostly done for pedagogical reasons, but similar code could be used for C/C++ implementations of e3nn-based models for inference or CUDA kernels for faster operations within Python libraries.

Currently the only operations implemented are the tensor product, and spherical harmonics.

Single-thread CPU performance of the tensor product on an Intel i5 Desktop Processor.

Message Computation

```c

include

include "e3nn.h"

// example.c int main(void){

float node_position_sh[9] = { 0 };
Irreps* node_irreps = irreps_create("1x0e + 1x1o + 1x2e");
spherical_harmonics(node_irreps, 1, 2, 3, node_position_sh);

printf("sh ["); for (int i = 0; i < 9; i++){ printf("%.2f, ", node_position_sh[i]); } printf("]\n");
irreps_free(node_irreps);

float neighbor_feature[] = { 7, 8, 9 };
float product[27] = { 0 };
Irreps* node_sh_irreps = irreps_create("1x0e + 1x1o + 1x2e");
Irreps* neighbor_feature_irreps = irreps_create("1x1e");
Irreps* product_irreps = irreps_create("1x0o + 1x1o + 2x1e + 1x2e + 1x2o + 1x3e");
tensor_product(node_sh_irreps, node_position_sh, 
               neighbor_feature_irreps, neighbor_feature, 
               product_irreps, product);

printf("product ["); for (int i = 0; i < 27; i++){ printf("%.2f, ", product[i]); } printf("]\n");
irreps_free(node_sh_irreps);
irreps_free(neighbor_feature_irreps);

float weights[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9 };
//                [ 1 x 1 weight] [1 x 1 weight] [2 x 2 weight] [1 x 1 weight] [1 x 1 weight] [ 1 x 1 weight]
float output[27] = { 0 };
Irreps* output_irreps = irreps_create("1x0o + 1x1o + 2x1e + 1x2e + 1x2o + 1x3e");
linear(product_irreps,
       product,
       weights,
       output_irreps,
       output);

printf("output ["); for (int i = 0; i < 27; i++) { printf("%.2f, ", output[i]); } printf("]\n");
irreps_free(product_irreps);
irreps_free(output_irreps);

return 0;

} ```

shell $ make example && ./example sh [1.00, 0.46, 0.93, 1.39, 0.83, 0.55, -0.16, 1.66, 1.11, ] product [13.36, -1.96, 3.93, -1.96, 7.00, 8.00, 9.00, 2.63, 9.50, 16.36, -2.71, 0.00, 4.69, 2.71, -1.36, 9.82, 7.20, -0.38, 13.75, 6.55, 10.76, 13.42, 2.58, -9.40, 5.91, 11.50, 2.93, ] output [13.36, -3.93, 7.86, -3.93, 24.13, 50.54, 76.95, 30.94, 62.91, 94.88, -18.97, 0.00, 32.86, 18.97, -9.49, 78.56, 57.61, -3.02, 109.98, 52.37, 96.83, 120.75, 23.18, -84.62, 53.18, 103.50, 26.41, ]

Writes the same values to buffer output as the following Python code:

```python import jax.numpy as jnp import e3nn_jax as e3nn

nodeposition = jnp.asarray([1, 2, 3]) nodepositionsh = e3nn.sphericalharmonics("1x0e + 1x1o + 1x2e", nodeposition, normalize=True, normalization="component") print("sp ", nodeposition_sh.array)

neighborfeature = e3nn.IrrepsArray("1x1e", jnp.asarray([7,8,9])) tp = e3nn.tensorproduct(nodepositionsh, neighbor_feature) print("product", tp.array) linear = e3nn.flax.Linear("1x0o + 1x1o + 2x1e + 1x2e + 1x2o + 1x3e", "1x0o + 1x1o + 2x1e + 1x2e + 1x2o + 1x3e") weights = {'params': {'w[0,0] 1x0o,1x0o': jnp.asarray([[1]]), 'w[1,1] 1x1o,1x1o': jnp.asarray([[2]]), 'w[2,2] 2x1e,2x1e': jnp.asarray([[3 , 4], [ 5, 6]]), 'w[3,3] 1x2e,1x2e': jnp.asarray([[7]]), 'w[4,4] 1x2o,1x2o': jnp.asarray([[8]]), 'w[5,5] 1x3e,1x3e': jnp.asarray([[9]])}} message = linear.apply(weights, tp) print("output", message.array) ```

Tetris

See tetris.c which implements a full E(3) equivariant neural network for the classification of tetrominoes. The model can be trained with python train_tetris.py, which saves the model weights to a binary format in tetris.bin. The model can be used for inference by supplying it 4 xyz coordinates on the command line. python run_tetris.py should produce the same outputs using the JAX implementation.

```shell $ make tetris

usage

$ ./tetris usage: ./tetris x1 y1 z1 x2 y2 z2 x3 y3 z3 x4 y4 z4

zigzag

$ ./tetris 0 0 0 1 0 0 1 1 0 2 1 0 logits: chiral 1 -0.00000 chiral 2 0.00000 square 4.51680 line 1.20807 corner 5.59851 L 4.09760 T 5.82929 zigzag 6.47695

line

$ ./tetris 0 0 0 0 0 1 0 0 2 0 0 3 logits: chiral 1 -0.00000 chiral 2 0.00000 square 0.72002 line 8.12406 corner -1.34077 L 6.86459 T 4.45846 zigzag 1.23425

rotated line

$ ./tetris 0 0 0 1 0 0 2 0 0 3 0 0
logits: chiral 1 -0.00000 chiral 2 0.00000 square 0.72002 line 8.12406 corner -1.34077 L 6.86459 T 4.45846 zigzag 1.23425

line with python

python run_tetris.py 0 0 0 0 0 1 0 0 2 0 0 3 chiral 1 -0.00000 chiral 2 0.00000 square 0.72002 line 8.12406 corner -1.34077 L 6.86459 T 4.45846 zigzag 1.23425 ```

Usage

See example above and in message_example.c. Run with

bash make message_example ./message_example

Currently the output irrep must be defined manually. This could be computed on the fly with minimal computational cost, however I am not sure what makes for the best API here. Additionally, only component normalization is currently implemented, and it will not function properly if the output irreps do not match the full simplified output irreps (i.e. no filtering); see Todo.

Benchmarking

```bash python -m ./venv source venv/bin/activate pip install -r extra/requirements.txt

make benchmark ```

e3nn.c contains several tensor product implementations, each with improvements over the previous for faster runtime.

v1

tensor_product_v1 Is a naive implementation that performs the entire tensor product for all Clebsch-Gordan coefficients:

math (u \otimes v)^{(l)}_m = \sum_{m_1 = -l_1}^{l_1}\sum_{m_2 = -l_2}^{l_2} C^{(l, m)}_{(l_1, m_1)(l_2, m_2)} u^{(l_1)}_{m_1}v^{(l_2)}_{m_2}

To minize overhead in the computation of the Clebsch-Gordan coefficients, they are pre-computed up to L_MAX and cached the first time the tensor product is called, creating a one-time startup cost.

v2

The tensor_product_v2 implementation leverages the fact that, even after conversion to the real basis, the Clebsch-Gordan coeffecients are generally sparse, with many entries equal to 0. To take advantage of this, we precompute a data structure that stores only the non-zero entries of $C$ at each $l1$, $l2$, $l$ and their corresponding index at $m1$, $m2$, $m$. This significantly improves performance by elminating needless operations of iterating through 0 valued coefficients. Just-in-time (JIT) compilers built into JAX and PyTorch are likely able to perform this optimization as well.

v3

tensor_product_v3 forgoes the computation of Clebsch-Gordan coefficients all together, and instead generates C code to compute the partial tensor product at every $l1$, $l2$, $l$ combination up to L_MAX. This elimates the need to iterate over any coefficients, allowing each value in the output to be written in a single step. As it as generated at compile time, the C compliler can also make optimizations to ensure the operations are fast. See tp_codegen.py, which generates tp.c, containing all of the tensor product paths.

Todo:

  • [X] Benchmark against e3nn and e3nn-jax
  • [X] Sparse Clebsch-Gordan implementation
  • [X] Implement Spherical Harmonics
  • [X] Implement Linear/Self-interaction operation
  • [ ] Implement filter_ir_out and irrep_normalization="norm" for tensor product
  • [ ] Full Nequip, Allegro, or ChargE3Net implementation
  • [ ] Implement integral, norm, and no normalization for spherical harmonics
  • [ ] ...

See also

  • e3nn PyTorch
  • e3nn-jax
  • The e3nn paper: https://arxiv.org/abs/2207.09453
  • Numerical Recipes in C, 2nd Edition (Press et al.) - helpful formulae and reference implementations for Legendre polynomials, Bessel functions
  • karpathy/llama.c - inspo for work

Owner

  • Name: Teddy Koker
  • Login: teddykoker
  • Kind: user
  • Location: Boston, MA
  • Company: MIT Lincoln Laboratory

Machine Learning @mit-ll

Citation (CITATION.CFF)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Koker"
  given-names: "Teddy"
  orcid: "https://orcid.org/0000-0001-8861-9788"
title: "e3nn.c"
version: 0.1.0
doi: 10.5281/zenodo.14183951
date-released: 2024-11-18
url: "https://github.com/teddykoker/e3nn.c"

GitHub Events

Total
  • Create event: 1
  • Issues event: 1
  • Release event: 1
  • Watch event: 7
  • Issue comment event: 3
  • Push event: 1
  • Fork event: 1
Last Year
  • Create event: 1
  • Issues event: 1
  • Release event: 1
  • Watch event: 7
  • Issue comment event: 3
  • Push event: 1
  • Fork event: 1

Dependencies

extra/requirements.txt pypi
  • Jinja2 ==3.1.4
  • MarkupSafe ==2.1.5
  • PyYAML ==6.0.1
  • Pygments ==2.18.0
  • absl-py ==2.1.0
  • attrs ==23.2.0
  • chex ==0.1.86
  • contourpy ==1.2.1
  • cycler ==0.12.1
  • e3nn ==0.5.1
  • e3nn-jax ==0.20.6
  • etils ==1.7.0
  • filelock ==3.15.2
  • flax ==0.8.4
  • fonttools ==4.53.0
  • fsspec ==2024.6.0
  • importlib_resources ==6.4.0
  • jax ==0.4.29
  • jaxlib ==0.4.29
  • jraph ==0.0.6.dev0
  • kiwisolver ==1.4.5
  • markdown-it-py ==3.0.0
  • matplotlib ==3.9.0
  • mdurl ==0.1.2
  • ml-dtypes ==0.4.0
  • mpmath ==1.3.0
  • msgpack ==1.0.8
  • nest-asyncio ==1.6.0
  • networkx ==3.3
  • numpy ==2.0.0
  • opt-einsum ==3.3.0
  • opt-einsum-fx ==0.1.4
  • optax ==0.2.2
  • orbax-checkpoint ==0.5.16
  • packaging ==24.1
  • pillow ==10.3.0
  • protobuf ==5.27.1
  • pyparsing ==3.1.2
  • python-dateutil ==2.9.0.post0
  • rich ==13.7.1
  • scipy ==1.13.1
  • six ==1.16.0
  • sympy ==1.12.1
  • tensorstore ==0.1.61
  • toolz ==0.12.1
  • torch ==2.3.1
  • triton ==2.3.1
  • typing_extensions ==4.12.2
  • zipp ==3.19.2