-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathtest_chat.py
115 lines (97 loc) · 4.85 KB
/
test_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModelForSeq2SeqLM
from Conversation.conversation import character_msg_constructor
from Conversation.translation.pipeline import Translate
from AIVoifu.tts import tts # text to speech from huggingface
from vtube_studio import Char_control
import romajitable # temporary use this since It'll blow up our ram if we use Machine Translation Model
import scipy.io.wavfile as wavfile
import torch
import wget
# ---------- Config ----------
translation = bool(input("Enable translation? (Y/n): ").lower() in {'y', ''})
device = torch.device('cpu') # default to cpu
use_gpu = torch.cuda.is_available()
print("Detecting GPU...")
if use_gpu:
print("GPU detected!")
device = torch.device('cuda')
print("Using GPU? (Y/N)")
if input().lower() == 'y':
print("Using GPU...")
else:
print("Using CPU...")
use_gpu = False
device = torch.device('cpu')
# ---------- load Conversation model ----------
print("Initilizing model....")
print("Loading language model...")
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-1.3b", use_fast=True)
config = AutoConfig.from_pretrained("PygmalionAI/pygmalion-1.3b", is_decoder=True)
model = AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-1.3b", config=config, )
if use_gpu: # load model to GPU
model = model.to(device)
print("Inference at half precision? (Y/N)")
if input().lower() == 'y':
print("Loading model at half precision...")
model.half()
else:
print("Loading model at full precision...")
if translation:
print("Translation enabled!")
print("Loading machine translation model...")
translator = Translate(device, language="jpn_Jpan") # initialize translator #todo **tt fix translation
else:
print("Translation disabled!")
print("Proceeding... wtih pure english conversation")
print('--------Finished!----------')
# --------------------------------------------------
# --------- Define Waifu personality ----------
talk = character_msg_constructor('Lilia', """Species("Elf")
Mind("sexy" + "cute" + "Loving" + "Based as Fuck")
Personality("sexy" + "cute"+ "kind + "Loving" + "Based as Fuck")
Body("160cm tall" + "5 foot 2 inches tall" + "small breasts" + "white" + "slim")
Description("Lilia is 18 years old girl" + "she love pancake")
Loves("Cats" + "Birds" + "Waterfalls")
Sexual Orientation("Straight" + "Hetero" + "Heterosexual")""")
# ---------------------------------------------
from fastapi.responses import JSONResponse
def get_waifuapi(command: str, data: str):
if command == "chat":
msg = data
# ----------- Create Response --------------------------
msg = talk.construct_msg(msg, talk.history_loop_cache) # construct message input and cache History model
## ----------- Will move this to server later -------- (16GB ram needed at least)
inputs = tokenizer(msg, return_tensors='pt')
if use_gpu:
inputs = inputs.to(device)
print("generate output ..\n")
out = model.generate(**inputs, max_length=len(inputs['input_ids'][0]) + 80, #todo 200 ?
pad_token_id=tokenizer.eos_token_id, do_sample=True, top_k=50, top_p=0.95)
conversation = tokenizer.batch_decode(out, skip_special_tokens=True)
print(conversation)
# print("conversation .. \n" + conversation)
## --------------------------------------------------
## get conversation in proper format and create history from [last_idx: last_idx+2] conversation
talk.split_counter += 0
print("get_current_converse ..\n")
current_converse = talk.get_current_converse(conversation[1])
print("answer ..\n") # only print waifu answer since input already show
print(current_converse)
# talk.history_loop_cache = '\n'.join(current_converse) # update history for next input message
# -------------- use machine translation model to translate to japanese and submit to client --------------
print("cleaning ..\n")
cleaned_text = talk.clean_emotion_action_text_for_speech(current_converse) # clean text for speech
print("cleaned_text\n"+ cleaned_text)
translated = '' # initialize translated text as empty by default
if translation:
translated = translator.translate(cleaned_text) # translate to [language] if translation is enabled
print("translated\n" + translated)
# return JSONResponse(content=f'{current_converse[-1]}<split_token>{translated}')
if command == "reset":
talk.conversation_history = ''
talk.history_loop_cache = ''
talk.split_counter = 0
# return JSONResponse(content='Story reseted...')
get_waifuapi("reset", "")
get_waifuapi("chat", "hi, how are you ?")
get_waifuapi("chat", "Can you recommend good place to relax in tokyo ?")