Source code for gptcache.similarity_evaluation.sbert_crossencoder

from typing import Dict, Tuple, Any
from gptcache.utils import import_sbert
from gptcache.similarity_evaluation import SimilarityEvaluation
from sentence_transformers import CrossEncoder # pylint: disable=C0413

[docs]class SbertCrossencoderEvaluation(SimilarityEvaluation): """Using SBERT crossencoders to evaluate sentences pair similarity. This evaluator use the crossencoder model to evaluate the similarity of two sentences. :param model: model name of SbertCrossencoderEvaluation. Default is 'cross-encoder/quora-distilroberta-base'. Check more please refer to :type model: str Example: .. code-block:: python from gptcache.similarity_evaluation import SbertCrossencoderEvaluation evaluation = SbertCrossencoderEvaluation() score = evaluation.evaluation( { 'question': 'What is the color of sky?' }, { 'question': 'hello' } ) """ def __init__(self, model: str="cross-encoder/quora-distilroberta-base"): self.model = CrossEncoder(model)
[docs] def evaluation( self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_ ) -> float: """Evaluate the similarity score of pair. :param src_dict: the query dictionary to evaluate with cache. :type src_dict: Dict :param cache_dict: the cache dictionary. :type cache_dict: Dict :return: evaluation score. """ try: src_question = src_dict["question"] cache_question = cache_dict["question"] if src_question.lower() == cache_question.lower(): return 1 return self.model.predict([(src_question, cache_question)])[0] except Exception: # pylint: disable=W0703 return 0
[docs] def range(self) -> Tuple[float, float]: """Range of similarity score. :return: minimum and maximum of similarity score. """ return 0.0, 1.0