TestSetup.py
4.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from topicgpt.TopicRepresentation import Topic
import unittest
from sklearn.datasets import fetch_20newsgroups
from topicgpt.TopicGPT import TopicGPT
class QuickTestTopicGPT_init_and_fit(unittest.TestCase):
"""
Run some basic tests on TopicGPT that do not require any saved data
"""
@classmethod
def setUpClass(cls, sample_size:int = 500):
"""
download the necessary data and only keep a sample of it
params:
api_key: the openai api key
sample_size: the number of documents to use for the test
"""
data = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes')) #download the 20 Newsgroups dataset
corpus = data['data']# just select the first 1000 documents for this example
corpus = [doc for doc in corpus if doc != ""]
corpus = corpus[:sample_size]
cls.corpus = corpus
def setUp(self):
self.api_key_openai = api_key
def test_init(self):
"""
test the init function of the TopicGPT class
"""
print("Testing init...")
topicgpt = TopicGPT(api_key = self.api_key_openai)
self.assertTrue(isinstance(topicgpt, TopicGPT))
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 20)
self.assertTrue(isinstance(topicgpt, TopicGPT))
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 20,
corpus_instruction="This is a corpus instruction")
self.assertTrue(isinstance(topicgpt, TopicGPT))
# check if assertions are triggered
with self.assertRaises(AssertionError):
topicgpt = TopicGPT(api_key = None,
n_topics= 32,
openai_prompting_model="gpt-4",
max_number_of_tokens=8000,
corpus_instruction="This is a corpus instruction")
with self.assertRaises(AssertionError):
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 0,
max_number_of_tokens=8000,
corpus_instruction="This is a corpus instruction")
with self.assertRaises(AssertionError):
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 20,
max_number_of_tokens=0,
corpus_instruction="This is a corpus instruction")
def test_fit(self):
"""
test the fit function of the TopicGPT class
"""
print("Testing fit...")
def instance_test(topicgpt):
topicgpt.fit(self.corpus)
self.assertTrue(hasattr(topicgpt, "vocab"))
self.assertTrue(hasattr(topicgpt, "topic_lis"))
self.assertTrue(isinstance(topicgpt.vocab, list))
self.assertTrue(isinstance(topicgpt.vocab[0], str))
self.assertTrue(isinstance(topicgpt.topic_lis, list))
self.assertTrue(type(topicgpt.topic_lis[0]) == Topic)
if topicgpt.n_topics is not None:
self.assertTrue(len(topicgpt.topic_lis) == topicgpt.n_topics)
self.assertTrue(topicgpt.topic_lis == topicgpt.topic_prompting.topic_lis)
self.assertTrue(topicgpt.vocab == topicgpt.topic_prompting.vocab)
self.assertTrue(topicgpt.vocab_embeddings == topicgpt.topic_prompting.vocab_embeddings)
topicgpt1 = TopicGPT(api_key = self.api_key_openai, n_topics = 1)
topic_gpt_list = [topicgpt1]
for topic_gpt in topic_gpt_list:
instance_test(topic_gpt)
import sys
if __name__ == "__main__":
for i, arg in enumerate(sys.argv):
if arg == "--api-key":
api_key = sys.argv.pop(i + 1)
sys.argv.pop(i)
break
if api_key is None:
print("API key must be provided with --api-key")
sys.exit(1)
unittest.main()