yax

Yet Another X: JAX/FLAX module tracing, modification, and evaluation.

https://github.com/daskol/yax

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.2%) to scientific vocabulary

Keywords

flax jax mox yax
Last synced: 8 months ago · JSON representation ·

Repository

Yet Another X: JAX/FLAX module tracing, modification, and evaluation.

Basic Info
  • Host: GitHub
  • Owner: daskol
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 273 KB
Statistics
  • Stars: 0
  • Watchers: 1
  • Forks: 0
  • Open Issues: 1
  • Releases: 0
Topics
flax jax mox yax
Created over 1 year ago · Last pushed over 1 year ago
Metadata Files
Readme License Citation

README.md

Linting and testing Nightly

YAX

Yet Another X: JAX/FLAX module tracing, modification, and evaluation.

Overview

Deep learning frameworks like PyTorch, Keras, and JAX/FLAX usually provide a "module-level" API, which abstracts a layer—an architectural unit in a neural network. While modules are descriptive and easy to use, they can sometimes be inconvenient to work with programmatically. Specifically, it is challenging to modify model architecture on the fly, though changing weight structures dynamically is not as difficult. So, why can't we work with modules in the same flexible way?

YAX is a library within the JAX/FLAX ecosystem for building, evaluating, and modifying the intermediate representation of a neural network's modular structure. Modular structures are represented with the help of MoX, a Module eXpression, which is an extension of JAX expressions (Jaxpr). MoX is pronounced as ∗[mokh]∗ and means "moss" in Russian.

bash pip install git+https://github.com/daskol/yax.git

Usage

Module expressions (MoX) are extremely useful in certain situations. For example, they enable the application of custom LoRA-like adapters or model performance optimizations, such as quantized gradient activation functions (see fewbit). We've briefly discussed what YAX/MoX can accomplish, and we’ll use the ResBlock below for further demonstrations.

```python import flax.linen as nn import yax

class ResBlock(nn.Module): @nn.compact def call(self, xs): return xs + nn.Dense(10)(xs)

mod = ResBlock() batch = jnp.empty(1, 10) params = jax.jit(mod.init)(jax.random.PRNGKey(42), batch) ```

Tracing First, we need to build a module representation (also known as MoX). This can be done in a similar way to creating a Jaxpr (see jax.make_jaxpr).

python mox = yax.make_mox(mod.apply)(params, batch) print(mox)

Pretty printing is is not very pretty for MoX at the moment but it will look like the following. Also, we have implemented serialization to XML and YSON (see Serialization section).

jaxpr { lambda ; a:f32[10] b:f32[10,10] c:f32[1,10]. let d:f32[1,10] = module_call { lambda ; a:f32[10] b:f32[10,10] c:f32[1,10]. let d:f32[1,10] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c b e:f32[1,10] = reshape[dimensions=None new_sizes=(1, 10)] a f:f32[1,10] = add d e in (f,) } e:f32[1,10] = add d a in (e,)}

Evaluation MoX can be evaluated similarly to Jaxpr, but the most important feature is that yax.eval_mox can be composed with common JAX transformations, as shown below.

```python def apply(params, batch): return yax.evalmox(mox, params, inputbatch)

_ = apply(params, batch) # Greedy evaluation. _ = jax.jit(apply)(params, batch) # JIT-compiled execution. ```

Querying MoX provides tools for model exploration and examination. Specifically, MoX can help answer questions like: "What nn.Dense modules have 10 features?"

python modules: Sequence[yax.Mox] = yax.query('//module_call[@features=10]', mox)

We use XPath (the familiar XML Path expression language) for writing queries. XPath is a concise and convenient way to express search conditions. In fact, the module tree can be represented similarly to a DOM structure, which effectively models the nested structure of a neural network as well as the module attributes in its internal nodes.

Modification With such an expressive query language, modifying an original model on the fly becomes easy. For example, one can replace all ReLU activation functions with GELU or substitute all nn.Dense layers with LoRA adapters.

```python

Replace ReLU with GELU

gelumox = yax.makemox(nn.gelu)(inputs) modifiedmox = yax.sub('//pjit[@name="relu"]', gelumox, mox)

Apply LoRA-adapters to all fully-connected layers.

loramox = yax.makemox(lora.apply)(params, inputs) modifiedmox = yax.sub('//modulecall[@type="Dense"]', lora_mox, mox) ```

Module Expression (MoX)

XML

The funniest part about MoX is that it can be serialized to XML. Hardly anyone uses XML nowadays outside the Java ecosystem and some legacy projects. However, XML is actually a good and even appropriate serialization format.

xml <module_call type="flax.nn.Dense" name="Dense_0" features="10"> <input type="fp32[10]">a</input> <input type="fp32[10,10]">b</input> <input type="fp32[10]">c</input> <dot_general dimension_numbers="(([0], [0]), ([], []))"> <input type="fp32[10,10]">b</input> <input type="fp32[10]">c</input> <output type="fp32[10,10]">d</output> </dot_general> <pjit jaxpr="{ lambda ; a:f32[10], b:f32[10]. let c:f32[10] = add a b in (c,) }"> <input type="fp32[10]">d</input> <input type="fp32[10]">a</input> <output type="fp32[10]">e</output> </pjit> <outputs type="fp32[10]">e</outputs> </module_call>

