PyTorch入门(三)模块的保存与加载

本文将介绍如何使用PyTorch保存模块和加载模型。

PyTorch模型保存与加载

在PyTorch中,一个torch.nn.Module模型的可训练参数(即权重与偏移项)保存在模型的参数parameters),使用model.parameters()获得)中。一个state_dict就是一个简单的Python字典,将每层映射到其参数张量。PyTorch的模型文件以.pt.pth为后缀。使用函数torch.save保存模型,使用函数torch.load加载模型。

PyTorch有两种保存与加载模型的方式,一种是保存整个模型(包括模型结构及参数值),另一种是只保存模型的参数值(即state_dict)。

  1. 保存整个网络结构信息和模型参数信息:
1
torch.save(model_object, './model.pth')

直接加载即可使用:

1
model = torch.load('./model.pth')
  1. 只保存网络的模型参数
1
torch.save(model_object.state_dict(), './params.pth')

加载则要先从本地网络模块导入网络,然后再加载参数:

1
2
3
from models import Model
model = Model()
model.load_state_dict(torch.load('./params.pth'))

示例代码

我们以文章PyTorch入门(二)搭建MLP模型实现分类任务中的二分类MLP模型为例,来演示如何在PyTorch中保存模型和加载代码。

只保存模型参数值的示例Python代码(save_model.py)如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# -*- coding: utf-8 -*-
from numpy import vstack
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import torch
from torch import Tensor
from torch.optim import SGD
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import Linear, ReLU, Sigmoid, Module, BCELoss
from torch.nn.init import kaiming_uniform_, xavier_uniform_


# dataset definition
class CSVDataset(Dataset):
# load the dataset
def __init__(self, path):
# load the csv file as a dataframe
df = read_csv(path, header=None)
# store the inputs and outputs
self.X = df.values[:, :-1]
self.y = df.values[:, -1]
# ensure input data is floats
self.X = self.X.astype('float32')
# label encode target and ensure the values are floats
self.y = LabelEncoder().fit_transform(self.y)
self.y = self.y.astype('float32')
self.y = self.y.reshape((len(self.y), 1))

# number of rows in the dataset
def __len__(self):
return len(self.X)

# get a row at an index
def __getitem__(self, idx):
return [self.X[idx], self.y[idx]]

# get indexes for train and test rows
def get_splits(self, n_test=0.3):
# determine sizes
test_size = round(n_test * len(self.X))
train_size = len(self.X) - test_size
# calculate the split
return random_split(self, [train_size, test_size])


# model definition
class MLP(Module):
# define model elements
def __init__(self, n_inputs):
super(MLP, self).__init__()
# input to first hidden layer
self.hidden1 = Linear(n_inputs, 10)
kaiming_uniform_(self.hidden1.weight, nonlinearity='relu')
self.act1 = ReLU()
# second hidden layer
self.hidden2 = Linear(10, 8)
kaiming_uniform_(self.hidden2.weight, nonlinearity='relu')
self.act2 = ReLU()
# third hidden layer and output
self.hidden3 = Linear(8, 1)
xavier_uniform_(self.hidden3.weight)
self.act3 = Sigmoid()

# forward propagate input
def forward(self, X):
# input to first hidden layer
X = self.hidden1(X)
X = self.act1(X)
# second hidden layer
X = self.hidden2(X)
X = self.act2(X)
# third hidden layer and output
X = self.hidden3(X)
X = self.act3(X)
return X


# prepare the dataset
def prepare_data(path):
# load the dataset
dataset = CSVDataset(path)
# calculate split
train, test = dataset.get_splits()
# prepare data loaders
train_dl = DataLoader(train, batch_size=32, shuffle=True)
test_dl = DataLoader(test, batch_size=1024, shuffle=False)
return train_dl, test_dl


# train the model
def train_model(train_dl, model):
# define the optimization
criterion = BCELoss()
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
# enumerate epochs
for epoch in range(100):
# enumerate mini batches
for i, (inputs, targets) in enumerate(train_dl):
# clear the gradients
optimizer.zero_grad()
# compute the model output
yhat = model(inputs)
# calculate loss
loss = criterion(yhat, targets)
# credit assignment
loss.backward()
print("epoch: {}, batch: {}, loss: {}".format(epoch, i, loss.data))
# update model weights
optimizer.step()


# evaluate the model
def evaluate_model(test_dl, model):
predictions, actuals = [], []
for i, (inputs, targets) in enumerate(test_dl):
# evaluate the model on the test set
yhat = model(inputs)
# retrieve numpy array
yhat = yhat.detach().numpy()
actual = targets.numpy()
actual = actual.reshape((len(actual), 1))
# round to class values
yhat = yhat.round()
# store
predictions.append(yhat)
actuals.append(actual)
predictions, actuals = vstack(predictions), vstack(actuals)
# calculate accuracy
acc = accuracy_score(actuals, predictions)
return acc


