SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning

1Visual AI Lab, The University of Hong Kong
2Visual Geometry Group, University of Oxford

Abstract


Generalized Category Discovery (GCD) aims to classify unlabelled images from both `seen' and `unseen' classes by transferring knowledge from a set of labelled `seen' class images. A key theme in existing GCD approaches is adapting large-scale pre-trained models for the GCD task.

An alternate perspective, however, is to adapt the data representation itself for better alignment with the pre-trained model. As such, in this paper, we introduce a two-stage adaptation approach termed SPTNet, which iteratively optimizes model parameters (i.e., model-finetuning) and data parameters (i.e., prompt learning). Furthermore, we propose a novel spatial prompt tuning method (SPT) which considers the spatial property of image data, enabling the method to better focus on object parts, which can transfer between seen and unseen classes.

We thoroughly evaluate our SPTNet on standard benchmarks and demonstrate that our method outperforms existing GCD methods. Notably, we find our method achieves an average accuracy of 61.4% on the SSB, surpassing prior state-of-the-art methods by approximately 10%. The improvement is particularly remarkable as our method yields extra parameters amounting to only 0.042% of those in the backbone architecture.

Framework


SPTNet introduces a two-stage adaptation approach termed SPTNet, which iteratively optimizes model parameters (i.e., model-finetuning) and data parameters (i.e., prompt learning).


In the first stage, we attach the same set of spatial prompts to the input images. During training, we freeze the parameters of models and only update the prompt parameters.

In the second stage, we freeze prompt parameters and learn the parameters of models.With our spatial prompt learning as a strong augmentation, we aim to obtain a representation that can better distinguish samples from different classes, as the core mechanism of contrastive learning involves implicitly clustering samples from the same class together.

Different from prior works that apply only hand-crafted augmentations, we propose to consider prompting the input with learnable prompts as a new type of augmentation. The `prompted' version of the input can be adopted by all loss terms. In this way, our framework can enjoy a learned augmentation that varies throughout the training process, enabling the backbone to learn discriminative representations. Each stage optimizes the parameters for k iterations.


A key insight in GCD is that object parts are effective in transferring knowledge between old and new categorie. Therefore, we propose Spatial Prompt Tuning (SPT) to serve as a learned data augmentation that enables the model to focus on local image object regions, while adapting the data representation from the pre-trained ViT model and maintaining the alignment with it.

Performance


We evaluate SPTNet on three generic datasets, CIFAR-10, CIFAR-100 and ImageNet-100. We compare SPTNet with previous state-of-the-art methods and two concurrent methods. The results are shown below. We can see that our method consistently outperforms previous state-of-the-art methods.

The results on the four fine-grained benchmarks (CUB, Stanford Cars, FGVC-Aircraft, and Herbarium19) are shown below.

Visualization


Comparing the representations of SimGCD and SimGCD+VPT, VPT appears to have a negative impact on the representation, leading to clutter between seen and unseen classes (e.g., bird and dog) in the GCD setting. Both SPTNet-P and SPTNet produce more discriminative features and more compact clusters than SimGCD.

Besides, SPTNet explores diverse regions in different heads and meanwhile covers the object across almost all heads.

BibTeX

@inproceedings{wang2024sptnet,
    author    = {Wang, Hongjun and Vaze, Sagar and Han, Kai},
    title     = {SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning},
    booktitle = {International Conference on Learning Representations (ICLR)},
    year      = {2024}
}