Fine-Tuned CLIP: Better Listing Experience and 80% More Budget-Friendly

This post is for Day 23 of Mercari Advent Calendar 2023, brought to you by a 2023 New Grad, @andy971022, from Mercari’s US@Tokyo Machine Learning team. For those curious about the term “US@Tokyo”, it represents a team serving Mercari’s US marketplace while being based in Tokyo.

Introduction

Pre-filling the category, brand, title, and color fields when a user uploads an image during listing has been a long-living feature in both Mercari JP and US. However, little do people know the engineering efforts put behind the feature that supports half a million listings daily.

Example of the Service

In this episode, we’ll demonstrate how we conducted fine-tuning on CLIP (Contrastive Language-Image Pre-Training) to significantly boost the performance of item category and brand prediction, requiring users to input fewer fields and hence improving the listing experience overall. In addition to streamlining user experience, our efforts yielded an impressive 80% reduction in serving costs, highlighting the cost-effectiveness of our approach.

Background

We started our journey with InceptionV3, a 24-million parameter, 5000+ class classification model trained on millions of Mercari listing images. The model is not used to directly predict the item fields as we have more brands and categories than do classes. Instead, we extracted the embedding from the listing image and probed that into a vector index of 50-100 million item image embeddings generated using the same model to retrieve top-K similar items. These similar items were then collected for a vote on the brand and category.

Earlier this year, we migrated this ML service to a GCP-managed service, namely, Vertex AI Vector Search (previously known as Matching Engine) and updated from using InceptionV3 to using a CLIP variant as part of our ongoing pursuit to simplify and elevate the selling experience of our users. But why CLIP?

CLIP

CLIP was released in early 2021 and stood as the best Zero-Shot Pretrained Contrastive Learning model at the time. Capable of comprehending both text and image inputs, CLIP has a base version of 151 million parameters that outputs 512-dimensional embeddings. Interestingly, CLIP naturally excels as both an image and a text encoder.

We can see from the pseudo-code of CLIP’s model architecture below that an L2-normalization is applied to both image and text embeddings after the final projection. Intuitively, it is mapping text/image embeddings to the surface of the same hyper-dimensional unit sphere, meaning that all the points on the surface are equidistant to the center of the sphere having an Euclidean distance of 1.

