torch.load()加载模型及其map_location参数

Python68

函数格式为: torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args),一般我们使用的时候,基本只使用前两个参数。

  • 模型保存有两种形式,一种是保存模型的 state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例 model,之后通过 torch.load()将保存的模型参数加载进来,得到 dict,再通过 model.load_state_dict(dict)将模型的参数更新。
  • 另一种是将整个模型保存下来,之后加载的时候只需要通过 torch.load()将模型加载,即可返回一个加载好的模型。
    具体可参考:PyTorch模型的保存与加载

具体来说, map_location参数是用于重定向,比如此前模型的参数是在 cpu中的,我们希望将其加载到 cuda:0中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。

  • 首先定义一个AlexNet,并使用 cuda:0将其训练了一个猫狗分类,之后把模型存储起来。

  • 我们先把 state_dict加载进来。

model_path = "./cuda_model.pth"
model = torch.load(model_path)
print(next(model.parameters()).device)

结果为:

输入验证码查看隐藏内容

扫描二维码关注本站微信公众号 Johngo学长
或者在微信里搜索 Johngo学长
回复 svip 获取验证码
wechat Johngo学长

相关文章
Python

Python自定义排序及实际遇到的一些实例

写在前面,本文主要介绍Python基础排序和自定义排序的一些规则,如果都比较熟悉,可以直接翻到第三节,看下实际的笔试面试题中关于自定义排序的应用。 一、基础排序 排序是比较基础的算法,与很多语言一样,...
Python

python_cookbook学习笔记

目录​ ​​一、数据结构和算法:4 ​​​ ​​1、解压序列赋值给多个变量4 ​​​ ​​2、解压可迭代对象赋值给多个变量5 ​​​ ​​3、保留最后N个元素collections.deque 5 ​...
Python

threading.local()实现线程数据隔离

同一个进程下,多个线程是共享进程的数据,多线程为了保证数据的安全性,多线程的写操作会加锁,加锁也就意味着多线程模型下,效率将降低。 threading.local()可以为每个线程创建局部名称空间,t...
Python

python—获取元素 Xpath

python---获取元素 Xpath 原创 夕陌2022-07-19 11:27:10©著作权 文章标签 绝对路径 firefox 元素定位 文章分类 Python 编程语言 ©著作权归作者所有:来...
Python

Ubuntu下安装PyTorch杂记

最近几天我一直常用的Kubuntu(KDE yes!)更新至22.04后居然出现无法更改软件源的bug,去Kubuntu论坛一看有同样问题的人还不在少数,但却没有好的解决办法,故而只有备份数据装回Ub...
Python

自动化运维开发-ansible接口

目录​ ​​探测模块和工具2 ​​​ ​​存活扫描nmap|telnetlib 2 ​​​ ​​主机登录探测pexpect|paramiko 2 ​​​ ​​ansible运维4 ​​​ ​​ansi...
Python

萌妹子Python入门指北(三)

前两篇网站我简单介绍了python环境的安装和基本的变量及运算。到目前为止,我们没办法用python做任何事,所以这篇文章我会介绍python的判断和循环语句,据说 顺序、判断、循环可以解决计算机中的...
Python

【整理】最常见的10道Python面试题及答案!

学完Python技术之后,接下来将要面临的就是面试找工作的问题了,虽说找工作面试很关键,但提前做好准备更重要。今天小编为大家准备了10道Python面试题及答案,希望能够给你们带来帮助。 1、如何在P...