torch-frft

PyTorch implementation of the fractional Fourier transform with trainable transform order.

https://github.com/tunakasif/torch-frft

Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 1 DOI reference(s) in README
  • Academic publication links
    Links to: ieee.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (12.1%) to scientific vocabulary

Keywords

fractional-fourier-transfrom pytorch
Last synced: 6 months ago · JSON representation ·

Repository

PyTorch implementation of the fractional Fourier transform with trainable transform order.

Basic Info
  • Host: GitHub
  • Owner: tunakasif
  • License: mit
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 1.38 MB
Statistics
  • Stars: 51
  • Watchers: 2
  • Forks: 1
  • Open Issues: 5
  • Releases: 13
Topics
fractional-fourier-transfrom pytorch
Created over 2 years ago · Last pushed 8 months ago
Metadata Files
Readme Changelog Funding License Citation

README.md

Trainable Fractional Fourier Transform

PyPI Conda (channel only) Tox & Bump Version Codecov PyPI - Python Version PyPI - Downloads GitHub PyTorch

A differentiable fractional Fourier transform (FRFT) implementation with layers that can be trained end-to-end with the rest of the network. This package provides implementations of both fast computations of continuous FRFT and discrete FRFT (DFRFT) and pre-configured layers that are eligible for use in neural networks.

The fast transform approximates the continuous FRFT and is based on Digital computation of the fractional Fourier transform paper. The DFRFT is based on The discrete fractional Fourier transform paper. MATLAB implementations of both approaches are provided on Haldun M. Özaktaş's page as fracF.m and dFRT.m, respectively.

This package implements these approaches in PyTorch with specific optimizations and, most notably, adds the ability to apply the transform along a particular tensor dimension.

We provide primer layers that extend torch.nn.Module for continuous and discrete transforms, an example of the custom layer implementation, is also provided in the README.md file.

We developed this project for the Trainable Fractional Fourier Transform paper, published in IEEE Signal Processing Letters. You can also access the paper's GitHub page for experiments and example usage. If you find this package useful, please consider citing as follows:

bibtex @article{trainable-frft-2024, author = {Koç, Emirhan and Alikaşifoğlu, Tuna and Aras, Arda Can and Koç, Aykut}, journal = {IEEE Signal Processing Letters}, title = {Trainable Fractional Fourier Transform}, year = {2024}, volume = {31}, number = {}, pages = {751-755}, keywords = {Vectors;Convolution;Training;Task analysis;Computational modeling;Time series analysis;Feature extraction;Machine learning;neural networks;FT;fractional FT;deep learning}, doi = {10.1109/LSP.2024.3372779} }

Table of Contents

Installation

For Usage

You can install the package directly from PYPI as follows:

sh pip install torch-frft

or directly from Conda

sh conda install -c conda-forge torch-frft

For Development

This codebase utilizes uv for package management. To install the dependencies:

sh uv sync

or one can install the dependencies provided in requirements.txt using pip or conda, e.g.,

sh pip install -r requirements.txt

Usage

Transforms

[!WARNING] Transforms applied in the same device as the input tensor. If the input tensor is on GPU, the transform will also be applied on GPU.

The package provides transform functions that operate on the $n^{th}$ dimension of an input tensor, frft() and dfrft(), which correspond to the fast computation of continuous fractional Fourier transform (FRFT) and discrete fractional Fourier transform (DFRFT), respectively. It also provides a function dfrftmtx(), which computes the DFRFT matrix for a given length and order, similar to MATLAB's dftmtx() function for the ordinary DFT matrix. Note that the frft() only operates on even-sized lengths as in the original MATLAB implementation fracF.m.

```python import torch from torchfrft.frftmodule import frft from torchfrft.dfrftmodule import dfrft, dfrftmtx

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

N = 128 a = 0.5 X = torch.rand(N, N, device=device) Y1 = frft(X, a) # equivalent to dim=-1 Y2 = frft(X, a, dim=0)

2D FRFT

a0, a1 = 1.25, 0.75 Y3 = frft(frft(X, a0, dim=0), a1, dim=1) ```

Pre-configured Layers

