ALL YOU NEED TO KNOW ABOUT HOW TO SAVE YOUR MODEL IN PYTORCH!!!
- vidyakamath1004
- Apr 25, 2023
- 3 min read

Models created with PyTorch may be saved as.pth files. The trained models may be saved to be examined on a separate device or deployed in your applications. They can also be stored as checkpoints throughout the training process and used to continue training. We must learn how to load and save the model at any moment in both scenarios.
Let us look at the two different approaches you can use to save your model.
import torch
import torch.nn as nn
path = "model.pth" # PATH TO YOUR MODEL
################# FIRST APPROACH: LAZY METHOD ##################
#torch.save({YOUR_MODEL_NAME}, {PATH_TO_MODEL})
'''uses python's pickle module to serialize your model and save it. '''
model= MyModel(arguments)
torch.save(model, path)
####################################
''' load your saved model using '''
#model = MyModel(arguments)
#torch.load({PATH_TO_MODEL})
torch.load(path)
model.eval()
################## SECOND APPROACH: PREFERRED METHOD ################
''' save your model's state dictionary and reload the dictionary into your model. '''
#torch.save({YOUR_MODEL_NAME}.state_dict(), {PATH_TO_MODEL})
torch.save(model.state_dict(), path)
####################################
'''define your model class variable and use it to load your saved model.'''
#model = MyModel(arguments)
#model.load_state_dict(torch.load({PATH_TO_MODEL}))
model.load_state_dict(torch.load(path))
model.eval()
The state of the optimizer settings and the model can both be saved as checkpoint objects. Using the code below, the model and other info may be stored as a checkpoint.
import torch
import torch.nn as nn
path = "model.pth" # PATH TO YOUR MODEL
learning_rate = 0.01
''' create your model '''
model = MyModel(arguments)
''' train your model '''
...
''' define your optimizer'''
optimizer = torch.optim.Adam( model.parameters(), lr= learning_rate)
print(optimizer.state_dict()) #visualize
...
''' create a checkpoint during the training as a dictionary object'''
checkpoint= { "epoch" : 90,
"model_state" : model.state_dict(),
"optimizer_state" = optimizer.state_dict()
}
'''save the checkpoint into a file'''
torch.save(checkpoint, path)
...
####################################
'''loading the checkpoint '''
loaded_checkpoint= torch.load(path)
epoch = checkpoint["epoch"]
''' define your model and optimizer'''
model = MyModel(arguments)
optimizer = torch.optim.Adam( model.parameters(), lr = 0)
#lr will be reloaded later when you load optimizer state dictionary.
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
print(model.state_dict())
print(optimizer.state_dict())
''' continue training '''
...
Let us summarize all what we learnt :
'''import the libraries'''
import torch
import torch.nn as nn
path = "model.pth"
cpath = "checkpoint.pth"
# step 1: Create your model.
class MyModel(nn.Module):
def __init__(self, in_channels):
super(Model, self).__init__()
self.linear = nn.Linear(in_channels, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
# step 2: Train your model.
learning_rate = 0.01
model= MyModel(in_channels=10)
optimizer = torch.optim.Adam( model.parameters(), lr = learning_rate)
epochs = 150
# step 3a: load from a checkpoint
checkpoint= torch.load(cpath)
cepoch = checkpoint["epoch"]
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
for epoch in range(cepoch, epochs):
...
''' continue training '''
OR
for epoch in epochs:
# step 3b: save a checkpoint
...
''' finish training an epoch and checkpoint it '''
checkpoint = {
"epoch" : epoch,
"model_state" = model.state_dict(),
"optimizer_state" = optimizer.state_dict()
}
torch.save(checkpoint, path)
'''After training '''
print("trained model:")
for param in loaded_model.parameters(): # print the params
print(param)
print(model.state_dict())
print(optimizer.state_dict())
# step 4: Save your model.
''' saving model '''
torch.save(model.state_dict(), path)
''' saving checkpoint '''
# step 5a: Re-load your saved model to resume training.
loaded_model = MyModel(in_channels= 10)
loaded_model.load_state_dict(torch.load(path))
print(loaded_model.state_dict())
# continue training the loaded_model
OR
# step 5b: Re-load your saved model for inferencing.
loaded_model = MyModel(in_channels= 10)
loaded_model.load_state_dict(torch.load(path))
loaded_model.eval()
# start model inference
...
print("loaded model:")
for param in loaded_model.parameters(): # print the params
print(param)
print(loaded_model.state_dict())
...ADDITIONAL
Now let us look into the scenario when you use different devices to save and load your model. While you are using only a CPU device the above code is fine, however, if you are running the code on a GPU then you must use the following method to map your model to your device.
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device) # map the model to your deviceWhen deploying the model:
''' while creating the model, always map it to your device'''
model = MyModel(arguments)
model.to(device) # optional for cpu device but always required when using a gpu for training the model.
...
...
...
################################
''load the saved model to device'''
model = torch.load( path, map_location = device) # first approach
OR
model.state_dict( torch.load(path, map_location = device) # second approach
...
...
...
Comments