_visual.py 10.4 KB
import numpy as np
import pandas as pd

from PIL import Image
from tqdm import tqdm
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Union
from transformers.pipelines import Pipeline, pipeline

from bertopic.representation._mmr import mmr
from bertopic.representation._base import BaseRepresentation


class VisualRepresentation(BaseRepresentation):
    """From a collection of representative documents, extract
    images to represent topics. These topics are represented by a
    collage of images.

    Arguments:
        nr_repr_images: Number of representative images to extract
        nr_samples: The number of candidate documents to extract per cluster.
        image_height: The height of the resulting collage
        image_square: Whether to resize each image in the collage
                      to a square. This can be visually more appealing
                      if all input images are all almost squares.
        image_to_text_model: The model to caption images.
        batch_size: The number of images to pass to the
                    `image_to_text_model`.

    Usage:

    ```python
    from bertopic.representation import VisualRepresentation
    from bertopic import BERTopic

    # The visual representation is typically not a core representation
    # and is advised to pass to BERTopic as an additional aspect.
    # Aspects can be labeled with dictionaries as shown below:
    representation_model = {
        "Visual_Aspect": VisualRepresentation()
    }

    # Use the representation model in BERTopic as a separate aspect
    topic_model = BERTopic(representation_model=representation_model)
    ```
    """

    def __init__(
        self,
        nr_repr_images: int = 9,
        nr_samples: int = 500,
        image_height: Tuple[int, int] = 600,
        image_squares: bool = False,
        image_to_text_model: Union[str, Pipeline] = None,
        batch_size: int = 32,
    ):
        self.nr_repr_images = nr_repr_images
        self.nr_samples = nr_samples
        self.image_height = image_height
        self.image_squares = image_squares

        # Text-to-image model
        if isinstance(image_to_text_model, Pipeline):
            self.image_to_text_model = image_to_text_model
        elif isinstance(image_to_text_model, str):
            self.image_to_text_model = pipeline("image-to-text", model=image_to_text_model)
        elif image_to_text_model is None:
            self.image_to_text_model = None
        else:
            raise ValueError(
                "Please select a correct transformers pipeline. For example:"
                "pipeline('image-to-text', model='nlpconnect/vit-gpt2-image-captioning')"
            )
        self.batch_size = batch_size

    def extract_topics(
        self,
        topic_model,
        documents: pd.DataFrame,
        c_tf_idf: csr_matrix,
        topics: Mapping[str, List[Tuple[str, float]]],
    ) -> Mapping[str, List[Tuple[str, float]]]:
        """Extract topics.

        Arguments:
            topic_model: A BERTopic model
            documents: All input documents
            c_tf_idf: The topic c-TF-IDF representation
            topics: The candidate topics as calculated with c-TF-IDF

        Returns:
            representative_images: Representative images per topic
        """
        # Extract image ids of most representative documents
        images = documents["Image"].values.tolist()
        (_, _, _, repr_docs_ids) = topic_model._extract_representative_docs(
            c_tf_idf,
            documents,
            topics,
            nr_samples=self.nr_samples,
            nr_repr_docs=self.nr_repr_images,
        )
        unique_topics = sorted(list(topics.keys()))

        # Combine representative images into a single representation
        representative_images = {}
        for topic in tqdm(unique_topics):
            # Get and order represetnative images
            sliced_examplars = repr_docs_ids[topic + topic_model._outliers]
            sliced_examplars = [sliced_examplars[i : i + 3] for i in range(0, len(sliced_examplars), 3)]
            images_to_combine = [
                [
                    Image.open(images[index]) if isinstance(images[index], str) else images[index]
                    for index in sub_indices
                ]
                for sub_indices in sliced_examplars
            ]

            # Concatenate representative images
            representative_image = get_concat_tile_resize(images_to_combine, self.image_height, self.image_squares)
            representative_images[topic] = representative_image

            # Make sure to properly close images
            if isinstance(images[0], str):
                for image_list in images_to_combine:
                    for image in image_list:
                        image.close()

        return representative_images

    def _convert_image_to_text(self, images: List[str], verbose: bool = False) -> List[str]:
        """Convert a list of images to captions.

        Arguments:
            images: A list of images or words to be converted to text.
            verbose: Controls the verbosity of the process

        Returns:
            List of captions
        """
        # Batch-wise image conversion
        if self.batch_size is not None:
            documents = []
            for batch in tqdm(self._chunks(images), disable=not verbose):
                outputs = self.image_to_text_model(batch)
                captions = [output[0]["generated_text"] for output in outputs]
                documents.extend(captions)

        # Convert images to text
        else:
            outputs = self.image_to_text_model(images)
            documents = [output[0]["generated_text"] for output in outputs]

        return documents

    def image_to_text(self, documents: pd.DataFrame, embeddings: np.ndarray) -> pd.DataFrame:
        """Convert images to text."""
        # Create image topic embeddings
        topics = documents.Topic.values.tolist()
        images = documents.Image.values.tolist()
        df = pd.DataFrame(np.hstack([np.array(topics).reshape(-1, 1), embeddings]))
        image_topic_embeddings = df.groupby(0).mean().values

        # Extract image centroids
        image_centroids = {}
        unique_topics = sorted(list(set(topics)))
        for topic, topic_embedding in zip(unique_topics, image_topic_embeddings):
            indices = np.array([index for index, t in enumerate(topics) if t == topic])
            top_n = min([self.nr_repr_images, len(indices)])
            indices = mmr(
                topic_embedding.reshape(1, -1),
                embeddings[indices],
                indices,
                top_n=top_n,
                diversity=0.1,
            )
            image_centroids[topic] = indices

        # Extract documents
        documents = pd.DataFrame(columns=["Document", "ID", "Topic", "Image"])
        current_id = 0
        for topic, image_ids in tqdm(image_centroids.items()):
            selected_images = [
                Image.open(images[index]) if isinstance(images[index], str) else images[index] for index in image_ids
            ]
            text = self._convert_image_to_text(selected_images)

            for doc, image_id in zip(text, image_ids):
                documents.loc[len(documents), :] = [
                    doc,
                    current_id,
                    topic,
                    images[image_id],
                ]
                current_id += 1

            # Properly close images
            if isinstance(images[image_ids[0]], str):
                for image in selected_images:
                    image.close()

        return documents

    def _chunks(self, images):
        for i in range(0, len(images), self.batch_size):
            yield images[i : i + self.batch_size]


