_multimodal.py
7.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import List, Union
from sentence_transformers import SentenceTransformer
from bertopic.backend import BaseEmbedder
class MultiModalBackend(BaseEmbedder):
"""Multimodal backend using Sentence-transformers.
The sentence-transformers embedding model used for
generating word, document, and image embeddings.
Arguments:
embedding_model: A sentence-transformers embedding model that
can either embed both images and text or only text.
If it only embeds text, then `image_model` needs
to be used to embed the images.
image_model: A sentence-transformers embedding model that is used
to embed only images.
batch_size: The sizes of image batches to pass
Examples:
To create a model, you can load in a string pointing to a
sentence-transformers model:
```python
from bertopic.backend import MultiModalBackend
sentence_model = MultiModalBackend("clip-ViT-B-32")
```
or you can instantiate a model yourself:
```python
from bertopic.backend import MultiModalBackend
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("clip-ViT-B-32")
sentence_model = MultiModalBackend(embedding_model)
```
"""
def __init__(
self,
embedding_model: Union[str, SentenceTransformer],
image_model: Union[str, SentenceTransformer] = None,
batch_size: int = 32,
):
super().__init__()
self.batch_size = batch_size
# Text or Text+Image model
if isinstance(embedding_model, SentenceTransformer):
self.embedding_model = embedding_model
elif isinstance(embedding_model, str):
self.embedding_model = SentenceTransformer(embedding_model)
else:
raise ValueError(
"Please select a correct SentenceTransformers model: \n"
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('clip-ViT-B-32')`"
)
# Image Model
self.image_model = None
if image_model is not None:
if isinstance(image_model, SentenceTransformer):
self.image_model = image_model
elif isinstance(image_model, str):
self.image_model = SentenceTransformer(image_model)
else:
raise ValueError(
"Please select a correct SentenceTransformers model: \n"
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('clip-ViT-B-32')`"
)
try:
self.tokenizer = self.embedding_model._first_module().processor.tokenizer
except AttributeError:
self.tokenizer = self.embedding_model.tokenizer
except: # noqa: E722
self.tokenizer = None
def embed(self, documents: List[str], images: List[str] = None, verbose: bool = False) -> np.ndarray:
"""Embed a list of n documents/words or images into an n-dimensional
matrix of embeddings.
Either documents, images, or both can be provided. If both are provided,
then the embeddings are averaged.
Arguments:
documents: A list of documents or words to be embedded
images: A list of image paths to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
# Embed documents
doc_embeddings = None
if documents[0] is not None:
doc_embeddings = self.embed_documents(documents)
# Embed images
image_embeddings = None
if isinstance(images, list):
image_embeddings = self.embed_images(images, verbose)
# Average embeddings
averaged_embeddings = None
if doc_embeddings is not None and image_embeddings is not None:
averaged_embeddings = np.mean([doc_embeddings, image_embeddings], axis=0)
if averaged_embeddings is not None:
return averaged_embeddings
elif doc_embeddings is not None:
return doc_embeddings
elif image_embeddings is not None:
return image_embeddings
def embed_documents(self, documents: List[str], verbose: bool = False) -> np.ndarray:
"""Embed a list of n documents/words into an n-dimensional
matrix of embeddings.
Arguments:
documents: A list of documents or words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
truncated_docs = [self._truncate_document(doc) for doc in documents]
embeddings = self.embedding_model.encode(truncated_docs, show_progress_bar=verbose)
return embeddings
def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray:
"""Embed a list of n words into an n-dimensional
matrix of embeddings.
Arguments:
words: A list of words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
embeddings = self.embedding_model.encode(words, show_progress_bar=verbose)
return embeddings
def embed_images(self, images, verbose):
if self.batch_size:
nr_iterations = int(np.ceil(len(images) / self.batch_size))
# Embed images per batch
embeddings = []
for i in tqdm(range(nr_iterations), disable=not verbose):
start_index = i * self.batch_size
end_index = (i * self.batch_size) + self.batch_size
images_to_embed = [
Image.open(image) if isinstance(image, str) else image for image in images[start_index:end_index]
]
if self.image_model is not None:
img_emb = self.image_model.encode(images_to_embed)
else:
img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
embeddings.extend(img_emb.tolist())
# Close images
if isinstance(images[0], str):
for image in images_to_embed:
image.close()
embeddings = np.array(embeddings)
else:
images_to_embed = [Image.open(filepath) for filepath in images]
if self.image_model is not None:
embeddings = self.image_model.encode(images_to_embed)
else:
embeddings = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
return embeddings
def _truncate_document(self, document):
if self.tokenizer:
tokens = self.tokenizer.encode(document)
if len(tokens) > 77:
# Skip the starting token, only include 75 tokens
truncated_tokens = tokens[1:76]
document = self.tokenizer.decode(truncated_tokens)
# Recursive call here, because the encode(decode()) can have different result
return self._truncate_document(document)
return document