-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_modules.py
198 lines (172 loc) · 5.71 KB
/
main_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import random
from py2neo import Graph
import requests
import json
from config import *
from modules.LR_GBDT.clf_model import CLFModel
LR_GBDT=CLFModel('./modules/LR_GBDT/model_weights/')
graph = Graph("http://localhost:7474", auth=("neo4j", "wangxiao1024"))
def intent_classifier(text):
url = 'http://192.168.31.112:5002/service/api/bert'
data = {"text":text}
# print(data)
headers = {'Content-Type':'application/json;charset=utf8'}
reponse = requests.post(url,data=json.dumps(data),headers=headers)
if reponse.status_code == 200:
reponse = json.loads(reponse.text)
# print(reponse)
return reponse['data']
else:
return -1
def slot_recognizer(text):
url = 'http://192.168.31.112:5001/service/api/ner'
data = {"text_list":text}
headers = {'Content-Type':'application/json;charset=utf8'}
reponse = requests.post(url,data=json.dumps(data),headers=headers)
if reponse.status_code == 200:
reponse = json.loads(reponse.text)
return reponse['data']
else:
return -1
def classifier(msg):
##对用户初次意图进行判断分为:
"""
greet,goodbye,deny,isbot
"""
return LR_GBDT.predict(msg)
def entity_link(mention,etype):
"""
对于识别到的实体mention,如果其不是知识库中的标准称谓
则对其进行实体链指,将其指向一个唯一实体(待实现)
"""
return mention
def text_analysis(msg):
"""
文本解析
:param msg:
:return:
"""
intent_msg=intent_classifier(msg)
entity=slot_recognizer(msg)
# print(intent_msg)
# print(entity)
if intent_msg.get("name")=='其他'or intent_msg==-1 or entity==-1:
return semantic_slot.get("unrecognized")
slot_info=semantic_slot.get(intent_msg.get('name'))
##语义槽的填充
slots=slot_info.get('slot_list')#Disease
slot_values={}
for slot in slots:
slot_values[slot]=None
for ent_info in entity:
for e in ent_info["entities"]:
if slot.lower() ==e['type']:
slot_values[slot]=entity_link(e['word'],e['type'])##做实体链接把实体传到字典里面
slot_info['slot_values']=slot_values
# print(slot_values)
# print(slot_info)
conf=intent_msg.get('confidence')
# print(slot_info)
if conf >= intent_threshold_config["accept"]:
slot_info["intent_strategy"] = "accept"
elif conf >= intent_threshold_config["deny"]:
slot_info["intent_strategy"] = "clarify"
else:
slot_info["intent_strategy"] = "deny"
return slot_info
def neo4j_searcher(cql_list):
ress = ""
if isinstance(cql_list, list):
for cql in cql_list:
rst = []
data = graph.run(cql).data()
if not data:
continue
for d in data:
d = list(d.values())
if isinstance(d[0], list):
rst.extend(d[0])
else:
rst.extend(d)
data = "、".join([str(i) for i in rst])
ress += data + "\n"
else:
data = graph.run(cql_list).data()
if not data:
return ress
rst = []
for d in data:
d = list(d.values())
if isinstance(d[0], list):
rst.extend(d[0])
else:
rst.extend(d)
data = "、".join([str(i) for i in rst])
ress += data
return ress
def get_answer(slot_info):
"""
根据语义槽获取答案回复
"""
cql_template = slot_info.get("cql_template")
reply_template = slot_info.get("reply_template")
ask_template = slot_info.get("ask_template")
slot_values = slot_info.get("slot_values")
strategy = slot_info.get("intent_strategy")
if not slot_values:
return slot_info
if strategy == "accept":
cql = []
if isinstance(cql_template, list):
for cqlt in cql_template:
# print(cqlt)
cql.append(cqlt.format(**slot_values))
else:
cql = cql_template.format(**slot_values)
print(cql)
answer = neo4j_searcher(cql)
if not answer:
slot_info["replay_answer"] = "唔~我装满知识的大脑此刻很贫瘠"
else:
pattern = reply_template.format(**slot_values)
slot_info["replay_answer"] = pattern + answer
elif strategy == "clarify":
# 澄清用户是否问该问题
pattern = ask_template.format(**slot_values)
slot_info["replay_answer"] = pattern
# 得到肯定意图之后需要给用户回复的答案
cql = []
if isinstance(cql_template, list):
for cqlt in cql_template:
cql.append(cqlt.format(**slot_values))
else:
cql = cql_template.format(**slot_values)
answer = neo4j_searcher(cql)
if not answer:
slot_info["replay_answer"] = "唔~我装满知识的大脑此刻很贫瘠"
else:
pattern = reply_template.format(**slot_values)
slot_info["choice_answer"] = pattern + answer
elif strategy == "deny":
slot_info["replay_answer"] = slot_info.get("deny_response")
return slot_info
def chat_robot(intent):
"""
闲聊机器人
:param intent:
:return: 根据意图随机返回一条自定义模板中的回复。
"""
return random.choice(cheat_corpus.get(intent))
def medical_robot(msg):
"""
医疗机器人
:param msg:
:return:
"""
semantic_slot=text_analysis(msg)
# print(semantic_slot)
answer=get_answer(semantic_slot)
return answer
# return "ok"
# if __name__=='__main__':
# medical_robot("你好我感冒了,应该吃什么药")