def get_concat_h_multi_resize(im_list):
    """Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
    min_height = min(im.height for im in im_list)
    min_height = max(im.height for im in im_list)
    im_list_resize = []
    for im in im_list:
        im.resize((int(im.width * min_height / im.height), min_height), resample=0)
        im_list_resize.append(im)

    total_width = sum(im.width for im in im_list_resize)
    dst = Image.new("RGB", (total_width, min_height), (255, 255, 255))
    pos_x = 0
    for im in im_list_resize:
        dst.paste(im, (pos_x, 0))
        pos_x += im.width
    return dst


def get_concat_v_multi_resize(im_list):
    """Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
    min_width = min(im.width for im in im_list)
    min_width = max(im.width for im in im_list)
    im_list_resize = [im.resize((min_width, int(im.height * min_width / im.width)), resample=0) for im in im_list]
    total_height = sum(im.height for im in im_list_resize)
    dst = Image.new("RGB", (min_width, total_height), (255, 255, 255))
    pos_y = 0
    for im in im_list_resize:
        dst.paste(im, (0, pos_y))
        pos_y += im.height
    return dst


def get_concat_tile_resize(im_list_2d, image_height=600, image_squares=False):
    """Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
    images = [[image.copy() for image in images] for images in im_list_2d]

    # Create
    if image_squares:
        width = int(image_height / 3)
        height = int(image_height / 3)
        images = [[image.resize((width, height)) for image in images] for images in im_list_2d]

    # Resize images based on minimum size
    else:
        min_width = min([min([img.width for img in imgs]) for imgs in im_list_2d])
        min_height = min([min([img.height for img in imgs]) for imgs in im_list_2d])
        for i, imgs in enumerate(images):
            for j, img in enumerate(imgs):
                if img.height > img.width:
                    images[i][j] = img.resize(
                        (int(img.width * min_height / img.height), min_height),
                        resample=0,
                    )
                elif img.width > img.height:
                    images[i][j] = img.resize((min_width, int(img.height * min_width / img.width)), resample=0)
                else:
                    images[i][j] = img.resize((min_width, min_width))

    # Resize grid image
    images = [get_concat_h_multi_resize(im_list_h) for im_list_h in images]
    img = get_concat_v_multi_resize(images)
    height_percentage = image_height / float(img.size[1])
    adjusted_width = int((float(img.size[0]) * float(height_percentage)))
    img = img.resize((adjusted_width, image_height), Image.Resampling.LANCZOS)

    return img