_sklearn.py 2.42 KB
from bertopic.backend import BaseEmbedder
from sklearn.utils.validation import check_is_fitted, NotFittedError


class SklearnEmbedder(BaseEmbedder):
    """Scikit-Learn based embedding model.

    This component allows the usage of scikit-learn pipelines for generating document and
    word embeddings.

    Arguments:
        pipe: A scikit-learn pipeline that can `.transform()` text.

    Examples:
    Scikit-Learn is very flexible and it allows for many representations.
    A relatively simple pipeline is shown below.

    ```python
    from sklearn.pipeline import make_pipeline
    from sklearn.decomposition import TruncatedSVD
    from sklearn.feature_extraction.text import TfidfVectorizer

    from bertopic.backend import SklearnEmbedder

    pipe = make_pipeline(
        TfidfVectorizer(),
        TruncatedSVD(100)
    )

    sklearn_embedder = SklearnEmbedder(pipe)
    topic_model = BERTopic(embedding_model=sklearn_embedder)
    ```

    This pipeline first constructs a sparse representation based on TF/idf and then
    makes it dense by applying SVD. Alternatively, you might also construct something
    more elaborate. As long as you construct a scikit-learn compatible pipeline, you
    should be able to pass it to Bertopic.

    !!! Warning
        One caveat to be aware of is that scikit-learns base `Pipeline` class does not
        support the `.partial_fit()`-API. If you have a pipeline that theoretically should
        be able to support online learning then you might want to explore
        the [scikit-partial](https://github.com/koaning/scikit-partial) project.
    """

    def __init__(self, pipe):
        super().__init__()
        self.pipe = pipe

    def embed(self, documents, verbose=False):
        """Embed a list of n documents/words into an n-dimensional
        matrix of embeddings.

        Arguments:
            documents: A list of documents or words to be embedded
            verbose: No-op variable that's kept around to keep the API consistent. If you want to get feedback on training times, you should use the sklearn API.

        Returns:
            Document/words embeddings with shape (n, m) with `n` documents/words
            that each have an embeddings size of `m`
        """
        try:
            check_is_fitted(self.pipe)
            embeddings = self.pipe.transform(documents)
        except NotFittedError:
            embeddings = self.pipe.fit_transform(documents)

        return embeddings