(Source: (Left) Learning Transferable Visual Models From Natural Language Supervision, https://arxiv.org/pdf/2103.00020.pdf, (Right) Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere, https://arxiv.org/pdf/2005.10242.pdf)

The InfoNCE loss maximizes the distance between unalike image-text, image-image, and text-text pairs and minimizes it for those that are alike. Figuratively, it forces the model to “use up” all the spaces on the surface of that sphere (uniformity) while keeping similar inputs close (alignment). This mimics the process of conducting a clustering method on the embeddings which eases downstream tasks such as classification or similarity search.

(Source: InfoNCE Loss Provably Learns Cluster-Preserving Representations, https://arxiv.org/pdf/2302.07920.pdf)

Optimizing Model Performance and Cost

After the migration, we saw an opportunity to improve our system from a model performance and cost perspective.

  1. Publicly available CLIP variants output embeddings of dimension 512 at the smallest, and optimization is essential as they currently stand at a performance level similar to the internally trained InceptionV3 model.
  2. Scaling down/up resources for cost savings isn’t straightforward with Vertex AI Vector Search being a managed service.

These seemingly different problems turned out to have a single common solution – CLIP Fine-Tuning + Dimensional Reduction.

Improving the model’s performance can be considered a domain-specific task, which is commonly tackled by adding and training extra linear layers at the end of the model, so-called fine-tuning. With the extra linear layers, the resulting dimension of the output embeddings can also be specified. Another common dimensional reduction approach, PCA, or Principal Component Analysis, is not a viable solution in this context due to its non-lossless compression nature, often resulting in performance degradation.

Vertex AI Vector Search bills us on the number of instances we use – the larger the index, the more expensive it is to serve. An index’s size is determined by the product of the number, dimension, and the bytes required by the datatype of the vectors. By reducing the embedding dimensions to 64, we also scale the index size down to an eighth of the initial 512-dimension, 4-byte float32 setup without having to reduce the number of items in the index. This thereby reduces 80% of the e2-standard-16 instances, and cost, needed to serve the index. To give a sample calculation, scaling 10 e2-standard-16 instances down to 2 alone can save around, monthly, $4,000+ or ¥560,000+ at a rate of $1:¥140.

All things combined, we were convinced that fine-tuning the CLIP model with additional lower-dimension linear layers was the way to go.

Finetuning CLIP on Cloud

We went a magnitude further this time and fine-tuned the base version of CLIP on a curated dataset consisting of over 10 million Mercari item images and text features using the same InfoNCE loss. The fine-tuning process consisted of two rounds.

  1. In the first round, we continued to train the model using our data for around 25 epochs with some standard hyperparameter settings. This is done to adapt the model to our data domain.
  2. The epoch with the best performance on the validation set from the first round was forwarded to the second round of training, where we froze CLIP’s vision model, a zero-shot transfer technique popularized by Zhai et al. (2022), and trained the dimensional reduction layers (512×64) that were added before the final L2-normalization.

Below are some code fragments that illustrate our fine-tuning implementation and architecture.

## Freezing the vision model
def freeze_vision_model(self):
    for param in self.vision_model.parameters():
        param.requires_grad = False

# Custom linear layer for dimensional reduction
# should be added after final projection and before normalization
self.image_embed = nn.Linear(512, embed_size)
self.text_embed = nn.Linear(512, embed_size) # embed_size=64

# image
image_embeds = self.visual_projection(image_embeds)
image_embeds = self.image_embed(
      image_embeds
  )  # custom linear layer for dimension reduction

# text
text_embeds = self.text_projection(text_embeds)
text_embeds = self.text_embed(
    text_embeds
)  # custom linear layer for dimension reduction

The size of the output dimension is determined based on the consideration of cost and performance. We found 64 a great balance and were seeing diminishing returns further down the track. The figure below shows the relative brand accuracy and the relative index size against the dimensions.

Relative Performance and Index Size against Dimensions

For reference only, the entire two-round fine-tuning process would take 5 days with, in total, 50 epochs, a training batch size of 234, and a validation batch size of 1000, on 2 A100s using over 10 million 224×224 images and text pairs. The batch sizes are chosen to best utilize our GPU resources. Do note that the batch size we used was far from the 32K batch size used to train the base CLIP model.

Apart from loss, another metric that we evaluate performance on during training is referred to as the image_to_text_mean_rank. This computes the mean ranking of the cosine similarity for each image embedding against all the text embeddings in the same validation batch. Rank, here, denotes the position of the ground truth, or the corresponding text, of an image in terms of similarity with 1 being the highest.

Image_to_text_mean_rank vs Epoch, Lower is Better

Generating Embeddings, Building Index, and Offline Experiments

After the model was trained, we carried out offline experiments based on the generated embeddings and the index built on top of ScaNN (Scalable Nearest Neighbors), the similarity search algorithm behind Vertex AI Vector Search. 50-100 million images would take 2-3 days to download, and the corresponding embeddings would take another day or two to generate with 10-20+ T4 GPU instances running in parallel. To ensure data consistency in the production environment, we used a dedicated dataflow job for the embedding pipeline.

Below is an example that demonstrates image-to-image search using CLIP and the index built with our inventory. As shown in the example, the majority of the similar listings returned from the search were also Nike sneakers and, in turn, voted “Nike” as the brand and “Shoes” as the category. In our offline experiments, we rinsed and repeated this process for 100K to 1 million items from a distinct test dataset to have a better understanding of how the model will perform online.

Querying on the CLIP Index Using an Image of a Pair of Nike Sneakers

Reflection

Reflecting upon our journey, we realized that there remain too many challenges and stories yet to be shared. Reasons and engineering behind the migration, handling hundreds of millions of image read/write operations, dealing with GPU shortages, conducting countless experiments, the hardship of being an early adopter of a novel GCP service, and all the backend adjustments – any of which can be easily expanded into another blog. Albeit unable to elaborate on them all, we have condensed what we think is the most important.

Mercari’s US@Tokyo ML team has consistently been trying to leverage AI techniques to simplify the selling experience of users. Among those efforts, one is the development and continuous improvement of the models to predict listing fields like category and brand. We genuinely hope that you find this a fruitful reading and that we can continue to be visionary and deliver enriching content.

Acknowledgments

I express my sincere gratitude to Karen Wang and Zainul Din for their invaluable contributions that played a pivotal role in bringing this project to fruition. Special thanks are extended to Rishabh Kumar Shrivastava, Shotaro Kohama, Takuma Yamaguchi, Ajay Daptardar, and Vamshi Teja Racha for their unwavering support and insightful guidance throughout the development process.

Tomorrow’s article will be by @jye. Look forward to it!

Bibliography

  1. Yamaguchi, T. (2017/12/23). 画像での商品検索に向けて. Mercari Engineering Blog. https://engineering.mercari.com/blog/entry/2017-12-23-100000/
  2. Radford, Alec, et al. "Learning transferable visual models from natural language supervision." International conference on machine learning. PMLR, 2021.
  3. Szegedy, Christian, et al. "Rethinking the inception architecture for computer vision." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
  4. Google Cloud Platform. (2023). Vertex AI Documentation. Google Cloud. https://cloud.google.com/vertex-ai/pricing#vectorsearch
  5. Wang, Tongzhou, and Phillip Isola. "Understanding contrastive representation learning through alignment and uniformity on the hypersphere." International Conference on Machine Learning. PMLR, 2020.
  6. Parulekar, Advait, et al. "InfoNCE Loss Provably Learns Cluster-Preserving Representations." arXiv preprint arXiv:2302.07920 (2023).
  7. Zhai, Xiaohua, et al. "Lit: Zero-shot transfer with locked-image text tuning." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
  • X
  • Facebook
  • linkedin
  • このエントリーをはてなブックマークに追加