Skip to main content

Apply a torch model to a single or multiple datapoints.

Project description

torchapply

Apply a torch model to some datapoints.

Here's an example:

import torch
from torch import tensor


class Main(torch.nn.Module):
    def __init__(self, model_0, model_1):
        super().__init__()
        self.model_0 = model_0
        self.model_1 = model_1
        self.dictionary = {'apple': 0, 'orange': 1, 'pear': 2}

    def preprocess(self, arg):
        return [
            {
                'a': {'b': self.dictionary[arg[0]['a']['b']]},
                'c': self.dictionary[arg[0]['c']]
            },
            torch.tensor([self.dictionary[x] for x in arg[1]])
        ]

    def forward(self, args):
        return self.model_0(args[0]), self.model_1(args[1])
      
    def postprocess(self, arg):
        total = [arg[0]['a']['b'].sum(), arg[0]['c'].sum(), arg[1].sum()]
        return {'score': sum(total), 'decision': sum(total) > 0}
        

class ModelA(torch.nn.Module):
    def forward(self, args):
        return {'b': torch.randn(args['b'].shape[0], 10)}


class ModelC(torch.nn.Module):
    def forward(self, args):
        return torch.randn(args.shape[0], 10)


class Model1(torch.nn.Module):
    def forward(self, args):
        return torch.randn(args.shape[0], 10)


class Model0(torch.nn.Module):
    def __init__(self, model_a, model_c):
        super().__init__()
        self.model_a = model_a
        self.model_c = model_c

    def forward(self, args):
        return {'a': self.model_a(args['a']), 'c': self.model_c(args['c'])}


model = Main(
    model_0=Model0(
        model_a=ModelA(),
        model_c=ModelC()
    ),
    model_1=Model1()
)

Apply to a single datapoint:

from torchapply import apply_model

apply_model(
   model, 
   ({'a': {'b': 'orange'}, 'c': 'pear'}, ('apple', 'apple')),
   single=True
)

Apply to multiple datapoints:

from torchapply import apply_model

apply_model(
    model,
    [({'a': {'b': 'orange'}, 'c': 'pear'}, ('apple', 'apple')) for _ in range(10)],
    single=False
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchapply-0.0.4.tar.gz (4.2 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page