import torch
import torch.nn as nnSave and Load
Save and Load
Save arg dict with python pickle
torch.save(arg, PATH)
torch.load(PATH)
model.load_state_dict(arg)Save model with python pickle
torch.save(model, PATH)
model = torch.load(PATH)
model.eval()Recommented Method
torch.save(model.state_dict(), PATH)
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()Test model
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_preddevice = torch.device("cuda")
model = Model(n_input_features = 6)
modelModel(
(linear): Linear(in_features=6, out_features=1, bias=True)
)
Method 1
FILE = "model.pth1"
torch.save(model, FILE)model = torch.load(FILE)
model.eval()Model(
(linear): Linear(in_features=6, out_features=1, bias=True)
)
for param in model.parameters():
print(param)Parameter containing:
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]],
requires_grad=True)
Parameter containing:
tensor([-0.3233], requires_grad=True)
Method 2
FILE = "model.pth2"
torch.save(model.state_dict(), FILE)model = Model(n_input_features = 6)model.load_state_dict(torch.load(FILE))
model.eval()Model(
(linear): Linear(in_features=6, out_features=1, bias=True)
)
for param in model.parameters():
print(param)Parameter containing:
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]],
requires_grad=True)
Parameter containing:
tensor([-0.3233], requires_grad=True)
model.state_dict()OrderedDict([('linear.weight',
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]])),
('linear.bias', tensor([-0.3233]))])
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
optimizer.state_dict(){'state': {},
'param_groups': [{'lr': 0.01,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'params': [0, 1]}]}
checkpoint = {
"epoch": 90,
"model_state": model.state_dict(),
"optim_state": optimizer.state_dict()
}
checkpoint{'epoch': 90,
'model_state': OrderedDict([('linear.weight',
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]])),
('linear.bias', tensor([-0.3233]))]),
'optim_state': {'state': {},
'param_groups': [{'lr': 0.01,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'params': [0, 1]}]}}
torch.save(checkpoint, "checkpoint.pth")loaded_checkpoint = torch.load("checkpoint.pth")loaded_checkpoint['epoch']90
model.load_state_dict(loaded_checkpoint['model_state'])
model.state_dict()OrderedDict([('linear.weight',
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]])),
('linear.bias', tensor([-0.3233]))])
optimizer.load_state_dict(loaded_checkpoint['optim_state'])
optimizer.state_dict(){'state': {},
'param_groups': [{'lr': 0.01,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'params': [0, 1]}]}