自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)

人工智能20

文章目录

简介

很早之前看了unet3+医学图像分割的论文,本来想直接去github找keras/Tensorflow的实现,奈何找到的似乎都和源码有一些出入,于是自己按照论文和源码写了一下,不过也不能保证和源码完全一致,发出来抛砖引玉。很多讲unet3+的博客都写的挺不错的,要想了解全文可以看看这篇翻译【UNet3+(UNet+++)论文解读 玖零猴】​,这篇文章也简单讲一下自己的理解。

unet3+论文
源码(Pytorch)

一、unet3+

简单来说,unet3+有三个特点:
1 跨尺度连接,防止语义在下采样/上采样之间存在损失
2 全尺度深监督,学习深层次的特征表示
3 为了消除医学图像中噪声导致的假阳性分割,提出一个分类指导模块
4 一个新的混合损失函数(TODO)

呃,前面三点其实各有槽点,后面再说

自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)
unet3+的网络结构如上图,总的来说还是非常易懂的,作者认为unet和unet++都没有做到跨尺度的特征图连接,于是想到将编码器不同尺度地信息传递到解码器,解码器中的信息也进行了跨层传递,以此减少信息丢失(真是简单粗暴=_=)。
自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)
以解码器3为例,解码器3融合了编码器1、2、3和解码器4、5的特征,这些特征通过最大池化(来自编码器的特征)或上采样(来自解码器的特征)调整到和解码器3一样的特征图大小,并且通过卷积层(源码里是卷积+BN+ReLu)将特征数调整到一致。这些拼接的特征图再经过一个卷积+BN+ReLu块输出特征就OK。
自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)
这张图解释了另外两个特点,一个是全尺度深监督,另一个是分类指导模块(CGM)。
全尺度深监督是针对所有解码器每一层的输出计算损失函数。
为了防止噪声导致的假阳性分割,作者提出了分类指导模块。分类指导模块是添加在网络瓶颈层(编码器底层,En5)的模块,这一层网络最深,特征图数量最多,且特征图最小,可能过滤掉了一定的噪声。作者在这一层后面添加了一个小的分类头(Dropout + Conv1x1 + Pooling + Sigmoid),这个分类头输出一个概率,表示输入图像中有无目标器官,将这个分类结果和分割头相乘,可以消除假阳性。

结果比较,直接看图叭:
自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)

特点讲完了,说说槽点:
1 全尺度连接好是好,而且作者特地提到了,unet3+的参数是少于unet和unet++的,但实际上训练需要的时间和占用的内存好像都更多一些,似乎是因为unet3+用到了更多的卷积操作(比如,unet解码器每层只需要2次卷积,但看看上面的Fig.2,unet3+的每层解码器需要6次卷积)
2 还没想好
3 CGM只是一个简单的模块,在我自己的实验中,就算加了Dropout也很快就过拟合了,图像分割头的验证集损失还在降低,CGM这边的损失函数却已经不降反升了。

; 二、完整代码(keras)

注:小朋友不懂事,写代码是为了好玩,不一定是对的,如果有问题,欢迎指出和讨论,转载请注明出处。

[En]

Note: children are not sensible, the code is written for fun, it is not necessarily correct, if there are problems, welcome to point out and discuss, reprint please indicate the source.

CGM输出这块的实现还是有待商榷的,我的代码里CGM和分割掩膜是分别输出的,所以后面要手动相乘一下。

1.引入库

import tensorflow as tf
import numpy as np
from keras.models import Model
from keras.layers import Conv2D, Input, concatenate, MaxPooling2D, UpSampling2D, Activation, BatchNormalization, LayerNormalization, Dropout, GlobalMaxPooling2D

2.辅助函数


def normalization(input_tensor, normalization):

    if normalization=='batch':
        return(BatchNormalization()(input_tensor))
    elif normalization=='layer':
        return(LayerNormalization()(input_tensor))
    elif normalization == None:
        return input_tensor
    else:
        raise ValueError('Invalid normalization')

