Skip to main content

A library to help transfer learn for structured data.

Project description

transfertab

Allow transfer learning using structured data.

Install

pip install transfertab

How to use

TransferTab enables effective transfer learning from models trained on tabular data.

To make use of transfertab, you'll need
* A pytorch model which contains some embeddings in a layer group.
* Another model to transfer these embeddings to, along with the metadata about the dataset on which this model will be trained.

The whole process takes place in two main steps-
1. Extraction
2. Transfer

Extraction

This involves storing the embeddings present in the model to a JSON structure. This JSON would contain the embeddings related to the categorical variables, and can be later transfered to another model which can also benefit from these categories. It will also be possible to have multiple JSON files constructed from various models with different categorical variables and then use them together.

Here we'll quickly construct a ModuleList with a bunch of Embedding layers, and see how to transfer it's embeddings.

emb_szs1 = ((3, 10), (2, 8))
emb_szs2 = ((2, 10), (2, 8))
embed1 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs1])
embed2 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs2])
embed1
ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(2, 8)
)

We can call the extractembeds function to extract the embeddings. Take a look at the documentation to see other dispatch methods, and details on the parameters.

df = pd.DataFrame({"old_cat1": [1, 2, 3, 4, 5], "old_cat2": ['a', 'b', 'b', 'b', 'a'], "old_cat3": ['A', 'B', 'B', 'B', 'A']})
cats = ("old_cat2", "old_cat3")
embdict = extractembeds(embed2, df, transfercats=cats, allcats=cats)
embdict
{'old_cat2': {'classes': ['a', 'b'],
  'embeddings': [[-0.28762340545654297,
    -0.142189621925354,
    0.2027226686477661,
    1.1096185445785522,
    -0.4540262520313263,
    -1.346120834350586,
    0.048871781677007675,
    0.1740419715642929,
    0.002095407573506236,
    0.721653163433075],
   [-0.9072648882865906,
    2.674738645553589,
    -0.8560850024223328,
    -1.119917869567871,
    -0.19618849456310272,
    1.1431224346160889,
    -0.5177133679389954,
    -0.6497849822044373,
    -0.9011525511741638,
    0.9314191341400146]]},
 'old_cat3': {'classes': ['A', 'B'],
  'embeddings': [[2.5755045413970947,
    -1.3670053482055664,
    -0.3207620680332184,
    -1.1824427843093872,
    0.07631386071443558,
    0.501422107219696,
    0.8510317802429199,
    -0.6687257289886475],
   [-1.3658113479614258,
    -0.27968257665634155,
    0.26537612080574036,
    0.36773681640625,
    -0.9940593242645264,
    0.9408144354820251,
    0.5295664668083191,
    -0.5038257241249084]]}}

Transfer

The transfer process involves using the extracted weights, or a model directly and reusing trained paramters. We can define how this process will take place using the metadict which is a mapping of all the categories (in the current dataset), and contains information about the category it is mapped to (from the previous dataset which was used to train the old model), and how the new classes map to the old classes. We can even choose to map multiple classes to a single one, and in this case the aggfn parameter is used to aggregate the embedding vectors.

json_file_path = "../data/jsons/metadict.json"

with open(json_file_path, 'r') as j:
     metadict = json.loads(j.read())
metadict
{'new_cat1': {'mapped_cat': 'old_cat2',
  'classes_info': {'new_class1': ['a', 'b'],
   'new_class2': ['b'],
   'new_class3': []}},
 'new_cat2': {'mapped_cat': 'old_cat3',
  'classes_info': {'new_class1': ['A'], 'new_class2': []}}}

We take a look at the layer parameters before and after transferring to see if it worked as expected.

