Particle Swarm Optimization using the torch.optim API.
Project description
Torch PSO
Particle Swarm Optimization is an optimization technique that iteratively attempts to improve a list of candidate solutions. Each candidate solution is called a "particle", and collectively they are called a "swarm". In each step of the optimization, each particle moves in a random directly while simultaneously being pulled towards the other particles in the swarm. A simple introduction to the algorithm can be found on its Wikipedia article.
This package implements the Particle Swarm Optimization using the PyTorch Optimizer API, making it compatible with most pre-existing Torch training loops.
Installation
To install Torch PSO using PyPI, run the following command:
$ pip install torch-pso
Getting Started
To use the ParticleSwarmOptimizer, simply import it, and use it as with any other PyTorch Optimizer. Hyperparameters of the optimizer can also be specified. In practice, most PyTorch tutorials could be used to create a use-case, simply substituting the ParticleSwarmOptimizer for any other optimizer. A simplified use-case can be seen below, which trains a simple neural network to match its output to a target.
import torch
from torch.nn import Sequential, Linear, MSELoss
from torch_pso import ParticleSwarmOptimizer
net = Sequential(Linear(10,100), Linear(100,100), Linear(100,10))
optim = ParticleSwarmOptimizer(net.parameters(),
inertial_weight=0.5,
num_particles=100,
max_param_value=1,
min_param_value=-1)
criterion = MSELoss()
target = torch.rand((10,)).round()
x = torch.rand((10,))
for _ in range(100):
def closure():
# Clear any grads from before the optimization step, since we will be changing the parameters
optim.zero_grad()
return criterion(net(x), target)
optim.step(closure)
print('Prediciton', net(x))
print('Target ', target)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_pso-0.0.1.dev1.tar.gz
.
File metadata
- Download URL: torch_pso-0.0.1.dev1.tar.gz
- Upload date:
- Size: 8.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6f9534398b9e4b9abb2e8067b209c6d1b1eb2407c5f4bde15c4299ed8ad06d5e |
|
MD5 | 7f31f73111e25db36ff1cf6a9562fe8d |
|
BLAKE2b-256 | 59619e28504a2d444b9542e5f84869e3cf93dc0ba26389b2eb36722cd9d77e4c |
File details
Details for the file torch_pso-0.0.1.dev1-py3-none-any.whl
.
File metadata
- Download URL: torch_pso-0.0.1.dev1-py3-none-any.whl
- Upload date:
- Size: 6.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c26d6a8c7f59aa704fda2a69cecae50cdf7e351c2b210e2a940c0d7525bad2a7 |
|
MD5 | 4c7e9028f56005536522bdee9530d7f9 |
|
BLAKE2b-256 | 1ab2644563f27e39124c9445f91c66dc2b79b25a91cd9a7c88037e66c21aa3c7 |