A deep learning framework built on top of PyTorch.
Project description
PyTorch-Cosma
Overview
PyTorch-Cosma is a deep learning framework built on top of PyTorch, designed to facilitate the creation, training, and visualization of neural networks. The framework supports various types of models, including convolutional autoencoders, graph neural networks, and vision transformers. It also provides utilities for latent space exploration and graph visualization.
Project Structure
├── pytorch_cosma/
│ ├── config_validation.py
│ ├── autoencoders.py
│ ├── basic_layers.py
│ ├── utils.py
│ ├── vision_transformer.py
│ ├── graphs.py
│ ├── latent_space.py
│ ├── model_yaml_parser.py
│ ├── network_construction.py
│ └── twin_dataset_maker.py
├── configs/
│ ├── example_conv_autoencoder.yaml
│ ├── example_gatconv_network.yaml
│ └── ...
├── examples/
│ ├── mnist_autoencode_and_latent_inspection.py
│ └── ...
├── unit_testing/
│ └── examples/
│ └── test_mnist_autoencode_and_latent_inspection.py
│ └── ...
├── data/
├── README.md
├── .gitignore
├── .vscode/
│ ├── launch.json
│ └── settings.json
Installation
- Create a virtual environment and activate it:
python -m venv venv source venv/bin/activate # On Windows, use `venv\Scripts\activate`
You can either install the package directly from PyPI or clone the repository and install the dependencies manually:
Option 1: Install from PyPI
- Install the package:
pip install pytorch-cosma
Option 2: Clone the Repository
-
Clone the repository:
git clone https://github.com/yourusername/pytorch-cosma.git cd pytorch-cosma
-
Install the required dependencies:
pip install .
Usage
Configuration
Model architectures are defined using YAML configuration files. Examples can be found in the configs/ directory.
Training a Model
To train a model, use the provided example scripts or create your own. Below is an example of training an autoencoder on the MNIST dataset:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pytorch_cosma.config_validation import ConfigModel
from pytorch_cosma.latent_space import LatentSpaceExplorer, Visualizer
from pytorch_cosma.model_yaml_parser import YamlParser
from pytorch_cosma.network_construction import BaseModel
# Define device (GPU/CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# Load configuration from YAML
raw_config = YamlParser("configs/example_conv_autoencoder.yaml").parse()
# Validate configuration
validated_config = ConfigModel(**raw_config).to_dict()
# Create model from configuration
model = BaseModel(validated_config, use_reconstruction=True, device=device)
# Train the model
model.train_model(train_loader, nn.MSELoss(), torch.optim.Adam(model.parameters(), lr=1e-3), epochs=5)
Latent Space Exploration
To explore the latent space of a trained model:
# Latent space exploration
explorer = LatentSpaceExplorer(model, train_loader, device)
latent_points, labels_points, all_inputs = explorer.extract_latent_space()
reduced_dimensionality = explorer.reduce_dimensionality(latent_points)
# Randomly sample points for visualization
sample_size = 100
indices = np.random.choice(len(reduced_dimensionality), size=sample_size, replace=False)
reduced_dimensionality = reduced_dimensionality[indices]
selected_inputs = all_inputs[indices]
# Visualize latent space
visualizer = Visualizer(reduced_dimensionality, labels_points, selected_inputs)
hover_images = visualizer.generate_hover_images()
app = visualizer.create_dash_app(hover_images)
app.run_server(debug=True)
Graph Visualization
To visualize a graph:
import networkx as nx
import torch
from pytorch_cosma.graphs import GraphVisualizer
# Create a sample graph
G = nx.karate_club_graph()
# Generate random predictions and ground truth
node_predictions = torch.randint(0, 2, (len(G.nodes),))
node_ground_truth = torch.randint(0, 2, (len(G.nodes),))
# Initialize the visualizer
visualizer = GraphVisualizer(G, node_predictions, node_ground_truth, subset_size=10)
# Create and run the Dash app
app = visualizer.create_dash_app()
app.run_server(debug=True)
Unit Testing
Unit tests are located in the unit_testing/ directory. To run the tests:
python -m unittest discover unit_testing/
License
This project is licensed under the MIT License. See the LICENSE file for details.
Acknowledgements
This project uses the following libraries:
Contact
For questions or suggestions, please open an issue or contact the repository owner at mahmoud.raad@yahoo.co.uk
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_cosma-0.1.10.tar.gz.
File metadata
- Download URL: pytorch_cosma-0.1.10.tar.gz
- Upload date:
- Size: 23.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
900e214ec9295b1c9cf487311996973fc445e69bad9094a4b34206de8fdbb048
|
|
| MD5 |
d294d32a9d425097544bd04376189bfe
|
|
| BLAKE2b-256 |
0a9ce3ea9ff7747f45b60d4e28d849b8b30019a166dab1d7ae526ddd48b0973b
|
File details
Details for the file pytorch_cosma-0.1.10-py3-none-any.whl.
File metadata
- Download URL: pytorch_cosma-0.1.10-py3-none-any.whl
- Upload date:
- Size: 24.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
96fb3ed867c9296d2222540372d411a993c718c90dd675630a89466024cd26ee
|
|
| MD5 |
a0a22bf598eea791dfe7a1281abc8799
|
|
| BLAKE2b-256 |
006386d98ccd2f550219366817579d10ddfde5eb7b2a7b19af4808ad79dfae9c
|