|
|
import torch |
|
|
from torch import nn |
|
|
from tqdm import tqdm |
|
|
from torch.nn import functional as F |
|
|
from transformers import ( |
|
|
set_seed, pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
) |
|
|
|
|
|
EMBEDDING = """ |
|
|
You are a helpful AI assistant. Your task is to analyze input text and create a high-quality semantic vector embedding, which represents key concepts, relationships, and semantic meaning. |
|
|
""" |
|
|
GENERATION = """ |
|
|
You are a helpful AI assistant. Your task is to enrich user input for more effective embedding representation by adding semantic depth. |
|
|
|
|
|
For each input, briefly enhance the content by: |
|
|
1. Identifying core concepts and their relationships. |
|
|
2. Including key terminology with essential definitions. |
|
|
3. Adding contextually relevant synonyms and related terms. |
|
|
4. Connecting to related topics and common applications without excessive elaboration. |
|
|
|
|
|
To represent the final embedding, you MUST end every response with <|embed_token|>. |
|
|
""" |
|
|
|
|
|
|
|
|
class SearchR3(nn.Module): |
|
|
def __init__(self, |
|
|
path: str, |
|
|
max_length: int, |
|
|
batch_size: int): |
|
|
nn.Module.__init__(self) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, torch_dtype='auto', device_map='auto' |
|
|
) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
path, truncation_side='left', padding_side='left' |
|
|
) |
|
|
self.embed_token = self.tokenizer.encode('<|embed_token|>')[0] |
|
|
self.max_length = max_length |
|
|
self.batch_size = batch_size |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return next(self.model.parameters()).device |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, batch: list[str]): |
|
|
if not isinstance(batch, (list, tuple)): |
|
|
raise ValueError('batch type is incorrect') |
|
|
if any(not isinstance(v, str) for v in batch): |
|
|
raise ValueError('batch item type is incorrect') |
|
|
|
|
|
|
|
|
if len(batch) > self.batch_size: |
|
|
outputs = [] |
|
|
for i in tqdm( |
|
|
range(0, len(batch), self.batch_size) |
|
|
): |
|
|
outputs.extend( |
|
|
self.generate( |
|
|
batch[i:i + self.batch_size] |
|
|
) |
|
|
) |
|
|
return outputs |
|
|
|
|
|
|
|
|
messages = [ |
|
|
[ |
|
|
{'role': 'system', 'content': GENERATION.strip()}, |
|
|
{'role': 'user', 'content': item} |
|
|
] |
|
|
for item in batch |
|
|
] |
|
|
context = self.tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
inputs = self.tokenizer( |
|
|
context, padding='longest', truncation=True, |
|
|
return_tensors='pt', max_length=self.max_length // 2 |
|
|
) |
|
|
prompt_length = inputs['input_ids'].size(-1) |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
outputs = self.model.generate( |
|
|
**inputs.to(device=self.device), |
|
|
max_new_tokens=self.max_length - prompt_length |
|
|
) |
|
|
outputs = self.tokenizer.batch_decode( |
|
|
outputs[:, prompt_length:], skip_special_tokens=False |
|
|
) |
|
|
|
|
|
|
|
|
for special_token in self.tokenizer.all_special_tokens: |
|
|
if special_token == '<|embed_token|>': |
|
|
continue |
|
|
outputs = [ |
|
|
item.replace(special_token, '') for item in outputs |
|
|
] |
|
|
messages = [ |
|
|
item + [ |
|
|
{'role': 'assistant', 'content': outputs[i].strip()} |
|
|
] |
|
|
for i, item in enumerate(messages) |
|
|
] |
|
|
return messages |
|
|
|
|
|
def format(self, batch: list[str]): |
|
|
if any(not isinstance(v, str) for v in batch): |
|
|
raise RuntimeError('batch type is incorrect') |
|
|
return [ |
|
|
[ |
|
|
{'role': 'system', 'content': EMBEDDING.strip()}, |
|
|
{'role': 'user', 'content': item}, |
|
|
{'role': 'assistant', 'content': 'The embedding is: <|embed_token|>'} |
|
|
] |
|
|
for item in batch |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode(self, batch: list[any]): |
|
|
if not isinstance(batch, (list, tuple)): |
|
|
raise ValueError('batch type is incorrect') |
|
|
|
|
|
|
|
|
if len(batch) > self.batch_size: |
|
|
outputs = [ |
|
|
self.encode( |
|
|
batch[i:i + self.batch_size] |
|
|
) |
|
|
for i in tqdm( |
|
|
range(0, len(batch), self.batch_size) |
|
|
) |
|
|
] |
|
|
return torch.cat(outputs, dim=0) |
|
|
|
|
|
|
|
|
if all(isinstance(v, str) for v in batch): |
|
|
batch = self.format(batch=batch) |
|
|
|
|
|
|
|
|
if any( |
|
|
m[-1]['role'] != 'assistant' for m in batch |
|
|
): |
|
|
raise RuntimeError('unexpected role') |
|
|
if any( |
|
|
m[-2]['role'] != 'user' for m in batch |
|
|
): |
|
|
raise RuntimeError('unexpected role') |
|
|
|
|
|
|
|
|
batch = [ |
|
|
m if '<|embed_token|>' in m[-1]['content'] |
|
|
else self.format([m[-2]['content']])[0] |
|
|
for m in batch |
|
|
] |
|
|
if any( |
|
|
'<|embed_token|>' not in m[-1]['content'] for m in batch |
|
|
): |
|
|
raise RuntimeError('unexpected embed token') |
|
|
|
|
|
|
|
|
context = self.tokenizer.apply_chat_template( |
|
|
batch, tokenize=False, add_generation_prompt=False |
|
|
) |
|
|
inputs = self.tokenizer( |
|
|
context, padding='longest', truncation=True, |
|
|
return_tensors='pt', max_length=self.max_length |
|
|
) |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
outputs = self.model( |
|
|
**inputs.to(device=self.device), |
|
|
return_dict=True, output_hidden_states=True |
|
|
) |
|
|
hidden_state = outputs['hidden_states'][-1] |
|
|
|
|
|
|
|
|
length = inputs['input_ids'].size(-1) |
|
|
valid_mask = torch.arange(length, device=self.device) |
|
|
valid_mask = torch.where( |
|
|
valid_mask.unsqueeze(0) > length - 5, True, False |
|
|
) |
|
|
embed_mask = torch.where( |
|
|
inputs['input_ids'] == self.embed_token, True, False |
|
|
) |
|
|
embed_mask = embed_mask.logical_and(valid_mask) |
|
|
return F.normalize( |
|
|
hidden_state[embed_mask].cpu().float(), dim=-1 |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
set_seed(42) |
|
|
from pprint import pprint |
|
|
|
|
|
|
|
|
generator = pipeline( |
|
|
task='text-generation', |
|
|
model='ytgui/Search-R3.0-Small', |
|
|
torch_dtype='auto', device_map='auto' |
|
|
) |
|
|
messages = [ |
|
|
{"role": 'user', 'content': 'Who are you?'}, |
|
|
] |
|
|
response = generator(messages, max_new_tokens=256) |
|
|
pprint(response) |
|
|
|
|
|
|
|
|
model = SearchR3( |
|
|
'ytgui/Search-R3.0-Small', max_length=1024, batch_size=8 |
|
|
) |
|
|
reasoning = model.generate( |
|
|
batch=['what python library is useful for data analysis?'] |
|
|
) |
|
|
pprint(reasoning) |
|
|
|
|
|
|
|
|
documents = [ |
|
|
'pandas is a fast, powerful, flexible and easy to use open source data analysis and manipulation tool, built on top of the Python programming language.', |
|
|
'The giant panda (Ailuropoda melanoleuca), also known as the panda bear or simply panda, is a bear species endemic to China. It is characterised by its white coat with black patches around the eyes, ears, legs and shoulders.', |
|
|
] |
|
|
E_d = model.encode(batch=documents) |
|
|
E_q = model.encode(batch=reasoning) |
|
|
print('distance:', torch.cdist(E_q, E_d, p=2.0)) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|