Showing
1 changed file
with
8 additions
and
3 deletions
| @@ -3,6 +3,7 @@ from transformers.models.bert import BertTokenizer, BertModel | @@ -3,6 +3,7 @@ from transformers.models.bert import BertTokenizer, BertModel | ||
| 3 | import torch | 3 | import torch |
| 4 | from tqdm import tqdm | 4 | from tqdm import tqdm |
| 5 | import numpy as np | 5 | import numpy as np |
| 6 | +import jieba | ||
| 6 | 7 | ||
| 7 | class BERT_CTM_Model: | 8 | class BERT_CTM_Model: |
| 8 | def __init__(self, bert_model_path): | 9 | def __init__(self, bert_model_path): |
| @@ -19,9 +20,13 @@ class BERT_CTM_Model: | @@ -19,9 +20,13 @@ class BERT_CTM_Model: | ||
| 19 | outputs = self.model(**inputs) | 20 | outputs = self.model(**inputs) |
| 20 | embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] | 21 | embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] |
| 21 | return np.vstack(embeddings) | 22 | return np.vstack(embeddings) |
| 23 | + | ||
| 24 | + def chinese_tokenize(self, text): | ||
| 25 | + """使用jieba对中文文本进行分词""" | ||
| 26 | + return " ".join(jieba.cut(text)) | ||
| 22 | 27 | ||
| 23 | if __name__ == "__main__": | 28 | if __name__ == "__main__": |
| 24 | model = BERT_CTM_Model('./bert_model') | 29 | model = BERT_CTM_Model('./bert_model') |
| 25 | - texts = ["这是第一个文本", "这是第二个文本"] | ||
| 26 | - embeddings = model.get_bert_embeddings(texts) | ||
| 27 | - print(embeddings.shape) | 30 | + text = "这是一个测试文本" |
| 31 | + tokenized_text = model.chinese_tokenize(text) | ||
| 32 | + print(tokenized_text) |
-
Please register or login to post a comment