Source code for gptcache.adapter.dolly
from typing import Any
from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.utils import import_huggingface, import_torch
import_torch()
import_huggingface()
from transformers import pipeline # pylint: disable=wrong-import-position
[docs]class Dolly:
"""Wrapper for Dolly (https://github.com/databrickslabs/dolly.git).
Example using from_model:
.. code-block:: python
from gptcache import cache
from gptcache.processor.pre import get_inputs
cache.init(pre_embedding_func=get_inputs)
from gptcache.adapter.dolly import Dolly
dolly = Dolly.from_model(
model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device=0
)
Example passing pipeline in directly:
.. code-block:: python
import torch
from transformers import pipeline
from gptcache import cache
from gptcache.processor.pre import get_inputs
cache.init(pre_embedding_func=get_inputs)
from gptcache.adapter.dolly import Dolly
pipe = pipeline(
model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device=0
)
dolly = Dolly(pipe)
"""
def __init__(self, dolly_pipeline: Any):
self._dolly_pipeline = dolly_pipeline
[docs] @classmethod
def from_model(cls, model: str, **kwargs):
pipe = pipeline(model=model, **kwargs)
return cls(pipe)
def __call__(self, prompt: str, **kwargs):
return adapt(
self._dolly_pipeline,
_cache_data_convert,
_update_cache_callback,
inputs=prompt,
**kwargs
)
def _cache_data_convert(cache_data):
return [{"generated_text": cache_data, "gptcache": True}]
def _update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
update_cache_func(Answer(llm_data[0]["generated_text"], DataType.STR))
return llm_data