Source code for gptcache.processor.context.summarization_context

from typing import Dict, Any

import numpy as np

from gptcache.processor import ContextProcess
from gptcache.utils import import_huggingface

import_huggingface()

import transformers  # pylint: disable=C0413


[docs]def summarize_to_length(summarizer, text, target_len, max_len=1024): tokenizer = summarizer.tokenizer def token_length(text): return len(tokenizer.encode(text)) segment_len = max_len - 100 summary_result = text while token_length(text) > target_len: tokens = tokenizer.encode(text) segments = [ tokens[i : i + segment_len] for i in range(0, len(tokens), segment_len - 1) ] summary_result = "" for segment in segments: len_seg = int(len(segment) / 4) summary = summarizer( tokenizer.decode(segment), min_length=max(len_seg - 10, 1), max_length=len_seg, ) summary_result += summary[0]["summary_text"] text = summary_result return summary_result
[docs]class SummarizationContextProcess(ContextProcess): """A context processor for summarizing large amounts of text data using a summarizer model. :param summarizer: The summarizer model to use for summarization. :type summarizer: transformers.PreTrainedModel :param tokenizer: The tokenizer to use for tokenizing the text data. It used for measuring the output length. :type tokenizer: transformers.PreTrainedTokenizer :param target_length: The length of the summarized text. :type target_length: int Example: .. code-block:: python from gptcache.processor.context.summarization_context import SummarizationContextProcess context_process = SummarizationContextProcess() cache.init(pre_embedding_func=context_process.pre_process) """ def __init__( self, model_name="facebook/bart-large-cnn", tokenizer=None, target_length=512 ): summarizer = transformers.pipeline(task="summarization", model=model_name) self.summarizer = summarizer self.target_length = target_length if tokenizer is None: tokenizer = transformers.RobertaTokenizer.from_pretrained("roberta-base") self.tokenizer = tokenizer self.content = ""
[docs] def summarize_to_sentence(self, sentences, target_size=1000): lengths = [] for sentence in sentences: lengths.append(len(sentence)) total_length = np.array(lengths).sum() target_lengths = [int(target_size * l / total_length) for l in lengths] target_sentences = [] for sent, target_len in zip(sentences, target_lengths): if len(self.tokenizer.tokenize(sent)) > target_len: response = summarize_to_length( self.summarizer, sent, target_len, self.tokenizer.model_max_length ) target_sentence = response else: target_sentence = sent target_sentences.append(target_sentence) result = "" for target_sentence in target_sentences: result = result + target_sentence return result
[docs] def format_all_content(self, data: Dict[str, Any], **params: Dict[str, Any]): contents = [] for query in data["messages"]: contents.append(query) self.content = contents
[docs] def process_all_content(self) -> (Any, Any): def serialize_content(content): ret = "" for message in content: ret += "[#RS]{}[#RE][#CS]{}[#CE]".format( message["role"], message["content"] ) return ret result = self.summarize_to_sentence( [message["content"] for message in self.content], self.target_length ) save_content = serialize_content(self.content) embedding_content = result return save_content, embedding_content