戒酒的李白

Implement Chinese word segmentation

@@ -3,6 +3,7 @@ from transformers.models.bert import BertTokenizer, BertModel @@ -3,6 +3,7 @@ from transformers.models.bert import BertTokenizer, BertModel
3 import torch 3 import torch
4 from tqdm import tqdm 4 from tqdm import tqdm
5 import numpy as np 5 import numpy as np
  6 +import jieba
6 7
7 class BERT_CTM_Model: 8 class BERT_CTM_Model:
8 def __init__(self, bert_model_path): 9 def __init__(self, bert_model_path):
@@ -19,9 +20,13 @@ class BERT_CTM_Model: @@ -19,9 +20,13 @@ class BERT_CTM_Model:
19 outputs = self.model(**inputs) 20 outputs = self.model(**inputs)
20 embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] 21 embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
21 return np.vstack(embeddings) 22 return np.vstack(embeddings)
  23 +
  24 + def chinese_tokenize(self, text):
  25 + """使用jieba对中文文本进行分词"""
  26 + return " ".join(jieba.cut(text))
22 27
23 if __name__ == "__main__": 28 if __name__ == "__main__":
24 model = BERT_CTM_Model('./bert_model') 29 model = BERT_CTM_Model('./bert_model')
25 - texts = ["这是第一个文本", "这是第二个文本"]  
26 - embeddings = model.get_bert_embeddings(texts)  
27 - print(embeddings.shape) 30 + text = "这是一个测试文本"
  31 + tokenized_text = model.chinese_tokenize(text)
  32 + print(tokenized_text)