defforward_features(self, x): if self.args['diff_aug']: x = DiffAugment(x, self.args.diff_aug, True)
B = x.shape[0] x = self.patch_embed(x).flatten(2).permute(0,2,1) # x (1, 64, 64) 第一个64是64个patch块,第二个64是通道
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) # x (1, 65, 64) x = x + self.pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) # x (1, 65, 64) return x[:,0] # (1, 64) 返回第一个位置的值
defforward(self, x): # x (1, 3, 64, 64) x = self.forward_features(x) # 特征提取过程 x (1, 64) x = self.head(x) # 线性层转换为2分类真假预测问题 (1, 2) return x