Using YAML file to create PyTorch model
Soooo. I was working on tuning a model for my college research thesis. I was already using YAML files to control parameters such as learning rates, dropout rates, so on. But, I also wanted to tune the model while changing the number of layers and neurons. Everytime, I was modifying the code to change and tune my model architecture. I thought, what if I could change the model architecture using config files? That would save me a lot of time, and also help me track changes. I know I can track my code changes using git, but swapping YAML files is for faster and efficient than creating multiple branches for quick prototyping and testing.
So lets get started!
I wrote a small module to achieve this:
import torch.nn as nn
import yaml
class TorchFromYAML:
"""
Create and load torch model using YAML files.
params:
config_file (str): yaml file path
"""
def __init__(self, config_file):
self.config_file = config_file
self.config = self.load_config()
self.SUPPORTED_MODELS = {"Sequential": self.load_sequential}
def build_model_from_config(self):
"""
builds model from config file
returns:
torch model
"""
try:
return self.SUPPORTED_MODELS[self.config["type"]]()
except:
ValueError(f"Unsupported model type: {self.config['type']}")
def load_sequential(self):
"""
Loads sequential model from config file.
returns:
sequential torch model
"""
layers = []
for layer_config in self.config["layers"]:
layer_type = getattr(nn, layer_config["type"])
parameters = {key: value for key, value in list(layer_config.items())[1:]}
layer = layer_type(**parameters)
layers.append(layer)
return nn.Sequential(*layers)
def load_config(self):
"""
load config file
returns:
dictionary for YAML file.
"""
with open(self.config_file, "r") as f:
config = yaml.safe_load(f)
return config['model']
The trick is to use the builtin getattr
function in python. The code is pretty much self explanatory after that. For my first iteration, I only consider sequential model because I can use multiple sequential models in my code in different part of the overall model architecture.
The YAML file should be in the following format:
#config.yaml
model:
type: Sequential
layers:
- type: Linear
in_features: 64
out_features: 128
- type: ReLU
- type: Linear
in_features: 128
out_features: 10
To use the class, it is simple!
model_builder = TorchFromYAML("your_model_config.yaml")
model = model_builder.build_model_from_config()
I hope you enjoyed it! It's a short article this time, but I really hope you find it useful for prototyping. You can check out my git repository here: https://github.com/siddhi47/torch-from-yaml-