今天在把.pt文件转ONNX文件时,遇到此错误。
报错
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_mm)
原因
代码中的Tensor,一会在CPU中运行,一会在GPU中运行,所以最好是都放在同一个device中执行。
pytorch有两种模型保存方式:
一、保存整个神经网络的的结构信息和模型参数信息,save的对象是网络net
二、只保存神经网络的训练模型参数,save的对象是net.state_dict()
对应两种保存模型的方式,pytorch也有两种加载模型的方式。对应第一种保存方式,加载模型时通过torch.load('.pth')直接初始化新的神经网络对象;对应第二种保存方式,需要首先导入对应的网络,再通过net.load_state_dict(torch.load('.pth'))完成模型参数的加载。