在鹅厂实习阶段,follow苏神(科学空间)的博客,启发了idea,成功改进了线上的一款模型。想法产出和实验进展很大一部分得益于苏神设计的bert4keras,清晰轻量、基于keras,可以很简洁的实现bert,同时附上了很多易读的example,对nlp新手及其友好!本文推荐几篇基于bert4keras的项目,均来自苏神,对新手入门bert比较合适~
- tokenizer:分词器,主要方法:encode,decode。
- build_transformer_model:建立bert模型,建议看源码,可以加载多种权重和模型结构(如unilm)。
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import to_array
config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt'
tokenizer = Tokenizer(dict_path, do_lower_case=True)
model = build_transformer_model(
config_path=config_path, checkpoint_path=checkpoint_path, with_mlm=True
)
token_ids, segment_ids = tokenizer.encode(u'科学技术是第一生产力')
token_ids[3] = token_ids[4] = tokenizer._token_mask_id
token_ids, segment_ids = to_array([token_ids], [segment_ids])
probas = model.predict([token_ids, segment_ids])[0]
print(tokenizer.decode(probas[3:5].argmax(axis=1)))
- 句子1和句子2拼接在一起输入bert。
- bert模型的pooler输出经dropout和mlp投影到2维空间,做分类问题。
- 最终整个模型是一个标准的keras model。
class data_generator(DataGenerator):
"""数据生成器
"""
def __iter__(self, random=False):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for is_end, (text1, text2, label) in self.sample(random):
token_ids, segment_ids = tokenizer.encode(
text1, text2, maxlen=maxlen
)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_labels.append([label])
if len(batch_token_ids) == self.batch_size or is_end:
batch_token_ids = sequence_padding(batch_token_ids)
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_labels = sequence_padding(batch_labels)
yield [batch_token_ids, batch_segment_ids], batch_labels
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
bert = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
with_pool=True,
return_keras_model=False,
)
output = Dropout(rate=0.1)(bert.model.output)
output = Dense(
units=2, activation='softmax', kernel_initializer=bert.initializer
)(output)
model = keras.models.Model(bert.model.input, output)
model = build_transformer_model(
config_path,
checkpoint_path,
application='unilm',
keep_tokens=keep_tokens,
)
NLG任务的loss是交叉熵,示例中的实现很美观:
- CrossEntropy类继承Loss类,重写compute_loss。
- 将参与计算loss的变量过一遍CrossEntropy,这个过程中loss会被计算,具体阅读Loss类源码。
- 最终整个模型是一个标准的keras model。
class CrossEntropy(Loss):
"""交叉熵作为loss,并mask掉输入部分
"""
def compute_loss(self, inputs, mask=None):
y_true, y_mask, y_pred = inputs
y_true = y_true[:, 1:]
y_mask = y_mask[:, 1:]
y_pred = y_pred[:, :-1]
loss = K.sparse_categorical_crossentropy(y_true, y_pred)
loss = K.sum(loss * y_mask) / K.sum(y_mask)
return loss
model = build_transformer_model(
config_path,
checkpoint_path,
application='unilm',
keep_tokens=keep_tokens,
)
output = CrossEntropy(2)(model.inputs + model.outputs)
model = Model(model.inputs, output)
model.compile(optimizer=Adam(1e-5))
model.summary()
预测阶段自回归解码,继承AutoRegressiveDecoder类可以很容易实现beam_search。
项目地址:SimBert
融合了unilm和对比学习,data generator和loss类的设计很巧妙,值得仔细阅读,建议看不懂的地方打开jupyter对着一行一行print来理解。
bert4keras项目的优点:
- build_transformer_model一句代码构建bert模型,一个参数即可切换为unilm结构。
- 继承Loss类,重写compute_loss方法,很容易计算loss。
- 深度基于keras,训练、保存和keras一致。
- 丰富的example!苏神的前沿算法研究也会附上bert4keras实现。
Original: https://blog.csdn.net/weixin_44597588/article/details/123910248
Author: 一只用R的浣熊
Title: 简洁优美的深度学习包-bert4keras
相关文章

