即插即用hilo注意力机制,捕获低频高频特征

题目:Fast Vision Transformers with HiLo Attention

论文地址:  https://arxiv.org/abs/2205.13213

创新点

  • HiLo自注意力机制:作者提出了一种新的自注意力机制,称为HiLo注意力,旨在同时捕捉图像中的高频和低频信息。该方法通过将自注意力分为两个分支,高频分支(Hi-Fi)处理局部的高分辨率细节,低频分支(Lo-Fi)处理全局的低分辨率结构。这样可以提高计算效率,特别是在高分辨率图像上,同时保持准确性。

  • LITv2模型:基于HiLo注意力机制,文献引入了LITv2模型,该模型在多个主流计算机视觉任务(如图像分类、物体检测和语义分割)上表现优越。LITv2通过在早期阶段删除多头自注意力(MSA)层,并在后期阶段使用高效的HiLo注意力机制,提升了模型的速度和内存效率。

  • 速度优化:作者通过实际平台上的速度评估(而非通常的FLOPs计算)设计了该模型,以确保其在GPU和CPU上的实际速度更快。例如,HiLo机制在CPU上比局部窗口注意力机制快1.6倍,比空间缩减注意力机制快1.4倍。

  • 相对位置编码优化:文献还对相对位置编码进行了优化,采用了3×3的深度卷积层代替传统的固定相对位置编码,这大大加快了密集预测任务(如分割)的训练和推理速度。

方法

整体结构

       LITv2模型基于HiLo注意力机制,分离处理高频和低频信息,通过局部窗口自注意力捕捉细节、高效全局注意力处理全局结构。此外,模型采用3×3深度卷积层替代位置编码,减少计算复杂度并扩大感受野。整体架构分为多阶段,生成金字塔特征图,适用于密集预测任务,结合残差连接和全局自注意力确保性能与效率的平衡。

  • Patch Embedding层:模型首先将输入图像切分为固定大小的图像块(patch),然后通过线性变换将每个patch映射到一个高维特征空间,这与大多数Vision Transformer类似。

  • HiLo注意力机制:这是模型的核心创新点。HiLo注意力机制将多头自注意力(MSA)分成两个部分:

  • 高频(Hi-Fi)注意力:处理局部的高频细节信息,使用的是局部窗口自注意力(例如2×2窗口),能够高效捕获图像中的细节信息。

  • 低频(Lo-Fi)注意力:处理全局的低频信息,先通过平均池化获得低频特征,再进行全局自注意力计算,从而减少计算复杂度。

  • 深度卷积层(Depthwise Convolution Layer):为了进一步提高效率,LITv2引入了3×3的深度卷积层用于代替传统的多层感知机(MLP)中的位置编码。这种设计不仅减少了位置编码的计算负担,还扩大了早期阶段特征的感受野。

  • 多阶段结构:模型通常分为多个阶段(例如4个阶段),在每个阶段生成金字塔结构的特征图(pyramid feature maps),用于处理不同分辨率的特征。这使得模型在图像分类之外的密集预测任务(如物体检测和语义分割)中更具优势。

  • 残差连接和归一化:在每个Transformer模块中,模型使用标准的残差连接和LayerNorm层。这些是标准的ViT组件,用于稳定训练并保持特征的传递。

  • 后期的全局自注意力:在模型的后期阶段,虽然早期阶段使用了高效的局部自注意力和低频注意力机制,但后期阶段会使用标准的多头自注意力机制来处理下采样后的低分辨率特征图,以进一步提升性能。

即插即用模块

将HiLo注意力机制提取为即插即用模块,主要适用于以下场景:

  • 高分辨率图像处理:在需要处理高分辨率图像的任务中,例如图像分类、目标检测、语义分割等,HiLo通过高效分离高频和低频信息,显著减少计算复杂度和内存占用,提升推理速度和处理能力。

  • 低延迟应用场景:HiLo能够在实际硬件平台(如GPU和CPU)上加快推理速度,特别适用于需要低延迟的场景,例如无人机图像处理、自动驾驶中的实时感知系统等。

  • 视觉任务中的密集预测:在需要对每个像素进行精细预测的任务中,如语义分割和实例分割,HiLo能够高效处理局部细节和全局结构,提升预测的准确性和速度。

