torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data

https://github.com/torchgeo/torchgeo

Science Score: 49.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 5 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (13.8%) to scientific vocabulary

Keywords

computer-vision datasets deep-learning earth-observation geospatial models pytorch remote-sensing satellite-imagery torchvision transforms
Last synced: 6 months ago · JSON representation

Repository

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data

Basic Info
Statistics
  • Stars: 3,620
  • Watchers: 56
  • Forks: 471
  • Open Issues: 173
  • Releases: 16
Topics
computer-vision datasets deep-learning earth-observation geospatial models pytorch remote-sensing satellite-imagery torchvision transforms
Created almost 5 years ago · Last pushed 6 months ago
Metadata Files
Readme Contributing License Code of conduct Citation Security Support Governance Maintainers

README.md

TorchGeo logo

TorchGeo is a PyTorch domain library, similar to torchvision, providing datasets, samplers, transforms, and pre-trained models specific to geospatial data.

The goal of this library is to make it simple:

  1. for machine learning experts to work with geospatial data, and
  2. for remote sensing experts to explore machine learning solutions.

Community: slack osgeo huggingface pytorch

Packaging: pypi conda spack

Testing: docs style tests codecov

Installation

The recommended way to install TorchGeo is with pip:

sh pip install torchgeo

For conda and spack installation instructions, see the documentation.

Documentation

You can find the documentation for TorchGeo on ReadTheDocs. This includes API documentation, contributing instructions, and several tutorials. For more details, check out our paper, blog post, and YouTube videos (below).

Example Usage

The following sections give basic examples of what you can do with TorchGeo.

First we'll import various classes and functions used in the following sections:

```python from lightning.pytorch import Trainer from torch.utils.data import DataLoader

from torchgeo.datamodules import InriaAerialImageLabelingDataModule from torchgeo.datasets import CDL, Landsat7, Landsat8, VHR10, stack_samples from torchgeo.samplers import RandomGeoSampler from torchgeo.trainers import SemanticSegmentationTask ```

Geospatial datasets and samplers

Many remote sensing applications involve working with geospatial datasetsdatasets with geographic metadata. These datasets can be challenging to work with due to the sheer variety of data. Geospatial imagery is often multispectral with a different number of spectral bands and spatial resolution for every satellite. In addition, each file may be in a different coordinate reference system (CRS), requiring the data to be reprojected into a matching CRS.

Example application in which we combine Landsat and CDL and sample from both

In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of Landsat and Cropland Data Layer (CDL) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we'll only use the bands that both satellites have in common. We'll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets.

python landsat7 = Landsat7(paths="...", bands=["B1", ..., "B7"]) landsat8 = Landsat8(paths="...", bands=["B2", ..., "B8"]) landsat = landsat7 | landsat8

Next, we take the intersection between this dataset and the CDL dataset. We want to take the intersection instead of the union to ensure that we only sample from regions that have both Landsat and CDL data. Note that we can automatically download and checksum CDL data. Also note that each of these datasets may contain files in different coordinate reference systems (CRS) or resolutions, but TorchGeo automatically ensures that a matching CRS and resolution is used.

python cdl = CDL(paths="...", download=True, checksum=True) dataset = landsat & cdl

This dataset can now be used with a PyTorch data loader. Unlike benchmark datasets, geospatial datasets often include very large images. For example, the CDL dataset consists of a single image covering the entire continental United States. In order to sample from these datasets using geospatial coordinates, TorchGeo defines a number of samplers. In this example, we'll use a random sampler that returns 256 x 256 pixel images and 10,000 samples per epoch. We also use a custom collation function to combine each sample dictionary into a mini-batch of samples.

python sampler = RandomGeoSampler(dataset, size=256, length=10000) dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)

This data loader can now be used in your normal training/evaluation pipeline.

```python for batch in dataloader: image = batch["image"] mask = batch["mask"]

# train a model, or make predictions using a pre-trained model

```

Many applications involve intelligently composing datasets based on geospatial metadata like this. For example, users may want to:

  • Combine datasets for multiple image sources and treat them as equivalent (e.g., Landsat 7 and 8)
  • Combine datasets for disparate geospatial locations (e.g., Chesapeake NY and PA)

These combinations require that all queries are present in at least one dataset, and can be created using a UnionDataset. Similarly, users may want to:

  • Combine image and target labels and sample from both simultaneously (e.g., Landsat and CDL)
  • Combine datasets for multiple image sources for multimodal learning or data fusion (e.g., Landsat and Sentinel)

These combinations require that all queries are present in both datasets, and can be created using an IntersectionDataset. TorchGeo automatically composes these datasets for you when you use the intersection (&) and union (|) operators.

Benchmark datasets

TorchGeo includes a number of benchmark datasetsdatasets that include both input images and target labels. This includes datasets for tasks like image classification, regression, semantic segmentation, object detection, instance segmentation, change detection, and more.

