Source code for gptcache.embedding.timm

import numpy as np

from gptcache.utils import import_timm, import_torch, import_pillow
from gptcache.embedding.base import BaseEmbedding

import_torch()
import_timm()
import_pillow()

import torch  # pylint: disable=C0413
from timm.models import create_model  # pylint: disable=C0413
from timm.data import create_transform, resolve_data_config  # pylint: disable=C0413
from PIL import Image  # pylint: disable=C0413


[docs]class Timm(BaseEmbedding): """Generate image embedding for given image using pretrained models from Timm. :param model: model name, defaults to 'resnet34'. :type model: str Example: .. code-block:: python import requests from io import BytesIO from gptcache.embedding import Timm encoder = Timm(model='resnet50') embed = encoder.to_embeddings('path/to/image') """ def __init__(self, model: str = "resnet18", device: str = "default"): if device == "default": self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device self.model_name = model self.model = create_model(model_name=model, pretrained=True) self.model.eval() try: self.__dimension = self.model.embed_dim except Exception: # pylint: disable=W0703 self.__dimension = None
[docs] def to_embeddings(self, data, skip_preprocess: bool = False, **_): """Generate embedding given image data :param data: image path. :type data: str :param skip_preprocess: flag to skip preprocess, defaults to False, enable this if the input data is torch.tensor. :type skip_preprocess: bool :return: an image embedding in shape of (dim,). """ if not skip_preprocess: data = self.preprocess(data) if data.dim() == 3: data = data.unsqueeze(0) feats = self.model.forward_features(data) emb = self.post_proc(feats).squeeze(0).detach().numpy() return np.array(emb).astype("float32")
[docs] def post_proc(self, features): features = features.to("cpu") if features.dim() == 3: features = features[:, 0] if features.dim() == 4: global_pool = torch.nn.AdaptiveAvgPool2d(1) features = global_pool(features) features = features.flatten(1) assert features.dim() == 2, f"Invalid output dim {features.dim()}" return features
[docs] def preprocess(self, image_path): """Load image from path and then transform image to torch.tensor with model transformations. :param image_path: image path. :type image_path: str :return: an image tensor (without batch size). """ data_cfg = resolve_data_config(self.model.pretrained_cfg) transform = create_transform(**data_cfg) image = Image.open(image_path).convert("RGB") image_tensor = transform(image) return image_tensor
@property def dimension(self): """Embedding dimension. :return: embedding dimension """ if not self.__dimension: input_size = self.model.pretrained_cfg["input_size"] dummy_input = torch.rand((1,) + input_size) feats = self.to_embeddings(dummy_input, skip_preprocess=True) self.__dimension = feats.shape[0] return self.__dimension