Spaces:
Runtime error
Runtime error
| from chats import stablelm | |
| from chats import alpaca | |
| from chats import koalpaca | |
| from chats import flan_alpaca | |
| from chats import os_stablelm | |
| from chats import vicuna | |
| from chats import starchat | |
| from chats import redpajama | |
| from chats import mpt | |
| from chats import alpacoom | |
| from chats import baize | |
| from chats import guanaco | |
| def chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ): | |
| model_type = state["model_type"] | |
| if model_type == "stablelm": | |
| cs = stablelm.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "baize": | |
| cs = baize.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "alpaca": | |
| cs = alpaca.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "alpaca-gpt4": | |
| cs = alpaca.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "alpacoom": | |
| cs = alpacoom.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "llama-deus": | |
| cs = alpaca.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "camel": | |
| cs = alpaca.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "koalpaca-polyglot": | |
| cs = koalpaca.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "flan-alpaca": | |
| cs = flan_alpaca.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "os-stablelm": | |
| cs = os_stablelm.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "t5-vicuna": | |
| cs = vicuna.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "stable-vicuna": | |
| cs = vicuna.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "vicuna": | |
| cs = vicuna.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "evolinstruct-vicuna": | |
| cs = vicuna.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "starchat": | |
| cs = starchat.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "mpt": | |
| cs = mpt.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "redpajama": | |
| cs = redpajama.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "guanaco": | |
| cs = guanaco.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| elif model_type == "nous-hermes": | |
| cs = alpaca.chat_stream( | |
| idx, local_data, user_message, state, model_num, | |
| global_context, ctx_num_lconv, ctx_sum_prompt, | |
| res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
| ) | |
| for idx, x in enumerate(cs): | |
| yield x | |