def conv2d_block(input_tensor, filters, kernel_size,
                norm_type, use_residual, act_type='relu',
                double_features = False, dilation=[1, 1]):

    x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[0], use_bias=False, kernel_initializer='he_normal')(input_tensor)
    x = normalization(x, norm_type)
    x = Activation(act_type)(x)

    if double_features:
        filters *= 2

    x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[1], use_bias=False, kernel_initializer='he_normal')(x)
    x = normalization(x, norm_type)

    if use_residual:
        if K.int_shape(input_tensor)[-1] != K.int_shape(x)[-1]:
            shortcut = Conv2D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(input_tensor)
            shortcut = normalization(shortcut, norm_type)
            x = add([x, shortcut])
        else:
            x = add([x, input_tensor])

    x = Activation(act_type)(x)

    return x

def down_layer_2d(input_tensor, down_pattern, filters, norm_type=None):
    if down_pattern == 'maxpooling':
        x = MaxPooling2D(pool_size=(2, 2))(input_tensor)
    elif down_pattern == 'avgpooling':
        x = AveragePooling2D(pool_size=(2, 2))(input_tensor)
    elif down_pattern == 'conv':
        x = Conv2D(filters, kernel_size=(2, 2), strides=(2, 2), padding='same', use_bias=False if norm_type is None else True, kernel_initializer='he_normal')(input_tensor)
        normalization(x, norm_type)
    elif down_pattern == 'normconv':
        x = normalization(input_tensor, norm_type)
        x = Conv2D(filters, kernel_size=(2, 2), strides=(2, 2), padding='same', kernel_initializer='he_normal')(x)
    else:
        raise ValueError('Invalid down_pattern')
    return x

def conv_norm_act(input_tensor, filters, kernel_size , norm_type='batch', act_type='relu', dilation=1):
    output_tensor = Conv2D(filters, kernel_size, padding='same', dilation_rate=(dilation, dilation), use_bias=False if norm_type is not None else True, kernel_initializer='he_normal')(input_tensor)
    output_tensor = normalization(output_tensor, normalization=norm_type)
    output_tensor = Activation(act_type)(output_tensor)
    return output_tensor

def aggregate(l1, l2, l3, l4, l5, filters, kernel_size, norm_type='batch', act_type='relu'):
    out = concatenate([l1, l2, l3, l4, l5], axis = -1)
    out = Conv2D(filters * 5, kernel_size, padding = 'same', use_bias=False if norm_type is not None else True, kernel_initializer = 'he_normal')(out)
    out = normalization(out, norm_type)
    out = Activation(act_type)(out)

    return out

def cgm_block(input_tensor, class_num, dropout_rate = 0.):
    x = Dropout(rate = dropout_rate)(input_tensor)
    x = Conv2D(class_num, 1, padding='same', kernel_initializer='he_normal')(x)

    x = GlobalMaxPooling2D()(x)
    x = Activation('sigmoid', name='cgm_output')(x)

    return x

3.搭建网络


