https://github.com/blutjens/geospatial_unet_pytorch
Barebone code for training UNet to segment remote sensing imagery
Science Score: 13.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (12.9%) to scientific vocabulary
Repository
Barebone code for training UNet to segment remote sensing imagery
Basic Info
Statistics
- Stars: 4
- Watchers: 1
- Forks: 2
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
geospatialunetpytorch
Barebone code for training and evaluating a UNet for segmentation or downscaling of geospatial data including remote sensing, satellite, aerial imagery or weather/climate data. This code works with datasets that are a collection of large-scale tifs. These tifs can contain nans, have different extent, and be quite large (e.g., 10,000px x 6,000px x 8-channels, with ~500MB each). The dataloader dynamically loads small tiles from the full-scale tifs into memory for training. After training, prediction across a new set of large-scale tifs can be created with predict.py that sweeps the model across the full area of the tif.
Installation
We recommend installing conda and then setting-up the project with the following lines. Installing pytorch, cuda, and gdal is a bit tricky, but the lines below worked on our machines: ```
click 'use this template' -> set -> click 'private'
git clone git@github.com:
Test pytorch-cuda and gdal installation:
```
If the following line returns 'True', the torch -- GPU connection seems to work.
python -c 'import torch; print(torch.cuda.is_available())'
This line should not throw an error
python -c 'from osgeo import gdal' ```
Train a model
``` python geospatialunetpytorch/train.py --cfgpath runs/unetsmp/demorun/config/config.yaml --nowandb
Remove --no_wandb to monitoring training on weights and biases curve
```
Succesful training after ~1hr on the hrmelt sample dataset could look like this:

And, wandb will log predictions across the large-scale tif:

Use the trained model to create predictions
python geospatial_unet_pytorch/predict.py --load 'runs/unet_smp/demo_run/checkpoints/checkpoint_epoch10.pth'
How to add a custom dataset
- We recommend to get this repository running using demo_run and dataset_hrmelt before adding a custom dataset
- Write a function creates data splits and saves by creating one filepath entry per full-scale tif into as train.csv, val.csv, test.csv, periodical_eval.csv. Each filepath should only occur once in each file. (The train.csv in unet_smp/demo_run/config currently contains the same filepath multiple times which is only to illustrate that the model can overfit on the training set.). The filepaths in train, val, and test.csv should be mutually exclusive. We'd recommend to make periodical_eval.csv a subselection of ~50 tifs from val.csv
- Create a train.csv, val.csv, test.csv, and periodical_eval.csv
- Then, create a new file, geospatial_unet_pytorch/dataset/dataset_yourdataset.py, similar to the hrmelt dataset
- Edit train.py to use the new dataset class
- Compute mean and standard deviation of each input channel across the dataset in
a language of your choosing and insert the values in config.yaml.
Optional: Rename the src folder into
```
Rename geospatialunetpytorch folder into
Replace geospatialunetpytorch in pyproject.toml with
Replace geospatialunetpytorch in all .py and .ipynb files with
```
Features and functionality
``` Implemented: - UNet from segmentation_models.pytorch (smp) library - Pretrained model weights via integration with timm via smp library - Logging and monitoring runs via weights and biases - Tested reproducibility and random seeds - Training on single GPU - Parallel batches during train and prediction - L1 and dice loss on nan-masked targets - Periodically evaluate the model during training on predictions across the full-scale tif
Not implemented: - Evaluate predictions of multiple models using one script - Python typing in all functions - Multi-GPU - Integration with pretrained models for geospatial data, e.g., SatCLIP - Mixed precision - Other loss functions on nan-masked targets ```
Owner
- Name: Björn Lütjens (he/him)
- Login: blutjens
- Kind: user
- Company: MIT
- Website: https://blutjens.github.io/
- Twitter: bjornlutjens
- Repositories: 31
- Profile: https://github.com/blutjens
Postdoctoral Associate in tackling climate change with AI @ MIT. Project overview at https://blutjens.github.io/
GitHub Events
Total
- Watch event: 1
- Push event: 11
Last Year
- Watch event: 1
- Push event: 11
Issues and Pull Requests
Last synced: over 1 year ago
All Time
- Total issues: 1
- Total pull requests: 0
- Average time to close issues: 17 days
- Average time to close pull requests: N/A
- Total issue authors: 1
- Total pull request authors: 0
- Average comments per issue: 2.0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 1
- Pull requests: 0
- Average time to close issues: 17 days
- Average time to close pull requests: N/A
- Issue authors: 1
- Pull request authors: 0
- Average comments per issue: 2.0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- melisandeteng (1)
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- gdal *
- jupyter *
- matplotlib *
- numpy *
- pandas *
- pip *
- rasterio *
- scikit-image *
- scikit-learn *
- segmentation-models-pytorch *
- tqdm *