https://github.com/csteinmetz1/auraloss

Collection of audio-focused loss functions in PyTorch

https://github.com/csteinmetz1/auraloss

Science Score: 33.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
    1 of 7 committers (14.3%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (9.6%) to scientific vocabulary

Keywords

audio loss-functions pytorch

Keywords from Contributors

archival projection parallel profiles distribution transformers sequences optimizing-compiler generic interactive
Last synced: 5 months ago · JSON representation

Repository

Collection of audio-focused loss functions in PyTorch

Basic Info
  • Host: GitHub
  • Owner: csteinmetz1
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 130 KB
Statistics
  • Stars: 802
  • Watchers: 16
  • Forks: 72
  • Open Issues: 26
  • Releases: 1
Topics
audio loss-functions pytorch
Created over 5 years ago · Last pushed over 1 year ago
Metadata Files
Readme License

README.md

# auraloss A collection of audio-focused loss functions in PyTorch. [[PDF](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf)]

Setup

pip install auraloss

If you want to use MelSTFTLoss() or FIRFilter() you will need to specify the extra install (librosa and scipy).

pip install auraloss[all]

Usage

```python import torch import auraloss

mrstft = auraloss.freq.MultiResolutionSTFTLoss()

input = torch.rand(8,1,44100) target = torch.rand(8,1,44100)

loss = mrstft(input, target) ```

NEW: Perceptual weighting with mel scaled spectrograms.

```python

bs = 8 chs = 1 seqlen = 131072 samplerate = 44100

some audio you want to compare

target = torch.rand(bs, chs, seqlen) pred = torch.rand(bs, chs, seqlen)

define the loss function

lossfn = auraloss.freq.MultiResolutionSTFTLoss( fftsizes=[1024, 2048, 8192], hopsizes=[256, 512, 2048], winlengths=[1024, 2048, 8192], scale="mel", nbins=128, samplerate=samplerate, perceptualweighting=True, )

compute

loss = loss_fn(pred, target)

```

Citation

If you use this code in your work please consider citing us. bibtex @inproceedings{steinmetz2020auraloss, title={auraloss: {A}udio focused loss functions in {PyTorch}}, author={Steinmetz, Christian J. and Reiss, Joshua D.}, booktitle={Digital Music Research Network One-day Workshop (DMRN+15)}, year={2020} }

Loss functions

We categorize the loss functions as either time-domain or frequency-domain approaches. Additionally, we include perceptual transforms.

Loss function Interface Reference
Time domain
Error-to-signal ratio (ESR) auraloss.time.ESRLoss() Wright & Välimäki, 2019
DC error (DC) auraloss.time.DCLoss() Wright & Välimäki, 2019
Log hyperbolic cosine (Log-cosh) auraloss.time.LogCoshLoss() Chen et al., 2019
Signal-to-noise ratio (SNR) auraloss.time.SNRLoss()
Scale-invariant signal-to-distortion
ratio (SI-SDR)
auraloss.time.SISDRLoss() Le Roux et al., 2018
Scale-dependent signal-to-distortion
ratio (SD-SDR)
auraloss.time.SDSDRLoss() Le Roux et al., 2018
Frequency domain
Aggregate STFT auraloss.freq.STFTLoss() Arik et al., 2018
Aggregate Mel-scaled STFT auraloss.freq.MelSTFTLoss(sample_rate)
Multi-resolution STFT auraloss.freq.MultiResolutionSTFTLoss() Yamamoto et al., 2019*
Random-resolution STFT auraloss.freq.RandomResolutionSTFTLoss() Steinmetz & Reiss, 2020
Sum and difference STFT loss auraloss.freq.SumAndDifferenceSTFTLoss() Steinmetz et al., 2020
Perceptual transforms
Sum and difference signal transform auraloss.perceptual.SumAndDifference()
FIR pre-emphasis filters auraloss.perceptual.FIRFilter() Wright & Välimäki, 2019

* Wang et al., 2019 also propose a multi-resolution spectral loss (that Engel et al., 2020 follow), but they do not include both the log magnitude (L1 distance) and spectral convergence terms, introduced in Arik et al., 2018, and then extended for the multi-resolution case in Yamamoto et al., 2019.

Examples

Currently we include an example using a set of the loss functions to train a TCN for modeling an analog dynamic range compressor. For details please refer to the details in examples/compressor. We provide pre-trained models, evaluation scripts to compute the metrics in the paper, as well as scripts to retrain models.

There are some more advanced things you can do based upon the STFTLoss class. For example, you can compute both linear and log scaled STFT errors as in Engel et al., 2020. In this case we do not include the spectral convergence term. python stft_loss = auraloss.freq.STFTLoss( w_log_mag=1.0, w_lin_mag=1.0, w_sc=0.0, )

There is also a Mel-scaled STFT loss, which has some special requirements. This loss requires you set the sample rate as well as specify the correct device. python sample_rate = 44100 melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda")

You can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily. Make sure you pass the correct device where the tensors you are comparing will be. python loss_fn = auraloss.freq.MultiResolutionSTFTLoss( scale="mel", n_bins=64, sample_rate=sample_rate, device="cuda" )

If you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss. Below we have shown an example of using this loss function with the perceptual weighting and mel scaling for further perceptual relevance.

```python

target = torch.rand(8, 2, 44100) pred = torch.rand(8, 2, 44100)

lossfn = auraloss.freq.SumAndDifferenceSTFTLoss( fftsizes=[1024, 2048, 8192], hopsizes=[256, 512, 2048], winlengths=[1024, 2048, 8192], perceptualweighting=True, samplerate=44100, scale="mel", n_bins=128, )

loss = loss_fn(pred, target) ```

Development

Run tests locally with pytest.

python -m pytest

Owner

  • Name: Christian J. Steinmetz
  • Login: csteinmetz1
  • Kind: user
  • Location: London, UK
  • Company: @aim-qmul

Machine learning for Hi-Fi audio. PhD Researcher at C4DM.

GitHub Events

Total
  • Issues event: 1
  • Watch event: 65
  • Issue comment event: 7
  • Pull request event: 1
  • Fork event: 9
Last Year
  • Issues event: 1
  • Watch event: 65
  • Issue comment event: 7
  • Pull request event: 1
  • Fork event: 9

Committers

Last synced: 9 months ago

All Time
  • Total Commits: 97
  • Total Committers: 7
  • Avg Commits per committer: 13.857
  • Development Distribution Score (DDS): 0.144
Past Year
  • Commits: 0
  • Committers: 0
  • Avg Commits per committer: 0.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
csteinmetz1 c****1@g****m 83
Ben Hayes b****s@q****k 4
Simon Schwär s****r@a****e 3
Joseph Turian t****n@g****m 3
christhetree c****e@g****m 2
dependabot[bot] 4****] 1
Leo Auri c****e@l****m 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 37
  • Total pull requests: 42
  • Average time to close issues: 2 months
  • Average time to close pull requests: 9 days
  • Total issue authors: 24
  • Total pull request authors: 12
  • Average comments per issue: 1.78
  • Average comments per pull request: 0.67
  • Merged pull requests: 31
  • Bot issues: 0
  • Bot pull requests: 2
