Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from datasets import load_from_disk | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.quantization import quantize_embeddings | |
| import faiss | |
| from usearch.index import Index | |
| import numpy as np | |
| import os | |
| base_path = os.getcwd() | |
| full_path = os.path.join(base_path, 'conala') | |
| conala_dataset = load_from_disk(full_path) | |
| int8_view = Index.restore(os.path.join(base_path, 'conala_int8_usearch.index'), view=True) | |
| binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(os.path.join(base_path, 'conala.index')) | |
| model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| def search(query, top_k: int = 20): | |
| # 1. Embed the query as float32 | |
| query_embedding = model.encode(query) | |
| # 2. Quantize the query to ubinary. To perform actual search with faiss | |
| query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary") | |
| # 3. Search the binary index | |
| index = binary_index | |
| _scores, binary_ids = index.search(query_embedding_ubinary, top_k) | |
| binary_ids = binary_ids[0] | |
| # 4. Load the corresponding int8 embeddings. To perform rescoring to calculate score of fetched documents. | |
| int8_embeddings = int8_view[binary_ids].astype(int) | |
| # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings | |
| scores = query_embedding @ int8_embeddings.T | |
| # 6. Sort the scores and return the top_k | |
| indices = scores.argsort()[::-1][:top_k] | |
| top_k_indices = binary_ids[indices] | |
| top_k_scores = scores[indices] | |
| top_k_codes = conala_dataset[top_k_indices] | |
| return top_k_codes | |
| def response_generator(user_prompt): | |
| top_k_outputs = search(user_prompt) | |
| probs = top_k_outputs['prob'] | |
| snippets = top_k_outputs['snippet'] | |
| idx = np.argsort(probs)[::-1] | |
| results = np.array(snippets)[idx] | |
| filtered_results = [] | |
| for item in results: | |
| if len(filtered_results)<3: | |
| if item not in filtered_results: | |
| filtered_results.append(item) | |
| output_template = "User Query: {user_query}\nBelow are some examples of previous conversations.\nQuery: {query1} Solution: {solution1}\nQuery: {query2} Solution: {solution2}\nYou may use the above examples for reference only. Create your own solution and provide only the solution" | |
| output_template = "The top three most relevant code snippets from the database are:\n\n1. {snippet1}\n\n2. {snippet2}\n\n3. {snippet3}" | |
| output = f'{output_template.format(snippet1=filtered_results[0],snippet2=filtered_results[1],snippet3=filtered_results[2])}' | |
| return {output_box:output} | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Embedding Quantization | |
| ## Quantized Semantic Search | |
| - ***Embedding:*** all-MiniLM-L6-v2 | |
| - ***Vetor DB:*** faiss, USearch | |
| - ***Vector_DB Size:*** `5,93,891` | |
| """) | |
| state_var = gr.State([]) | |
| input_box = gr.Textbox(autoscroll=True,visible=True,label='User',info="Enter a query.",value="How to extract the n-th elements from a list of tuples in python?") | |
| output_box = gr.Textbox(autoscroll=True,max_lines=30,value="Output",label='Assistant') | |
| gr.Interface(fn=response_generator, inputs=[input_box], outputs=[output_box], | |
| delete_cache=(20,10), | |
| allow_flagging='never') | |
| demo.queue() | |
| demo.launch() | |