消融实验

  • 该表展示了LITv1-S模型在引入不同结构修改后的性能变化,包括加入3×3深度卷积层(ConvFFN)、去除相对位置编码(RPE)、以及使用HiLo注意力机制后的影响。

  • 结果表明:引入深度卷积层后,模型在ImageNet分类和COCO检测任务中的性能提升显著,移除RPE后虽然有轻微的性能下降,但推理速度(FPS)显著提升,使用HiLo注意力机制后进一步提升了模型效率,特别是在FLOPs和推理速度上。

  • 该图展示了HiLo注意力机制中高频和低频头部分配比例(α)的影响。随着α值的增加(更多头部用于低频注意力),FLOPs逐渐减少,模型的Top-1准确率在α=0.9时达到最佳。

  • 该实验表明高频和低频信息在自注意力中的合理分配对模型效率和性能有重要影响。

 

  • 该图通过Fast Fourier Transform(FFT)可视化了Hi-Fi和Lo-Fi注意力机制输出特征中的频率成分。结果显示,Hi-Fi注意力捕捉更多的高频信息,而Lo-Fi主要关注低频信息。

  • 该实验验证了HiLo注意力机制在分离高频和低频特征时的有效性,符合文献提出的设计理念。

即插即用模块HiLo

import math
import torch
import torch.nn as nn
# 论文:Fast Vision Transformers with HiLo Attention
# 论文地址:https://arxiv.org/abs/2205.13213
class HiLo(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=2, alpha=0.5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        head_dim = int(dim/num_heads)
        self.dim = dim

        # self-attention heads in Lo-Fi
        self.l_heads = int(num_heads * alpha)
        # token dimension in Lo-Fi
        self.l_dim = self.l_heads * head_dim

        # self-attention heads in Hi-Fi
        self.h_heads = num_heads - self.l_heads
        # token dimension in Hi-Fi
        self.h_dim = self.h_heads * head_dim

        # local window size. The `s` in our paper.
        self.ws = window_size

        if self.ws == 1:
            # ws == 1 is equal to a standard multi-head self-attention
            self.h_heads = 0
            self.h_dim = 0
            self.l_heads = num_heads
            self.l_dim = dim

        self.scale = qk_scale or head_dim ** -0.5

        # Low frequence attention (Lo-Fi)
        if self.l_heads > 0:
            if self.ws != 1:
                self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
            self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)
            self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)
            self.l_proj = nn.Linear(self.l_dim, self.l_dim)

        # High frequence attention (Hi-Fi)
        if self.h_heads > 0:
            self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)
            self.h_proj = nn.Linear(self.h_dim, self.h_dim)

    def hifi(self, x):
        B, H, W, C = x.shape
        h_group, w_group = H // self.ws, W // self.ws

        total_groups = h_group * w_group

        x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)

        qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1, 4, 2, 5)
        q, k, v = qkv[0], qkv[1], qkv[2] # B, hw, n_head, ws*ws, head_dim

        attn = (q @ k.transpose(-2, -1)) * self.scale # B, hw, n_head, ws*ws, ws*ws
        attn = attn.softmax(dim=-1)
        attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
        x = attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim)

        x = self.h_proj(x)
        return x

    def lofi(self, x):
        B, H, W, C = x.shape

        q = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)

        if self.ws > 1:
            x_ = x.permute(0, 3, 1, 2)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
        x = self.l_proj(x)
        return x

    def forward(self, x, H, W):
        B, N, C = x.shape

        x = x.reshape(B, H, W, C)

        if self.h_heads == 0:
            x = self.lofi(x)
            return x.reshape(B, N, C)

        if self.l_heads == 0:
            x = self.hifi(x)
            return x.reshape(B, N, C)

        hifi_out = self.hifi(x)
        lofi_out = self.lofi(x)

        x = torch.cat((hifi_out, lofi_out), dim=-1)
        x = x.reshape(B, N, C)

        return x

    def flops(self, H, W):
        # pad the feature map when the height and width cannot be divided by window size
        Hp = self.ws * math.ceil(H / self.ws)
        Wp = self.ws * math.ceil(W / self.ws)

        Np = Hp * Wp

        # For Hi-Fi
        # qkv
        hifi_flops = Np * self.dim * self.h_dim * 3
        nW = (Hp // self.ws) * (Wp // self.ws)
        window_len = self.ws * self.ws
        # q @ k and attn @ v
        window_flops = window_len * window_len * self.h_dim * 2
        hifi_flops += nW * window_flops
        # projection
        hifi_flops += Np * self.h_dim * self.h_dim

        # for Lo-Fi
        # q
        lofi_flops = Np * self.dim * self.l_dim
        kv_len = (Hp // self.ws) * (Wp // self.ws)
        # k, v
        lofi_flops += kv_len * self.dim * self.l_dim * 2
        # q @ k and attn @ v
        lofi_flops += Np * self.l_dim * kv_len * 2
        # projection
        lofi_flops += Np * self.l_dim * self.l_dim

        return hifi_flops + lofi_flops

if __name__ == '__main__':
    block = HiLo(dim=128)
    input = torch.rand(32, 128, 128) # input with shape (B, N, C)
    output = block(input, 16, 8) # H = 16, W = 8, since H * W should equal N
    print(input.size())
    print(output.size())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/890676.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

通信工程学习:什么是SPI串行外设接口

SPI:串行外设接口 SPI,即串行外设接口(Serial Peripheral Interface),是一种由Motorola公司首先在其MC68HCXX系列处理器上定义的同步串行接口技术。SPI接口主要用于微控制器(MCU)与外部设备之间…

1. 到底什么是架构

1. 什么是架构 定义:架构,又名软件架构,是有关软件整体结构与组件的抽象描述,用于指导大型软件系统各个方面的设计优秀架构的特点:优秀的性能、超强的TPS/QPS的承载能力、高可用决定了你能够支撑多少PV的流量 2. 什么…

【Linux修炼进程之权限篇】探讨Linux权限问题

【Linux修炼】——权限问题 目录 一:认识Linux下用户的分类 1.1:如何添加新用户【使用root用户创建添加】 1.2:su指令用法 二:Linux下权限是什么? 2.1:权限所认证的是身份(人身份角色) 2.2&#xff…

【WPF】04 Http消息处理类

这里引入微软官方提供的HttpClient类来实现我们的目的。 首先,介绍一下官方HttpClient类的内容。 HttpClient 类 定义 命名空间: System.Net.Http 程序集: System.Net.Http.dll Source: HttpClient.cs 提供一个类,用于从 URI 标识的资源发送 HTTP 请…

dbt doc 生成文档命令示例应用

DBT提供了强大的命令行工具,它使数据分析师和工程师能够更有效地转换仓库中的数据。dbt的一个关键特性是能够为数据模型生成文档,这就是dbt docs命令发挥作用的地方。本教程将指导您完成使用dbt生成和提供项目文档的过程。 dbt doc 命令 dbt docs命令有…

Gitxray:一款基于GitHub REST API的网络安全工具

关于Gitxray Gitxray是一款基于GitHub REST API的网络安全工具,支持利用公共 GitHub REST API 进行OSINT、信息安全取证和安全检测等任务。 Gitxray(Git X-Ray 的缩写)是一款多功能安全工具,专为 GitHub 存储库而设计。它可以用于…

STM32CUBEIDE的使用【三】RTC

于正点原子潘多拉开发板&#xff0c;使用stm32官方免费软件进行开发 CubeMx 配置 使用CubeMx 配置RTC 勾选RTC 设置日期和时间 配置LCD的引脚用来显示 STM32CUBEIDE 在usbd_cdc_if.c中重定向printf函数用于打印 #include <stdarg.h>void usb_printf(const char *f…

第十六章 RabbitMQ延迟消息之延迟插件优化

目录 一、引言 二、优化方案 三、核心代码实现 3.1. 生产者代码 3.2. 消息处理器 3.3. 自定义多延迟消息封装类 3.4. 订单实体类 3.5. 消费者代码 四、运行效果 一、引言 上一章节我们提到&#xff0c;直接使用延迟插件&#xff0c;创建一个延迟指定时间的消息&…

【C++算法】双指针

目录 一、快乐数&#xff1a; 二、有效三角形的个数&#xff1a; 三、盛最多水的容器&#xff1a; 四、复写0&#xff1a; 五、三数之和&#xff1a; 总结&#xff1a; 一、快乐数&#xff1a; 题目出处&#xff1a; 202. 快乐数 - 力扣&#xff08;LeetCode&#xff09…

ROS2 通信三大件之动作 -- Action

通信最后一个&#xff0c;也是不太容易理解的方式action&#xff0c;复杂且重要 1、创建action数据结构 创建工作空间和模块就不多说了 在模块 src/action_moudle/action/Counter.action 下创建文件 Counter.action int32 target # Goal: 目标 --- int32 current_value…

智能健康顾问:基于SpringBoot的系统

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

Qt:图片文字转base64程序

目录 一.Base64 1.编码原理 2.应用场景 3.优点 4.限制 5.变种 二.文字与Base64互转 1.ui设计 2.文字转Base64 3.Base64转文字 三.图片与Base64互转 1.ui设计 2.选择图片与图片路径 3.图片转Base64 4.Base64转图片 四.清空设置 五.效果 六.代码 base64conver…

PDF编辑不求人!4款高效工具,内容修改从此变得简单又快捷

咱们现在生活在一个数字时代&#xff0c;PDF文件可不就是工作、学习还有日常生活中经常要用的东西嘛。但遇到那些需要改动的PDF文件&#xff0c;是不是就觉得有点头疼啊&#xff1f; 因为传统的PDF文件真的不好编辑&#xff0c;这确实挺烦人的。不过呢&#xff0c;我今天要给你…

【北京迅为】《STM32MP157开发板嵌入式开发指南》- 第三十九章 Linux Misc驱动

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器&#xff0c;既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构&#xff0c;主频650M、1G内存、8G存储&#xff0c;核心板采用工业级板对板连接器&#xff0c;高可靠&#xff0c;牢固耐…

SpringBoot下的智能健康推荐引擎

3系统分析 3.1可行性分析 通过对本基于智能推荐的卫生健康系统实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本基于智能推荐的卫生健康系统采用SSM框架&#…

24秋面试笔记

文章目录 一、专业技能1.1 具备扎实的Java基础&#xff0c;熟练掌握面向对象编码规范、集合、反射以及Java8特性等。1.1.1 Java基础1.1.2 集合1.1.3 Java8新特性 1.2 熟悉常用的数据结构(链表、栈、队列、二叉树等)&#xff0c;熟练使用排序、动态规划、DPS等算法。1.2.1 数据结…

CountUp.js 实现数字增长动画 Vue

效果&#xff1a; 官网介绍 1. 安装 npm install --save countup.js2. 基本使用 // template <span ref"number1Ref"></span>// script const number1Ref ref<HTMLElement>() onMounted(() > {new CountUp(number1Ref.value!, 9999999).sta…

Centos7 搭建单机elasticsearch

以下是在 CentOS 7 上安装 Elasticsearch 7.17.7 的完整步骤&#xff1a;&#xff08;数据默认保存在/var/lib/elasticsearch下&#xff0c;自行更改&#xff09; 一、装 Java 环境 Elasticsearch 是用 Java 编写的&#xff0c;所以需要先安装 Java 运行环境。 检查系统中是…

弘景光电:以创新为翼,翱翔光学科技新蓝海

在科技日新月异的今天&#xff0c;光学镜头及模组作为智能设备的核心组件&#xff0c;其重要性日益凸显。广东弘景光电科技股份有限公司&#xff08;以下简称“弘景光电”&#xff09;正是在这一领域中&#xff0c;凭借其卓越的研发实力和市场洞察力&#xff0c;即将在创业板上…

001 Qt_从零开始创建项目

文章目录 前言什么是QtQt的优点Qt的应用场景创建项目小结 前言 本文是Qt专栏的第一篇文章&#xff0c;该文将会向你介绍如何创建一个Qt项目 什么是Qt Qt 是⼀个 跨平台的 C 图形⽤⼾界⾯应⽤程序框架 。它为应⽤程序开发者提供了建⽴艺术级图形界⾯所需的所有功能。它是完全…