Tensorflow2——模型保存与加载以及训练数据保存和断点续训

人工智能33

通过阅读这篇博客,你可以了解如何在Tensorflow训练过程中保存准确率和loss,以及如何在tensorflow中保存与加载模型,如何再重新接着上一轮的训练过程继续训练。

最近在神经网络训练的过程中,需要保存训练过程中的数据,并且再下次训练的时候能够接着上次训练的结果进行断点续训。所以通过Tensorflow2官网查询到对于model.fit相关的回调函数的编写方法。下面总结了一下,在Tensorflow中,对于参数保存以及断点续训的内容,在最后会给出一个示例代码,供大家参考。

1. 保存训练数据

如何保存训练过程的数据,包括训练轮数(Epoch),训练集acc,训练集loss,验证集acc,验证集loss。通过Tensorflow官网可以找到一个保存训练数据的回调函数,即tf.keras.callbacks.CSVLogger。使用方法很简单,只需要指定保存路径,然后将方法加入到model.fit的callbacks参数列表中。示例代码如下:


csv_logger = CSVLogger('training.log',append=False)
model.fit(X_train, Y_train, callbacks=[csv_logger])

2. 保存与加载模型

保存模型可以借助tf.keras.callbacks.ModelCheckpoint()进行保存。接口说明如下:我们在使用过程中,大致只需要使用filepath,save_best_only,save_weight_only三个参数。其中filepath指明保存文件的路径,save_best_only说明是否只保存为佳模型,save_weight_only说明是否只保存模型权重。

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch',
    options=None, initial_value_threshold=None, **kwargs
)

我们保存了模型,就需要加载模型,加载模型就通过model.load_weights(filepath)进行模型的加载操作。

这里是一个示例代码供您参考,具体要求可以根据代码进行修改。

[En]

Here is a sample code for your reference, the specific requirements can be modified according to the code.


model = TestModel()
model.compile(....)

checkpoint_save_path = './checkpoint/Baseline.ckpt'
checkpoint_save_best_path = './checkpoint_best/Baseline.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('------load the model------')
    model.load_weights(checkpoint_save_path)

cp_callback_save = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)

cp_callback_save_best = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_best_path,save_weights_only=True,save_best_only=True)

model.fit(.....,callbacks=[cp_callback_save,cp_callback_save_best])

3. 示例代码

前面介绍了参数和模型的保存,回到刚刚的问题,我们需要在上一次训练后继续之前的训练过程,并且保存数据参数。首先,如果要接着上一轮继续训练,那么就需要知道上一轮训练了多少轮,我们可以通过我们的参数数据文件,很容易得出我们训练了多少论,接着我们可以借助model.fit的initial_epoch指定起始轮数,这样就可以使得训练接着上一轮继续训练。参考代码如下:


def get_init_epoch(filename):
    with open(filename) as f:
        f_csv = csv.DictReader(f)
        count = 0
        for row in f_csv:
            count = count+1
        return count

init_epoch = 0
if os.path.exists(filename):
    init_epoch = get_init_epoch(filename)
model = Test()
model.compile(...)

checkpoint_save_path = './checkpoint/Baseline.ckpt'
checkpoint_save_best_path = './checkpoint_best/Baseline.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('------load the model------')
    model.load_weights(checkpoint_save_path)

cp_callback_save = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
cp_callback_save_best = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_best_path,save_weights_only=True,save_best_only=True)

csv_logger = CSVLogger('training_log',append=True)

model.fit(....,initial_epoch=init_epoch,callbacks=[csv_logger,cp_callback_save_best,cp_callback_save])

所有回调函数都将 keras.callbacks.Callback 类作为子类,并重写在训练、测试和预测的各个阶段调用的一组方法。回调函数对于在训练期间了解模型的内部状态和统计信息十分有用。

回调函数方法概述
全局方法

on_(train|test|predict)_begin(self, logs=None)
在 fit/evaluate/predict 开始时调用。

on_(train|test|predict)_end(self, logs=None)
在 fit/evaluate/predict 结束时调用。

Batch-level methods for training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None)
正好在训练/测试/预测期间处理批次之前调用。

on_(train|test|predict)_batch_end(self, batch, logs=None)
在训练/测试/预测批次结束时调用。在此方法中,logs 是包含指标结果的字典。

周期级方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
在训练期间周期开始时调用。

on_epoch_end(self, epoch, logs=None)
在训练期间周期开始时调用。

基本示例
让我们来看一个具体的例子。

class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

Original: https://blog.csdn.net/weixin_43387852/article/details/123507105
Author: 归来空空
Title: Tensorflow2——模型保存与加载以及训练数据保存和断点续训



相关阅读

Title: ubuntu18.04安装nvidia_driver_510+cuda_11.6+cudnn_11.x

一、安装nvidia_driver

1、在 软件和更新 中选择一个可用的驱动Tensorflow2——模型保存与加载以及训练数据保存和断点续训

2.1首先我们需要添加源

sudo add-apt-repository ppa:graphics-drivers/ppa
sudo apt update