# make a class prediction for one row of data
def predict(row, model):
# convert row to data
row = Tensor([row])
# make prediction
yhat = model(row)
# retrieve numpy array
yhat = yhat.detach().numpy()
return yhat


if __name__ == '__main__':
# prepare the data
path = './data/ionosphere.csv'
train_dl, test_dl = prepare_data(path)
print(len(train_dl.dataset), len(test_dl.dataset))
# define the network
model = MLP(34)
print(model)
# train the model
train_model(train_dl, model)
torch.save(model.state_dict(), 'binary_classification.pth')
print(model.state_dict())
# evaluate the model
acc = evaluate_model(test_dl, model)
print('Accuracy: %.3f' % acc)

运行代码,会输出该MLP模型的参数值(state_dict)如下:

1
2
3
4
5
OrderedDict([('hidden1.weight', tensor([[-4.3042e-02, -1.3315e-01, -3.5050e-01, -1.4949e-01, -1.6642e-01,
......), ('hidden1.bias', tensor([ 0.2563, -0.0024, -0.1276, 0.1943, -0.2728, -0.2992, 0.3130, 0.0245,
-0.0381, 0.4498])), ('hidden2.weight', tensor([[-0.5759, -0.9750, 1.0027, 0.5148, 0.6903, 0.3534, -1.0665, 0.1220,
-0.0757, 0.4448], ......), ('hidden2.bias', tensor([ 1.7468e-01, 5.9972e-02, -4.2997e-02, -2.2675e-01, 8.3250e-01,
-3.2392e-04, 3.9665e-01, -2.5674e-01])), ('hidden3.weight', tensor([[ 1.3292, -0.6698, -0.2412, 1.0923, -2.5248, 0.3479, -1.1331, -0.0240]])), ('hidden3.bias', tensor([-0.8218]))])

值得注意的是,state_dict输出的格式为Python字典结构。保存为文件名称为binary_classification.pth。

接着我们加载该模型文件,并对新数据进行预测,示例代码(load_model.py)如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# -*- coding: utf-8 -*-
import torch
from torch import Tensor

from save_model import MLP

model = MLP(34)
state_dict = torch.load('./binary_classification.pth')
model.load_state_dict(state_dict)
print(model)
# make a single prediction (expect class=1)
row = [1, 0, 0.99539, -0.05889, 0.85243, 0.02306, 0.83398, -0.37708, 1, 0.03760, 0.85243, -0.17755, 0.59755, -0.44945,
0.60536, -0.38223, 0.84356, -0.38542, 0.58212, -0.32192, 0.56971, -0.29674, 0.36946, -0.47357, 0.56811, -0.51171,
0.41078, -0.46168, 0.21266, -0.34090, 0.42267, -0.54487, 0.18641, -0.45300]
row = Tensor([row])
# make prediction
yhat = model(row)
# retrieve numpy array
yhat = yhat.detach().numpy()
print('Predicted: %.3f (class=%d)' % (yhat, yhat.round()))

如果我们想保存、加载整个模型及模型参数,则在模型保存代码(save_model.py)中使用代码:

1
torch.save(model, 'binary_classification.pth')

加载模型部分代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# -*- coding: utf-8 -*-
import torch
from torch import Tensor

from save_model import MLP

model = torch.load('./binary_classification.pth')

# make a single prediction (expect class=1)
row = [1, 0, 0.99539, -0.05889, 0.85243, 0.02306, 0.83398, -0.37708, 1, 0.03760, 0.85243, -0.17755, 0.59755, -0.44945,
0.60536, -0.38223, 0.84356, -0.38542, 0.58212, -0.32192, 0.56971, -0.29674, 0.36946, -0.47357, 0.56811, -0.51171,
0.41078, -0.46168, 0.21266, -0.34090, 0.42267, -0.54487, 0.18641, -0.45300]
row = Tensor([row])
# make prediction
yhat = model(row)
# retrieve numpy array
yhat = yhat.detach().numpy()
print('Predicted: %.3f (class=%d)' % (yhat, yhat.round()))

需要注意的是,模型结构MLP类仍需在代码中(虽然后面代码中并没有用到MLP类),这样模型才能加载成功,否则会报模型加载失败。

总结

本文简单介绍了如何在PyTorch中保存和加载模型。本文介绍的模型代码已开源,Github地址为:https://github.com/percent4/PyTorch_Learning。后续将持续介绍PyTorch内容,欢迎大家关注~


PyTorch入门(三)模块的保存与加载
https://percent4.github.io/PyTorch入门(三)模块的保存与加载/
作者
Jclian91
发布于
2023年7月30日
许可协议