https://github.com/blutjens/geospatial_unet_pytorch

Barebone code for training UNet to segment remote sensing imagery

https://github.com/blutjens/geospatial_unet_pytorch

Science Score: 13.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (12.9%) to scientific vocabulary
Last synced: 10 months ago · JSON representation

Repository

Barebone code for training UNet to segment remote sensing imagery

Basic Info
  • Host: GitHub
  • Owner: blutjens
  • License: cc0-1.0
  • Language: Jupyter Notebook
  • Default Branch: main
  • Homepage:
  • Size: 22.3 MB
Statistics
  • Stars: 4
  • Watchers: 1
  • Forks: 2
  • Open Issues: 0
  • Releases: 0
Created over 1 year ago · Last pushed over 1 year ago
Metadata Files
Readme License

README.md

geospatialunetpytorch

Barebone code for training and evaluating a UNet for segmentation or downscaling of geospatial data including remote sensing, satellite, aerial imagery or weather/climate data. This code works with datasets that are a collection of large-scale tifs. These tifs can contain nans, have different extent, and be quite large (e.g., 10,000px x 6,000px x 8-channels, with ~500MB each). The dataloader dynamically loads small tiles from the full-scale tifs into memory for training. After training, prediction across a new set of large-scale tifs can be created with predict.py that sweeps the model across the full area of the tif.

Installation

We recommend installing conda and then setting-up the project with the following lines. Installing pytorch, cuda, and gdal is a bit tricky, but the lines below worked on our machines: ```

click 'use this template' -> set -> click 'private'

git clone git@github.com:/.git cd conda create -n conda activate conda install python==3.10 pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 pip install numpy>1.0.0 wheel setuptools>=67 pip install jupyter pip install --find-links=https://girder.github.io/largeimagewheels --no-cache GDAL pip install -r requirements.txt pip install -e . # installs geospatialunetpytorch as python module ```

Test pytorch-cuda and gdal installation:

```

If the following line returns 'True', the torch -- GPU connection seems to work.

python -c 'import torch; print(torch.cuda.is_available())'

This line should not throw an error

python -c 'from osgeo import gdal' ```

Train a model

``` python geospatialunetpytorch/train.py --cfgpath runs/unetsmp/demorun/config/config.yaml --nowandb

Remove --no_wandb to monitoring training on weights and biases curve

```

Succesful training after ~1hr on the hrmelt sample dataset could look like this: screenshot of wandb train progress

And, wandb will log predictions across the large-scale tif: screenshot of wandb predictions across the large-scale tif

Use the trained model to create predictions

python geospatial_unet_pytorch/predict.py --load 'runs/unet_smp/demo_run/checkpoints/checkpoint_epoch10.pth'

How to add a custom dataset

- We recommend to get this repository running using demo_run and dataset_hrmelt before adding a custom dataset - Write a function creates data splits and saves by creating one filepath entry per full-scale tif into as train.csv, val.csv, test.csv, periodical_eval.csv. Each filepath should only occur once in each file. (The train.csv in unet_smp/demo_run/config currently contains the same filepath multiple times which is only to illustrate that the model can overfit on the training set.). The filepaths in train, val, and test.csv should be mutually exclusive. We'd recommend to make periodical_eval.csv a subselection of ~50 tifs from val.csv - Create a train.csv, val.csv, test.csv, and periodical_eval.csv - Then, create a new file, geospatial_unet_pytorch/dataset/dataset_yourdataset.py, similar to the hrmelt dataset - Edit train.py to use the new dataset class - Compute mean and standard deviation of each input channel across the dataset in a language of your choosing and insert the values in config.yaml.

Optional: Rename the src folder into

```

Rename geospatialunetpytorch folder into

Replace geospatialunetpytorch in pyproject.toml with

Replace geospatialunetpytorch in all .py and .ipynb files with

```

Features and functionality

``` Implemented: - UNet from segmentation_models.pytorch (smp) library - Pretrained model weights via integration with timm via smp library - Logging and monitoring runs via weights and biases - Tested reproducibility and random seeds - Training on single GPU - Parallel batches during train and prediction - L1 and dice loss on nan-masked targets - Periodically evaluate the model during training on predictions across the full-scale tif

Not implemented: - Evaluate predictions of multiple models using one script - Python typing in all functions - Multi-GPU - Integration with pretrained models for geospatial data, e.g., SatCLIP - Mixed precision - Other loss functions on nan-masked targets ```

Owner

  • Name: Björn Lütjens (he/him)
  • Login: blutjens
  • Kind: user
  • Company: MIT

Postdoctoral Associate in tackling climate change with AI @ MIT. Project overview at https://blutjens.github.io/

GitHub Events

Total
  • Watch event: 1
  • Push event: 11
Last Year
  • Watch event: 1
  • Push event: 11

Issues and Pull Requests

Last synced: over 1 year ago

All Time
  • Total issues: 1
  • Total pull requests: 0
  • Average time to close issues: 17 days
  • Average time to close pull requests: N/A
  • Total issue authors: 1
  • Total pull request authors: 0
  • Average comments per issue: 2.0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 1
  • Pull requests: 0
  • Average time to close issues: 17 days
  • Average time to close pull requests: N/A
  • Issue authors: 1
  • Pull request authors: 0
  • Average comments per issue: 2.0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • melisandeteng (1)
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels

Dependencies

pyproject.toml pypi
requirements.txt pypi
  • gdal *
  • jupyter *
  • matplotlib *
  • numpy *
  • pandas *
  • pip *
  • rasterio *
  • scikit-image *
  • scikit-learn *
  • segmentation-models-pytorch *
  • tqdm *