Skip to main content

No project description provided

Project description

Quick Start

Three classes are available: LearnableActivation, TabularDenseNet, and CustomLoss.

LearnableActivation is, as the name suggests, a learnable activation function, which take in n inputs and create n outputs. It is first initialized as a linear activation function, where the shape is learned through placing various points across a predefined width, and interpolated over those points. Points outside of predefined width will be extrapolated.

Example LearnableActivation initialization:

from learnable_activation import LearnableActivation
activation = LearnableActivation(input_size, width, density)

The width parameter is default initialized as 10 which interpolates over the range -5 and 5, and extrapolate outside that range. The density parameter determine how many data points will be packed in an interval with width of 1. Default density is 1.

Class Definitions

class LearnableActivation(nn.Module):
    def __init__(self, num_features, width=10, density=1):
        super(LearnableActivation, self).__init__()
        self.num_features = num_features
        self.width = width
        self.density = density

        num_control_points = width * density + 1
        range_values = torch.linspace(-width / 2, width / 2, num_control_points)

        self.interp_tensor = nn.Parameter(range_values.repeat(num_features, 1))
        self.register_buffer("feature_idx", torch.arange(self.num_features).view(1, -1))
        self.location = self.width * self.density / 2
        self.max_index = self.width * self.density

    def forward(self, x):
        scaled_x = (x * self.density) + self.location

        lower_idx = torch.clamp(scaled_x.long(), min=0, max=self.max_index - 1)
        upper_idx = lower_idx + 1

        lower_value = self.interp_tensor[self.feature_idx, lower_idx]
        upper_value = self.interp_tensor[self.feature_idx, upper_idx]

        interpolation_weight = scaled_x - lower_idx.float()
        return torch.lerp(lower_value, upper_value, interpolation_weight)
class TabularDenseNet(nn.Module):
    def __init__(self, input_size, output_size, num_layers=2, width=10, density=1):
        super(TabularDenseNet, self).__init__()

        self.layers = nn.ModuleList()
        self.activations = nn.ModuleList()

        for i in range(num_layers):
            self.activations.append(LearnableActivation(input_size, width, density))
            self.layers.append(nn.Linear(input_size, input_size, bias=False))

            with torch.no_grad():
                self.layers[-1].weight.copy_(torch.eye(input_size))

            input_size *= 2

        self.activation_second_last_layer = LearnableActivation(input_size, width, density)
        self.last_layer = nn.Linear(input_size, output_size, bias=False)

        with torch.no_grad():
            self.last_layer.weight.copy_(torch.zeros(output_size, input_size))

        self.activation_last_layer = LearnableActivation(output_size, width, density)

    def forward(self, x):
        outputs = [x]

        for i in range(len(self.layers)):
            concatenated_outputs = torch.cat(outputs, dim=1)
            outputs.append(self.layers[i](self.activations[i](concatenated_outputs)))

        outputs = torch.cat(outputs, dim=1)
        outputs = self.activation_second_last_layer(outputs)
        outputs = self.last_layer(outputs)
        outputs = self.activation_last_layer(outputs)
        return outputs.squeeze()
class CustomLoss(nn.Module):
    def __init__(self, criterion, l1_lambda=0.0, l2_lambda=0.0, f1_lambda=0.0, f2_lambda=0.0):
        super(CustomLoss, self).__init__()
        self.criterion = criterion
        self.l1_lambda = l1_lambda
        self.l2_lambda = l2_lambda
        self.f1_lambda = f1_lambda
        self.f2_lambda = f2_lambda

    def forward(self, outputs, labels, model):
        l1_norm = sum(
            p.abs().sum()
            for name, module in model.named_modules()
            if isinstance(module, nn.Linear)
            for p in module.parameters()
            if "bias" not in name
        )
        l1_loss = self.l1_lambda * l1_norm

        l2_norm = sum(
            p.pow(2.0).sum()
            for name, module in model.named_modules()
            if isinstance(module, nn.Linear)
            for p in module.parameters()
            if "bias" not in name
        )
        l2_loss = self.l2_lambda * l2_norm

        f1_loss = 0
        f2_loss = 0
        for name, module in model.named_modules():
            if isinstance(module, LearnableActivation):
                interp_tensor = module.interp_tensor

                f1_diff = interp_tensor[:, 1:] - interp_tensor[:, :-1]
                f1_loss += self.f1_lambda * f1_diff.abs().sum()

                f2_diff = f1_diff[:, 1:] - f1_diff[:, :-1]
                f2_loss += self.f2_lambda * f2_diff.abs().sum()

        return self.criterion(outputs, labels) + l1_loss + l2_loss + f1_loss + f2_loss

    def regular_loss(self, outputs, labels):
        return self.criterion(outputs, labels)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

learnable_activation-0.0.3.tar.gz (3.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

learnable_activation-0.0.3-py3-none-any.whl (5.0 kB view details)

Uploaded Python 3

File details

Details for the file learnable_activation-0.0.3.tar.gz.

File metadata

  • Download URL: learnable_activation-0.0.3.tar.gz
  • Upload date:
  • Size: 3.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for learnable_activation-0.0.3.tar.gz
Algorithm Hash digest
SHA256 c21a4b14c5adc42213fa5c6378b89565d5067c58d375d2779575b3f6b1c6ef2a
MD5 75230afd35890f253572813cc8293535
BLAKE2b-256 df06a155eee81ea5b92acc9a62a6bb9a00710400e5a8ef6612c16cdb9a7c3e53

See more details on using hashes here.

File details

Details for the file learnable_activation-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for learnable_activation-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 2536bcccd8692c5f877d1f3c0bbcb71d85a99d23112ab90683ec079e13ae6cfc
MD5 9f1111a43cc8670e209ddb626e153a41
BLAKE2b-256 c488799ba63a9c4d028bff39146f081e1d7d9abc7c9f5a09d80a7b32c53aa3cf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page