Skip to main content

A deep learning framework built on top of PyTorch.

Project description

PyTorch-Cosma

Overview

PyTorch-Cosma is a deep learning framework built on top of PyTorch, designed to facilitate the creation, training, and visualization of neural networks. The framework supports various types of models, including convolutional autoencoders, graph neural networks, and vision transformers. It also provides utilities for latent space exploration and graph visualization.

Project Structure

├── pytorch_cosma/
│   ├── config_validation.py
│   ├── autoencoders.py
│   ├── basic_layers.py
│   ├── utils.py
│   ├── vision_transformer.py
│   ├── graphs.py
│   ├── latent_space.py
│   ├── model_yaml_parser.py
│   ├── network_construction.py
│   └── twin_dataset_maker.py
├── configs/
│   ├── example_conv_autoencoder.yaml
│   ├── example_gatconv_network.yaml
│   └── ...
├── examples/
│   ├── mnist_autoencode_and_latent_inspection.py
│   └── ...
├── unit_testing/
│   └── examples/
│       └── test_mnist_autoencode_and_latent_inspection.py
│       └── ...
├── data/
├── README.md
├── .gitignore
├── .vscode/
│   ├── launch.json
│   └── settings.json

Installation

  1. Create a virtual environment and activate it:
    python -m venv venv
    source venv/bin/activate  # On Windows, use `venv\Scripts\activate`
    

You can either install the package directly from PyPI or clone the repository and install the dependencies manually:

Option 1: Install from PyPI

  1. Install the package:
    pip install pytorch-cosma
    

Option 2: Clone the Repository

  1. Clone the repository:

    git clone https://github.com/yourusername/pytorch-cosma.git
    cd pytorch-cosma
    
  2. Install the required dependencies:

    pip install .
    

Usage

Configuration

Model architectures are defined using YAML configuration files. Examples can be found in the configs/ directory.

Training a Model

To train a model, use the provided example scripts or create your own. Below is an example of training an autoencoder on the MNIST dataset:

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from pytorch_cosma.config_validation import ConfigModel
from pytorch_cosma.latent_space import LatentSpaceExplorer, Visualizer
from pytorch_cosma.model_yaml_parser import YamlParser
from pytorch_cosma.network_construction import BaseModel

# Define device (GPU/CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Load configuration from YAML
raw_config = YamlParser("configs/example_conv_autoencoder.yaml").parse()

# Validate configuration
validated_config = ConfigModel(**raw_config).to_dict()

# Create model from configuration
model = BaseModel(validated_config, use_reconstruction=True, device=device)

# Train the model
model.train_model(train_loader, nn.MSELoss(), torch.optim.Adam(model.parameters(), lr=1e-3), epochs=5)

Latent Space Exploration

To explore the latent space of a trained model:

# Latent space exploration
explorer = LatentSpaceExplorer(model, train_loader, device)
latent_points, labels_points, all_inputs = explorer.extract_latent_space()
reduced_dimensionality = explorer.reduce_dimensionality(latent_points)

# Randomly sample points for visualization
sample_size = 100
indices = np.random.choice(len(reduced_dimensionality), size=sample_size, replace=False)
reduced_dimensionality = reduced_dimensionality[indices]
selected_inputs = all_inputs[indices]

# Visualize latent space
visualizer = Visualizer(reduced_dimensionality, labels_points, selected_inputs)
hover_images = visualizer.generate_hover_images()
app = visualizer.create_dash_app(hover_images)
app.run_server(debug=True)

Graph Visualization

To visualize a graph:

import networkx as nx
import torch

from pytorch_cosma.graphs import GraphVisualizer

# Create a sample graph
G = nx.karate_club_graph()

# Generate random predictions and ground truth
node_predictions = torch.randint(0, 2, (len(G.nodes),))
node_ground_truth = torch.randint(0, 2, (len(G.nodes),))

# Initialize the visualizer
visualizer = GraphVisualizer(G, node_predictions, node_ground_truth, subset_size=10)

# Create and run the Dash app
app = visualizer.create_dash_app()
app.run_server(debug=True)

Unit Testing

Unit tests are located in the unit_testing/ directory. To run the tests:

python -m unittest discover unit_testing/

License

This project is licensed under the MIT License. See the LICENSE file for details.

Acknowledgements

This project uses the following libraries:

Contact

For questions or suggestions, please open an issue or contact the repository owner at mahmoud.raad@yahoo.co.uk

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_cosma-0.1.9.tar.gz (23.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_cosma-0.1.9-py3-none-any.whl (24.2 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_cosma-0.1.9.tar.gz.

File metadata

  • Download URL: pytorch_cosma-0.1.9.tar.gz
  • Upload date:
  • Size: 23.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.10

File hashes

Hashes for pytorch_cosma-0.1.9.tar.gz
Algorithm Hash digest
SHA256 4c7ebc61dc265c65b5decb6541b73d07d54b872b4eab3c5b0801d9ccc13d3402
MD5 f8a10e7837b138241cf801aeb7447d84
BLAKE2b-256 2dcaf53ed6f936a4c83560af23c6a9b8411deb57958fef86f98f214ca39cced7

See more details on using hashes here.

File details

Details for the file pytorch_cosma-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: pytorch_cosma-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 24.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.10

File hashes

Hashes for pytorch_cosma-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 6e089abd587930633e71476bb0a2c0cbe5db5dc36c3cb3f19798a8cd9db8b21b
MD5 a1d9a8ead0d0529c74fa1b810d6c5846
BLAKE2b-256 c0be49a70c7c909e08d5c53d9bdf9b7f33e94a8f2f017f19cf046bdb3309ba31

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page