_utils.py
5.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from ._base import BaseEmbedder
# Imports for light-weight variant of BERTopic
from bertopic.backend._sklearn import SklearnEmbedder
from bertopic._utils import MyLogger
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline as ScikitPipeline
logger = MyLogger()
logger.configure("WARNING")
languages = [
"arabic",
"bulgarian",
"catalan",
"czech",
"danish",
"german",
"greek",
"english",
"spanish",
"estonian",
"persian",
"finnish",
"french",
"canadian french",
"galician",
"gujarati",
"hebrew",
"hindi",
"croatian",
"hungarian",
"armenian",
"indonesian",
"italian",
"japanese",
"georgian",
"korean",
"kurdish",
"lithuanian",
"latvian",
"macedonian",
"mongolian",
"marathi",
"malay",
"burmese",
"norwegian bokmal",
"dutch",
"polish",
"portuguese",
"brazilian portuguese",
"romanian",
"russian",
"slovak",
"slovenian",
"albanian",
"serbian",
"swedish",
"thai",
"turkish",
"ukrainian",
"urdu",
"vietnamese",
"chinese (simplified)",
"chinese (traditional)",
]
def select_backend(embedding_model, language: str = None, verbose: bool = False) -> BaseEmbedder:
"""Select an embedding model based on language or a specific provided model.
When selecting a language, we choose all-MiniLM-L6-v2 for English and
paraphrase-multilingual-MiniLM-L12-v2 for all other languages as it support 100+ languages.
If sentence-transformers is not installed, in the case of a lightweight installation,
a scikit-learn backend is default.
Returns:
model: The selected model backend.
"""
logger.set_level("INFO" if verbose else "WARNING")
# BERTopic language backend
if isinstance(embedding_model, BaseEmbedder):
return embedding_model
# Scikit-learn backend
if isinstance(embedding_model, ScikitPipeline):
return SklearnEmbedder(embedding_model)
# Flair word embeddings
if "flair" in str(type(embedding_model)):
from bertopic.backend._flair import FlairBackend
return FlairBackend(embedding_model)
# Spacy embeddings
if "spacy" in str(type(embedding_model)):
from bertopic.backend._spacy import SpacyBackend
return SpacyBackend(embedding_model)
# Gensim embeddings
if "gensim" in str(type(embedding_model)):
from bertopic.backend._gensim import GensimBackend
return GensimBackend(embedding_model)
# USE embeddings
if "tensorflow" and "saved_model" in str(type(embedding_model)):
from bertopic.backend._use import USEBackend
return USEBackend(embedding_model)
# Sentence Transformer embeddings
if "sentence_transformers" in str(type(embedding_model)) or isinstance(embedding_model, str):
from ._sentencetransformers import SentenceTransformerBackend
return SentenceTransformerBackend(embedding_model)
# Hugging Face embeddings
if "transformers" and "pipeline" in str(type(embedding_model)):
from ._hftransformers import HFTransformerBackend
return HFTransformerBackend(embedding_model)
# Model2Vec embeddings
if "model2vec" in str(type(embedding_model)):
from ._model2vec import Model2VecBackend
return Model2VecBackend(embedding_model)
# FastEmbed word embeddings
if "fastembed" in str(type(embedding_model)):
from bertopic.backend._fastembed import FastEmbedBackend
return FastEmbedBackend(embedding_model)
# Select embedding model based on language
if language:
try:
from ._sentencetransformers import SentenceTransformerBackend
if language.lower() in ["English", "english", "en"]:
return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")
elif language.lower() in languages or language == "multilingual":
return SentenceTransformerBackend("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
else:
raise ValueError(
f"{language} is currently not supported. However, you can "
f"create any embeddings yourself and pass it through fit_transform(docs, embeddings)\n"
"Else, please select a language from the following list:\n"
f"{languages}"
)
# A ModuleNotFoundError might be a lightweight installation
except ModuleNotFoundError as e:
if e.name != "sentence_transformers":
# Error occurred in a downstream module, probably not a lightweight install
raise e
# Whole sentence_transformers module is missing, probably a lightweight install
if verbose:
logger.info(
"Automatically selecting lightweight scikit-learn embedding backend as sentence-transformers appears to not be installed."
)
pipe = make_pipeline(TfidfVectorizer(), TruncatedSVD(100))
return SklearnEmbedder(pipe)
from ._sentencetransformers import SentenceTransformerBackend
return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")