https://github.com/dbraun/dac-jax

JAX Implementations of Descript Audio Codec and EnCodec

https://github.com/dbraun/dac-jax

Science Score: 36.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.4%) to scientific vocabulary

Keywords

audio audio-codec audio-compression jax machine-learning
Last synced: 5 months ago · JSON representation

Repository

JAX Implementations of Descript Audio Codec and EnCodec

Basic Info
  • Host: GitHub
  • Owner: DBraun
  • License: mit
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 268 KB
Statistics
  • Stars: 24
  • Watchers: 2
  • Forks: 2
  • Open Issues: 0
  • Releases: 0
Topics
audio audio-codec audio-compression jax machine-learning
Created about 2 years ago · Last pushed 11 months ago
Metadata Files
Readme License

README.md

DAC-JAX and EnCodec-JAX

This repository holds unofficial JAX implementations of Descript's DAC and Meta's EnCodec. We are not affiliated with Descript or Meta.

You can read the DAC-JAX paper here.

Background

In 2022, Meta published "High Fidelity Neural Audio Compression". They eventually open-sourced the code inside AudioCraft.

In 2023, Descript published a related work "High-Fidelity Audio Compression with Improved RVQGAN" and released their code under the name DAC (Descript Audio Codec).

Both EnCodec and DAC are neural audio codecs which use residual vector quantization inside a fully convolutional encoder-decoder architecture.

Usage

Installation

  1. Upgrade pip and setuptools: bash pip install --upgrade pip setuptools

  2. Install the CPU version of PyTorch. We strongly suggest the CPU version because trying to install a GPU version can conflict with JAX's CUDA-related installation. PyTorch is required because it's used to load pretrained model weights.

  3. Install JAX (with GPU support).

  4. Install DAC-JAX with one of the following:

    pip install git+https://github.com/DBraun/DAC-JAX

    Or,

    bash python -m pip install .

    Or, if you intend to contribute, clone and do an editable install: bash python -m pip install -e ".[dev]"

Weights

The original Descript repository releases model weights under the MIT license. These weights are for models that natively support 16 kHz, 24kHz, and 44.1kHz sampling rates. Our scripts download these PyTorch weights and load them into JAX. Weights are automatically downloaded when you first run an encode or decode command. You can download them in advance with one of the following commands: bash python -m dac_jax download_model # downloads the default 44kHz variant python -m dac_jax download_model --model_type 44khz --model_bitrate 16kbps # downloads the 44kHz 16 kbps variant python -m dac_jax download_model --model_type 44khz # downloads the 44kHz variant python -m dac_jax download_model --model_type 24khz # downloads the 24kHz variant python -m dac_jax download_model --model_type 16khz # downloads the 16kHz variant

EnCodec weights can be downloaded similarly. This will download the 32 kHz EnCodec used in MusicGen. bash python -m dac_jax download_encodec

For both DAC and EnCodec, the default download location is ~/.cache/dac_jax. You can change the location by setting an absolute path value for an environment variable DAC_JAX_CACHE. For example, on macOS/Linux: bash export DAC_JAX_CACHE=/Users/admin/my-project/dac_jax_models

If you do this, remember to still have DAC_JAX_CACHE set before you use the load_model function.

Compress audio

python -m dac_jax encode /path/to/input --output /path/to/output/codes

This command will create .dac files with the same name as the input files. It will also preserve the directory structure relative to input root and re-create it in the output directory. Please use python -m dac_jax encode --help for more options.

Reconstruct audio from compressed codes

python -m dac_jax decode /path/to/output/codes --output /path/to/reconstructed_input

This command will create .wav files with the same name as the input files. It will also preserve the directory structure relative to input root and re-create it in the output directory. Please use python -m dac_jax decode --help for more options.

Programmatic usage (DAC and EnCodec)

Here we use jax.jit for optimized encoding and decoding. This does not do sample-rate conversion or volume normalization in the encoder or decoder.

```python from functools import partial

import jax from jax import numpy as jnp import librosa

import dac_jax

model, variables = dacjax.loadmodel(model_type="44khz")

If you want to use pretrained 32 kHz EnCodec from Meta's MusicGen, use this:

model, variables = dacjax.loadencodec_model()

@jax.jit def encodetocodes(x: jnp.ndarray): codes, scale = model.apply( variables, x, method="encode", ) return codes, scale

@partial(jax.jit, staticargnums=(1, 2)) def decodefrom_codes(codes: jnp.ndarray, scale, length: int = None): recons = model.apply( variables, codes, scale, length, method="decode", ) return recons

Load a mono audio file with the correct sample rate

signal, samplerate = librosa.load('input.wav', sr=model.samplerate, mono=True, duration=.5)

signal = jnp.array(signal, dtype=jnp.float32) while signal.ndim < 3: signal = jnp.expand_dims(signal, axis=0)

original_length = signal.shape[-1]

codes, scale = encodetocodes(signal) assert codes.shape[1] == model.num_codebooks

recons = decodefromcodes(codes, scale, original_length) ```

DAC with Binding

Here we use DAC-JAX as a "bound" module, freeing us from repeatedly passing variables as an argument and using .apply. Note that bound modules are not meant to be used in fine-tuning.