Past Year
  • Issues: 2
  • Pull requests: 2
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 2
  • Pull request authors: 2
  • Average comments per issue: 1.5
  • Average comments per pull request: 0.5
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • turian (9)
  • csteinmetz1 (4)
  • sevagh (2)
  • fncode246 (1)
  • IvanDSM (1)
  • LearnedVector (1)
  • leoauri (1)
  • Kinyugo (1)
  • happyTonakai (1)
  • jarredou (1)
  • drscotthawley (1)
  • dlfqhsdugod1106 (1)
  • bluenote10 (1)
  • ben-hayes (1)
  • sdatkinson (1)
Pull Request Authors
  • csteinmetz1 (25)
  • sai-soum (3)
  • simonschwaer (3)
  • jmoso13 (2)
  • turian (2)
  • dependabot[bot] (2)
  • renared (2)
  • christhetree (2)
  • sdatkinson (1)
  • leoauri (1)
  • bluenote10 (1)
  • ben-hayes (1)
Top Labels
Issue Labels
Pull Request Labels
dependencies (2)

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 30,926 last-month
  • Total docker downloads: 123
  • Total dependent packages: 4
  • Total dependent repositories: 20
  • Total versions: 6
  • Total maintainers: 1
pypi.org: auraloss

Collection of audio-focused loss functions in PyTorch.

  • Versions: 6
  • Dependent Packages: 4
  • Dependent Repositories: 20
  • Downloads: 30,926 Last month
  • Docker Downloads: 123
Rankings
Dependent packages count: 1.9%
Downloads: 2.6%
Stargazers count: 2.6%
Dependent repos count: 3.2%
Average: 3.3%
Docker downloads count: 4.1%
Forks count: 5.5%
Maintainers (1)
Last synced: 6 months ago