A deep learning config framework based on hierarchical dataclasses.
Project description
micro_config.py
an opinionated python dataclass config framework for deep learning
I hope this approach to configurations can make your life easier, but if it doesn't, please submit a pull request or an issue and I'll see what I can do. This config system is certaintly under development, so open to any new ideas or suggestions.
Installation
pip install:
pip install micro-config
or install from source:
Place
micro_config.py
at the root of your project.
Repo Guide
The config framework is defined in micro_config.py
.
The rest of the repo provides a demo for how one might actually want to use micro_config.py
in a deep learning project. Specifically, I implement transformer language model training on wikitext in pytorch.
To run the demo:
- navigate to the root directory
pip install -r requirements.txt
export PYTHONPATH="$PWD"
cd scripts
python train_lm.py
Optionally you can define commandline arguments to train_lm.py
like:
python train_lm.py epochs=1 bsize=16 model.transformer_config.hidden_dim=256
overview of demo project code:
scripts/train_lm.py
defines the training configuration and script execution.base_config.py
defines config schema and defaults for all main config objects:WikiDataConfig
,TransformerConfig
,LMModelConfig
,AdamWConfig
general_train_loop.py
defines the config schema and script for training models.src/
defines all of the core demo project code.
Quick Start / Walkthrough
Most demo code in this section is adopted from the demo project provided in the repo.
Python dataclasses provide a more natural and flexible config definition interface than .yaml
files.
- All config schema should be defined as an instance of
ConfigScript
orConfigScriptModel
and include a@dataclass
decorator - ConfigScripts firstly define a parameter schema and default config values.
For example, a simple dataset object configuration:
from dataclasses import dataclass, adsict
from micro_config import ConfigScript
# data config
@dataclass
class WikiDataConfig(ConfigScript):
f_path: str='data/wikitext-2-raw/wiki.train.raw'
max_len: int=256
ConfigScript
s load associated objects or functions.
- To do this, all
ConfigScript
s implementunroll(self, metaconfig)
. - The
metaconfig
parameter is another dataclass which specifies configs for the config framework.
For example, loading the dataset from the config:
from dataclasses import dataclass, adsict
from micro_config import ConfigScript, MetaConfig
from src.data import WikitextDataset
import torch
import os
# data config
@dataclass
class WikiDataConfig(ConfigScript):
f_path: str='data/wikitext-2-raw/wiki.train.raw'
max_len: int=256
def unroll(self, metaconfig):
# metaconfig.convert_path converts paths reletive to metaconfig.project_root into absolute paths
return WikitextDataset(metaconfig.convert_path(self.f_path), self.max_len)
if __name__ == "__main__":
metaconfig = MetaConfig(project_root=os.path.dirname(__file__),
verbose=True, device='cpu')
data_config = WikiDataConfig(max_len=512)
data = data_config.unroll(metaconfig)
Configurations can be defined hierarchically.
- You can define
ConfigScripts
as paremeters of otherConfigScripts
- You can define lists or dictionaries of
ConfigScript
s as parameters of aConfigScript
by wrapping your list or dict inConfigScriptList
orConfigScriptDict
respectively.
For example, the LM model config below defines ConfigScript
s for both a dataset and a transformer_config
as parameters:
from micro_config import ModelConfigScript
from dataclasses import field
from src.lm import LMModel
import os
# model config
@dataclass
class LMModelConfig(ModelConfigScript):
dataset: WikiDataConfig=field(default_factory=lambda: WikiDataConfig())
transformer_config: TransformerConfig=field(default_factory=lambda: TransformerConfig(max_len=256))
def unroll(self, metaconfig):
dataset = self.dataset.unroll(metaconfig)
transformer_config = self.transformer_config.unroll(metaconfig)
return LMModel(dataset, transformer_config, self.device)
if __name__ == "__main__":
metaconfig = MetaConfig(project_root=os.path.dirname(__file__),
verbose=True, device='cpu')
model_config = LMModelConfig(
checkpoint_path=None,
strict_load=True,
device=None,
dataset=WikiDataConfig(f_path='data/wikitext-2-raw/wiki.train.raw', max_len=256),
transformer_config=TransformerConfig(
max_length=256,
heads=12,
hidden_dim=768,
attn_dim=64,
intermediate_dim=3072,
num_blocks=12,
dropout=0.1
)
)
model = model_config.unroll(metaconfig)
ModelConfigScript
, as used above, is a subclass of ConfigScript
which provides some default functionality for loading a pytorch module returned by unroll. It provides default parameters for:
checkpoint_path
: optionally load model from a checkpoint specified reletive to the pathmicro_config.py
is in.strict_load
: forcestrict_load
when loading model parameters.device
: specify model device. If None, use default device provided in the metaconfig.
Configs and scripts are unified: a config is to a script as a script is to a config.
unroll(self, metaconfig)
can not only be used to load objects, but also to define script logic.
For example, let's define a simple configurable training loop:
from src.utils import combine_logs
@dataclass
class TrainLoop(ConfigScript):
train_dataset: Any=None
eval_dataset: Any=None
model: Any=None
optim: Any=None
epochs: int=10
bsize: int=32
def unroll(self, metaconfig):
print('using config:', asdict(self))
device = metaconfig.device
train_dataset = self.train_dataset.unroll(metaconfig)
eval_dataset = self.eval_dataset.unroll(metaconfig)
model = self.model.unroll(metaconfig)
model.train()
train_dataloader = DataLoader(train_dataset, batch_size=self.bsize)
eval_dataloader = DataLoader(eval_dataset, batch_size=self.bsize)
optim = self.optim.unroll(metaconfig)(model)
for epoch in range(epochs):
for x in tqdm(train_dataloader):
loss, logs = model.get_loss(x.to(device))
optim.zero_grad()
loss.backward()
optim.step()
model.eval()
val_x = next(iter(eval_dataloader))
_, val_logs = model.get_loss(val_x.to(device))
out_log = print({'train': combine_logs([logs]), 'val': combine_logs([val_logs]), 'step': (step+1)})
model.train()
return model
Objects returned by unroll(self, metaconfig)
respect the reference structure of the config hierarchy.
- If the same config object is referenced multiple times in a config hierarchy, the object's
unroll(self, metaconfig)
method will only be called once and its output cached, subsequent calls will return the cached output.
For example, train_dataset
is referenced twice in train_config_script
:
import torch
import os
train_dataset = WikiDataConfig(f_path='data/wikitext-2-raw/wiki.train.raw', max_len=256)
eval_dataset = WikiDataConfig(f_path='data/wikitext-2-raw/wiki.valid.raw', max_len=256)
model = LMModelConfig(
checkpoint_path=None,
strict_load=True,
dataset=train_dataset,
transformer_config=TransformerConfig(
max_length=256,
heads=12,
hidden_dim=768,
attn_dim=64,
intermediate_dim=3072,
num_blocks=12,
dropout=0.1
)
)
train_config_script = TrainLoop(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
optim=AdamWConfig(lr=1e-4, weight_decay=0.01),
epochs=10,
bsize=16,
)
if __name__ == "__main__":
metaconfig = MetaConfig(project_root=os.path.dirname(__file__),
verbose=True, device='cpu')
# run the script
train_config_script.unroll(metaconfig)
The dataset object configured by train_dataset
will only be loaded once in the above hiararchy, even though both LMModelConfig
and TrainLoop
take it in as input.
A method for parsing commandline args is provided.
parse_args(config)
parses the command line arguments into a dictionarydeep_replace(config, **kwargs)
implements a nested version of the standarddataclasses.replace
function
from micro_config import parse_args, deep_replace
import os
if __name__ == "__main__":
metaconfig = MetaConfig(project_root=os.path.dirname(__file__),
verbose=True, device='cpu')
train_config_script = deep_replace(train_config_script, **parse_args())
# run the script
train_config_script.unroll(metaconfig)
To edit any arguments in the hierarchy through the commandline, call the script like so:
python train_lm.py epochs=1 bsize=16 model.transformer_config.hidden_dim=256
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.