The package also provides two differentiable FRFT layers, FrFTLayer and DFrFTLayer, which can be used as follows:

```python import torch import torch.nn as nn

from torchfrft.dfrftmodule import dfrft from torch_frft.layer import DFrFTLayer, FrFTLayer

FRFT with initial order 1.25, operating on the last dimension

model = nn.Sequential(FrFTLayer(order=1.25, dim=-1))

DFRFT with initial order 0.75, operating on the first dimension

model = nn.Sequential(DFrFTLayer(order=0.75, dim=0)) ```

Then, the simplest toy example to train the layer is as follows:

```python device = torch.device("cuda" if torch.cuda.isavailable() else "cpu") numsamples, seqlength = 100, 16 aoriginal = 1.1 ainitial = 1.25 X = torch.randn(numsamples, seqlength, dtype=torch.float32, device=device) Y = dfrft(X, aoriginal)

model = DFrFTLayer(order=a_initial).to(device) optim = torch.optim.Adam(model.parameters(), lr=1e-3) epochs = 1000

for epoch in range(1 + epochs): optim.zero_grad() loss = torch.norm(Y - model(X)) loss.backward() optim.step()

print("Original a:", a_original) print("Estimated a:", model.order.item()) ```

One can also place these layers directly into the torch.nn.Sequential. Remark that these transforms generate complex-valued outputs, so one may need to convert them to real-valued outputs, e.g., taking the real part, absolute value, etc. For example, the following code snippet implements a simple fully connected network with real parts of FrFTLayer and DFrFTLayer in between:

```python class Real(nn.Module): def init(self) -> None: super().init()

def forward(self, x: torch.Tensor) -> torch.Tensor:
    return x.real

model = nn.Sequential( nn.Linear(16, 6), nn.ReLU(), DFrFTLayer(1.35, dim=-1), Real(), nn.ReLU(), nn.Linear(6, 1), FrFTLayer(0.65, dim=0), Real(), nn.ReLU(), ) ```

Custom Layers

Creating custom layers with the provided frft() and dfrft() functions is also possible. The below example contains in-between Linear and ReLU layers and the same fractional order for forward and backward DFRFT transforms is as follows:

```python import torch import torch.nn as nn from torchfrft.dfrftmodule import dfrft

class CustomLayer(nn.Module): def init(self, infeat: int, outfeat: int, *, order: float = 1.0, dim: int = -1) -> None: super().init() self.infeatures = infeat self.outfeatures = outfeat self.order = nn.Parameter(torch.tensor(order, dtype=torch.float32), requires_grad=True) self.dim = dim

def forward(self, x: torch.Tensor) -> torch.Tensor:
    x1 = dfrft(x, self.order, dim=self.dim)
    a1 = nn.ReLU()(x1.real) + 1j * nn.ReLU()(x1.imag)
    x2 = nn.Linear(self.in_features, self.in_features, dtype=a1.dtype, device=x.device)(a1)
    a2 = nn.ReLU()(x2.real) + 1j * nn.ReLU()(x2.imag)
    x3 = dfrft(a2, -self.order, dim=self.dim)
    a3 = nn.ReLU()(x3.real) + 1j * nn.ReLU()(x3.imag)
    x4 = nn.Linear(self.in_features, self.out_features, dtype=a3.dtype, device=x.device)(a3)
    a4 = nn.ReLU()(x4.real)
    return a4

```

Then, a simple training example for the given CustomLayer can be given as follows:

```python device = torch.device("cuda" if torch.cuda.isavailable() else "cpu") numsamples, seqlength, outlength = 100, 32, 5 X = torch.rand(numsamples, seqlength, device=device) y = torch.rand(numsamples, outlength, device=device)

model = CustomLayer(seqlength, outlength, order=1.25) optim = torch.optim.Adam(model.parameters(), lr=1e-3) epochs = 1000

for epoch in range(epochs): optim.zero_grad() loss = torch.nn.MSELoss() output = loss(model(X), y) output.backward() optim.step() if epoch % 100 == 0: print(f"Epoch {epoch:4d} | Loss {output.item():.4f}") print("Final a:", model.order) ```

FRFT Shift

