Skip to main content

Particle Swarm Optimization using the torch.optim API.

Project description

Torch PSO

Particle Swarm Optimization is an optimization technique that iteratively attempts to improve a list of candidate solutions. Each candidate solution is called a "particle", and collectively they are called a "swarm". In each step of the optimization, each particle moves in a random directly while simultaneously being pulled towards the other particles in the swarm. A simple introduction to the algorithm can be found on its Wikipedia article.

This package implements the Particle Swarm Optimization using the PyTorch Optimizer API, making it compatible with most pre-existing Torch training loops.

Installation

To install Torch PSO using PyPI, run the following command:

$ pip install torch-pso

Getting Started

To use the ParticleSwarmOptimizer, simply import it, and use it as with any other PyTorch Optimizer. Hyperparameters of the optimizer can also be specified. In practice, most PyTorch tutorials could be used to create a use-case, simply substituting the ParticleSwarmOptimizer for any other optimizer. A simplified use-case can be seen below, which trains a simple neural network to match its output to a target.

import torch
from torch.nn import Sequential, Linear, MSELoss
from torch_pso import ParticleSwarmOptimizer

net = Sequential(Linear(10,100), Linear(100,100), Linear(100,10))
optim = ParticleSwarmOptimizer(net.parameters(),
                               inertial_weight=0.5,
                               num_particles=100,
                               max_param_value=1,
                               min_param_value=-1)
criterion = MSELoss()
target = torch.rand((10,)).round()

x = torch.rand((10,))
for _ in range(100):
    
    def closure():
        # Clear any grads from before the optimization step, since we will be changing the parameters
        optim.zero_grad()  
        return criterion(net(x), target)
    
    optim.step(closure)
    print('Prediciton', net(x))
    print('Target    ', target)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_pso-1.0.0.tar.gz (8.5 kB view details)

Uploaded Source

Built Distribution

torch_pso-1.0.0-py3-none-any.whl (6.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_pso-1.0.0.tar.gz.

File metadata

  • Download URL: torch_pso-1.0.0.tar.gz
  • Upload date:
  • Size: 8.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for torch_pso-1.0.0.tar.gz
Algorithm Hash digest
SHA256 f9b1d91aa3d6ac74750d0efb0e4450c29aed115cc82c6edf69d4c8d9cad31514
MD5 7507ccb1ba4dc60098a48bc5ff03c87f
BLAKE2b-256 67bf2b4309018df0c7f685b5b7b7677405c172a7b0aba1c8297c263d9592c599

See more details on using hashes here.

File details

Details for the file torch_pso-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: torch_pso-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 6.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for torch_pso-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3572ae2b28933f5a23f939235b8bbc253b9a898742acff54a1fceb234332c193
MD5 32b36924289cd0dfae73d10cb0d8300e
BLAKE2b-256 dd37905d8f28472b8dd1991085753d757364eb2497784f3b2eb7041cf492dc5c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page