import torch
import torch.nn as nn
Save and Load
Save and Load
Save arg dict with python pickle
torch.save(arg, PATH)
torch.load(PATH) model.load_state_dict(arg)
Save model with python pickle
torch.save(model, PATH)
= torch.load(PATH)
model
eval() model.
Recommented Method
torch.save(model.state_dict(), PATH)
= Model(*args, **kwargs)
model
model.load_state_dict(torch.load(PATH))
eval() model.
Test model
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
= torch.sigmoid(self.linear(x))
y_pred return y_pred
= torch.device("cuda")
device = Model(n_input_features = 6)
model model
Model(
(linear): Linear(in_features=6, out_features=1, bias=True)
)
Method 1
= "model.pth1"
FILE torch.save(model, FILE)
= torch.load(FILE)
model eval() model.
Model(
(linear): Linear(in_features=6, out_features=1, bias=True)
)
for param in model.parameters():
print(param)
Parameter containing:
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]],
requires_grad=True)
Parameter containing:
tensor([-0.3233], requires_grad=True)
Method 2
= "model.pth2"
FILE torch.save(model.state_dict(), FILE)
= Model(n_input_features = 6) model
model.load_state_dict(torch.load(FILE))
eval() model.
Model(
(linear): Linear(in_features=6, out_features=1, bias=True)
)
for param in model.parameters():
print(param)
Parameter containing:
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]],
requires_grad=True)
Parameter containing:
tensor([-0.3233], requires_grad=True)
model.state_dict()
OrderedDict([('linear.weight',
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]])),
('linear.bias', tensor([-0.3233]))])
= 0.01
learning_rate = torch.optim.SGD(model.parameters(), lr = learning_rate)
optimizer optimizer.state_dict()
{'state': {},
'param_groups': [{'lr': 0.01,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'params': [0, 1]}]}
= {
checkpoint "epoch": 90,
"model_state": model.state_dict(),
"optim_state": optimizer.state_dict()
} checkpoint
{'epoch': 90,
'model_state': OrderedDict([('linear.weight',
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]])),
('linear.bias', tensor([-0.3233]))]),
'optim_state': {'state': {},
'param_groups': [{'lr': 0.01,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'params': [0, 1]}]}}
"checkpoint.pth") torch.save(checkpoint,
= torch.load("checkpoint.pth") loaded_checkpoint
'epoch'] loaded_checkpoint[
90
'model_state'])
model.load_state_dict(loaded_checkpoint[ model.state_dict()
OrderedDict([('linear.weight',
tensor([[-0.3002, -0.2477, -0.2695, -0.1810, -0.0604, 0.1516]])),
('linear.bias', tensor([-0.3233]))])
'optim_state'])
optimizer.load_state_dict(loaded_checkpoint[ optimizer.state_dict()
{'state': {},
'param_groups': [{'lr': 0.01,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'params': [0, 1]}]}