多模态动态融合模型Predictive Dynamic Fusion论文阅读与代码分析运行1-信度概念与基础参数指标

多模态动态融合模型Predictive Dynamic Fusion论文阅读与代码分析运行1-信度概念与基础参数指标

参考文:Cao B, Xia Y, Ding Y, et al. Predictive Dynamic Fusion[J]. arXiv preprint arXiv:2406.04802, 2024.[2406.04802] Predictive Dynamic Fusion

一、理论

今天就先看看论文中的各个指标含义和多模态训练代码的参数吧

文章中一个比较重要的概念就是置信度的概念了,在论文前段,对置信度的扩展比较多同时没有什么具体说明,不知道概念的话读着还是很混乱的;

置信度

在机器学习中,置信度表示模型对其预测结果“有多确定”。
它刻画的是:模型认为自己预测是正确的程度

例如,在分类任务中:“这是正类的概率是 0.92”,那么 0.92 就可以视为模型对该预测的置信度

在监督学习中,给定输入样本 xxx,模型预测类别为 y^\hat{y}y^​,则置信度通常定义为:

即:模型对预测类别的后验概率估计

置信度 和 不确定性(补充)

文中用来衡量整体不确定性,算是置信度的一种扩展:

关于熵的概念,之前在b站看到的一位up主讲的很生动:https://www.bilibili.com/video/BV15V411W7VB/

置信度高 <=> 熵低

分类评价指标对比

指标含义对照

指标一句话解释
Accuracy模型整体准不准
Precision模型说“是”的时候靠谱吗
Recall真正“是”的有没有被找全
F1Precision 和 Recall 的折中
ROC-AUC正样本排在负样本前面的能力

The Mono-Confidences and Holo-Confidences

该文的目的之一是为了解决模态权重融合的权重问题;也就是,多个模态分别从多个维度评价目标的状态,给出不一样的结果,怎么融合这几个结果的问题。

目前可以确定的是:融合权重 ω 应当与损失 l 呈负相关,并且与其他模态的损失呈正相关。也就是:当前模态越可靠 → 权重越大;其他模态越不可靠 → 当前模态权重越大

对单个模态的模型,权重 ω 是要求的权重,损失loss是:

所以,就有人两个信度指标:

The Mono-ConfidencesHolo-Confidences
当前模态本身有多可靠相对其他模态我有多可靠

将他们统合:

Co-Belief(协同信度)

Mono-Confidence:只看自己;Holo-Confidence:只看别人;但多模态融合需要:既考虑自身可靠性,又考虑整体模态状态。

故有:

再由协同信度确定该模态的权重。

理论先到这里,其他的后面再看;

二、代码

1、运行环境

代码训练环境没有明确说明,但根据结构可以看得出来用的是autodl里的云服务器,Ubuntu20.04+python3.11的版本,卡随便租一个都一样。

论文附带代码只有2mb,明显缺失了很多预训练结构与数据集文件;

2、数据集文件

这里选用了代码中可选的第二个训练集MVSA_Single,需要自己到网站下好转到autodl服务器上:MVSA_Single

训练集之类的划分源代码已有了,自己按要求放到同一目录下即可。

3、词向量文件

源代码缺失了预训练好的词向量文件glove.840B.300d,需要自己使用指令下载到指定目录

wget https://nlp.stanford.edu/data/glove.840B.300d.zip

4、源代码逻辑错误

训练代码中的forward函数存在运行逻辑错误,文本和图像的loss(txt_clf_loss和img_clf_loss)定义在了if之外,会运行不成功;估计是作者没有仔细整理,代码算法逻辑倒没什么问题;

原代码150行左右:

def model_forward(i_epoch, model, args, criterion,optimizer, batch,mode='eval'): txt, segment, mask, img, tgt,idx = batch freeze_img = i_epoch < args.freeze_img freeze_txt = i_epoch < args.freeze_txt if args.model == "bow": txt = txt.cuda() out = model(txt) elif args.model == "img": img = img.cuda() out = model(img) elif args.model == "concatbow": txt, img = txt.cuda(), img.cuda() out = model(txt, img) elif args.model == "bert": txt, mask, segment = txt.cuda(), mask.cuda(), segment.cuda() out = model(txt, mask, segment) elif args.model == "concatbert": txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out = model(txt, mask, segment, img) elif args.model == "latefusion_pdf": txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() tgt = tgt.cuda() maeloss = nn.L1Loss(reduction='mean') out, txt_logits, img_logits, txt_tcp_pred, img_tcp_pred = model(txt, mask,segment,img,'pdf_train') label = F.one_hot(tgt, num_classes=args.n_classes) # [b,c] if args.task_type == "multilabel": txt_pred = torch.sigmoid(txt_logits) img_pred = torch.sigmoid(img_logits) else: txt_pred = torch.nn.functional.softmax(txt_logits, dim=1) img_pred = torch.nn.functional.softmax(img_logits, dim=1) txt_tcp, _ = torch.max(txt_pred * label, dim=1,keepdim=True) img_tcp, _ = torch.max(img_pred * label, dim=1,keepdim=True) tcp_pred_loss = maeloss(txt_tcp_pred, txt_tcp.detach()) + maeloss(img_tcp_pred, img_tcp.detach()) else: assert args.model == "mmbt" for param in model.enc.img_encoder.parameters(): param.requires_grad = not freeze_img for param in model.enc.encoder.parameters(): param.requires_grad = not freeze_txt txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out = model(txt, mask, segment, img) tgt = tgt.cuda() txt_clf_loss = nn.CrossEntropyLoss()(txt_logits, tgt) img_clf_loss = nn.CrossEntropyLoss()(img_logits, tgt) clf_loss=txt_clf_loss+img_clf_loss+nn.CrossEntropyLoss()(out,tgt) if mode=='train': loss = torch.mean(clf_loss)+torch.mean(tcp_pred_loss) return loss,out,tgt else: loss= torch.mean(clf_loss)+torch.mean(tcp_pred_loss) return loss,out,tgt

