Skip to main content

A simple package that makes it easy to load remote sensing foundation models for downstream use cases.

Project description

SatlasPretrain Models: Foundation models for satellite and aerial imagery.

SatlasPretrain is a large-scale pre-training dataset for remote sensing image understanding. This work was published at ICCV 2023. Details and download links for the dataset can be found here.

This repository contains satlaspretrain_models, a lightweight library to easily load pretrained SatlasPretrain models for:

  • Sentinel-2
  • Sentinel-1
  • Landsat 8/9
  • 0.5-2 m/pixel aerial imagery

These models can be fine-tuned on downstream tasks that use these image sources, leading to faster training and improved performance compared to training from other initializations.

Model Structure and Usage

The SatlasPretrain models consist of three main components: backbone, feature pyramid network (FPN), and prediction head.

SatlasPretrain model architecture diagram, described in the next paragraph.

For models trained on multi-image input, the backbone is applied on each individual image, and then max pooling is applied in the temporal dimension, i.e., across the multiple aligned images. Single-image models input an individual image.

This package allows you to load the backbone or backbone+FPN using a model checkpoint ID from the tables below.

MODEL_CHECKPOINT_ID = "Sentinel2_SwinB_SI_RGB"
model = weights_manager.get_pretrained_model(MODEL_CHECKPOINT_ID)
model = weights_manager.get_pretrained_model(MODEL_CHECKPOINT_ID, fpn=True)

The output of the model is the multi-scale feature map (either from the backbone or from the FPN).

For a complete fine-tuning example, see our tutorial on fine-tuning the pre-trained model on EuroSAT.

Installation

conda create --name satlaspretrain python==3.9
conda activate satlaspretrain
pip install satlaspretrain-models

Available Pretrained Models

The tables below list available model checkpoint IDs (like Sentinel2_SwinB_SI_RGB). Checkpoints are released under ODC-BY. This package will download model checkpoints automatically, but you can download them directly using the links below if desired.

Sentinel-2 Pretrained Models

Single-image, RGB Multi-image, RGB
Swin-v2-Base Sentinel2_SwinB_SI_RGB Sentinel2_SwinB_MI_RGB
Swin-v2-Tiny Sentinel2_SwinT_SI_RGB Sentinel2_SwinT_MI_RGB
Resnet50 Sentinel2_Resnet50_SI_RGB Sentinel2_Resnet50_MI_RGB
Resnet152 Sentinel2_Resnet152_SI_RGB Sentinel2_Resnet152_MI_RGB
Single-image, MS Multi-image, MS
Swin-v2-Base Sentinel2_SwinB_SI_MS Sentinel2_SwinB_MI_MS
Swin-v2-Tiny Sentinel2_SwinT_SI_MS Sentinel2_SwinT_MI_MS
Resnet50 Sentinel2_Resnet50_SI_MS Sentinel2_Resnet50_MI_MS
Resnet152 Sentinel2_Resnet152_SI_MS Sentinel2_Resnet152_MI_MS

Sentinel-1 Pretrained Models

Single-image, VH+VV Multi-image, VH+VV
Swin-v2-Base Sentinel1_SwinB_SI Sentinel1_SwinB_MI

Landsat 8/9 Pretrained Models

Single-image, all bands Multi-image, all bands
Swin-v2-Base Landsat_SwinB_SI Landsat_SwinB_MI

Aerial (0.5-2m/px high-res imagery) Pretrained Models

Single-image, RGB Multi-image, RGB
Swin-v2-Base Aerial_SwinB_SI Aerial_SwinB_MI

Single-image models learn strong representations for individual satellite or aerial images, while multi-image models use multiple image captures of the same location for added robustness when making predictions about static objects. In multi-image models, feature maps from the backbone are passed through temporal max pooling, so the backbone itself is still applied on individual images, but is trained to provide strong representations after the temporal max pooling step. See ModelArchitecture.md for more details.

Sentinel-2 RGB models input the B4, B3, and B2 bands only, while the multi-spectral (MS) models input 9 bands. The aerial (0.5-2m/px high-res imagery) models input RGB NAIP and other high-res images, and we have found them to be effective on aerial imagery from a variety of sources. Landsat models input B1-B11 (all bands). Sentinel-1 models input VV and VH bands. See Normalization.md for details on how pixel values should be normalized for input to the pre-trained models.

Usage Examples

First initialize a Weights instance:

import satlaspretrain_models
import torch
weights_manager = satlaspretrain_models.Weights()

Then choose a model_identifier from the tables above to specify the pretrained model you want to load. Below are examples showing how to load in a few of the available models.

Pretrained single-image Sentinel-2 RGB model, backbone only:

model = weights_manager.get_pretrained_model(model_identifier="Sentinel2_SwinB_SI_RGB")

# Expected input is a portion of a Sentinel-2 L1C TCI image.
# The 0-255 pixel values should be divided by 255 so they are 0-1.
# tensor = tci_image[None, :, :, :] / 255
tensor = torch.zeros((1, 3, 512, 512), dtype=torch.float32)

# Since we only loaded the backbone, it outputs feature maps from the Swin-v2-Base backbone.
output = model(tensor)
print([feature_map.shape for feature_map in output])
# [torch.Size([1, 128, 128, 128]), torch.Size([1, 256, 64, 64]), torch.Size([1, 512, 32, 32]), torch.Size([1, 1024, 16, 16])]

Pretrained single-image Sentinel-1 model, backbone+FPN

model = weights_manager.get_pretrained_model("Sentinel1_SwinB_SI", fpn=True)