def unet3p_2d(input_shape, initial_features=32, kernel_size=3,
              class_num=1, norm_type='batch', double_features=False,
              use_residual=False, down_pattern='maxpooling', using_deep_supervision=True,
              using_cgm=False, cgm_drop_rate=0.5, show_summary=True):
    '''
    input_shape: (height, width, channel)
    initial_features: int, 初始特征图数量,每次下采样特征图数量加倍, unet3+原文中用的是64
    kernel_size: int, 卷积核大小
    class_num: int, 图像分割的类别数
    norm_type: str, 标准化方式, 'batch' 或 'layer', unet3+使用的是BatchNormalization
    double_features: bool, 在conv2d_block模块中是否在第二个卷积中将特征图数量翻倍,3dunet论文中提出该方法可以避免瓶颈问题,通常可以设为False
    use_residual: bool, 编码器部分是否使用残差连接
    down_pattern: str, 下采样方式, 'maxpooling' 或 'avgpooling' 或 'conv' 或 'normconv', unet3+使用的是MaxPooling
    using_deep_supervision: bool, 是否使用全尺度深度监督
    using_cgm: bool, 是否使用分类指导模块(CGM)
    cgm_drop_rate: float, CGM模块中Dropout比率
    show_summary: bool, 是否显示模型概况
    '''

    if class_num == 1:
        last_layer_activation = 'sigmoid'
    else:
        last_layer_activation = 'softmax'

    inputs = Input(input_shape)

    xe1 = conv2d_block(input_tensor=inputs, filters=initial_features, kernel_size=kernel_size,
                    norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe1_pool = down_layer_2d(input_tensor=xe1, down_pattern=down_pattern, filters=initial_features)

    xe2 = conv2d_block(input_tensor=xe1_pool, filters=initial_features * 2, kernel_size=kernel_size,
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe2_pool = down_layer_2d(input_tensor=xe2, down_pattern=down_pattern, filters=initial_features * 2)

    xe3 = conv2d_block(input_tensor=xe2_pool, filters=initial_features * 4, kernel_size=kernel_size,
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe3_pool = down_layer_2d(input_tensor=xe3, down_pattern=down_pattern, filters=initial_features * 4)

    xe4 = conv2d_block(input_tensor=xe3_pool, filters=initial_features * 8, kernel_size=kernel_size,
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe4_pool = down_layer_2d(input_tensor=xe4, down_pattern=down_pattern, filters=initial_features * 8)

    xe5 = conv2d_block(input_tensor=xe4_pool, filters=initial_features * 16, kernel_size=kernel_size,
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)

    if using_cgm:
        cgm = cgm_block(input_tensor = xe5 , class_num = class_num ,dropout_rate = cgm_drop_rate)

    xd4_from_xe5 = UpSampling2D(size=(2,2), interpolation='bilinear')(xe5)
    xd4_from_xe5 = conv_norm_act(input_tensor=xd4_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe4 = conv_norm_act(input_tensor=xe4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe3 = MaxPooling2D(pool_size = (2, 2))(xe3)
    xd4_from_xe3 = conv_norm_act(input_tensor=xd4_from_xe3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe2 = MaxPooling2D(pool_size = (4, 4))(xe2)
    xd4_from_xe2 = conv_norm_act(input_tensor=xd4_from_xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe1 = MaxPooling2D(pool_size = (8, 8))(xe1)
    xd4_from_xe1 = conv_norm_act(input_tensor=xd4_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4 = aggregate(xd4_from_xe5, xd4_from_xe4, xd4_from_xe3, xd4_from_xe2, xd4_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    xd3_from_xe5 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xe5)
    xd3_from_xe5 = conv_norm_act(input_tensor=xd3_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xd4 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd4)
    xd3_from_xd4 = conv_norm_act(input_tensor=xd3_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xe3 = conv_norm_act(input_tensor=xe3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xe2 = MaxPooling2D(pool_size = (2, 2))(xe2)
    xd3_from_xe2 = conv_norm_act(input_tensor=xd3_from_xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xe1 = MaxPooling2D(pool_size = (4, 4))(xe1)
    xd3_from_xe1 = conv_norm_act(input_tensor=xd3_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3 = aggregate(xd3_from_xe5, xd3_from_xd4, xd3_from_xe3, xd3_from_xe2, xd3_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    xd2_from_xe5 = UpSampling2D(size=(8, 8), interpolation='bilinear')(xe5)
    xd2_from_xe5 = conv_norm_act(input_tensor=xd2_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xd4 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xd4)
    xd2_from_xd4 = conv_norm_act(input_tensor=xd2_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xd3 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd3)
    xd2_from_xd3 = conv_norm_act(input_tensor=xd2_from_xd3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xe2 = conv_norm_act(input_tensor=xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xe1 = MaxPooling2D(pool_size = (2, 2))(xe1)
    xd2_from_xe1 = conv_norm_act(input_tensor=xd2_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2 = aggregate(xd2_from_xe5, xd2_from_xd4, xd2_from_xd3, xd2_from_xe2, xd2_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    xd1_from_xe5 = UpSampling2D(size=(16, 16), interpolation='bilinear')(xe5)
    xd1_from_xe5 = conv_norm_act(input_tensor=xd1_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xd4 = UpSampling2D(size=(8, 8), interpolation='bilinear')(xd4)
    xd1_from_xd4 = conv_norm_act(input_tensor=xd1_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xd3 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xd3)
    xd1_from_xd3 = conv_norm_act(input_tensor=xd1_from_xd3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xd2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd2)
    xd1_from_xd2 = conv_norm_act(input_tensor=xd1_from_xd2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xe1 = conv_norm_act(input_tensor=xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1 = aggregate(xd1_from_xe5, xd1_from_xd4, xd1_from_xd3, xd1_from_xd2, xd1_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    if using_deep_supervision:
        xd55 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xe5)
        xd55 = UpSampling2D(size=(16, 16))(xd55)
        xd55 = Activation(last_layer_activation, name='output_de5')(xd55)

        xd44 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd4)
        xd44 = UpSampling2D(size=(8, 8))(xd44)
        xd44 = Activation(last_layer_activation, name='output_de4')(xd44)

        xd33 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd3)
        xd33 = UpSampling2D(size=(4, 4))(xd33)
        xd33 = Activation(last_layer_activation, name='output_de3')(xd33)

        xd22 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd2)
        xd22 = UpSampling2D(size=(2, 2))(xd22)
        xd22 = Activation(last_layer_activation, name='output_de2')(xd22)

        xd11 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd1)
        xd11 = Activation(last_layer_activation, name='output_de1')(xd11)

        if using_cgm: outputs=[xd11, xd22, xd33, xd44, xd55, cgm]
        else: outputs=[xd11, xd22, xd33, xd44, xd55]

    else:
        conv_output = Conv2D(class_num, 1, activation=last_layer_activation, name='output')(xd1)
        if using_cgm: outputs=[conv_output, cgm]
        else: outputs = conv_output

    model = Model(inputs, outputs)
    if show_summary: model.summary()

    return model

4.创建模型

如果以上代码都在同一个.py文件下,可以加上以下代码尝试构建网络:

if __name__ == '__main__':
    model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
                      class_num=1, norm_type='batch', double_features=False,
                      use_residual=False, down_pattern='maxpooling',
                      using_deep_supervision=True, using_cgm=False, show_summary=True)
    model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
                      class_num=1, norm_type='batch', double_features=False,
                      use_residual=False, down_pattern='maxpooling',
                      using_deep_supervision=True, using_cgm=True, show_summary=True)
    model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
                      class_num=1, norm_type='batch', double_features=False,
                      use_residual=False, down_pattern='maxpooling',
                      using_deep_supervision=False, using_cgm=False, show_summary=True)

如果用到了预训练的主干网络,需要修改下编码器(En)部分。

感觉自己好菜,不知道能不能顺利be yeah,哎TAT

Original: https://blog.csdn.net/weixin_42723174/article/details/125306304
Author: 求你涨点吧
Title: 自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)

相关文章
STM32之DAC音频播放 人工智能

STM32之DAC音频播放

本文内容:本文主要介绍如何用AU获取wav格式的正弦波以及截取到的音频,并通过stm32f103c8t6的DAC模块,转换为模拟音频,并用示波器观察波形。 一、获取正弦信号 点击左上角的文件,新建,音...
Praat语音标注说明 人工智能

Praat语音标注说明

经常使用这个工具,隔太久就会忘记,这里总结和记录一下。 1 下载 Praat官网 praat github praat6121_win64.zip 字体文件DoulosSIL-5.000.exe 解压...
激光SLAM框架总结 人工智能

激光SLAM框架总结

一、激光SLAM简介 基于激光雷达的同时定位与地图构建技术(simultaneous localization and mapping, SLAM)以其准确测量障碍点的角度与距离、 无须预先布置场景、...
yolov5模型配置yaml文件详解 人工智能

yolov5模型配置yaml文件详解

yolov5的代码模型构建是通过.yaml文件实现的,初次看上去会一头雾水,这里记录一下,也方便自己后面用到的时候查看。 以models/yolov5s.yaml为例 文件内容如下: nc: 5 de...
EMNLP-2021_no_fin 人工智能

EMNLP-2021_no_fin

怎么选特征? 选什么特征? A Partition Filter Network for Joint Entity and Relation Extraction 就是,实体识别一个表,关系识别一个表...
day5 人工智能

day5

! posted @ 2022-04-15 19:24 Novice!!! 阅读( 4 ) 评论( ) 编辑 Original: https://www.cnblogs.com/brain-keep-...
tensorflow运行模式 人工智能

tensorflow运行模式

一个c++或者java,python程序员,一开始接触到tensorflow,必然会被其运行方式所迷惑。 因为tensorflow是先构图,再执行。c=tf.add(a,b),在执行这行代码时,只是在...