PyTorch——模型保存与模型加载(一)

模型保存与加载有两种方式,本文暂时只讨论模型参数方式

1> 单GPU

保存

1 torch.save(model.state_dict(), "model.pth")

加载

1 model = SimpleNet()
2 model.load_state_dict(torch.load("./model.pth"))

2> 多GPU 

保存

1 torch.save(model.module.state_dict(), "./model.pth")

加载

1 mdoel = SimpleNet()
2 model.load_state_dict(torch.load("./model.pth"))