Skip to main content

Particle Swarm Optimization using the torch.optim API.

Project description

Torch PSO

Particle Swarm Optimization is an optimization technique that iteratively attempts to improve a list of candidate solutions. Each candidate solution is called a "particle", and collectively they are called a "swarm". In each step of the optimization, each particle moves in a random directly while simultaneously being pulled towards the other particles in the swarm. A simple introduction to the algorithm can be found on its Wikipedia article.

This package implements the Particle Swarm Optimization using the PyTorch Optimizer API, making it compatible with most pre-existing Torch training loops.

Installation

To install Torch PSO using PyPI, run the following command:

$ pip install torch-pso

Getting Started

To use the ParticleSwarmOptimizer, simply import it, and use it as with any other PyTorch Optimizer. Hyperparameters of the optimizer can also be specified. In practice, most PyTorch tutorials could be used to create a use-case, simply substituting the ParticleSwarmOptimizer for any other optimizer. A simplified use-case can be seen below, which trains a simple neural network to match its output to a target.

import torch
from torch.nn import Sequential, Linear, MSELoss
from torch_pso import ParticleSwarmOptimizer

net = Sequential(Linear(10,100), Linear(100,100), Linear(100,10))
optim = ParticleSwarmOptimizer(net.parameters(),
                               inertial_weight=0.5,
                               num_particles=100,
                               max_param_value=1,
                               min_param_value=-1)
criterion = MSELoss()
target = torch.rand((10,)).round()

x = torch.rand((10,))
for _ in range(100):
    
    def closure():
        # Clear any grads from before the optimization step, since we will be changing the parameters
        optim.zero_grad()  
        return criterion(net(x), target)
    
    optim.step(closure)
    print('Prediciton', net(x))
    print('Target    ', target)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_pso-1.2.1.tar.gz (13.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_pso-1.2.1-py3-none-any.whl (20.2 kB view details)

Uploaded Python 3

File details

Details for the file torch_pso-1.2.1.tar.gz.

File metadata

  • Download URL: torch_pso-1.2.1.tar.gz
  • Upload date:
  • Size: 13.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torch_pso-1.2.1.tar.gz
Algorithm Hash digest
SHA256 27803976b9fd106c918e80f0f2e563d1ac8fd2c69ab05a34d23f378dca24d37f
MD5 a77d5c4f1b24143d99871f452cce6dac
BLAKE2b-256 f7372fd738ceb78d0b4f71792ed838752a1757bab789e91b752ba98d9399c38f

See more details on using hashes here.

File details

Details for the file torch_pso-1.2.1-py3-none-any.whl.

File metadata

  • Download URL: torch_pso-1.2.1-py3-none-any.whl
  • Upload date:
  • Size: 20.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torch_pso-1.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f050b9121127a50804786a152604ecffd66310fb3b5161c78cecf1e011081605
MD5 d3fce6d9c6141a96abdc37741eed7fcd
BLAKE2b-256 ea92da167d0cbc3d325d482b8cb787983fd9af3985fa286ac10f76a47804e8d4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page