embed1.state_dict()
OrderedDict([('0.weight',
              tensor([[-0.6940, -0.0337,  0.9491, -1.0520,  0.7804,  2.0246,  0.4242, -1.8351,
                        0.4660,  1.7667],
                      [-0.2802,  0.6081, -0.8459, -0.3288, -1.1264,  0.7621,  0.9347,  1.8096,
                       -0.1998, -0.2541],
                      [ 0.5706, -0.5213, -0.1398, -0.3742, -1.1951,  1.9640,  0.4132,  2.0365,
                        0.0655,  0.5189]])),
             ('1.weight',
              tensor([[ 0.9506, -0.0057,  0.2754,  0.8276,  0.8675,  1.2238, -1.5603,  1.0301],
                      [-0.7315, -0.3735,  0.6059,  0.2659, -0.4918,  1.5501,  0.0221, -0.6199]]))])
embed2.state_dict()
OrderedDict([('0.weight',
              tensor([[-2.8762e-01, -1.4219e-01,  2.0272e-01,  1.1096e+00, -4.5403e-01,
                       -1.3461e+00,  4.8872e-02,  1.7404e-01,  2.0954e-03,  7.2165e-01],
                      [-9.0726e-01,  2.6747e+00, -8.5609e-01, -1.1199e+00, -1.9619e-01,
                        1.1431e+00, -5.1771e-01, -6.4978e-01, -9.0115e-01,  9.3142e-01]])),
             ('1.weight',
              tensor([[ 2.5755, -1.3670, -0.3208, -1.1824,  0.0763,  0.5014,  0.8510, -0.6687],
                      [-1.3658, -0.2797,  0.2654,  0.3677, -0.9941,  0.9408,  0.5296, -0.5038]]))])
transfer_cats = ("new_cat1", "new_cat2")
newcatcols = ("new_cat1", "new_cat2")
oldcatcols = ("old_cat2", "old_cat3")

newcatdict = {"new_cat1" : ["new_class1", "new_class2", "new_class3"], "new_cat2" : ["new_class1", "new_class2"]}
oldcatdict = {"old_cat2" : ["a", "b"], "old_cat3" : ["A", "B"]}

transferembeds_(embed1, embdict, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)
embed1.state_dict()
OrderedDict([('0.weight',
              tensor([[-0.5974,  1.2663, -0.3267, -0.0051, -0.3251, -0.1015, -0.2344, -0.2379,
                       -0.4495,  0.8265],
                      [-0.9073,  2.6747, -0.8561, -1.1199, -0.1962,  1.1431, -0.5177, -0.6498,
                       -0.9012,  0.9314],
                      [-0.5974,  1.2663, -0.3267, -0.0051, -0.3251, -0.1015, -0.2344, -0.2379,
                       -0.4495,  0.8265]])),
             ('1.weight',
              tensor([[ 2.5755, -1.3670, -0.3208, -1.1824,  0.0763,  0.5014,  0.8510, -0.6687],
                      [ 0.6048, -0.8233, -0.0277, -0.4074, -0.4589,  0.7211,  0.6903, -0.5863]]))])

As we can see, the embeddings have been transferred over.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transfertab-0.1.0.tar.gz (42.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transfertab-0.1.0-py3-none-any.whl (12.4 kB view details)

Uploaded Python 3

File details

Details for the file transfertab-0.1.0.tar.gz.

File metadata

  • Download URL: transfertab-0.1.0.tar.gz
  • Upload date:
  • Size: 42.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/3.10.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4

File hashes

Hashes for transfertab-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0f23cd2d7fa727ecbfe27cf6133f41cf80c51ee6e401379ecf3fd5fe564138a5
MD5 1d9708a58c94eda5826169d48ac4f988
BLAKE2b-256 fb79d409beaaea88f09d9273a1d68f14e3100f9d5a53694fdd7db2b3f266d589

See more details on using hashes here.

File details

Details for the file transfertab-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: transfertab-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 12.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/3.10.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4

File hashes

Hashes for transfertab-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 de84e8f18231d1518616adf7770cfab4719bf087f41178305cbe38cb6da82adf
MD5 21931f64a6a0a77cc27a3f9fd6dbcd90
BLAKE2b-256 062d7365a79769aaa9fa7769f7c750844fe1c3e248b5ad3e6398c2f559765ce3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page