Note that the fast computation of continuous FRFT is defined for the central grid of $\left[-\lfloor\frac{N}{2}\rfloor, \lfloor\frac{N-1}{2}\rfloor\right]$. Therefore, we need fftshift() to create equivalence with the original FFT when the transform order is precisely $1$. In this package, we also provide a shifted version of the fast FRFT computation, frft_shifted(), which operates with the assumption that the grid is $[0, N-1]$. The latter interval is not the default behavior since we want consistency with the original MATLAB implementation. The all there lines below are equivalent:

```python import torch from torch.fft import fft, fftshift from torchfrft.frftmodule import frft, frft_shifted

torch.manualseed(0) x = torch.rand(100) y1 = fft(x, norm="ortho") y2 = fftshift(frft(fftshift(x), 1.0)) y3 = frftshifted(x, 1.0)

assert torch.allclose(y1, y2, atol=1e-5) assert torch.allclose(y1, y3, atol=1e-5) ```

Owner

  • Name: Tuna Alikaşifoğlu
  • Login: tunakasif
  • Kind: user
  • Location: Ankara / Turkey
  • Company: Bilkent University

Graduate Student at Bilkent University, Department of Electrical and Electronics Engineering

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
preferred-citation:
  authors:
    - family-names: Koç
      given-names: Emirhan
    - family-names: Alikaşifoğlu
      given-names: Tuna
    - family-names: Aras
      given-names: Arda Can
    - family-names: Koç
      given-names: Aykut
  doi: 10.1109/lsp.2024.3372779
  identifiers:
    - type: doi
      value: 10.1109/lsp.2024.3372779
    - type: url
      value: http://dx.doi.org/10.1109/LSP.2024.3372779
    - type: other
      value: urn:issn:1070-9908
  title: Trainable Fractional Fourier Transform
  url: http://dx.doi.org/10.1109/LSP.2024.3372779
  database: Crossref
  date-published: 2024-03-04
  year: 2024
  issn: 1070-9908
  journal: IEEE Signal Processing Letters
  publisher:
    name: Institute of Electrical and Electronics Engineers (IEEE)
  start: '751'
  end: '755'
  type: article
  volume: '31'

GitHub Events

Total
  • Create event: 12
  • Release event: 1
  • Issues event: 2
  • Watch event: 19
  • Delete event: 8
  • Issue comment event: 7
  • Push event: 5
  • Pull request event: 16
  • Fork event: 1
Last Year
  • Create event: 12
  • Release event: 1
  • Issues event: 2
  • Watch event: 19
  • Delete event: 8
  • Issue comment event: 7
  • Push event: 5
  • Pull request event: 16
  • Fork event: 1

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 0
  • Total pull requests: 9
  • Average time to close issues: N/A
  • Average time to close pull requests: 27 days
  • Total issue authors: 0
  • Total pull request authors: 2
  • Average comments per issue: 0
  • Average comments per pull request: 0.22
  • Merged pull requests: 2
  • Bot issues: 0
  • Bot pull requests: 8
Past Year
  • Issues: 0
  • Pull requests: 9
  • Average time to close issues: N/A
  • Average time to close pull requests: 27 days
  • Issue authors: 0
  • Pull request authors: 2
  • Average comments per issue: 0
  • Average comments per pull request: 0.22
  • Merged pull requests: 2
  • Bot issues: 0
  • Bot pull requests: 8
Top Authors
Issue Authors
  • JiacLuo (2)
  • kravrolens (1)
  • zzzikun (1)
Pull Request Authors
  • dependabot[bot] (8)
Top Labels
Issue Labels
Pull Request Labels
dependencies (8) python (1)

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 117 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 1
  • Total versions: 8
  • Total maintainers: 1
pypi.org: torch-frft

PyTorch implementation of the fractional Fourier transform with trainable transform order.

  • Versions: 8
  • Dependent Packages: 0
  • Dependent Repositories: 1
  • Downloads: 117 Last month
Rankings
Dependent packages count: 10.1%
Dependent repos count: 21.5%
Average: 25.0%
Downloads: 43.3%
Maintainers (1)
Last synced: 6 months ago