A simple package that makes it easy to load remote sensing foundation models for downstream use cases.
Project description
SatlasPretrain Models: Foundation models for satellite and aerial imagery.
SatlasPretrain is a large-scale pre-training dataset for remote sensing image understanding. This work was published at ICCV 2023. Details and download links for the dataset can be found here.
This repository contains satlaspretrain_models
, a lightweight library to easily load pretrained SatlasPretrain models for:
- Sentinel-2
- Sentinel-1
- Landsat 8/9
- 0.5-2 m/pixel aerial imagery
These models can be fine-tuned on downstream tasks that use these image sources, leading to faster training and improved performance compared to training from other initializations.
Model Structure and Usage
The SatlasPretrain models consist of three main components: backbone, feature pyramid network (FPN), and prediction head.
For models trained on multi-image input, the backbone is applied on each individual image, and then max pooling is applied in the temporal dimension, i.e., across the multiple aligned images. Single-image models input an individual image.
This package allows you to load the backbone or backbone+FPN using a model checkpoint ID from the tables below.
MODEL_CHECKPOINT_ID = "Sentinel2_SwinB_SI_RGB"
model = weights_manager.get_pretrained_model(MODEL_CHECKPOINT_ID)
model = weights_manager.get_pretrained_model(MODEL_CHECKPOINT_ID, fpn=True)
The output of the model is the multi-scale feature map (either from the backbone or from the FPN).
For a complete fine-tuning example, see our tutorial on fine-tuning the pre-trained model on EuroSAT.
Installation
conda create --name satlaspretrain python==3.9
conda activate satlaspretrain
pip install satlaspretrain-models
Available Pretrained Models
The tables below list available model checkpoint IDs (like Sentinel2_SwinB_SI_RGB
).
Checkpoints are released under ODC-BY.
This package will download model checkpoints automatically, but you can download them directly using the links below if desired.
Sentinel-2 Pretrained Models
Single-image, RGB | Multi-image, RGB | |
---|---|---|
Swin-v2-Base | Sentinel2_SwinB_SI_RGB | Sentinel2_SwinB_MI_RGB |
Swin-v2-Tiny | Sentinel2_SwinT_SI_RGB | |
Resnet50 | Sentinel2_Resnet50_SI_RGB | Sentinel2_Resnet50_MI_RGB |
Resnet152 | Sentinel2_Resnet152_SI_RGB | Sentinel2_Resnet152_MI_RGB |
Single-image, MS | Multi-image, MS | |
---|---|---|
Swin-v2-Base | Sentinel2_SwinB_SI_MS | Sentinel2_SwinB_MI_MS |
Resnet50 | Sentinel2_Resnet50_SI_MS | |
Resnet152 | Sentinel2_Resnet152_SI_MS |
Sentinel-1 Pretrained Models
Single-image, VH+VV | Multi-image, VH+VV | |
---|---|---|
Swin-v2-Base | Sentinel1_SwinB_SI | Sentinel1_SwinB_MI |
Landsat 8/9 Pretrained Models
Single-image, all bands | Multi-image, all bands | |
---|---|---|
Swin-v2-Base | Landsat_SwinB_SI | Landsat_SwinB_MI |
Aerial (0.5-2m/px high-res imagery) Pretrained Models
Single-image, RGB | Multi-image, RGB | |
---|---|---|
Swin-v2-Base | Aerial_SwinB_SI | Aerial_SwinB_MI |
Single-image models learn strong representations for individual satellite or aerial images, while multi-image models use multiple image captures of the same location for added robustness when making predictions about static objects. In multi-image models, feature maps from the backbone are passed through temporal max pooling, so the backbone itself is still applied on individual images, but is trained to provide strong representations after the temporal max pooling step. See ModelArchitecture.md for more details.
Sentinel-2 RGB models input the B4, B3, and B2 bands only, while the multi-spectral (MS) models input 9 bands. The aerial (0.5-2m/px high-res imagery) models input RGB NAIP and other high-res images, and we have found them to be effective on aerial imagery from a variety of sources. Landsat models input B1-B11 (all bands). Sentinel-1 models input VV and VH bands. See Normalization.md for details on how pixel values should be normalized for input to the pre-trained models.
Usage Examples
First initialize a Weights
instance:
import satlaspretrain_models
import torch
weights_manager = satlaspretrain_models.Weights()
Then choose a model_identifier from the tables above to specify the pretrained model you want to load. Below are examples showing how to load in a few of the available models.
Pretrained single-image Sentinel-2 RGB model, backbone only:
model = weights_manager.get_pretrained_model(model_identifier="Sentinel2_SwinB_SI_RGB")
# Expected input is a portion of a Sentinel-2 L1C TCI image.
# The 0-255 pixel values should be divided by 255 so they are 0-1.
# tensor = tci_image[None, :, :, :] / 255
tensor = torch.zeros((1, 3, 512, 512), dtype=torch.float32)
# Since we only loaded the backbone, it outputs feature maps from the Swin-v2-Base backbone.
output = model(tensor)
print([feature_map.shape for feature_map in output])
# [torch.Size([1, 128, 128, 128]), torch.Size([1, 256, 64, 64]), torch.Size([1, 512, 32, 32]), torch.Size([1, 1024, 16, 16])]
Pretrained single-image Sentinel-1 model, backbone+FPN
model = weights_manager.get_pretrained_model("Sentinel1_SwinB_SI", fpn=True)
# Expected input is a portion of a Sentinel-1 vh+vv image (in that order).
# The 16-bit pixel values should be divided by 255 and clipped to 0-1 (any pixel values greater than 255 become 1).
# tensor = torch.clip(torch.stack([vh_image, vv_image], dim=0)[None, :, :, :] / 255, 0, 1)
tensor = torch.zeros((1, 2, 512, 512), dtype=torch.float32)
# The model outputs feature maps from the FPN.
output = model(tensor)
print([feature_map.shape for feature_map in output])
# [torch.Size([1, 128, 128, 128]), torch.Size([1, 128, 64, 64]), torch.Size([1, 128, 32, 32]), torch.Size([1, 128, 16, 16])]
Prediction heads
Although the checkpoints include prediction head parameters, these heads are task-specific, so loading the head parameters is not supported in this repository. Computing outputs from the pre-trained prediction heads is supported in the dataset codebase.
For convenience when fine-tuning on certain types of tasks, though, satlaspretrain_models
supports attaching certain heads (initialized randomly) to the pre-trained model:
# Backbone and FPN parameters initialized from checkpoint, head parameters initialized randomly.
model = weights_manager.get_pretrained_model(MODEL_CHECKPOINT_ID, fpn=True, head=satlaspretrain_models.Head.CLASSIFY, num_categories=2)
The following head architectures are available:
- Segmentation: U-Net Decoder w/ Cross Entropy loss
- Detection: Faster R-CNN Decoder
- Instance Segmentation: Mask R-CNN Decoder
- Regression: U-Net Decoder w/ L1 loss
- Classification: Pooling + Linear layers
- Multi-label Classification: Pooling + Linear layers
Pretrained multi-image aerial model, backbone + FPN + classification head:
# num_categories is the number of categories to predict.
# All heads are randomly initialized and provided only for convenience for fine-tuning.
model = weights_manager.get_pretrained_model("Aerial_SwinB_MI", fpn=True, head=satlaspretrain_models.Head.CLASSIFY, num_categories=2)
# Expected input is 8-bit (0-255) aerial images at 0.5 - 2 m/pixel.
# The 0-255 pixel values should be divided by 255 so they are 0-1.
# This multi-image model is trained to input 4 images but should perform well with different numbers of images.
# tensor = torch.stack([rgb_image1, rgb_image2], dim=0)[None, :, :, :, :] / 255
tensor = torch.zeros((1, 4, 3, 512, 512), dtype=torch.float32)
# The head needs to be fine-tuned on a downstream classification task.
# It outputs classification probabilities.
model.eval()
output = model(tensor.reshape(1, 4*3, 512, 512))
print(output)
# tensor([[0.0266, 0.9734]])
Pretrained multi-image Landsat model, backbone + FPN + detection head
# num_categories is the number of bounding box detection categories.
# All heads are randomly initialized and provided only for convenience for fine-tuning.
model = weights_manager.get_pretrained_model("Landsat_SwinB_MI", fpn=True, head=satlaspretrain_models.Head.DETECT, num_categories=5)
# Expected input is Landsat B1-B11 stacked in order.
# This multi-image model is trained to input 8 images but should perform well with different numbers of images.
# The 16-bit pixel values are normalized as follows:
# landsat_images = torch.stack([landsat_image1, landsat_image2], dim=0)
# tensor = torch.clip(landsat_images[None, :, :, :, :]-4000)/16320, 0, 1)
tensor = torch.zeros((1, 8, 11, 512, 512), dtype=torch.float32)
# The head needs to be fine-tuned on a downstream object detection task.
# It outputs bounding box detections.
model.eval()
output = model(tensor.reshape(1, 8*11, 512, 512))
print(output)
#[{'boxes': tensor([[ 67.0772, 239.2646, 95.6874, 16.3644], ...]),
# 'labels': tensor([3, ...]),
# 'scores': tensor([0.5443, ...])}]
Demos
We provide a demo showing how to finetune a SatlasPretrain Sentinel-2 model on the EuroSAT classification task.
Tests
There are tests to test loading pretrained models and one to test randomly initialized models.
To run the tests, run the following command from the root directory:
pytest tests/
Contact
If you have any questions, please email satlas@allenai.org
or open an issue here.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file satlaspretrain_models-0.2.0.tar.gz
.
File metadata
- Download URL: satlaspretrain_models-0.2.0.tar.gz
- Upload date:
- Size: 19.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ff8bb7bc682aec6c6c0357792bbfd45f09f1a7898ac6a598b70aba1ff1c1e81f |
|
MD5 | 10ebb9204e58a441b3b42abe0447ebfb |
|
BLAKE2b-256 | 196435338461f68290e9db521b666531dffb8f3fa20b9f65c42d2025427ab768 |
File details
Details for the file satlaspretrain_models-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: satlaspretrain_models-0.2.0-py3-none-any.whl
- Upload date:
- Size: 17.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ee884ad0b008f7162b776265262d2787b93a128825902648d79177ef0def88f7 |
|
MD5 | 87e17753991d451f4daa81c2e36ffbfa |
|
BLAKE2b-256 | 31ea96c812b2becd50445bde8a507b992cd2a467ec9caad34be44afbb64accbf |