If you've used torchvision before, these datasets should seem very familiar. In this example, we'll create a dataset for the Northwestern Polytechnical University (NWPU) very-high-resolution ten-class (VHR-10) geospatial object detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision.

```python from torch.utils.data import DataLoader

from torchgeo.datamodules.utils import collatefndetection from torchgeo.datasets import VHR10

Initialize the dataset

dataset = VHR10(root="...", download=True, checksum=True)

Initialize the dataloader with the custom collate function

dataloader = DataLoader( dataset, batchsize=128, shuffle=True, numworkers=4, collatefn=collatefn_detection, )

Training loop

for batch in dataloader: image = batch["image"] # list of images bboxxyxy = batch["bboxxyxy"] # list of boxes label = batch["label"] # list of labels mask = batch["mask"] # list of masks

# train a model, or make predictions using a pre-trained model

```

Example predictions from a Mask R-CNN model trained on the VHR-10 dataset

All TorchGeo datasets are compatible with PyTorch data loaders, making them easy to integrate into existing training workflows. The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch Tensor.

Pre-trained Weights

Pre-trained weights have proven to be tremendously beneficial for transfer learning tasks in computer vision. Practitioners usually utilize models pre-trained on the ImageNet dataset, containing RGB images. However, remote sensing data often goes beyond RGB with additional multispectral channels that can vary across sensors. TorchGeo is the first library to support models pre-trained on different multispectral sensors, and adopts torchvision's multi-weight API. A summary of currently available weights can be seen in the docs. To create a timm Resnet-18 model with weights that have been pretrained on Sentinel-2 imagery, you can do the following:

```python import timm from torchgeo.models import ResNet18_Weights

weights = ResNet18Weights.SENTINEL2ALLMOCO model = timm.createmodel("resnet18", inchans=weights.meta["inchans"], numclasses=10) model.loadstatedict(weights.getstate_dict(progress=True), strict=False) ```

These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the weights argument. For a notebook example, see this tutorial.

Reproducibility with Lightning

In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created Lightning datamodules with well-defined train-val-test splits and trainers for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a semantic segmentation model on the Inria Aerial Image Labeling dataset is as easy as a few imports and four lines of code.

```python datamodule = InriaAerialImageLabelingDataModule(root="...", batchsize=64, numworkers=6) task = SemanticSegmentationTask( model="unet", backbone="resnet50", weights=True, inchannels=3, task="binary", loss="bce", ) trainer = Trainer(defaultroot_dir="...")

trainer.fit(model=task, datamodule=datamodule) ```

Building segmentations produced by a U-Net model trained on the Inria Aerial Image Labeling dataset

TorchGeo also supports command-line interface training using LightningCLI. It can be invoked in two ways:

```sh

If torchgeo has been installed

torchgeo

If torchgeo has been installed, or if it has been cloned to the current directory

python3 -m torchgeo ```

It supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages:

```sh

See valid stages

torchgeo --help

See valid trainer options

torchgeo fit --help

See valid model options

torchgeo fit --model.help ClassificationTask

See valid data options

torchgeo fit --data.help EuroSAT100DataModule ```

Using the following config file:

yaml trainer: max_epochs: 20 model: class_path: ClassificationTask init_args: model: 'resnet18' in_channels: 13 num_classes: 10 data: class_path: EuroSAT100DataModule init_args: batch_size: 8 dict_kwargs: download: true

we can see the script in action:

```sh

Train and validate a model

torchgeo fit --config config.yaml

Validate-only

torchgeo validate --config config.yaml

Calculate and report test accuracy

torchgeo test --config config.yaml --ckpt_path=... ```

It can also be imported and used in a Python script if you need to extend it to add new features:

```python from torchgeo.main import main

main(["fit", "--config", "config.yaml"]) ```

See the Lightning documentation for more details.

Citation

If you use this software in your work, please cite our paper:

