Source code for gptcache.adapter.diffusers
import base64
from io import BytesIO
from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.utils import (
import_pillow, import_diffusers, import_huggingface
)
from gptcache.utils.error import CacheError
import_pillow()
import_huggingface()
import_diffusers()
from PIL import Image # pylint: disable=C0413
import diffusers # pylint: disable=C0413
[docs]class StableDiffusionPipeline(diffusers.StableDiffusionPipeline):
"""Diffuser StableDiffusionPipeline Wrapper
Example:
.. code-block:: python
import torch
from gptcache import cache
from gptcache.processor.pre import get_prompt
from gptcache.adapter.diffusers import StableDiffusionPipeline
# init gptcache
cache.init(pre_embedding_func=get_prompt)
# run with gptcache
model_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt=prompt).images[0]
"""
def _llm_handler(self, *llm_args, **llm_kwargs):
try:
return super().__call__(*llm_args, **llm_kwargs)
except Exception as e:
raise CacheError("diffuser error") from e
def __call__(self, *args, **kwargs):
def cache_data_convert(cache_data):
return _construct_resp_from_cache(cache_data)
def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
img = llm_data["images"][0]
buffered = BytesIO()
img.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue())
update_cache_func(Answer(img_b64, DataType.IMAGE_BASE64))
return llm_data
return adapt(
self._llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)
def _construct_resp_from_cache(img_64):
im_bytes = base64.b64decode(img_64) # im_bytes is a binary image
im_file = BytesIO(im_bytes) # convert image to file-like object
img = Image.open(im_file)
return diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput(images=[img], nsfw_content_detected=None)