在学习中涉及到了TensorFlow的自定义算子实现,现将整个工程中的一些思考写下来,有问题的部分也请大家指正!!!
OP和Kernel是TensorFlow框架最重要的两个概念,OP类似于函数声明,Kernel类似于实现。要注意以下四个方面:一是所有Op包含注册和实现两部分;二是OpKernel类(./core/framework/op_kernel.h)是所有Op类的基类;三是所有Op类的实现需要overide抽象基函数void Compute(OpKernelContext* context),实现自身Op功能;四是所有Op操作的属性定义和描述符合protobuf协议。
一、自定义算子实现基本流程
在一个C++文件中注册新Op,其注册与实现相互独立,该文件指定自定义算子的输入输出、参数,命名采用驼峰命名法。
/**
* ./tensorflow/core/framework/op.h
* #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
* #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
* #define REGISTER_OP_UNIQ(ctr, name) \
* static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
* TF_ATTRIBUTE_UNUSED = \
* ::tensorflow::register_op::OpDefBuilderWrapper(name)
* REGISTER_OP本质是创建了一个OpDefBuilderReceiver对象,
* 并将Attr,Input,Output等保存在OpDefBuilder对象中。
*/
REGISTER_OP("myFunc") //: ,通过context参数访问这个属性
.Input("in1: int32")
.Input("in2: int32")
.Output("out: int32")
.Attr("Para1: int")
.Attr("Para2: int")
.SetShapeFn([](InferenceContext *c){return Status::OK();})
上述表示:注册名为myFunc的算子,输入in1和in2,类型为int32;输出为out,类型为int32;参数为Para1和Para2,类型为int,ShapeFn用于shape推断。
也可以在注册时赋予默认值,默认值支持的语法将在最终GraphDef定义的pb表示中被使用。
/**
* tensorflow/core/framework/op_kernel.h
* class OpKernel {
* public:
* explicit OpKernel(OpKernelConstruction* context);
*
* OpKernel(OpKernelConstruction* context, bool is_deferred);
*
* OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
* bool is_deferred);
* ...
* TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
* };
*/
class myFuncOp: public OpKernel{ //创建一个类,继承OpKernel类
public:
//创建构造函数并显示调用OpKernel(context)
explicit myFuncOp(OpKernelConstruction* context):OpKernel(context)
{
//参数获取
OP_REQUIRES_OK(context,context->GetAttr("attr_name",&attr_name));
}
void Compute(OpKernelContext* context) override //重写OpKernel类的Compute方法
{
//输入tensor
Tensor* in1 = const_cast(&context->input(0));
Tensor* in2 = const_cast(&context->input(1));
//创建一个输出, 使用context->allocate_ouput()分配空间
Tensor* out = NULL;
TensorShape out_shape(...);
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &out));
...
//算子行为的具体实现
...
}
}
/**
* #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
* REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
* #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
* REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
* #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
* constexpr bool should_register_##ctr##__flag = \
* SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
* static ::tensorflow::kernel_factory::OpKernelRegistrar \
* registrar__body__##ctr##__object( \
* should_register_##ctr##__flag \
* ? ::tensorflow::register_kernel::kernel_builder.Build() \
* : nullptr, \
* #__VA_ARGS__, \
* [](::tensorflow::OpKernelConstruction* context) \
* -> ::tensorflow::OpKernel* { \
* return new __VA_ARGS__(context); \
* });
* REGISTER_KERNEL_BUILDER实质是创建一个名称唯一的类型为OpKernelRegistrar的全局静态变量
* class OpKernelRegistrar {
* public:
* OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
* std::unique_ptr factory) {
* if (kernel_def != nullptr) {
* InitInternal(kernel_def, kernel_class_name, std::move(factory));
* }
* }
* OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
* OpKernel* (*create_fn)(OpKernelConstruction*)) {
* if (kernel_def != nullptr) {
* InitInternal(kernel_def, kernel_class_name,
* absl::make_unique(create_fn));
* }
* }
* }
* OpKernelRegistrar的构造需要三个被包装到KernelRegistration这个结构体里的参数,并作为Kernel注册表的值:
* 第一个是KernelDef,第二个是定义Kernel的类名,第三个是创建kernel对象的函数;
* 首先调用KernelDefBuilder的Build函数获得对应的KernelDef;
* 然后获取用于创建这个Kernel的C++类名称;
* 最后包装一个factory函数用来接收传进来的OpKernelConstruction*,创建对应的Kernel类对象,并返回其指针。
*/
REGISTER_KERNEL_BUILDER(Name("myFunc").Device(DEVICE_CPU), myFuncOp);
二、示例(基于《智能计算系统》实验7-1)
在NMS实现之后,需要将其集成到TF框架中重编译,整个过程涉及接口封装与算子集成。
利用CNML PluginOP封装出便于用户使用的CNPlugin接口(该过程已实现)。
//plugin_yolov3_detection_output_op.cc
cnmlStatus_t cnmlCreatePluginYolov3DetectionOutputOp(//算子创建、参数声明及初始化...
cnmlBaseOp_t *op,
cnmlPluginYolov3DetectionOutputOpParam_t param,
cnmlTensor_t *yolov3_input_tensors,
cnmlTensor_t *yolov3_output_tensors){...}
cnmlStatus_t cnmlComputePluginYolov3DetectionOutputOpForward(...)//调用cnmlComputePluginOpForward完成计算
{
...
cnmlComputePluginOpForward_V3(...);//cnmlComputePluginOpForward_V4(...)
...
}
直接封装CNML和CNPlugin算子,结果供算子的DLP实现函数调用,该封装目的是将高层调用与底层实现有效隔离。
//mlu_lib_ops.cc & mlu_lib_ops.h
tensorflow::Status CreateYolov3DetectionOutputOp(...)
{
CNML_RETURN_STATUS(cnmlCreatePluginYolov3DetectionOutputOp(op, param, input_tensors, output_tensors));
}
tensorflow::Status ComputeYolov3DetectionOutputOp(...)//
{
...
cnmlComputePluginYolov3DetectionOutputOpForward(op, inputs, input_num, outputs, output_num, &compute_forw_param, queue);
}
//mlu_ops.h 算子类声明
struct MLUYolov3DetectionOutputOpParam{//数据成员声明
...
MLUYolov3DetectionOutputOpParam(...): ...{}
}
/**
* 类声明,继承自MLUBaseOpWrapper
* CreateMLUOp(inputs, outputs, param)
* Compute(const std::vector &inputs, const std::vector &outputs, cnrtQueue_t queue) override
*/
DECLARE_OP_CLASS(MLUYolov3DetectionOutput);
//yolov3detectionoutput.cc 实现
Status MLUYolov3DetectionOutput::CreateMLUOp(std::vector &inputs, std::vector &outputs, void *param){
//定义输入输出tensor
...
//参量初始化
...
//调用cnmlCreatePluginYolov3DetectionOutputOpParam
//调用CreateYolov3DetectionOutputOp
...
}
Status MLUYolov3DetectionOutput::Compute(const std::vector &inputs, const std::vector &outputs, cnrtQueue_t queue)
{
//变量获取
...
//调用ComputeYolov3DetectionOutputOp
...
}
运行时会MLU自动将算子与运行时队列绑定并下发执行。
//mlu_stream.h
Status Yolov3DetectionOutput(OpKernelContext* ctx,
Tensor* tensor_input0,
Tensor* tensor_input1,
Tensor* tensor_input2,
...
Tensor* output1,
Tensor* output2){
//实例化MLUYolov3DetectionOutputOpParam
ops::MLUYolov3DetectionOutputOpParam op_param(...);
//调用MLUYolov3Detectionutput,CommonOpImpl接口用于处理输入输出并创建OP
return CommonOpImpl(
ctx,
{tensor_input0, tensor_input1, tensor_input2},
{output1, output2},
static_cast(&op_param));
}
//yolov3_detection_output_op_mlu.h
class MLUYolov3DetectionOutputOp: public MLUOpKernel{//创建继承自MLUOpKernel的类
public:
//创建构造函数并显示调用MLUOpKernel(context)
explicit MLUYolov3DetectionOutputOp(OpKernelConstruction* context):MLUOpKernel(context){
//参数获取
OP_REQUIRES_OK(context,context->GetAttr("Attr",&Attr_));
...
}
void ComputeOnMLU(OpKernelContext* context) override {
...
//将输入tensor从context中取出
Tensor* input0 = const_cast(&context->input(0));
Tensor* input1 = const_cast(&context->input(1));
Tensor* input2 = const_cast(&context->input(2));
...
//创建输出, 使用context->allocate_ouput()给它分配空间,并进行形状推断
Tensor* output;
Tensor* buffer;
TensorShape tf_output_shape {...};
TensorShape tf_buffer_shape {...};
OP_REQUIRES_OK(context, context->allocate_output(0, tf_output_shape, &output));
OP_REQUIRES_OK(context, context->allocate_output(0, tf_buffer_shape, &buffer));
//调用自定义算子
OP_REQUIRES_OK(context,stream->Yolov3DetectionOutput(...));
}
//参数声明
private:
int batchNum_;
int inputNum_;
int classNum_;
int maskGroupNum_;
int maxBoxNum_;
int netw_;
int neth_;
float confidence_thresh_;
float nms_thresh_;
std::vector inputWs_;
std::vector inputHs_;
std::vector biases_;
};
在进行形状推断时,需要注意以下:
//cnplugin.h
/*!
* @brief A function.
*
* This function creates PluginYolov3DetectionOutputOp with proper param,
* input, and output tensors.
*
* PluginYolov3DetectionOutputOp takes in feature maps and network
* parameters and computes valid bounding boxes based on two thresholds
* you have chosen.
*
* **Reference:**
* This implementation is based on the project on ``github/pjreddie/darknet`` .
*
* **Formula:** This op contains two steps:
*
* 1. DecodeAllBBoxes.
*
* Convert input feature maps into real ojectness score and coordinates.
* for inputIdx in (0, inputNum - 1)
*
* obj = sigmoid(obj_feature);
* x = (x_offset + sigmoid(x_feature)) / inputWs[inputIdx]
* y = (y_offset + sigmoid(y_feature)) / inputHs[inputIdx]
* w = (w_biases * exp(w_feature)) / netw
* h = (h_biases * exp(h_feature)) / neth
* Obj, x_feature, y_feature, w_feature, h_feature are data from input feature maps.
* x_offset, y_offset are the coordinates of the grid cell in the feature map.
* w_offset, h_biases are the shape of the anchor box.
*
* 2. Non-maximum Suppression
* For each class of data, compute IOU score for every pair of bounding boxes.
* If IOU score exceeds the IOU threshold, keep the box with larger score.
* x1 = x - w / 2
* y1 = y - y / 2
* x2 = x + w / 2
* y2 = y + y / 2
* for classIdx in (0, classNum - 1)
* conf = obj * probability[classIdx]
* max, maxIdx = findMaxValueAndIndex(conf)
* if (max >= confidence_thresh)
* for boxIdx in (0, boxNum - 1)
* iou = computeIOU(coord_maxIdx, coord_boxIdx) // where "coords" means x1,y1,x2,y2
* if (iou < nms_thresh)
* keep coords and conf for boxIdx
*
* **DataType:**
* Support only half(float16) type for both input and output tensors.
*
* **Performance Optimization:**
* The performance of detection layer depends on both the data size and the value.
* However, this op achieves relatively better performance when
* all of the following conditions are met:
* - inputH/Ws are 64-aligned(unit in number of data).
* - (5 + classNum) is 64-aligned(unit in number of data).
* The bigger the remainder of the value of param divided by 64, the better performance the op will achieve.
* Supports both MLU220 and MLU270.
*
* @param[out] op
* Output. A pointer to the base operator address.
* @param[in] param
* Input. A PluginYolov3DetectionOutput parameter struct pointer.
* @param[in] yolov3_input_tensors
* Input. An array of four-demensional cnmlTensors with a shape of
* [batchNum, (5 + classNum) * numMaskGroup, inputH, inputW](NCHW).
* Support only FLOAT16 dataType currently.
* @param[in] outputs
* Input. An array of four-demensional cnmlTensors with a shape of
* [batchNum, 64 + 7 * numMaxBox, 1, 1](NCHW).
* Support only FLOAT16 dataType currently.
* The first two numbers of each batch store the number of
* detected boxes. The data for each box starts from the 65th number,
* with an order of [batchId, classId, score, x1, y1, x2, y2], where
* (x1, y1) and (x2, y2) are the coordinates of top-left and bottom-
* -right points accordingly.
* @retval CNML_STATUS_SUCCESS
* The function ends normally
* @retval CNML_STATUS_INVALIDPARAM
* At least one of the following conditions is not met:
* - Base op pointer is nullptr
* - Param is nullptr or not initialized
* - Input / output tensor desps is nullptr or inconsistent with param.
*/
cnmlStatus_t cnmlCreatePluginYolov3DetectionOutputOp(
cnmlBaseOp_t *op,
cnmlPluginYolov3DetectionOutputOpParam_t param,
cnmlTensor_t *yolov3_input_tensors,
cnmlTensor_t *yolov3_output_tensors);
定义cnmlCreatePluginYolov3DetectionOutputOp时,对输出张量shape进行了明确,为[batchNum, 64 + 7 * numMaxBox, 1, 1]。
//yolov3_detection_output_op.cc Kernel注册
REGISTER_KERNEL_BUILDER( \
Name("Yolov3DetectionOutput") \
.Device(DEVICE_MLU) \
.TypeConstraint("T"), \
MLUYolov3DetectionOutputOp);
//image_ops.cc OP注册
REGISTER_OP("Yolov3DetectionOutput")
.Output("predicts: T")
.Input("input0: T")
.Input("input1: T")
.Input("input2: T")
.Attr("batchNum:int")
.Attr("inputNum:int")
.Attr("classNum:int")
.Attr("maskGroupNum:int")
.Attr("maxBoxNum:int")
.Attr("netw:int")
.Attr("neth:int")
.Attr("confidence_thresh:float")
.Attr("nms_thresh:float")
.Attr("inputWs: list(int)")
.Attr("inputHs: list(int)")
.Attr("biases: list(float)")
.Attr("T: type")
.SetShapeFn([](InferenceContext *c){return SetOutputForYolov3DetectionOutput(c);
});
在OP注册时,其涉及到的输入输出及参量和.pbtxt中node一一对应。
//./cnplugin.h
/*!
* @brief A function.
* This function creates a PluginYolov3DetectionOutputOp param object with
* the pointer and parameters provided by user.
* **Supports MLU220/MLU270**
* @param[out] param
* Output. The returning param descriptor.
* @param[in] batchNum
* Input. The number of input batches.
* No default value, a valid batchNum must be in the range of [1, inf).
* @param[in] inputNum
* Input. The number of input tensors.
* No default value, a valid inputNum must be in the range of [1, 7].
* @param[in] classNum
* Input. The number of input classes.
* No default value, a valid classNum must be in the range of [1, 4096].
* @param[in] maskGroupNum
* Input. The number of anchors used by every input tensors.
* No default value, a valid maskGroupNum must be in the range of [1, inf].
* @param[in] maxBoxNum
* Input. The largest possible number of output boxes.
* Default value is 1024, a valid maxBoxNum must be in the range of [1, inf].
* @param[in] netw
* Input. Width of input image of backbone network.
* No default value, a valid netw must be in the range of [1, inf).
* @param[in] neth
* Input. Height of input image of backbone network.
* No default value, a valid neth must be in the range of [1, inf).
* @param[in] confidence_thresh
* Input. Confidence threshold.
* No default value, a valid confidence_thresh must be in the range of [0, 1].
* @param[in] nms_thresh.
* Input. IOU threshold used in NMS function.
* No default value, a valid nms_thresh must be in the range of [0, 1].
* @param[in] core_version
* Input. Supported core version.
* No default value, a valid core_version must be either MLU220 or MLU270.
* @param[in] inputWs
* Input. Width of every input tensor. Must have the same order as inputHs
* No default value, the number of valid elements must be equal with inputNum.
* @param[in] inputHs
* Input. Height of every input tensor. Must have the same order as inputWs
* No default value, the number of valid elements must be equal with inputNum.
* @param[in] biases
* Input. Anchors of every input tensor.
* No default value. The number of valid elements must be equal with 2 x inputNum x maskGroupNum.
* The order of data from high to low, is [N(1) H(inputNum) W(maskGroupNum) C(2)]. For example:
* Width of anchor for mask0 input0, Height of anchor for mask0 input0,
* Width of anchor for mask1 input0, Height of anchor for mask1 input0,
* ...
* Width of anchor for maskN input0, Height of anchor for maskN input0,
* Width of anchor for mask0 input1, Height of anchor for mask0 input1,
* ......
* @retval CNML_STATUS_SUCCESS
* The object was set successfully.
* @retval CNML_STATUS_INVALIDPARAM
* The inputH/Ws ptr is nullptr or input param is invalid.
*/
cnmlStatus_t cnmlCreatePluginYolov3DetectionOutputOpParam(
cnmlPluginYolov3DetectionOutputOpParam_t *param,
int batchNum,
int inputNum,
int classNum,
int maskGroupNum,
int maxBoxNum,
int netw,
int neth,
float confidence_thresh,
float nms_thresh,
cnmlCoreVersion_t core_version,
int *inputWs,
int *inputHs,
float *biases);
在./cnplugin.h里定义了cnmlCreatePluginYolov3DetectionOutputOpParam,注释对每个参数含义进行了说明。对涉及到的参量,需要给定默认值,可以在OP注册时给定,也可以在添加node时给定。
其参数由数据集及算法特性给定:
①COCO共有80个类,原始图片全部resize为416 × 416;
②YOLOv3分别在尺度13 x 13, 26 x26, 52 x52上执行检测;
③在每个尺度上,每个单元使用 3 个锚点预测 3 个边界框,锚点的总数为 9,v3中每个尺度上平均检测三个锚点;
④在进行检测时,九个框分别是 (10×13),(16×30),(33×23),(30×61),(62×45),(59× 119), (116 × 90), (156 × 198),(373 × 326) ,顺序为w × h,数据依次从大到小排列。
三、自定义开发时涉及TensorFlow源码目录
tensorflow/core:
tensorflow/stream_executor:
运行时环境,管理TF中高性能并行编程设备的执行过程(限制哪些任务可以并发执行并指定存在哪些任务依赖项...)Original: https://blog.csdn.net/weixin_40943865/article/details/122225775
Author: 继明照于四方
Title: TensorFlow的自定义算子实现
相关阅读
Title: CONTINUAL SELF-TRAINING WITH BOOTSTRAPPED REMIXING FOR SPEECH ENHANCEMENT
题目:CONTINUAL SELF-TRAINING WITH BOOTSTRAPPED REMIXING FOR SPEECH ENHANCEMENT
时间:2021.10
作者:Efthymios Tzinis1,∗, Yossi Adi2, Vamsi K. Ithapu3, Buye Xu3, Anurag Kumar3
机构:University of Illinois at Urbana-Champaign, 2Facebook AI Research, 3Facebook Reality Labs Research
摘要:
我们提出了Remix IT,一个简单并创新的语音增强自监督训练方法。此方法基于连续自训练模式,这种模式克服了之前研究中的限制,包括域内噪声分布的假设和可获得的纯净语音目标。具体来说,首先在域外数据集上(OOD)预训练一个分离模型。并把他用于推断每个batch中域内mixture的估计目标信号,然后,通过使用排列的估计的干净和噪声信号生成人工mixture。最后,学生模型使用permuted 估计源作为目标训练模型,同时我们使用最新的学生模型周期性更新老师的权重。实验表明RemixIT在多语音增强任务下超越了之前最新的自监督方法,另外,RemixIT在语音增强任务中,在半监督和无监督之间实现了无缝连接,而且本方法能够应用在任何分离任务和分离模型一起使用。
引言:
神经网络已经被发现可以被高效并且广泛应用于大量语音任务上,包括语音增强,语音增强的目的是提升带噪语音的质量和可懂度。最近,有监督的,实时的,半监督的语音增强方法相继出现。大部分方法都是有监督的,训练这样的模型需要大量大量音频数据,并且期望这些训练数据可以和测试数据的分布相匹配,有限的监督数据虽然可得,使用这些数据训练的有监督模型由于不匹配测试数据的分布,测试时性能下降。
为了解决这些问题,减少对于纯监督数据的依赖,一些语音增强和声源分离方法转向了自监督的方法。在[5]中,训练模型估计带噪语音的SNR,并且为每个带噪片段设置一个置信值。其次,分离模型使用权重重构损失过滤带噪真实语音。最近提出的Mix IT已经能够i实现无监督分离,通过人工混合mixture of mixture,并且使分离模型估计和重新排列源混合物。MixIT提供了鲁棒性的无监督解决方法,语音增强方法中也有follow它的,然而,Mix IT假设能获得域内噪声类型并且能够改变输入数据的SNR分布通过在人工MoMs使用多于一种的噪声类型。
教师-学生模型已经在语音任务上表现出很大提升,包括:学生模型在预训练的MIx IT模型的输出上训练,解决了在训练集和测试集分布上,人工创造的SNR不匹配的问题。使用一个能力阈值减少出现在带噪语音的源数量,而且,此外,学生模型可以适应给定的测试集使用回归预先训练的教师的估计。与我们的工作最接近的自训练框架是一个半监督歌声分离,它使用教师在域外监督数据预训练,用来估计更大的域内带噪数据的源。带有新标签的数据集对低质量的分离源进行过滤,并存储为一个新的学生模型的离线训练,使用人工生成的来自估计的自标记估计源的混合物。之前的大部分方法都是冻结教师模型,在ASR上有一些使用moving mean teacher来更新教师模型表现出很大的提升。
本文章中,我们提出了自训练方法,能够在大的域内带噪数据集上进行自学习,仅仅要求域外预训练教师模型(Mix IT on an OOD dataset),与文献中使用ad-hoc 过滤程序来提高教师模型估计的质量的自我训练方法相比,我们的方法通过执行线上remix教师的估计源,而且,我们不冻结教师,RemixIT将自训练过程看作一个终身学习的过程使用连续,移动平均更新模式,者能够更快收敛,我们的实验证明了对自我监督语音增强的普遍适用性,半监督OOD泛化和zero-shot自适应
方法
我们提出了语音增强的一般情况,其目标是从一个有噪声的语音信号重建干净的语音。
分离模型:fs
输出:M个声源
估计的speech,s 纯净语音目标,估计的噪声信号,n:噪声目标
估计的语音加估计的噪声等于纯净语音+噪声使用mixture consistency layer
MixIT:
MixIT已经在不同的自监督增强方法中证明了他的有效性,Mix IT假设训练集是由两部分组成(Dm,Dn),Dm是数据集的一部分,包含了语音和一种噪声的带噪语音,Dn包含了分离的噪声,在训练中,人工合成的MoM:x=s+n1+n2,通过从Dm采样一个batch的带噪语音m,从Dn中采样一个batch的噪声n2,其中m=s+n1,分离模型会估计三个源:s,n1,n2,使用PIT loss最小化第b个输入MoM
然而,MixIT假设可以获得域内分离的噪声,这种情况在现实中是不现实的,我们不能总是能够使噪声分布与可用的噪声集Dn相匹配,并且可能需要和有点的监督数据一起使用。有数据增强方法提出:从域外噪声集注入一种噪声源到MoM中效果会得到提升,然而,这种方法的性能仍然依赖于在实际噪声分布与Dn之间的分布位移水平。
使用bootstrapped remix连续自训练
RemixIT不假设能够获得域内信息,因此我们仅仅从域内带噪数据集m=s+n描述mixture,域内带噪数据集包含一个单一的噪声源,m,s,n.RemixIT利用学生老师架构,通过排列之前的带噪估计重新混合,并且将他们作为训练目标。
RemixIT的连续自训练框架
我们假设可以使用监督或者自监督方法在域外数据集上预训练教师模型D',第一步是使用教师模型对给定的mixture batch估计新的带噪目标,m=s+n1
k是优化步骤,如果教师模型通过有监督OOD预训练得到的,(无监督使用MIx IT),会有M=2,M=3个输出,其次,我们使用估计得到的源生成新的noisy mixture,如下:
现在,我们使用排列的目标对训练学生模型
所提方法的损失函数与常规的监督学习中的信号级别的损失函数类似。我们的方法不会人工改变输入SNR的分布(Mixit就是这样做的),取而代之,学生模型使用相同数量的源(bootstrapped mixture),教师模型已经充分训练好。不同于之前的学生-老师模型,他们使用老师估计的相同的声源对作为学生网络的目标,提出的bootstrapped mixture增加了学生输入的多样性,可以更快的训练模型。Remix IT使用终身学习的训练方法,而不是冻结老师模型,使用静态的线下自标注数据集,分开训练学生模型。我们的方法能够与任何线上联合训练的方法一起使用,除了主要的学生模型训练之外连续更新教师模型权重
欧几里得标准下的误差分析
理想情况下,最后一项内积为0,如果教师产生的输出与干净的目标信号中或者上一个公式中的条件误差分布是独立的。直观地说,当我们不断更新教师模型和改进他的输出时,我们最小化了教师误差的范数,另外,这种方法强迫误差更加不相关,因为学生试图重构相同的干净语音信号,类似于他的老师,但是在不同的mixture分布下。学生努力重构s当观察bootstrapped mixture m=s+n~,而教师努力重构s从初始输入mixture m=s+n.
; 实验
数据集:
DNS2020,由64649和150对干净语音和噪声训练和测试,这个数据集用于验证方法的有效性在利用大量带噪语音
LibriFSD50K (LFSD):本研究主要将该数据集用于语音增强模式的OOD无监督或半监督预训练
WHAM!:这个数据集作为一个中等规模的数据集,其中有20,000个训练噪声语音对和3,000个测试mixture
VCTK:使用了他的测试集部分
分离模型
这里重点强调提出的方法可以和任何分离模型一起使用,我们选择了sudo-rm -rf架构,因为它在语音增强质量和时间内存计算要求之间实现了很好的权衡。实验中,我们将新的学生网络的深度增加到16和32。我们为MixIT模型或M=2固定了M=3的输出。
RemixIT 设置
对于无监督RemixIT,我们假设初始的教师模型使用Mix IT在特点的OOD数据集上预训练。对于半监督RemixIT,使用PIT预训练,我们还实验了各种在线教师更新协议,如
k是训练epoch数,对于顺序更新的老师,每20个epoch用最新的学生替换原来的老师。对于zero-shot域自适应实验,首先将学生和老师的参数设置为一样,然后使用moving average 以0.01更新教师
训练和测试细节
尽管我们能够使用有效的信号级别的loss,我们选择-SI-SDR
使损失与估计的源y^和目标信号y的尺度不变,
; 结果
教师估计的持续改进
这是使用具有不同的教师更新策略的RemixIT,所有的方法都使用与U=8相同的教师架构,使用WHAM的训练分割,以一种有监督的方式进行预先训练!
连续更新教师模型:每20个epoch用之前的学生模型取代教师,与使用相同的静态教师的方法相比,能够获得显著的改进。在半监督RemixIT设置中,学生模型超越初始的OOD上预训练的教师,在SI-SDR上提升了1DB。
我们的实验结果验证了假设,即语音增强模型可以在教师将与学生并行更新的终身学习过程中更快更有效地训练。顺序更新和冻结的教师方法所获得的语音增强性能如图上所示,在第20个epoch之前,两种策略是一样的,教师是静态的。在20个epoch之后,教师被最新的学生所取代,而下一个学生的深度增加,8→16。结果,连续更新的方法在40个epoch之后表现更好相比于冻结教师。顺序更新的教师规模比其他方法更好,这是我们在所有其他实验中使用的默认策略,除了零镜头适应,我们也表明,运行意味着教师更新方案也是一个有效的选择
自监督和半监督语音增强
在DNS测试集上的语音增强性能,方法有所提出的RemixIT、无监督MixIT、有监督域内训练,全部使用相同的Sudorm-rf模型(U=8)、和fullsubnet
表中%是指从DNS或者LFSD使用成对的数据。例如:无监督的RemixIT预训练老师模型要求无监督MixIT使用80%的LFSD数据对模拟带噪语音D'm其他20%用于干净OOD噪声Dn,而学生模型利用整个DNS带噪数据集
注意无监督和半监督RemixIT不依赖于干净的域内噪声样本,尽管如此,无监督的学生模型显著优于所有类似于mixit的方法,包括域内训练和最近提出的额外噪声增强方法(MoMs包含3种噪声源),此外,最大的无监督学生(U=32)在所有评估指标上都远远优于其OODMixIT无监督教师,这显示了RemixIT对自我监督设置的有效性。该方法在半监督下也产生了显著的提升,学生模型与使用U=8的默认Sudorm-rf模型和最近的最新模型的域内监督训练表现相当。我们想强调的是,我们的方法可以使用更复杂的教师模型,而不是高效的Sudorm-rf架构,并提供更高质量的语音增强性能。
zero-shot自适应
RemixIT也可以用于低资源的数据集,训练数据有限,但可以访问测试数据集来适应预先训练的模型。在更大的OOD数据集上,用各种有监督和无监督的预训练网络,描述了零镜头语音增强任务的性能改进,RemixIT在SI-SDR上提升了0.8db相比于未经校准的预训练模型当使用有限数量的域内mixture,模型的性能与可用noisy mixture有关,在WHAM(DNS)测试集上看到原因,最大的有3000(只有150)mixture。此外,我们还注意到,在训练数据和自适应集中的混合数据之间有很大的分布转移的情况下,我们有了很大的改进
结论
我们提出了一个新的连续自训练降噪方法,在几个现实语音增强任务上表明了他的效果,我们的方法依赖于域内noisy和一个纯使用OOD数据的预训练模型,这种数据可能无法捕捉域内数据的分布,自助再混合过程与连续双向师生自我训练框架的耦合,导致了零射击和自监督语音增强和半监督语音增强以及半监督域自适应的显著改进。在未来,我们的目标是将我们的方法应用于其他领域和去噪任务,并为我们的算法的收敛性提供更强的理论保证。
Original: https://blog.csdn.net/weixin_44223902/article/details/122102009
Author: weixin_44223902
Title: CONTINUAL SELF-TRAINING WITH BOOTSTRAPPED REMIXING FOR SPEECH ENHANCEMENT