Skip to main content

Distributed Sparse Neural Network implementation on COINSTAC (GPU).

Project description

coinstac-sparse-dinunet

Distributed Sparse Neural Network implementation on COINSTAC.

PyPi version versions

pip install coinstac-sparse-dinunet

Specify supported packages like pytorch & torchvision in a requirements.txt file

Highlights:

1. Creates sparse network based on single shot pruning SNIP algorithm (https://arxiv.org/abs/1810.02340). 
2. Performs distributed training and optimization with reduced bandwidth.
3. Automatic data splitting/k-fold cross validation.
4. Automatic model checkpointing.
5. GPU enabled local sites.
6. Customizable metrics(w/Auto serialization between nodes) to work with any schemes.
7. We can integrate any custom reduction and learning mechanism by extending coinstac_sparse_dinunet.distrib.reducer/learner.
...

Running an analysis in the coinstac App.

Add a new NN computation to COINSTAC (Development guide):

imports

from coinstac_sparse_dinunet import COINNDataset, COINNTrainer, COINNLocal
from coinstac_sparse_dinunet.metrics import COINNAverages, Prf1a

1. Define Data Loader

class MyDataset(COINNDataset):
    def __init__(self, **kw):
        super().__init__(**kw)
        self.labels = None

    def load_index(self, id, file):
        data_dir = self.path(id, 'data_dir') # data_dir comes from inputspecs.json
        ...
        self.indices.append([id, file])

    def __getitem__(self, ix):
        id, file = self.indices[ix]
        data_dir = self.path(id, 'data_dir') # data_dir comes from inputspecs.json
        label_dir = self.path(id, 'label_dir') # label_dir comes from inputspecs.json
        ...
        # Logic to load, transform single data item.
        ...
        return {'inputs':.., 'labels': ...}

2. Define Trainer

class MyTrainer(COINNTrainer):
    def __init__(self, **kw):
        super().__init__(**kw)

    def _init_nn_model(self):
        self.nn['model'] = MYModel(in_size=self.cache['input_size'], out_size=self.cache['num_class'])
    
    
   def single_iteration_for_masking(self, model, batch):
    
        #Defines sparsity level, loss function and other parameters to perform masking using SNIP
        
        sparsity_level = 0.85
        inputs, labels = batch['inputs'].to(self.device['gpu']).float(), batch['labels'].to(self.device['gpu']).long()
        indices = batch['ix'].to(self.device['gpu']).long()
        model.zero_grad()
        out = F.log_softmax(model.forward(inputs), 1)
        loss = F.nll_loss(out, labels)
        return {'out': out, 'loss': loss, 'indices': indices, 'sparsity_level': sparsity_level}



    def iteration(self, batch):
        inputs, labels = batch['inputs'].to(self.device['gpu']).float(), batch['labels'].to(self.device['gpu']).long()

        out = F.log_softmax(self.nn['model'](inputs), 1)
        loss = F.nll_loss(out, labels)
        _, predicted = torch.max(out, 1)
        score = self.new_metrics()
        score.add(predicted, labels)
        val = self.new_averages()
        val.add(loss.item(), len(inputs))
        return {'out': out, 'loss': loss, 'averages': val,
                'metrics': score, 'prediction': predicted}

Advanced use cases:

Referenced from Trends Center coinstac-dinunet repository (https://github.com/trendscenter/coinstac-dinunet)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

coinstac-sparse-dinunet-16.0.3.tar.gz (34.3 kB view details)

Uploaded Source

File details

Details for the file coinstac-sparse-dinunet-16.0.3.tar.gz.

File metadata

File hashes

Hashes for coinstac-sparse-dinunet-16.0.3.tar.gz
Algorithm Hash digest
SHA256 e66ec0f76cd644b171b0967685ac7d6302172a1738c60c518d3a6ead95cd1ca1
MD5 819c6eaf15d83b009874c3991825d60d
BLAKE2b-256 bcec9e975290ebaff39d3543e4d33639e756d7f9d99c6c3379b4babfabea0995

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page