YSON

YSON stands for Yandex Serialization Object Notation. It is a serialization format similar to JSON due to its compact notation but is more expressive. In terms of representational expressiveness, YSON is comparable to XML.

```yson [ #; <primitive="pjit"; inputs={d="fp32[10]"; a="fp32[10]"}; outputs={e="fp32[10]"}; jaxpr="{ lambda ; a:f32[10], b:f32[10]. let c:f32[10] = add a b in (c,) }";

;

] ```

Limitations

Substitution requires the preservation of some invariants.

  • Inputs and outputs are reused.
  • New outputs are prohibited for now.
  • New inputs are propagated to root node. There is a difference between Jaxpr (leaf) and Mox (inode).

    • [Jaxpr] New inputs are append to all parents.
    • [MoX] Inernal node have two kind of input parameters: plain inputs and weight params. FLAX requires weight params to be the first input parameter. Thus old subtree should be updated with the new one.

In order to update input parameters, we should update in_tree as well. Similarly, update to weight params requires update to var_tree. Note that inputs/params handling for root node differs since params are passed explicitely while for all internal expressions params comprises closure context. Surely, any modification of in_tree or var_tree requires update of input symbols.

Note, the all parent should be marks as ephemeral. Also, inputs and outputs of a replacement should be type checked agains its predcessors and successors respectively.

python def substitute(parents, expr): for parent in reversed(parents): update_param_tree(parent, expr) - Compositionality with jax.scan, jax.vmap, and jax.pmap is not verified. - Pretty printing of module expressions is not available for now.

Container

shell docker pull ghcr.io/daskol/yax

Owner

  • Name: Daniel Bershatsky
  • Login: daskol
  • Kind: user
  • Location: Russia, Moscow
  • Company: @skoltech-ai

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Bershatsky"
  given-names: "Daniel"
  orcid: "https://orcid.org/0000-0001-8917-8187"
title: "YAX: JAX/FLAX Module Tracing, Evaluation, and Mutation"
version: 0.0.0
date-released: 2024-10-30
url: "https://github.com/daskol/yax"

GitHub Events

Total
  • Issues event: 1
  • Delete event: 1
  • Push event: 23
  • Public event: 1
  • Pull request event: 2
  • Create event: 2
Last Year
  • Issues event: 1
  • Delete event: 1
  • Push event: 23
  • Public event: 1
  • Pull request event: 2
  • Create event: 2

Committers

Last synced: 10 months ago

All Time
  • Total Commits: 56
  • Total Committers: 1
  • Avg Commits per committer: 56.0
  • Development Distribution Score (DDS): 0.0
Past Year
  • Commits: 56
  • Committers: 1
  • Avg Commits per committer: 56.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
Daniel Bershatsky d****y@g****m 56

Issues and Pull Requests

Last synced: 11 months ago

All Time
  • Total issues: 1
  • Total pull requests: 1
  • Average time to close issues: N/A
  • Average time to close pull requests: 4 minutes
  • Total issue authors: 1
  • Total pull request authors: 1
  • Average comments per issue: 0.0
  • Average comments per pull request: 0.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 1
  • Pull requests: 1
  • Average time to close issues: N/A
  • Average time to close pull requests: 4 minutes
  • Issue authors: 1
  • Pull request authors: 1
  • Average comments per issue: 0.0
  • Average comments per pull request: 0.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • daskol (1)
Pull Request Authors
  • daskol (2)
Top Labels
Issue Labels
bug (1)
Pull Request Labels
bug (2)

Dependencies

.github/workflows/on-push.yml actions
  • actions/cache v4 composite
  • actions/checkout v4 composite
  • actions/setup-python f677139bbe7f9c59b41e40162b753c062f5d49a3 composite
  • actions/setup-python v4 composite
  • actions/upload-artifact v4 composite
  • pre-commit/action v3.0.0 composite
.github/workflows/on-release.yml actions
  • actions/checkout 692973e3d937129bcbf40652eb9f2f61becf3332 composite
  • actions/download-artifact fa0a91b85d4f404e444e00e005971372dc801d16 composite
  • actions/setup-python 39cd14951b08e74b54015e9e001cdefcf80e669f composite
  • actions/upload-artifact 834a144ee995460fba8ed112a2fc961b36a5ec5a composite
  • pypa/gh-action-pypi-publish ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 composite
.github/workflows/on-schedule.yml actions
  • actions/cache v4 composite
  • actions/checkout v4 composite
  • actions/setup-python v5 composite
  • actions/setup-python v4 composite
  • actions/upload-artifact v4 composite
  • pre-commit/action v3.0.0 composite
pyproject.toml pypi
  • flax *
  • jax *
  • numpy *
Dockerfile docker
  • base latest build
  • nvidia/cuda 12.6.2-cudnn-devel-ubuntu24.04 build
examples/nonlinearity/pyproject.toml pypi