bibtex @article{Stewart_TorchGeo_Deep_Learning_2024, author = {Stewart, Adam J. and Robinson, Caleb and Corley, Isaac A. and Ortiz, Anthony and Lavista Ferres, Juan M. and Banerjee, Arindam}, doi = {10.1145/3707459}, journal = {ACM Transactions on Spatial Algorithms and Systems}, month = dec, title = {{TorchGeo}: Deep Learning With Geospatial Data}, url = {https://doi.org/10.1145/3707459}, year = {2024} }

Contributing

This project welcomes contributions and suggestions. If you would like to submit a pull request, see our Contribution Guide for more information.

This project has adopted the Contributor Covenant Code of Conduct. For more information see the Contributor Covenant Code of Conduct FAQ or contact @adamjstewart on Slack with any additional questions or comments.

Owner

  • Name: TorchGeo
  • Login: torchgeo
  • Kind: organization

TorchGeo: deep learning with geospatial data

GitHub Events

Total
  • Create event: 5
  • Issues event: 4
  • Watch event: 47
  • Delete event: 5
  • Issue comment event: 27
  • Member event: 1
  • Push event: 11
  • Pull request event: 28
  • Pull request review event: 52
  • Pull request review comment event: 70
  • Fork event: 5
Last Year
  • Create event: 5
  • Issues event: 4
  • Watch event: 47
  • Delete event: 5
  • Issue comment event: 27
  • Member event: 1
  • Push event: 11
  • Pull request event: 28
  • Pull request review event: 52
  • Pull request review comment event: 70
  • Fork event: 5

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 4
  • Total pull requests: 18
  • Average time to close issues: 5 days
  • Average time to close pull requests: about 5 hours
  • Total issue authors: 4
  • Total pull request authors: 8
  • Average comments per issue: 1.0
  • Average comments per pull request: 0.39
  • Merged pull requests: 6
  • Bot issues: 0
  • Bot pull requests: 7
Past Year
  • Issues: 4
  • Pull requests: 18
  • Average time to close issues: 5 days
  • Average time to close pull requests: about 5 hours
  • Issue authors: 4
  • Pull request authors: 8
  • Average comments per issue: 1.0
  • Average comments per pull request: 0.39
  • Merged pull requests: 6
  • Bot issues: 0
  • Bot pull requests: 7
Top Authors
Issue Authors
  • adamjstewart (1)
  • gcaria (1)
  • forrestfwilliams (1)
  • f-schi (1)
Pull Request Authors
  • dependabot[bot] (7)
  • adamjstewart (2)
  • calebrob6 (2)
  • isaaccorley (2)
  • MUYang99 (1)
  • blaz-r (1)
  • yichiac (1)
  • forrestfwilliams (1)
Top Labels
Issue Labels
documentation (2) testing (1)
Pull Request Labels
dependencies (7) python (6) documentation (4) models (3) testing (3) datasets (2) github_actions (1) trainers (1) datamodules (1)

Dependencies

.github/workflows/labeler.yml actions
  • actions/labeler v4.3.0 composite
.github/workflows/release.yaml actions
  • actions/cache v3.3.2 composite
  • actions/checkout v4.0.0 composite
  • actions/setup-python v4.7.0 composite
.github/workflows/style.yaml actions
  • actions/cache v3.3.2 composite
  • actions/checkout v4.0.0 composite
  • actions/setup-python v4.7.0 composite
.github/workflows/tests.yaml actions
  • actions/cache v3.3.2 composite
  • actions/checkout v4.0.0 composite
  • actions/setup-python v4.7.0 composite
  • codecov/codecov-action v3.1.4 composite
  • pyvista/setup-headless-display-action v2 composite
.github/workflows/tutorials.yaml actions
  • actions/cache v3.3.2 composite
  • actions/checkout v4.0.0 composite
  • actions/setup-python v4.7.0 composite
docs/requirements.txt pypi
pyproject.toml pypi
  • einops >=0.3
  • fiona >=1.8.19
  • kornia >=0.6.9
  • lightly >=1.4.4
  • lightning >=1.8
  • matplotlib >=3.3.3
  • numpy >=1.19.3
  • pillow >=8
  • pyproj >=3
  • rasterio >=1.2
  • rtree >=1
  • segmentation-models-pytorch >=0.2
  • shapely >=1.7.1
  • timm >=0.4.12
  • torch >=1.12
  • torchmetrics >=0.10
  • torchvision >=0.13
requirements/datasets.txt pypi
  • h5py ==3.9.0
  • laspy ==2.5.1
  • opencv-python ==4.8.0.76
  • pandas ==2.1.1
  • pycocotools ==2.0.7
  • pyvista ==0.42.2
  • radiant-mlhub ==0.4.1
  • rarfile ==4.1
  • scikit-image ==0.21.0
  • scipy ==1.11.2
  • zipfile-deflate64 ==0.2.0
requirements/docs.txt pypi
  • ipywidgets ==8.1.1
  • nbsphinx ==0.9.3
  • sphinx ==5.3.0
requirements/required.txt pypi
  • einops ==0.6.1
  • fiona ==1.9.4.post1
  • kornia ==0.7.0
  • lightly ==1.4.19
  • lightning ==2.0.9
  • matplotlib ==3.8.0
  • numpy ==1.26.0
  • pillow ==10.0.1
  • pyproj ==3.6.1
  • rasterio ==1.3.8
  • rtree ==1.0.1
  • segmentation-models-pytorch ==0.3.3
  • setuptools ==68.2.0
  • shapely ==2.0.1
  • timm ==0.9.2
  • torch ==2.0.1
  • torchmetrics ==1.2.0
  • torchvision ==0.15.2
requirements/style.txt pypi
  • black ==23.9.1
  • flake8 ==6.1.0
  • isort ==5.12.0
  • pydocstyle ==6.3.0
  • pyupgrade ==3.12.0
requirements/tests.txt pypi
  • hydra-core ==1.3.2 test
  • mypy ==1.5.1 test
  • nbmake ==1.4.3 test
  • omegaconf ==2.3.0 test
  • pytest ==7.4.2 test
  • pytest-cov ==4.1.0 test
  • tensorboard ==2.14.0 test