#问题及解决方案整理# | python | anaconda+TensorFlow软件安装运行中的各种问题(已成功解决)
#问题及解决方案整理# | python | anaconda+TensorFlow软件安装运行中的各种问题(已成功解决) 背景 * 常用指令 - 查看conda的版本 查看python的版本 查看a...

二进制数的运算原理与门电路实现
本文地址:https://www.cnblogs.com/faranten/p/16099916.html 转载请注明作者与出处 1 数据和表示方法 1.1 数字的表示 1.1.1 定点数 在计算...

语音控制小车运动APP(基于百度语音识别)
项目背景 由于暑期优秀本科生项目需求,开发了一款控制机器人行走的APP,具体要求如下: 在第一个界面(连接界面)实现Socket连接,连接成功则跳转到下一个页面(控制界面)。 在控制界面中创建5个按钮...

【Opencv小项目 1】Opencv实现简单颜色识别
参考 Opencv简单颜色识别 Youtube教学视频 BGR HSV颜色模型 步骤 一、 BGR 和 HSV 颜色模型 BGR Model BGR模型表示三种颜色通道:红、绿、蓝,采用BGR模型的图...

Unity -Demo 之 ✨ 接入“科大讯飞”语音识别SDK(完整)
抵扣说明: 1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。 2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。 Original: https://blog.cs...

NeuS: Learning Neural Implicit Surfaces by Volume Rendering for Multi-view Reconstruction 论文笔记
文章目录 Related Works 方法 * Rendering Procedure - 场景表示 Scene Representation 渲染 Rendering 权重函数 weight fun...

PyTorch: 目标检测(object detection)介绍
目标检测(object detection) 一、 介绍 在图像分类任务中,我们假设图像中只有一个主要物体对象,我们只关注如何识别其类别。 然而,很多时候图像里有多个我们感兴趣的目标,我们不仅想知道它...

数据增强之Mosaic数据增强的优点、Mixup,Cutout,CutMix的区别
一、Mosaic data augmentation Mosaic数据增强方法是YOLOV4论文中提出来的,主要思想是将 四张图片进行随机裁剪,再拼接到一张图上作为训练数据。 这样做有以下几个优点: ...

jQuery triger与trigerHandler的区别
trigger(event, [data]) 与 triggerHandler(event, [data]) 都是用于触发一个事件。 其两者的区别在于,如果触发的事件是有浏览器默认行为的,trigge...

【论文阅读】ICCV2021|超分辨重建论文整理和阅读
本文主要对ICCV2021中超分辨率重建相关论文进行整理与阅读。 1. Learning A Single Network for Scale-Arbitrary Super-Resolution P...

哈工大 计算机网络 实验二 可靠数据传输协议(停等协议与GBN协议)
计算机网络实验代码与文件可见github:计算机网络实验整理 实验名称 可靠数据传输协议(停等协议与GBN协议) 实验目的: 本次实验的主要目的。 理解可靠数据传输的基本原理;掌握停等协议的工作原理;...

Anaconda安装tensorflow和keras包
1.背景 在Anaconda中无法直接安装这两个包,安装过程异常漫长。 2.准备工作 添加清华源 1.在Anaconda prompt中(可利用全局搜索查找)运行 conda config命令,然后寻...

添加metadata到tflite模型
1、metadata简介 TensorFlow Lite 元数据提供了模型描述的标准。 元数据是关于模型做什么及其输入/输出信息的重要知识来源。 元数据主要由人类可读部分和机器可读部分组成。 注: 在...

自然语言处理实战:小说读取及分析(附代码)
自然语言处理实战:小说读取及分析 ——— 本文来自于萌新的小作业 目录 自然语言处理实战:小说读取及分析 前言 一、自然语言处理是什么? * 1、概念 2、应用 二、准备工作 * 1、学习目标 2、库...

Yolov5自学笔记之二–在游戏中实时推理并应用(实例:哈利波特手游跳舞小游戏中自动按圈圈)
上一篇帖子我已经自学了Yolov5的基本流程,并运用yolov5进行图片、视频、摄像头、网络视频流等多种方式的推理,这些结合到实际工作中就可以有很广泛的应用了。但是还有一类情况,就是在电脑中的某个程序...