2.2选择一个版本安装即可(如1,我选择安装510)

sudo apt install nvidia-driver-510

2.3 重启电脑后终端输入

nvidia-smi

查看驱动信息

Tensorflow2——模型保存与加载以及训练数据保存和断点续训

二、安装CUDA

1、官网找到自己的版本(如2.3版本信息中,我510驱动对应的cuda_11.6)

CUDA Toolkit Archive | NVIDIA Developer 选择runfile格式的CUDA文件下载

2、选择环境,并根据官网步骤安装Tensorflow2——模型保存与加载以及训练数据保存和断点续训

3、下载完成后,解压,并运行上图中的命令,会有条款,接受即可。

3.1注意安装CUDA的时候不要安装驱动(因为在第一步我们已经安装过了)

Tensorflow2——模型保存与加载以及训练数据保存和断点续训3.2添加环境变量

sudo gedit ~/.bashrc

在打开的txt文件末尾加

export CUDA_HOME=/usr/local/cuda
export PATH=$PATH:$CUDA_HOME/bin
export LD_LIBRARY_PATH=/usr/local/cuda-11.6/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

保存,退出。终端执行

source ~/.bashrc

3.3 验证cuda

3.3.1

11.6版本cuda的安装目录/usr/local/cuda-11.6/samples里只有一个txt文件,大致意思是告诉你新版本的cuda,samples中内容需要自己在github下载。

由于github下载过慢,在此放上gitee链接

git clone https://gitee.com/liwuhaoooo/cuda-samples.git

在samples文件夹下打开终端执行上述语句。

大概率无权访问,此时在cuda-11.6文件夹下打开终端

su
输入密码切换超级用户
chmod 777 samples

再次执行git clone 就可以了。

3.3.2

进入/usr/local/cuda-11.6/samples/cuda-samples/Samples

cd /usr/local/cuda/samples/1_Utilities/deviceQuery
sudo make
./deviceQuery

输出

Tensorflow2——模型保存与加载以及训练数据保存和断点续训

则安装成功。

三、安装cuDNN

1、

进入NVIDIA cuDNN | NVIDIA Developer注册,并选择合适的版本下载(cuDNN Library for Linux),然后解压;

2、

并进入到/home/lwh/Downloads/cudnn-11.3-linux-x64-v8.2.1.32目录,运行以下命令:

sudo cp cuda/include/cudnn.h /usr/local/cuda-11.6/include
sudo cp cuda/lib64/libcudnn* /usr/local/cuda-11.6/lib64
sudo chmod a+r /usr/local/cuda-11.6/include/cudnn.h
sudo chmod a+r /usr/local/cuda-11.6/lib64/libcudnn*

若无权访问,像3.3.1一样,分别更改include和lib64文件夹权限。

四、验证GPU可用

import torch
print(torch.cuda.is_available())

Tensorflow2——模型保存与加载以及训练数据保存和断点续训

Original: https://blog.csdn.net/weixin_54424184/article/details/122654844
Author: CPU疼
Title: ubuntu18.04安装nvidia_driver_510+cuda_11.6+cudnn_11.x

相关文章
计算机视觉 什么是计算机视觉 人工智能

计算机视觉 什么是计算机视觉

1、什么是计算机视觉? 作为人类,我们可以轻松地感知周围世界的三维结构。想想当你看着坐在你旁边桌子上的花瓶时,三维感知是多么生动。您可以通过在其表面上播放的微妙的光影图案来分辨每个花瓣的形状和半透明度...
esp32的智能遥控 人工智能

esp32的智能遥控

文章目录 * - 一、演示视频 - 二、程序框架 - 三、硬件设计 - 四、模块介绍 - + 1、语音识别模块 + * 离线语音识别 * 优化语音识别 + 2、BLE模块 + 3、MQTT模块 + *...
Python语音识别实践【百度AI平台】 人工智能

Python语音识别实践【百度AI平台】

这几天想要用Python来体验一下语音识别技术,虽然我知道有很多开源免费的语音识别库,例如,CMU Sphinx,好像以前玩过,但只为了愉快地体验,这次选择百度AI平台来简单实践一下,后期再深入研究开...
docker安装TensorFlow 人工智能

docker安装TensorFlow

TensorFlow 是一个端到端开源机器学习平台。它拥有一个全面而灵活的生态系统,其中包含各种工具、库和社区资源,可助力研究人员推动先进机器学习技术的发展,并使开发者能够轻松地构建和部署由机器学习提...
基于PaddleOCR的体检报告识别 人工智能

基于PaddleOCR的体检报告识别

✨ 写在前面:强烈推荐给大家一个优秀的人工智能学习网站,内容包括人工智能基础、机器学习、深度学习神经网络等,详细介绍各部分概念及实战教程,通俗易懂,非常适合人工智能领域初学者及研究者学习。➡️ 点击跳...
wenet mask原理解析 人工智能

wenet mask原理解析

简介: 该程序主要对wenet使用的mask原理进行分析,更多详细内容参照 https://zhuanlan.zhihu.com/p/381271607 代码位置: wenet/mask.py at ...