```python import dacjax from dacjax import DACFile

from jax import numpy as jnp import librosa

Download a model and bind variables to it.

model, variables = dacjax.loadmodel(model_type="44khz") model = model.bind(variables)

Load a mono audio file

signal, sample_rate = librosa.load('input.wav', sr=44100, mono=True, duration=.5)

signal = jnp.array(signal, dtype=jnp.float32) while signal.ndim < 3: signal = jnp.expand_dims(signal, axis=0)

Encode audio signal as one long file (may run out of GPU memory on long files).

This performs resampling to the codec's sample rate and volume normalization.

dacfile = model.encodetodac(signal, samplerate)

Save to a file

dacfile.save("dacfile_001.dac")

Load a file

dacfile = DACFile.load("dacfile_001.dac")

Decode audio signal. Since we're passing a dac_file, this undoes the

previous sample rate conversion and volume normalization.

y = model.decode(dac_file)

Calculate mean-square error of reconstruction in time-domain

mse = jnp.square(y-signal).mean() ```

DAC compression with constant GPU memory regardless of input length:

```python import dac_jax

import jax import jax.numpy as jnp import librosa

Download a model and set padding to False because we will use the chunk functions.

model, variables = dacjax.loadmodel(model_type="44khz", padding=False)

Load a mono audio file at any sample rate

signal, sample_rate = librosa.load('input.wav', sr=None, mono=True)

signal = jnp.array(signal, dtype=jnp.float32) while signal.ndim < 3: # signal will eventually be shaped [B, C, T] signal = jnp.expand_dims(signal, axis=0)

Jit-compile these functions because they're used inside a loop over chunks.

@jax.jit def compresschunk(x): return model.apply(variables, x, method='compresschunk')

@jax.jit def decompresschunk(c): return model.apply(variables, c, method='decompresschunk')

winduration = 0.5 # Adjust based on your GPU's memory size dacfile = model.compress(compresschunk, signal, samplerate, winduration=winduration)

Save and load to and from disk

dacfile.save("compressed.dac") dacfile = dac_jax.DACFile.load("compressed.dac")

Decompress it back to audio

y = model.decompress(decompresschunk, dacfile) ```

DAC Training

The baseline model configuration can be trained using the following commands.

bash python scripts/train.py --args.load conf/final/44khz.yml --train.ckpt_dir="/tmp/dac_jax_runs"

In root directory, monitor with Tensorboard (runs will appear next to scripts): bash tensorboard --logdir="/tmp/dac_jax_runs"

Testing

python -m pytest tests

Limitations

Pull requests—especially ones which address any of the limitations below—are welcome.

  • We implement the "chunked" compress/decompress methods from the PyTorch repository, although this technique has some problems outlined here.
  • We have not run all evaluation scripts in the scripts directory. For some of them, it makes sense to just keep using PyTorch instead of JAX.
  • The model architecture code (model/dac.py) has many static methods to help with finding DAC's delay and output_length. Please help us refactor this so that code is not so duplicated and at risk of typos.
  • In audio_utils.py we use DM_AUX's STFT function instead of jax.scipy.signal.stft. We believe this is faster but requires more memory.
  • The source code of DAC-JAX has some todo: markings which indicate (mostly minor) improvements we'd like to have.
  • We don't have a Docker image yet like the original DAC repository does.
  • Please check the limitations of argbind.
  • We don't provide a training script for EnCodec.

Citation

If you use this repository in your work, please cite EnCodec: @article{defossez2022high, title={High fidelity neural audio compression}, author={D{\'e}fossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi}, journal={arXiv preprint arXiv:2210.13438}, year={2022} }

DAC:

@article{kumar2024high, title={High-fidelity audio compression with improved rvqgan}, author={Kumar, Rithesh and Seetharaman, Prem and Luebs, Alejandro and Kumar, Ishaan and Kumar, Kundan}, journal={Advances in Neural Information Processing Systems}, volume={36}, year={2024} }

and DAC-JAX:

@misc{braun2024dacjax, title={{DAC-JAX}: A {JAX} Implementation of the Descript Audio Codec}, author={David Braun}, year={2024}, eprint={2405.11554}, archivePrefix={arXiv}, primaryClass={cs.SD} }

Owner

  • Name: David Braun
  • Login: DBraun
  • Kind: user
  • Company: DIRT Design

Do It Real-Time | Audiovisual ML, Faust, TouchDesigner | alum @ccrma

GitHub Events

Total
  • Watch event: 8
  • Delete event: 1
  • Issue comment event: 2
  • Push event: 5
  • Pull request event: 2
  • Fork event: 2
  • Create event: 2
Last Year
  • Watch event: 8
  • Delete event: 1
  • Issue comment event: 2
  • Push event: 5
  • Pull request event: 2
  • Fork event: 2
  • Create event: 2

Committers

Last synced: over 1 year ago

All Time
  • Total Commits: 29
  • Total Committers: 1
  • Avg Commits per committer: 29.0
  • Development Distribution Score (DDS): 0.0
Past Year
  • Commits: 29
  • Committers: 1
  • Avg Commits per committer: 29.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
David Braun 2****n 29

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 1
  • Total pull requests: 4
  • Average time to close issues: about 24 hours
  • Average time to close pull requests: 1 minute
  • Total issue authors: 1
  • Total pull request authors: 1
  • Average comments per issue: 2.0
  • Average comments per pull request: 0.0
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 1
  • Pull requests: 2
  • Average time to close issues: about 24 hours
  • Average time to close pull requests: 1 minute
  • Issue authors: 1
  • Pull request authors: 1
  • Average comments per issue: 2.0
  • Average comments per pull request: 0.0
  • Merged pull requests: 2
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • shamuiscoding (1)
Pull Request Authors
  • DBraun (6)
Top Labels
Issue Labels
Pull Request Labels

Dependencies

pyproject.toml pypi