PromptKD是一个简单有效的基于prompt的视觉语言模型蒸馏新方法,在prompt learning的11个benchmark数据集上大幅领先,达到了SOTA。
已经很了解VLMs和prompt learning的同学可以直接跳过,到背景问题~
这里的介绍目的是让没有相关基础和背景的同学也可以看懂这篇工作,能有所收获。
什么是视觉-语言模型(Vision-Language Models, VLMs)?
视觉语言模型VLM一般由两个部分构成,即视觉(Vision)部分和语言(Language)部分。以一个经典的VLM网络 CLIP[1] 的结构为例:
如图1所示,CLIP由text branch和image branch组成。
其中,text branch主要由transformer构成,当要进行cls_num个类的分类任务时,会取每个类别对应的名称,如"plane", "car", "dog",与"a photo of a"进行组合,作为prompt输入进text encoder,得到大小为[cls_num, feat_dim]的text feature。
image branch的核心就是对输入的图像提取image feature,其通常为ResNet或者ViT[2]。图像经过image encoder之后得到image feature,其大小为[batch_size, feat_dim]。
将两个feature进行相乘就得到了预测logits。
CLIP有两个明确的特性,是这个工作的基础:
1. CLIP可以进行zero-shot分类,即对未见过的类别进行识别,并保持很高的性能。而传统的CNN或者ViT由于模型架构限制不可以。
2.对于已知的类别,CLIP的text branch只需要一次forward就可以得到对应text feature用于分类。
什么是提示学习(Prompt Learning)?
在Text Branch部分中,a photo of a {class_name} 这样的描述太过宽泛,明显不是最优的。例如对于图2(b)的花,手工设计的a flower photo a {class}要描述的更加精确,其产生的结果就更好。
这就产生来两个问题,第一,固定模板的prompt不是最优的。第二,针对性的手工设计费时费力,且无法泛化。
于是,提示学习(Prompt Learning)[3] [4]就提出将prompt变成了一种learnable的方式,通过优化的方法让prompt在下游数据集上学习适用的表征,来替代手工设计的prompt,参考图2中的绿色方块。
这样优势是,可以在少量数据的情况下,仅通过引入一少部分的可学习参数(即learnable prompt),就可以将原始的CLIP快速适用到下游的任务/数据,同时在性能上比全参数微调的结果更好[4]。
实验衡量指标是什么?
有三个指标,分别是base acc,novel acc和harmonic mean。
以imagenet-1k数据集为例,会取1000类中的前500类作为base class,后500类作为novel class。模型在base class上训练,完成后在base class和novel class上测试acc性能。因为novel class与base class数据类别不重复,所以novel acc可以有效反应模型泛化性能。harmonic mean指标是对base acc和novel acc的综合反映,为harmonic mean = (2*base acc*novel acc) / (base acc+novel acc)。总体的harmonic mean值越高,模型综合性能越好。
prompt learning的核心作用是,保持原始CLIP参数不变,通过引入小部分learnable prompt参数,来将大的原始的经过预训练的CLIP模型适用到下游任务/数据上,提升CLIP模型在下游任务的性能,同时保持CLIP模型zero-shot能力。
除去一直发展至今的各种设计prompt形式的工作[3] [5] [6] [7] [8] [9] [10] [11] [12] [13],现如今最前沿的prompt learning方法主要还可以分为另外两类:
1. 引入额外数据/信息。这一类工作核心就是通过引入额外的数据或信息,做法包括但不限于,
(1).通过LLM来生成{class_name}相关的语句,获得额外的有关{class_name}的特性 特征[14] [15] [16],或者更多描述性语句[17] [18] [19] [20]。
(2).引入额外的数据源,从wikipedia上引入文本描述[21],从额外数据集例如ImageNet-21K来做预训练 [22]。
(3).设计给原始图像数据引入额外的tag或标注[23] [24] [25]。
从以上的方式我们看到,大部分引入额外数据信息的工作都是围绕text branch展开,本质原因是输入的text本身"{class_name}"或"a photo of a {classname}"包含信息太少,丰富度要远低于image,通过额外的域内文本信息的引入,可以显著增强text feature的质量。 所以text feature的质量是关键。
同时,可以看到,围绕image branch的工作是相对较少的。这时候问题就来了:那我们可不可以用同样的思路来增强image feature呢?
诶,这个方法好!因为互联网内往往存在非常大量的图像数据,很容易获取。
但问题是这些图像往往是没有标注的,没办法用gt训,如果要去进行标注,需要消耗很多的时间或者钱。明显限制了这种方式的应用。 2. 利用原始CLIP自身信息约束模型学习[19] [26] [27] [28] [29] [30] [31],防止过拟合。
在Prompt learning中,learnable prompt的参数量是相对较少的,在经过大量base class数据训练之后,模型会对base class数据存在过拟合,丧失对novel class的泛化性能。要解决这个问题,一种非常有效的做法就是利用vanilla CLIP来约束带有prompt的模型的学习。
以ICCV 23 PromptSRC为例,如图3所示,
图3这篇工作就看两条线,蓝线和灰线。
蓝线,就是原始CLIP的前向计算路径,分别会得到对应的image和text feature。
灰线,就是带有learnable prompt的计算过程,也会得到对应的feature。
在两条线的末尾,计算了三个loss,这里就是用原始CLIP产生的image和text feature来约束由含有learnable prompt产生的image和text feature。通过这样的约束,限制了prompt向着base class过拟合,达到了SOTA的性能。
由这个工作我们就想,如果换一个更好的模型来做约束是不是性能会更好?
于是,这就引出了我们的工作。
PromptKD其实核心就在做一件事,引入更大的CLIP模型作为teacher,解决了上面提到的三个问题。
(1) 重用(Reuse) teacher CLIP产生的text feature用于学生的训练和推断。这样确保了text feature高质量的同时,还显著的节省计算量,训练时只涉及student的image encoder。
(2) 对齐学生CLIP和教师CLIP的logits。让大的CLIP模型给小的学生CLIP模型提供更好的监督。
(3) 因为有了教师CLIP的存在,就解决了数据量限制的问题,我们可以用大量的无标签domain data来训学生,不再拘泥于原来有限的有标签数据。在训练时,我们直接可以使用数据集的全量数据作为无标签数据进行蒸馏,这样一来就prompt就可以学到更广泛的domain knowledge。同时高性能的教师CLIP也保证了用于蒸馏的软标签的准确性。
我们先来看一个简单的结构缩略图:
黄色的方块部分代表的就是教师CLIP,在教师CLIP经过训练之后,直接一次forward,得到并保存下来对应类别的text feaure,也就得到了图4中的Pre-stored Text Feature。
蓝色的方块代表的是学生CLIP,这里其实就只有一个image encoder,在带有learnablr prompt的输入进入image encoder之后会得到对应的image feature,这是因为与teacher text feature在维度上不匹配,所以经过一个Projector,将512转成768维的特征。然后再与Pre-stored Text Feature相乘,得到logits。
然后进行蒸馏。
完整的框架图如图5所示:
图5里就是图4过程的细化。
这里将PromptKD的每个阶段都进行了详细的阐明。大家看图就明白了~
第一阶段,教师模型的预训练。在这里,我们选择之前的SOTA方法PromptSRC去预训练我们的教师ViT-L/14 CLIP模型,我们的学生模型是ViT-B/16 CLIP模型。
注意,这里的预训练不是必须的一步,选择去预训练教师模型,是为了让教师有一个更好的性能,从而有更好的学生蒸馏结果。如果直接使用vanilla ViT-L/14 CLIP作为教师,相比于baseline,也取得了明显的性能提升,具体结果请参考表4。
第二阶段,学生CLIP模型的蒸馏。
第三阶段,学生的推断。
最后再来一个简洁明了的流程概括图:
我们的PromptKD方法在prompt learning的11个benchmark dataset上都达到了SOTA的性能。
Base-to-novel实验
Cross-dataset实验
消融实验
为了实验快速进行,消融实验里使用的不是全量数据集,而是64 shots per class进行的训练。所以会与表1中的数据相比略低。
与其他同样使用了无标签数据的工作的性能对比:
教师预训练方法的选择
在PromptKD中,任意类型的ViT-L/14 CLIP教师模型都可以蒸馏出一个很好的ViT-B/16 CLIP模型,相比于baseline (70.22 HM)都有明显的提升。
这里有一点非常有意思的是,我们可以看到,第四行的Teacher(CLIP) ViT-L/14也就是原始的CLIP模型,在经过PromptKD的蒸馏之后,我们的ViT-B/16 CLIP的结果(表1(b))明显超过了原始的ViT-L/14 CLIP模型。(77.62 vs. 76.52)
不同容量教师模型的选择
如表5所示,绿色代表学生ViT-B/16 CLIP的HM分数,土黄色代表教师的HM分数。教师的性能越高,越能训练出更好的学生。
欢迎大家试用PrompKD~
1. 问:蒸馏和推理阶段向学生模型的输入中visual prompt在代码中的位置。
答:这个在CLIP的代码里已经实现了,在PromptKD/clip/model.py line 366开始以后,line 375的self.VPT就是learnable visual prompt的定义。在forward函数里面line 402就有concat的操作,将visual prompt与image token进行concat输入进ViT进行计算。
2. 问:想要找一个更小backbone的CLIP做蒸馏,只有ResNet-based CLIP了,但是ResNet-based CLIP不支持token形式的learnable prompt,怎么办?
答:两种方式,第一种,学生模型在这里不是必须要有prompt的,当变成resnet或者更小的模型时,也可以考虑去全参数微调去拟合下游任务。 第二种,当不支持token形式的prompt时,VPT论文其实给出了方案,就是在spatial的层面去加prompt,另外还可以参考MIT的工作《Exploring Visual Prompts for Adapting Large-Scale Models》,这篇论文的图里给出了很具体的可以应用的visual prompt实现方法,代码 (https://github.com/hjbahng/visual_prompting) 也已经开源了,可以参考去使用。
3. 问:Teacher CLIP如果没有prompt,不做pretrain可不可以?
答:是可以的。其实PromptKD这里的teacher不用局限在到底有没有经过pretrain这个事情上,我们在论文的table 6里也验证了,即使是最原始的ViT-L/14 CLIP用来做蒸馏,也可以取得明显的提升效果。因为promptkd本身是一种纯kd的训练方法,所以teacher的acc其实决定了student学习效果的上限,我们对teacher去进行pre-train,就是在提升这个上限,所以是上限越高蒸馏结果越好。但是如果不做pre-train,也不影响promptkd方法的使用。
4. 问:PromptKD和PromptSRC对硬件的需求。
答:我的实验是在A100的卡上完成的,所以没有特别在意这个,可能记得不太清楚具体细节了,PromptSRC对于卡的需要还是比较高的,最好是24g的卡,promptkd很省显存,我印象里之前跑某个实验时大概7-8G显存,用11g的1080ti应该就可以跑起来。
5. 问:蒸馏阶段的数据如果有真实标签怎么办?
答:在本文中,PromptKD受限于论文实验验证标准,使用的是无标签数据进行的蒸馏。而在现实中,如果训练数据包含有gt label,则可以考虑在学生的训练时直接使用gt label,即将图6里算loss这一行只有kd loss的情况换成loss = a* CE(l_stu, gt)+ b * KLD(l_stu, l_tea)进行训练,其中a,b为两项loss的超参,ce为朝着gt优化的cross entropy loss,在训练时可以先固定a=1不动,调整b来进行蒸馏实验,直到发现最优参数。
6. 问:蒸馏后的学生模型碰到新的类别怎么办?
答:当遇到新类别时,已有的预存储的文本特征已经不再适用,这时可以用教师的text encoder再inference一遍得到新的text feature用来计算。这里需要说明的一点是,这一步重新计算的操作确实存在限制,但这不是promptkd方法层面导致的,而是clip在遇到新的类别的时候都需要进行一次计算。promptkd是在clip本身基础上进行的对已知类别的优化。
7. 代码复现问题集合
(1) 问:结果是单次还是平均?
答:基于3个seed得到的结果取平均。
(2) 问:尝试自己预训练教师,然后进行蒸馏,结果达不到
答:首先需要注意的是,预训练教师是否达到了论文补充材料的table 10里面报告的acc。如果没达到,那么可能蒸馏的效果就不会那么好。(论文的预训练方法是采用promptsrc用默认训练的setting训的VIT-L/14 CLIP)。为了方便复现,github代码里面已经在各种渠道提供了预训练的模型,包括百度云,terabox,google cloud,github的repo下releases部分。推荐使用已有预训练模型进行训练。
(3) 问:复现结果有波动
答:prompt learning领域方法的论文有两个限制:训练数据量小和训练参数量小。这时候会存在一定的训练的波动,推荐多run几个seed,比如seed 1-5,然后去掉训练波动的数据值,取三个结果avg作为最终结果。更推荐大家在数据量大的数据集imagenet上做实验,实验结果会比较稳定。
这篇论文解读感谢师弟武戈同学的部分论文总结,PromptKD这篇工作感谢我的导师和co-author们,另外还非常感谢蚂蚁的申书恒,张长浩和傅幸同学的讨论和帮助。