ytgui commited on
Commit
cf584be
·
verified ·
1 Parent(s): baf2444

Create example.py

Browse files
Files changed (1) hide show
  1. example.py +231 -0
example.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from tqdm import tqdm
4
+ from torch.nn import functional as F
5
+ from transformers import (
6
+ set_seed, pipeline, AutoTokenizer, AutoModelForCausalLM
7
+ )
8
+
9
+ EMBEDDING = """
10
+ You are a helpful AI assistant. Your task is to analyze input text and create a high-quality semantic vector embedding, which represents key concepts, relationships, and semantic meaning.
11
+ """
12
+ GENERATION = """
13
+ You are a helpful AI assistant. Your task is to enrich user input for more effective embedding representation by adding semantic depth.
14
+
15
+ For each input, briefly enhance the content by:
16
+ 1. Identifying core concepts and their relationships.
17
+ 2. Including key terminology with essential definitions.
18
+ 3. Adding contextually relevant synonyms and related terms.
19
+ 4. Connecting to related topics and common applications without excessive elaboration.
20
+
21
+ To represent the final embedding, you MUST end every response with <|embed_token|>.
22
+ """
23
+
24
+
25
+ class SearchR3(nn.Module):
26
+ def __init__(self,
27
+ path: str,
28
+ max_length: int,
29
+ batch_size: int):
30
+ nn.Module.__init__(self)
31
+ #
32
+ self.model = AutoModelForCausalLM.from_pretrained(
33
+ path, torch_dtype='auto', device_map='auto'
34
+ )
35
+ self.tokenizer = AutoTokenizer.from_pretrained(
36
+ path, truncation_side='left', padding_side='left'
37
+ )
38
+ self.embed_token = self.tokenizer.encode('<|embed_token|>')[0]
39
+ self.max_length = max_length
40
+ self.batch_size = batch_size
41
+
42
+ @property
43
+ def device(self):
44
+ return next(self.model.parameters()).device
45
+
46
+ @torch.no_grad()
47
+ def generate(self, batch: list[str]):
48
+ if not isinstance(batch, (list, tuple)):
49
+ raise ValueError('batch type is incorrect')
50
+ if any(not isinstance(v, str) for v in batch):
51
+ raise ValueError('batch item type is incorrect')
52
+
53
+ # batch
54
+ if len(batch) > self.batch_size:
55
+ outputs = []
56
+ for i in tqdm(
57
+ range(0, len(batch), self.batch_size)
58
+ ):
59
+ outputs.extend(
60
+ self.generate(
61
+ batch[i:i + self.batch_size]
62
+ )
63
+ )
64
+ return outputs
65
+
66
+ # tokenize
67
+ messages = [
68
+ [
69
+ {'role': 'system', 'content': GENERATION.strip()},
70
+ {'role': 'user', 'content': item}
71
+ ]
72
+ for item in batch
73
+ ]
74
+ context = self.tokenizer.apply_chat_template(
75
+ messages, tokenize=False, add_generation_prompt=True
76
+ )
77
+ inputs = self.tokenizer(
78
+ context, padding='longest', truncation=True,
79
+ return_tensors='pt', max_length=self.max_length // 2
80
+ )
81
+ prompt_length = inputs['input_ids'].size(-1)
82
+
83
+ # generate
84
+ self.model.eval()
85
+ outputs = self.model.generate(
86
+ **inputs.to(device=self.device),
87
+ max_new_tokens=self.max_length - prompt_length
88
+ )
89
+ outputs = self.tokenizer.batch_decode(
90
+ outputs[:, prompt_length:], skip_special_tokens=False
91
+ )
92
+
93
+ # cleanup
94
+ for special_token in self.tokenizer.all_special_tokens:
95
+ if special_token == '<|embed_token|>':
96
+ continue
97
+ outputs = [
98
+ item.replace(special_token, '') for item in outputs
99
+ ]
100
+ messages = [
101
+ item + [
102
+ {'role': 'assistant', 'content': outputs[i].strip()}
103
+ ]
104
+ for i, item in enumerate(messages)
105
+ ]
106
+ return messages
107
+
108
+ def format(self, batch: list[str]):
109
+ if any(not isinstance(v, str) for v in batch):
110
+ raise RuntimeError('batch type is incorrect')
111
+ return [
112
+ [
113
+ {'role': 'system', 'content': EMBEDDING.strip()},
114
+ {'role': 'user', 'content': item},
115
+ {'role': 'assistant', 'content': 'The embedding is: <|embed_token|>'}
116
+ ]
117
+ for item in batch
118
+ ]
119
+
120
+ @torch.no_grad()
121
+ def encode(self, batch: list[any]):
122
+ if not isinstance(batch, (list, tuple)):
123
+ raise ValueError('batch type is incorrect')
124
+
125
+ # batch
126
+ if len(batch) > self.batch_size:
127
+ outputs = [
128
+ self.encode(
129
+ batch[i:i + self.batch_size]
130
+ )
131
+ for i in tqdm(
132
+ range(0, len(batch), self.batch_size)
133
+ )
134
+ ]
135
+ return torch.cat(outputs, dim=0)
136
+
137
+ # format
138
+ if all(isinstance(v, str) for v in batch):
139
+ batch = self.format(batch=batch)
140
+
141
+ # validate
142
+ if any(
143
+ m[-1]['role'] != 'assistant' for m in batch
144
+ ):
145
+ raise RuntimeError('unexpected role')
146
+ if any(
147
+ m[-2]['role'] != 'user' for m in batch
148
+ ):
149
+ raise RuntimeError('unexpected role')
150
+
151
+ # ensure <embed_token>
152
+ batch = [
153
+ m if '<|embed_token|>' in m[-1]['content']
154
+ else self.format([m[-2]['content']])[0]
155
+ for m in batch
156
+ ]
157
+ if any(
158
+ '<|embed_token|>' not in m[-1]['content'] for m in batch
159
+ ):
160
+ raise RuntimeError('unexpected embed token')
161
+
162
+ # tokenize
163
+ context = self.tokenizer.apply_chat_template(
164
+ batch, tokenize=False, add_generation_prompt=False
165
+ )
166
+ inputs = self.tokenizer(
167
+ context, padding='longest', truncation=True,
168
+ return_tensors='pt', max_length=self.max_length
169
+ )
170
+
171
+ # forward
172
+ self.model.eval()
173
+ outputs = self.model(
174
+ **inputs.to(device=self.device),
175
+ return_dict=True, output_hidden_states=True
176
+ )
177
+ hidden_state = outputs['hidden_states'][-1]
178
+
179
+ # pooling
180
+ length = inputs['input_ids'].size(-1)
181
+ valid_mask = torch.arange(length, device=self.device)
182
+ valid_mask = torch.where(
183
+ valid_mask.unsqueeze(0) > length - 5, True, False
184
+ )
185
+ embed_mask = torch.where(
186
+ inputs['input_ids'] == self.embed_token, True, False
187
+ )
188
+ embed_mask = embed_mask.logical_and(valid_mask)
189
+ return F.normalize(
190
+ hidden_state[embed_mask].cpu().float(), dim=-1
191
+ )
192
+
193
+
194
+ def main():
195
+ # init
196
+ set_seed(42)
197
+ from pprint import pprint
198
+
199
+ # basic
200
+ generator = pipeline(
201
+ task='text-generation',
202
+ model='ytgui/Search-R3.0-Small',
203
+ torch_dtype='auto', device_map='auto'
204
+ )
205
+ messages = [
206
+ {"role": 'user', 'content': 'Who are you?'},
207
+ ]
208
+ response = generator(messages, max_new_tokens=256)
209
+ pprint(response)
210
+
211
+ # reasoning
212
+ model = SearchR3(
213
+ 'ytgui/Search-R3.0-Small', max_length=1024, batch_size=8
214
+ )
215
+ reasoning = model.generate(
216
+ batch=['what python library is useful for data analysis?']
217
+ )
218
+ pprint(reasoning)
219
+
220
+ # embedding
221
+ documents = [
222
+ 'pandas is a fast, powerful, flexible and easy to use open source data analysis and manipulation tool, built on top of the Python programming language.',
223
+ 'The giant panda (Ailuropoda melanoleuca), also known as the panda bear or simply panda, is a bear species endemic to China. It is characterised by its white coat with black patches around the eyes, ears, legs and shoulders.',
224
+ ]
225
+ E_d = model.encode(batch=documents)
226
+ E_q = model.encode(batch=reasoning)
227
+ print('distance:', torch.cdist(E_q, E_d, p=2.0))
228
+
229
+
230
+ if __name__ == '__main__':
231
+ main()