# Expected input is a portion of a Sentinel-1 vh+vv image (in that order).
# The 16-bit pixel values should be divided by 255 and clipped to 0-1 (any pixel values greater than 255 become 1).
# tensor = torch.clip(torch.stack([vh_image, vv_image], dim=0)[None, :, :, :] / 255, 0, 1)
tensor = torch.zeros((1, 2, 512, 512), dtype=torch.float32)

# The model outputs feature maps from the FPN.
output = model(tensor)
print([feature_map.shape for feature_map in output])
# [torch.Size([1, 128, 128, 128]), torch.Size([1, 128, 64, 64]), torch.Size([1, 128, 32, 32]), torch.Size([1, 128, 16, 16])]

Prediction heads

Although the checkpoints include prediction head parameters, these heads are task-specific, so loading the head parameters is not supported in this repository. Computing outputs from the pre-trained prediction heads is supported in the dataset codebase.

For convenience when fine-tuning on certain types of tasks, though, satlaspretrain_models supports attaching certain heads (initialized randomly) to the pre-trained model:

# Backbone and FPN parameters initialized from checkpoint, head parameters initialized randomly.
model = weights_manager.get_pretrained_model(MODEL_CHECKPOINT_ID, fpn=True, head=satlaspretrain_models.Head.CLASSIFY, num_categories=2)

The following head architectures are available:

  • Segmentation: U-Net Decoder w/ Cross Entropy loss
  • Detection: Faster R-CNN Decoder
  • Instance Segmentation: Mask R-CNN Decoder
  • Regression: U-Net Decoder w/ L1 loss
  • Classification: Pooling + Linear layers
  • Multi-label Classification: Pooling + Linear layers

Pretrained multi-image aerial model, backbone + FPN + classification head:

# num_categories is the number of categories to predict.
# All heads are randomly initialized and provided only for convenience for fine-tuning.
model = weights_manager.get_pretrained_model("Aerial_SwinB_MI", fpn=True, head=satlaspretrain_models.Head.CLASSIFY, num_categories=2)

# Expected input is 8-bit (0-255) aerial images at 0.5 - 2 m/pixel.
# The 0-255 pixel values should be divided by 255 so they are 0-1.
# This multi-image model is trained to input 4 images but should perform well with different numbers of images.
# tensor = torch.stack([rgb_image1, rgb_image2], dim=0)[None, :, :, :, :] / 255
tensor = torch.zeros((1, 4, 3, 512, 512), dtype=torch.float32)

# The head needs to be fine-tuned on a downstream classification task.
# It outputs classification probabilities.
model.eval()
output = model(tensor.reshape(1, 4*3, 512, 512))
print(output)
# tensor([[0.0266, 0.9734]])

Pretrained multi-image Landsat model, backbone + FPN + detection head

# num_categories is the number of bounding box detection categories.
# All heads are randomly initialized and provided only for convenience for fine-tuning.
model = weights_manager.get_pretrained_model("Landsat_SwinB_MI", fpn=True, head=satlaspretrain_models.Head.DETECT, num_categories=5)

# Expected input is Landsat B1-B11 stacked in order.
# This multi-image model is trained to input 8 images but should perform well with different numbers of images.
# The 16-bit pixel values are normalized as follows:
# landsat_images = torch.stack([landsat_image1, landsat_image2], dim=0)
# tensor = torch.clip(landsat_images[None, :, :, :, :]-4000)/16320, 0, 1)
tensor = torch.zeros((1, 8, 11, 512, 512), dtype=torch.float32)

# The head needs to be fine-tuned on a downstream object detection task.
# It outputs bounding box detections.
model.eval()
output = model(tensor.reshape(1, 8*11, 512, 512))
print(output)
#[{'boxes': tensor([[ 67.0772, 239.2646, 95.6874, 16.3644], ...]),
# 'labels': tensor([3, ...]),
# 'scores': tensor([0.5443, ...])}]

Demos

We provide a demo showing how to finetune a SatlasPretrain Sentinel-2 model on the EuroSAT classification task.

We also provide a torchgeo demo, showing how to load SatlasPretrain weights into a model, download a dataset, initialize a trainer, and finetune the model on the UCMerced classification task. Note: a separate conda environment must be initialized to run this demo, see details in the notebook.

Tests

There are tests to test loading pretrained models and one to test randomly initialized models.

To run the tests, run the following command from the root directory: pytest tests/

Contact

If you have any questions, please email satlas@allenai.org or open an issue here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

satlaspretrain_models-0.3.1.tar.gz (21.0 kB view details)

Uploaded Source

Built Distribution

satlaspretrain_models-0.3.1-py3-none-any.whl (18.2 kB view details)

Uploaded Python 3

File details

Details for the file satlaspretrain_models-0.3.1.tar.gz.

File metadata

  • Download URL: satlaspretrain_models-0.3.1.tar.gz
  • Upload date:
  • Size: 21.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.8

File hashes

Hashes for satlaspretrain_models-0.3.1.tar.gz
Algorithm Hash digest
SHA256 bda93e30494372e59d2af604683143dec13d19712d953d94dc466d34bb24559c
MD5 c0f0b24a596838d3301dda2837a66983
BLAKE2b-256 d02c4cfaf36d603d7d3cd73bd0ec7aea3fec566805ae30b63169e52099b07b0e

See more details on using hashes here.

File details

Details for the file satlaspretrain_models-0.3.1-py3-none-any.whl.

File metadata

File hashes

Hashes for satlaspretrain_models-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bb1aee7bd2cb78581b4e8b2e0f4dc683f2b4e531b1362bec3f589b7edc56fd0b
MD5 98a023facc26f2c05fee91928d8bd8e7
BLAKE2b-256 439f651c3bfc0f69c7bced91b2239d1a4c8d24e97bb89aef4dcb75955898464b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page