修改后:

def model_forward(i_epoch, model, args, criterion, optimizer, batch, mode='eval'): txt, segment, mask, img, tgt, idx = batch tgt = tgt.cuda() clf_loss = 0.0 tcp_pred_loss = 0.0 # ⭐ 先初始化,避免炸 # ---------- 普通单 / 早期融合模型 ---------- if args.model == "bow": txt = txt.cuda() out = model(txt) clf_loss = criterion(out, tgt) elif args.model == "img": img = img.cuda() out = model(img) clf_loss = criterion(out, tgt) elif args.model == "concatbow": txt, img = txt.cuda(), img.cuda() out = model(txt, img) clf_loss = criterion(out, tgt) elif args.model == "bert": txt, mask, segment = txt.cuda(), mask.cuda(), segment.cuda() out = model(txt, mask, segment) clf_loss = criterion(out, tgt) elif args.model == "concatbert": txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out = model(txt, mask, segment, img) clf_loss = criterion(out, tgt) # ---------- late fusion(特例) ---------- elif args.model == "latefusion_pdf": txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out, txt_logits, img_logits, txt_tcp_pred, img_tcp_pred = \ model(txt, mask, segment, img, 'pdf_train') # 分类 loss txt_loss = criterion(txt_logits, tgt) img_loss = criterion(img_logits, tgt) clf_loss = txt_loss + img_loss # TCP loss maeloss = nn.L1Loss(reduction='mean') label = F.one_hot(tgt, num_classes=args.n_classes) if args.task_type == "multilabel": txt_pred = torch.sigmoid(txt_logits) img_pred = torch.sigmoid(img_logits) else: txt_pred = F.softmax(txt_logits, dim=1) img_pred = F.softmax(img_logits, dim=1) txt_tcp, _ = torch.max(txt_pred * label, dim=1, keepdim=True) img_tcp, _ = torch.max(img_pred * label, dim=1, keepdim=True) tcp_pred_loss = ( maeloss(txt_tcp_pred, txt_tcp.detach()) + maeloss(img_tcp_pred, img_tcp.detach()) ) # ---------- mmbt ---------- else: assert args.model == "mmbt" txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out = model(txt, mask, segment, img) clf_loss = criterion(out, tgt) # ---------- 总 loss ---------- loss = clf_loss + tcp_pred_loss return loss, out, tgt

四、各训练参数

主要是get_args里面的参数解释:

训练与优化相关参数

参数名默认值含义说明影响阶段备注 / 建议
batch_sz128每个 batch 的样本数量训练大 batch 更稳定,但占显存
gradient_accumulation_steps24梯度累积步数训练等效 batch = batch_sz × steps
lr1e-4初始学习率训练BERT 微调常用 1e-5~5e-5
weight_decay0.0权重衰减系数(L2 正则)训练防止过拟合
dropout0.1Dropout 概率模型Transformer 常用 0.1
max_epochs100最大训练轮数训练搭配 early stopping
patience10Early stopping 容忍轮数训练验证集无提升时停止
warmup0.1学习率 warmup 比例训练防止初期梯度震荡
lr_factor0.5学习率衰减倍率训练ReduceLROnPlateau
lr_patience2学习率衰减等待轮数训练验证集不提升则降 lr
seed123随机种子全局保证实验可复现
n_workers12DataLoader 线程数数据加载与 CPU 核数相关

文本模态:

参数名默认值含义说明影响阶段备注
bert_model./bert-base-uncasedBERT 预训练模型路径模型可换成 large
freeze_txt0是否冻结文本编码器训练1 表示不更新 BERT
max_seq_len512文本最大 token 长度数据BERT 上限
embed_sz300词向量维度模型对应 GloVe
glove_pathglove.840B.300d.txtGloVe 文件路径数据300 维
hidden_sz768文本隐藏层维度模型BERT-base 默认

图像模态(Image)相关参数

