https://github.com/csteinmetz1/auraloss
Collection of audio-focused loss functions in PyTorch
Science Score: 33.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
✓Committers with academic emails
1 of 7 committers (14.3%) from academic institutions -
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (9.6%) to scientific vocabulary
Keywords
Keywords from Contributors
Repository
Collection of audio-focused loss functions in PyTorch
Basic Info
Statistics
- Stars: 802
- Watchers: 16
- Forks: 72
- Open Issues: 26
- Releases: 1
Topics
Metadata Files
README.md
Setup
pip install auraloss
If you want to use MelSTFTLoss() or FIRFilter() you will need to specify the extra install (librosa and scipy).
pip install auraloss[all]
Usage
```python import torch import auraloss
mrstft = auraloss.freq.MultiResolutionSTFTLoss()
input = torch.rand(8,1,44100) target = torch.rand(8,1,44100)
loss = mrstft(input, target) ```
NEW: Perceptual weighting with mel scaled spectrograms.
```python
bs = 8 chs = 1 seqlen = 131072 samplerate = 44100
some audio you want to compare
target = torch.rand(bs, chs, seqlen) pred = torch.rand(bs, chs, seqlen)
define the loss function
lossfn = auraloss.freq.MultiResolutionSTFTLoss( fftsizes=[1024, 2048, 8192], hopsizes=[256, 512, 2048], winlengths=[1024, 2048, 8192], scale="mel", nbins=128, samplerate=samplerate, perceptualweighting=True, )
compute
loss = loss_fn(pred, target)
```
Citation
If you use this code in your work please consider citing us.
bibtex
@inproceedings{steinmetz2020auraloss,
title={auraloss: {A}udio focused loss functions in {PyTorch}},
author={Steinmetz, Christian J. and Reiss, Joshua D.},
booktitle={Digital Music Research Network One-day Workshop (DMRN+15)},
year={2020}
}
Loss functions
We categorize the loss functions as either time-domain or frequency-domain approaches. Additionally, we include perceptual transforms.
| Loss function | Interface | Reference |
|---|---|---|
| Time domain | ||
| Error-to-signal ratio (ESR) | auraloss.time.ESRLoss() |
Wright & Välimäki, 2019 |
| DC error (DC) | auraloss.time.DCLoss() |
Wright & Välimäki, 2019 |
| Log hyperbolic cosine (Log-cosh) | auraloss.time.LogCoshLoss() |
Chen et al., 2019 |
| Signal-to-noise ratio (SNR) | auraloss.time.SNRLoss() |
|
| Scale-invariant signal-to-distortion ratio (SI-SDR) |
auraloss.time.SISDRLoss() |
Le Roux et al., 2018 |
| Scale-dependent signal-to-distortion ratio (SD-SDR) |
auraloss.time.SDSDRLoss() |
Le Roux et al., 2018 |
| Frequency domain | ||
| Aggregate STFT | auraloss.freq.STFTLoss() |
Arik et al., 2018 |
| Aggregate Mel-scaled STFT | auraloss.freq.MelSTFTLoss(sample_rate) |
|
| Multi-resolution STFT | auraloss.freq.MultiResolutionSTFTLoss() |
Yamamoto et al., 2019* |
| Random-resolution STFT | auraloss.freq.RandomResolutionSTFTLoss() |
Steinmetz & Reiss, 2020 |
| Sum and difference STFT loss | auraloss.freq.SumAndDifferenceSTFTLoss() |
Steinmetz et al., 2020 |
| Perceptual transforms | ||
| Sum and difference signal transform | auraloss.perceptual.SumAndDifference() |
|
| FIR pre-emphasis filters | auraloss.perceptual.FIRFilter() |
Wright & Välimäki, 2019 |
* Wang et al., 2019 also propose a multi-resolution spectral loss (that Engel et al., 2020 follow), but they do not include both the log magnitude (L1 distance) and spectral convergence terms, introduced in Arik et al., 2018, and then extended for the multi-resolution case in Yamamoto et al., 2019.
Examples
Currently we include an example using a set of the loss functions to train a TCN for modeling an analog dynamic range compressor.
For details please refer to the details in examples/compressor.
We provide pre-trained models, evaluation scripts to compute the metrics in the paper, as well as scripts to retrain models.
There are some more advanced things you can do based upon the STFTLoss class.
For example, you can compute both linear and log scaled STFT errors as in Engel et al., 2020.
In this case we do not include the spectral convergence term.
python
stft_loss = auraloss.freq.STFTLoss(
w_log_mag=1.0,
w_lin_mag=1.0,
w_sc=0.0,
)
There is also a Mel-scaled STFT loss, which has some special requirements.
This loss requires you set the sample rate as well as specify the correct device.
python
sample_rate = 44100
melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda")
You can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily.
Make sure you pass the correct device where the tensors you are comparing will be.
python
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
scale="mel",
n_bins=64,
sample_rate=sample_rate,
device="cuda"
)
If you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss. Below we have shown an example of using this loss function with the perceptual weighting and mel scaling for further perceptual relevance.
```python
target = torch.rand(8, 2, 44100) pred = torch.rand(8, 2, 44100)
lossfn = auraloss.freq.SumAndDifferenceSTFTLoss( fftsizes=[1024, 2048, 8192], hopsizes=[256, 512, 2048], winlengths=[1024, 2048, 8192], perceptualweighting=True, samplerate=44100, scale="mel", n_bins=128, )
loss = loss_fn(pred, target) ```
Development
Run tests locally with pytest.
python -m pytest
Owner
- Name: Christian J. Steinmetz
- Login: csteinmetz1
- Kind: user
- Location: London, UK
- Company: @aim-qmul
- Website: christiansteinmetz.com
- Twitter: csteinmetz1
- Repositories: 79
- Profile: https://github.com/csteinmetz1
Machine learning for Hi-Fi audio. PhD Researcher at C4DM.
GitHub Events
Total
- Issues event: 1
- Watch event: 65
- Issue comment event: 7
- Pull request event: 1
- Fork event: 9
Last Year
- Issues event: 1
- Watch event: 65
- Issue comment event: 7
- Pull request event: 1
- Fork event: 9
Committers
Last synced: 9 months ago
Top Committers
| Name | Commits | |
|---|---|---|
| csteinmetz1 | c****1@g****m | 83 |
| Ben Hayes | b****s@q****k | 4 |
| Simon Schwär | s****r@a****e | 3 |
| Joseph Turian | t****n@g****m | 3 |
| christhetree | c****e@g****m | 2 |
| dependabot[bot] | 4****] | 1 |
| Leo Auri | c****e@l****m | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 37
- Total pull requests: 42
- Average time to close issues: 2 months
- Average time to close pull requests: 9 days
- Total issue authors: 24
- Total pull request authors: 12
- Average comments per issue: 1.78
- Average comments per pull request: 0.67
- Merged pull requests: 31
- Bot issues: 0
- Bot pull requests: 2
Past Year
- Issues: 2
- Pull requests: 2
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 2
- Pull request authors: 2
- Average comments per issue: 1.5
- Average comments per pull request: 0.5
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- turian (9)
- csteinmetz1 (4)
- sevagh (2)
- fncode246 (1)
- IvanDSM (1)
- LearnedVector (1)
- leoauri (1)
- Kinyugo (1)
- happyTonakai (1)
- jarredou (1)
- drscotthawley (1)
- dlfqhsdugod1106 (1)
- bluenote10 (1)
- ben-hayes (1)
- sdatkinson (1)
Pull Request Authors
- csteinmetz1 (25)
- sai-soum (3)
- simonschwaer (3)
- jmoso13 (2)
- turian (2)
- dependabot[bot] (2)
- renared (2)
- christhetree (2)
- sdatkinson (1)
- leoauri (1)
- bluenote10 (1)
- ben-hayes (1)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
-
Total downloads:
- pypi 30,926 last-month
- Total docker downloads: 123
- Total dependent packages: 4
- Total dependent repositories: 20
- Total versions: 6
- Total maintainers: 1
pypi.org: auraloss
Collection of audio-focused loss functions in PyTorch.
- Homepage: https://github.com/csteinmetz1/auraloss
- Documentation: https://auraloss.readthedocs.io/
- License: Apache License 2.0
-
Latest release: 0.4.0
published almost 3 years ago