参数名默认值含义说明影响阶段备注
img_hidden_sz2048图像特征维度模型ResNet 输出
num_image_embeds1图像 token 数模型MMBT 中常见
img_embed_pool_typeavg图像特征池化方式模型avg / max
freeze_img0是否冻结图像编码器训练1 表示冻结
drop_img_percent0.0随机丢弃图像比例数据增强模态缺失模拟

融合参数:

参数名默认值含义说明影响阶段备注
modellatefusion_pdf使用的模型结构模型PDF = Predictive Dynamic Fusion
hidden[]额外隐藏层结构模型如 [512,256]
include_bnTrue是否使用 BatchNorm模型提高训练稳定性
dfTrue是否启用动态融合模型PDF 核心开关
baselineNone对比方法名称实验仅用于记录

任务与数据相关参数:

参数名默认值含义说明影响阶段备注
taskMVSA_Single使用的数据集数据多模态情绪识别
task_typeclassification任务类型训练单标签 / 多标签
weight_classes1是否类别加权loss类别不平衡时用
noise0.0标签噪声比例数据鲁棒性实验
data_path/path/to/data_dir/数据集路径数据必须配置
savedir/path/to/save_dir/模型保存路径输出checkpoint

其中,很多任务数据相关参数都需要调整

Read more

【Java】反射详解

【Java】反射详解

Java 反射详解 运行时动态获取类信息、创建对象、调用方法完整教程 目录 * 一、反射概述 * 二、获取Class对象 * 三、构造方法反射 * 四、字段反射 * 五、方法反射 * 六、实战案例 一、反射概述 1.1 什么是反射 反射(Reflection)是Java提供的一种机制,允许程序在运行时检查和操作类的结构(类、方法、字段等)。 反射的核心功能: * 运行时获取类的信息 * 动态创建对象 * 动态调用方法 * 动态访问和修改字段 1.2 反射的应用场景 * 框架开发:Spring、Hibernate等框架大量使用反射 * 动态代理:AOP面向切面编程 * 注解处理:运行时处理注解 * 插件系统:动态加载类 * 序列化/反序列化:JSON、

By Ne0inhk
Arthas 快速上手与实战指南:Java 线上诊断利器全解析

Arthas 快速上手与实战指南:Java 线上诊断利器全解析

在日常 Java 开发中,我们经常会遇到线上问题难以重现、日志信息不足、重启成本高昂等棘手情况。为了解决这些问题,阿里巴巴开源了强大的诊断工具——Arthas。它可以在 不重启、不改代码、不侵入业务 的前提下,对运行中的 Java 应用进行实时排查与分析。本文将通过实战演示,带你从零上手 Arthas,并掌握其核心功能,提升线上故障的定位效率。 文章目录 * 1、Arthas 的介绍 * 2、Arthas 的安装和运行 * 2.1、下载 * 2.2、运行 * 3、Arthas 的应用案例 * 3.1、新建演示项目 * 3.2、启动演示项目 * 3.3、Arthas attach 进程 * 3.

By Ne0inhk
基于飞算JavaAI的学生成绩综合统计分析系统

基于飞算JavaAI的学生成绩综合统计分析系统

第一章:项目概述与背景 1.1 项目背景与意义 在教育信息化飞速发展的今天,学生成绩管理已成为学校教学管理的核心环节。传统的学生成绩管理多依赖于手工操作或基础的信息管理系统,存在数据处理效率低、统计分析功能薄弱、数据可视化缺失等问题。随着大数据技术的发展,教育领域对数据驱动的决策支持需求日益增长,一个能够提供综合统计分析功能的学生成绩管理系统显得尤为重要。 学生成绩综合统计分析系统旨在通过对学生成绩数据的深度挖掘和多维度分析,为教师、学生和管理者提供全面的数据支持。系统不仅能够实现基础的成绩录入和查询,更重要的是能够识别学习趋势、发现教学问题、预测学业表现,从而为个性化教学和精准教育干预提供科学依据。 1.2 飞算JavaAI平台介绍 飞算JavaAI是一款智能代码生成平台,采用人工智能技术辅助Java项目开发。 飞算JavaAI的核心功能模块,紧密围绕“高效、智能、安全”的Java开发全流程展开:左侧聚焦智能交互,包含三大实用工具——编程智能体可自动调用工具执行编程任务(如自动生成基础代码、辅助调试),智能问答提供实时技术答疑(快速解决开发中的疑难问题),Java Cha

By Ne0inhk

Spring Boot 版本怎么选?2/3/4 深度对比 + 迁移避坑指南(含 Java 8→21 适配要点)

Spring Boot 版本怎么选?2/3/4 深度对比 + 迁移避坑指南(含 Java 8→21 适配要点) 大家好,我是重阳。今天是2026年1月22日,Spring Boot 已经进入4.0时代。作为 Java 生态的核心框架,Spring Boot 的版本选择直接影响项目的稳定性、性能和维护成本。Spring Boot 2.x(2018年发布)是许多老项目的基石,3.x(2022年)带来了 Jakarta EE 和 Native 支持,而4.0(2025年11月GA)则聚焦模块化、Java 25优化和云原生增强。根据 Spring

By Ne0inhk