Spaces:
Sleeping
Sleeping
Upload 14 files
Browse files- src/anna_agent_template.py +37 -0
- src/anna_agent_template_en.py +34 -0
- src/complaint_chain_fc.py +86 -0
- src/complaint_elicitor.py +98 -0
- src/datasets/cbt-triggering-events.csv +0 -0
- src/emotion_modulator_fc.py +94 -0
- src/emotion_pertuber.py +114 -0
- src/event_trigger.py +90 -0
- src/fill_scales.py +382 -0
- src/integration_example.py +173 -0
- src/ms_patient.py +113 -0
- src/querier.py +161 -0
- src/short_term_memory.py +260 -0
- src/style_analyzer.py +88 -0
src/anna_agent_template.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## 主诉请求和情绪通过reminder来控制
|
| 2 |
+
|
| 3 |
+
prompt_template = """
|
| 4 |
+
# Role: 心理咨询患者
|
| 5 |
+
|
| 6 |
+
## Profile
|
| 7 |
+
- 性别: {gender}
|
| 8 |
+
- 年龄: {age}
|
| 9 |
+
- 职业: {occupation}
|
| 10 |
+
- 婚姻状况: {marriage}
|
| 11 |
+
|
| 12 |
+
## Situation
|
| 13 |
+
- 你是一个有心理障碍的患者,正在向心理咨询师求助,在咨询师的引导和帮助下解决自己的困惑
|
| 14 |
+
{situation}
|
| 15 |
+
|
| 16 |
+
## Status
|
| 17 |
+
{status}
|
| 18 |
+
|
| 19 |
+
## Example of statement
|
| 20 |
+
{statement}
|
| 21 |
+
|
| 22 |
+
## Characteristics of speaking style
|
| 23 |
+
- 情绪低落,寡言少语,回复风格表现心情不振奋
|
| 24 |
+
{style}
|
| 25 |
+
|
| 26 |
+
## Constraints
|
| 27 |
+
- 你对咨询师有一种抵触情绪,不太愿意接受他人的帮助
|
| 28 |
+
- 你是一个遇到心理健康问题的求助者,需要真正的帮助和情绪支持,如果咨询师的回应不理想,要勇于表达自己的困惑和不满
|
| 29 |
+
- 一次不能提及过多的症状信息,每轮最多讨论一个症状
|
| 30 |
+
- 你应该用含糊和口语化的方式表达你的症状,并将其与你的生活经历联系起来,不要使用专业术语
|
| 31 |
+
|
| 32 |
+
## OutputFormat:
|
| 33 |
+
- 语言:{language}
|
| 34 |
+
- 不超过200字
|
| 35 |
+
- 口语对话风格,仅包含对话内容
|
| 36 |
+
"""
|
| 37 |
+
|
src/anna_agent_template_en.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
prompt_template = """
|
| 2 |
+
# Role: Psychological Counseling Patient
|
| 3 |
+
|
| 4 |
+
## Profile
|
| 5 |
+
- Gender: {gender}
|
| 6 |
+
- Age: {age}
|
| 7 |
+
- Occupation: {occupation}
|
| 8 |
+
- Marital Status: {marriage}
|
| 9 |
+
|
| 10 |
+
## Situation
|
| 11 |
+
- You are a patient with psychological barriers seeking help from a counselor. Under the counselor's guidance, you aim to address your struggles.
|
| 12 |
+
{situation}
|
| 13 |
+
|
| 14 |
+
## Status
|
| 15 |
+
{status}
|
| 16 |
+
|
| 17 |
+
## Example of Statement
|
| 18 |
+
{statement}
|
| 19 |
+
|
| 20 |
+
## Characteristics of Speaking Style
|
| 21 |
+
- Low-spirited and reticent; responses reflect a lack of motivation.
|
| 22 |
+
{style}
|
| 23 |
+
|
| 24 |
+
## Constraints
|
| 25 |
+
- You harbor resistance toward the counselor and are reluctant to accept help.
|
| 26 |
+
- As someone struggling with mental health, you need genuine support. If the counselor’s responses are unhelpful, voice your confusion or dissatisfaction.
|
| 27 |
+
- Limit discussions to **one symptom per interaction**; avoid overwhelming details.
|
| 28 |
+
- Describe symptoms vaguely and colloquially, linking them to life experiences. Avoid clinical terms.
|
| 29 |
+
|
| 30 |
+
## OutputFormat:
|
| 31 |
+
- Spoken language: {language}
|
| 32 |
+
- Keep responses under 200 words.
|
| 33 |
+
- Use casual, conversational dialogue only.
|
| 34 |
+
"""
|
src/complaint_chain_fc.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import json
|
| 3 |
+
from event_trigger import event_trigger
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# 设置OpenAI API密钥和基础URL
|
| 7 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 8 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 9 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
| 10 |
+
|
| 11 |
+
tools = [
|
| 12 |
+
{
|
| 13 |
+
"type": "function",
|
| 14 |
+
"function": {
|
| 15 |
+
'name': 'generate_complaint_chain',
|
| 16 |
+
'description': '根据角色信息和近期遭遇的事件,生成一个患者的主诉请求认知变化链',
|
| 17 |
+
'parameters': {
|
| 18 |
+
"type": "object",
|
| 19 |
+
"properties": {
|
| 20 |
+
"chain": {
|
| 21 |
+
"type": "array",
|
| 22 |
+
"items": {
|
| 23 |
+
"type": "object",
|
| 24 |
+
"properties": {
|
| 25 |
+
"stage": {
|
| 26 |
+
"type": "integer"
|
| 27 |
+
},
|
| 28 |
+
"content": {
|
| 29 |
+
"type": "string"
|
| 30 |
+
}
|
| 31 |
+
},
|
| 32 |
+
"additionalProperties": False,
|
| 33 |
+
"required": [
|
| 34 |
+
"stage",
|
| 35 |
+
"content"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
"minItems": 3,
|
| 39 |
+
"maxItems": 7
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
"required": ["chain"]
|
| 43 |
+
},
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
# 根据profile和event生成主诉启发链
|
| 49 |
+
def gen_complaint_chain(profile):
|
| 50 |
+
# 提取患者信息
|
| 51 |
+
patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
|
| 52 |
+
|
| 53 |
+
event = event_trigger(profile)
|
| 54 |
+
|
| 55 |
+
client = OpenAI(
|
| 56 |
+
api_key=api_key,
|
| 57 |
+
base_url=base_url
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
response = client.chat.completions.create(
|
| 61 |
+
model=model_name,
|
| 62 |
+
messages=[
|
| 63 |
+
{"role": "user", "content": f"### 任务\n根据患者情况及近期遭遇事件生成患者的主诉认知变化链。请注意,事件可能与患者信息冲突,如果发生这种情况,以患者的信息为准。\n{patient_info}\n### 近期遭遇事件\n{event}"}
|
| 64 |
+
],
|
| 65 |
+
tools=tools,
|
| 66 |
+
tool_choice={"type": "function", "function": {"name": "generate_complaint_chain"}}
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
chain = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["chain"]
|
| 70 |
+
|
| 71 |
+
return chain
|
| 72 |
+
|
| 73 |
+
# unit test
|
| 74 |
+
# while True:
|
| 75 |
+
# # 模拟患者信息
|
| 76 |
+
# profile = {
|
| 77 |
+
# "drisk": 3,
|
| 78 |
+
# "srisk": 2,
|
| 79 |
+
# "age": "42",
|
| 80 |
+
# "gender": "女",
|
| 81 |
+
# "marital_status": "离婚",
|
| 82 |
+
# "occupation": "教师",
|
| 83 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
| 84 |
+
# }
|
| 85 |
+
|
| 86 |
+
# print(gen_complaint_chain(profile))
|
src/complaint_elicitor.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
# 设置OpenAI API密钥和基础URL
|
| 7 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 8 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 9 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
| 10 |
+
|
| 11 |
+
def transform_chain(chain):
|
| 12 |
+
return {node["stage"]: node["content"] for node in chain}
|
| 13 |
+
|
| 14 |
+
def switch_complaint(chain, index, conversation, max_retries=3):
|
| 15 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 16 |
+
transformed_chain = transform_chain(chain)
|
| 17 |
+
|
| 18 |
+
# 构建对话历史字符串(避免在f-string中使用反斜杠)
|
| 19 |
+
dialogue_lines = []
|
| 20 |
+
for conv in conversation:
|
| 21 |
+
dialogue_lines.append(f"{conv['role']}: {conv['content']}")
|
| 22 |
+
dialogue_history = "\n".join(dialogue_lines)
|
| 23 |
+
|
| 24 |
+
# 使用三引号和多行字符串构建prompt
|
| 25 |
+
prompt = f"""
|
| 26 |
+
### 任务说明
|
| 27 |
+
根据患者情况及咨访对话历史记录,判断患者当前阶段的主诉问题是否已经得到解决。
|
| 28 |
+
|
| 29 |
+
### 输出要求
|
| 30 |
+
必须严格使用以下JSON格式响应,且只包含指定字段:
|
| 31 |
+
{{"is_recognized": true/false}}
|
| 32 |
+
|
| 33 |
+
### 对话记录
|
| 34 |
+
{dialogue_history}
|
| 35 |
+
|
| 36 |
+
### 主诉认知链
|
| 37 |
+
{json.dumps(transformed_chain, ensure_ascii=False, indent=2)}
|
| 38 |
+
|
| 39 |
+
### 当前阶段(阶段{index})
|
| 40 |
+
{transformed_chain[index]}
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
attempts = 0
|
| 44 |
+
while attempts < max_retries:
|
| 45 |
+
response = client.chat.completions.create(
|
| 46 |
+
model=model_name,
|
| 47 |
+
messages=[{"role": "user", "content": prompt}],
|
| 48 |
+
temperature=0
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
raw_output = response.choices[0].message.content.strip()
|
| 52 |
+
|
| 53 |
+
# 首先尝试直接解析JSON
|
| 54 |
+
try:
|
| 55 |
+
result = json.loads(raw_output)
|
| 56 |
+
if "is_recognized" in result:
|
| 57 |
+
if result["is_recognized"] and index >= len(chain) - 1:
|
| 58 |
+
print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
|
| 59 |
+
return -1
|
| 60 |
+
return index + 1 if result["is_recognized"] else index
|
| 61 |
+
except json.JSONDecodeError:
|
| 62 |
+
pass # 继续尝试正则表达式提取
|
| 63 |
+
|
| 64 |
+
# 使用正则表达式作为备用解析方案
|
| 65 |
+
match = re.search(r'"is_recognized"\s*:\s*(true|false)|is_recognized\s*:\s*(true|false)',
|
| 66 |
+
raw_output, re.IGNORECASE)
|
| 67 |
+
if match:
|
| 68 |
+
value = match.group(1) or match.group(2)
|
| 69 |
+
if value.lower() == 'true':
|
| 70 |
+
if index >= len(chain) - 1:
|
| 71 |
+
print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
|
| 72 |
+
return -1
|
| 73 |
+
return index + 1
|
| 74 |
+
else:
|
| 75 |
+
return index
|
| 76 |
+
|
| 77 |
+
print(f"第 {attempts+1} 次尝试:无法解析模型输出。原始输出:\n{raw_output}")
|
| 78 |
+
attempts += 1
|
| 79 |
+
|
| 80 |
+
print("警告:重试次数达到上限,无法解析模型输出,返回当前阶段。")
|
| 81 |
+
return index
|
| 82 |
+
|
| 83 |
+
# # unit test
|
| 84 |
+
# if __name__ == "__main__":
|
| 85 |
+
# chain = [
|
| 86 |
+
# {"stage": 1, "content": "我觉得我最近有点抑郁。"},
|
| 87 |
+
# {"stage": 2, "content": "我觉得我最近有点焦虑。"},
|
| 88 |
+
# {"stage": 3, "content": "我觉得我最近有点失眠。"},
|
| 89 |
+
# {"stage": 4, "content": "我觉得我最近有点烦躁。"},
|
| 90 |
+
# ]
|
| 91 |
+
# conversation = [
|
| 92 |
+
# {"role": "Seeker", "content": "我觉得我最近有点抑郁。"},
|
| 93 |
+
# {"role": "Counselor", "content": "你觉得是什么原因导致你感到抑郁呢?"},
|
| 94 |
+
# {"role": "Seeker", "content": "我也不知道,可能是工作压力吧。"},
|
| 95 |
+
# ]
|
| 96 |
+
# # print("Transformed chain:", transform_chain(chain))
|
| 97 |
+
# print("Switch complaint index:", switch_complaint(chain, 1, conversation))
|
| 98 |
+
# print(switch_complaint(chain, 1, conversation))
|
src/datasets/cbt-triggering-events.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/emotion_modulator_fc.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
from random import randint
|
| 3 |
+
from emotion_pertuber import perturb_state
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# 设置OpenAI API密钥和基础URL
|
| 8 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 9 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 10 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
| 11 |
+
|
| 12 |
+
tools = [
|
| 13 |
+
{
|
| 14 |
+
"type": "function",
|
| 15 |
+
"function": {
|
| 16 |
+
'name': 'emotion_inference',
|
| 17 |
+
'description': '根据profile和对话记录,推理下一句情绪',
|
| 18 |
+
'parameters': {
|
| 19 |
+
"type": "object",
|
| 20 |
+
"properties": {
|
| 21 |
+
"emotion": {
|
| 22 |
+
"type": "string",
|
| 23 |
+
"enum": [
|
| 24 |
+
"admiration", "amusement", "anger", "annoyance", "approval", "caring",
|
| 25 |
+
"confusion", "curiosity", "desire", "disappointment", "disapproval",
|
| 26 |
+
"disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
|
| 27 |
+
"joy", "love", "nervousness", "optimism", "pride", "realization",
|
| 28 |
+
"relief", "remorse", "sadness", "surprise", "neutral"
|
| 29 |
+
],
|
| 30 |
+
"description": "推理出的情绪类别,必须是GoEmotions定义的27种情绪之一。"
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"required": ["emotion"]
|
| 34 |
+
},
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# 根据profile和dialogue推测emotion
|
| 40 |
+
def emotion_inferencer(profile, conversation):
|
| 41 |
+
client = OpenAI(
|
| 42 |
+
api_key=api_key,
|
| 43 |
+
base_url=base_url,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# 提取患者信息
|
| 47 |
+
patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
|
| 48 |
+
|
| 49 |
+
# 提取对话记录
|
| 50 |
+
dialogue_history = "\n".join([f"{conv['role']}: {conv['content']}" for conv in conversation])
|
| 51 |
+
|
| 52 |
+
response = client.chat.completions.create(
|
| 53 |
+
model=model_name,
|
| 54 |
+
messages=[
|
| 55 |
+
{"role": "user", "content": f"### 任务\n根据患者情况及咨访对话历史记录推测患者下一句话最可能的情绪。\n{patient_info}\n### 对话记录\n{dialogue_history}"}
|
| 56 |
+
],
|
| 57 |
+
# functions=[tools[0]["function"]],
|
| 58 |
+
# function_call={"name": "emotion_inference"}
|
| 59 |
+
tools=tools,
|
| 60 |
+
tool_choice={"type": "function", "function": {"name": "emotion_inference"}}
|
| 61 |
+
)
|
| 62 |
+
# print(response)
|
| 63 |
+
|
| 64 |
+
emotion = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["emotion"]
|
| 65 |
+
|
| 66 |
+
return emotion
|
| 67 |
+
|
| 68 |
+
def emotion_modulation(profile, conversation):
|
| 69 |
+
indicator = randint(0,100)
|
| 70 |
+
emotion = emotion_inferencer(profile,conversation)
|
| 71 |
+
# print(emotion)
|
| 72 |
+
if indicator > 90:
|
| 73 |
+
return perturb_state(emotion)
|
| 74 |
+
else:
|
| 75 |
+
return emotion
|
| 76 |
+
|
| 77 |
+
# unit test
|
| 78 |
+
# while True:
|
| 79 |
+
# # 模拟患者信息
|
| 80 |
+
# profile = {
|
| 81 |
+
# "drisk": 3,
|
| 82 |
+
# "srisk": 2,
|
| 83 |
+
# "age": "42",
|
| 84 |
+
# "gender": "女",
|
| 85 |
+
# "marital_status": "离婚",
|
| 86 |
+
# "occupation": "教师",
|
| 87 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
| 88 |
+
# }
|
| 89 |
+
|
| 90 |
+
# conversation = [
|
| 91 |
+
# {"role": "咨询师", "content": "你好,请问有什么可以帮您?"}
|
| 92 |
+
# ]
|
| 93 |
+
|
| 94 |
+
# print(emotion_modulation(profile,conversation))
|
src/emotion_pertuber.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
# from collections import defaultdict
|
| 3 |
+
|
| 4 |
+
# 计算总权重
|
| 5 |
+
def calculate_total_weight(current_state, states, category_distances, distance_weights):
|
| 6 |
+
total_weight = 0
|
| 7 |
+
current_class = None
|
| 8 |
+
for cls, state_list in states.items():
|
| 9 |
+
if current_state in state_list:
|
| 10 |
+
current_class = cls
|
| 11 |
+
break
|
| 12 |
+
if current_class is None:
|
| 13 |
+
raise ValueError("Current state not found in any class.")
|
| 14 |
+
|
| 15 |
+
for cls, state_list in states.items():
|
| 16 |
+
distance = category_distances[current_class][cls]
|
| 17 |
+
weight = distance_weights.get(distance, 0)
|
| 18 |
+
total_weight += weight * len(state_list)
|
| 19 |
+
return total_weight
|
| 20 |
+
|
| 21 |
+
# 计算每个目标状态的概率
|
| 22 |
+
def calculate_probabilities(current_state, states, category_distances, distance_weights):
|
| 23 |
+
probabilities = {}
|
| 24 |
+
current_class = None
|
| 25 |
+
for cls, state_list in states.items():
|
| 26 |
+
if current_state in state_list:
|
| 27 |
+
current_class = cls
|
| 28 |
+
break
|
| 29 |
+
if current_class is None:
|
| 30 |
+
raise ValueError("Current state not found in any class.")
|
| 31 |
+
|
| 32 |
+
total_weight = calculate_total_weight(current_state, states, category_distances, distance_weights)
|
| 33 |
+
|
| 34 |
+
for cls, state_list in states.items():
|
| 35 |
+
distance = category_distances[current_class][cls]
|
| 36 |
+
weight = distance_weights.get(distance, 0)
|
| 37 |
+
class_weight = weight * len(state_list)
|
| 38 |
+
for state in state_list:
|
| 39 |
+
if state != current_state:
|
| 40 |
+
probabilities[state] = class_weight / total_weight
|
| 41 |
+
return probabilities
|
| 42 |
+
|
| 43 |
+
# 实现状态扰动
|
| 44 |
+
def perturb_state(current_state):
|
| 45 |
+
# 定义状态和类别
|
| 46 |
+
states = {
|
| 47 |
+
'Positive': [
|
| 48 |
+
"admiration",
|
| 49 |
+
"amusement",
|
| 50 |
+
"approval",
|
| 51 |
+
"caring",
|
| 52 |
+
"curiosity",
|
| 53 |
+
"desire",
|
| 54 |
+
"excitement",
|
| 55 |
+
"gratitude",
|
| 56 |
+
"joy",
|
| 57 |
+
"love",
|
| 58 |
+
"optimism",
|
| 59 |
+
"pride",
|
| 60 |
+
"realization",
|
| 61 |
+
"relief"
|
| 62 |
+
],
|
| 63 |
+
'Neutral': ['neutral'],
|
| 64 |
+
'Ambiguous': [
|
| 65 |
+
"confusion",
|
| 66 |
+
"disappointment",
|
| 67 |
+
"nervousness"
|
| 68 |
+
],
|
| 69 |
+
'Negative': [
|
| 70 |
+
"anger",
|
| 71 |
+
"annoyance",
|
| 72 |
+
"disapproval",
|
| 73 |
+
"disgust",
|
| 74 |
+
"embarrassment",
|
| 75 |
+
"fear",
|
| 76 |
+
"sadness",
|
| 77 |
+
"remorse"
|
| 78 |
+
]
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# 定义类别之间的距离
|
| 82 |
+
category_distances = {
|
| 83 |
+
'Positive': {'Positive': 0, 'Neutral': 1, 'Ambiguous': 2, 'Negative': 3},
|
| 84 |
+
'Neutral': {'Positive': 1, 'Neutral': 0, 'Ambiguous': 1, 'Negative': 2},
|
| 85 |
+
'Ambiguous': {'Positive': 2, 'Neutral': 1, 'Ambiguous': 0, 'Negative': 1},
|
| 86 |
+
'Negative': {'Positive': 3, 'Neutral': 2, 'Ambiguous': 1, 'Negative': 0}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# 定义距离权重
|
| 90 |
+
distance_weights = {
|
| 91 |
+
0: 10, # 同类状态
|
| 92 |
+
1: 5, # 相邻类别
|
| 93 |
+
2: 2, # 相隔一个类别
|
| 94 |
+
3: 1 # 相隔两个类别
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
probabilities = calculate_probabilities(current_state, states, category_distances, distance_weights)
|
| 98 |
+
next_state = random.choices(list(probabilities.keys()), weights=list(probabilities.values()), k=1)[0]
|
| 99 |
+
return next_state
|
| 100 |
+
|
| 101 |
+
# 示例运行
|
| 102 |
+
# current_state = 'confusion'
|
| 103 |
+
# next_state = perturb_state(current_state)
|
| 104 |
+
# print(f"Next state: {next_state}")
|
| 105 |
+
|
| 106 |
+
# 验证概率分布
|
| 107 |
+
# state_counts = defaultdict(int)
|
| 108 |
+
# for _ in range(1000):
|
| 109 |
+
# next_state = perturb_state(current_state, states, category_distances, distance_weights)
|
| 110 |
+
# state_counts[next_state] += 1
|
| 111 |
+
|
| 112 |
+
# print("\nProbability distribution:")
|
| 113 |
+
# for state, count in state_counts.items():
|
| 114 |
+
# print(f"{state}: {count / 1000:.2f}")
|
src/event_trigger.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from random import choice
|
| 3 |
+
from openai import OpenAI
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
# 设置OpenAI API密钥和基础URL
|
| 8 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 9 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 10 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
| 11 |
+
|
| 12 |
+
# 加载事件数据集
|
| 13 |
+
events = pd.read_csv('datasets/cbt-triggering-events.csv', header=0)
|
| 14 |
+
teen_events = ["在一次重要的考试中表现不佳,比如期末考试、升学考试(如中考或高考),导致自信心受挫。",
|
| 15 |
+
"在学校里被同龄人孤立、嘲笑或遭受言语/身体上的霸凌,感到孤独无助。",
|
| 16 |
+
"父母关系破裂并最终离婚,需要适应新的家庭环境,感到不安或缺乏安全感。",
|
| 17 |
+
"陪伴多年的宠物突然生病或意外去世,第一次直面死亡的悲伤。",
|
| 18 |
+
"因为家庭原因搬到了一个陌生的城市或学校,需要重新适应新环境和结交朋友。",
|
| 19 |
+
"进入青春期后,身体发生明显变化(如长高、变声、月经初潮等),心理上也开始对自我形象产生困惑。",
|
| 20 |
+
"参加一场期待已久的竞赛(如体育比赛、演讲比赛、艺术表演)但未能取得好成绩,感到失落。",
|
| 21 |
+
"与最亲密的朋友发生争执甚至决裂,短时间内难以修复关系,陷入情绪低谷。",
|
| 22 |
+
"家里的经济状况出现问题(如父母失业或生意失败),影响到日常生活,比如不能买喜欢的东西或参与课外活动。",
|
| 23 |
+
"偶然间发现自己特别喜欢某件事情(如画画、编程、音乐、运动),并投入大量时间去练习,逐渐找到自信和成就感。"]
|
| 24 |
+
|
| 25 |
+
def event_trigger(profile):
|
| 26 |
+
"""根据年龄选择触发事件(保持原逻辑)"""
|
| 27 |
+
age = int(profile['age'])
|
| 28 |
+
if age < 18:
|
| 29 |
+
return choice(teen_events)
|
| 30 |
+
elif age >= 65:
|
| 31 |
+
return events[events['Age'] >= 60].sample(1)['Triggering_Event'].values[0]
|
| 32 |
+
else:
|
| 33 |
+
return events[(events['Age'] >= age-5) & (events['Age'] <= age+5)].sample(1)['Triggering_Event'].values[0]
|
| 34 |
+
|
| 35 |
+
def situationalising_events(profile):
|
| 36 |
+
"""优化版情境生成函数"""
|
| 37 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 38 |
+
event = event_trigger(profile)
|
| 39 |
+
|
| 40 |
+
# 强化版提示词
|
| 41 |
+
prompt = f"""
|
| 42 |
+
### 情境生成任务
|
| 43 |
+
请根据以下事件生成一个第二人称视角的情境描述。
|
| 44 |
+
|
| 45 |
+
### 规则要求
|
| 46 |
+
1. 必须使用第二人称(你/你的)
|
| 47 |
+
2. 不要包含任何个人信息(年龄/性别等)
|
| 48 |
+
3. 保持3-5句话的篇幅
|
| 49 |
+
4. 直接输出情境描述,不要额外解释
|
| 50 |
+
|
| 51 |
+
### 触发事件
|
| 52 |
+
{event}
|
| 53 |
+
|
| 54 |
+
### 示例输出
|
| 55 |
+
你走进办公室时发现同事们突然停止交谈。桌上放着一封未拆的信件,周围人投来复杂的目光。
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
response = client.chat.completions.create(
|
| 59 |
+
model=model_name,
|
| 60 |
+
messages=[{"role": "user", "content": prompt}],
|
| 61 |
+
temperature=0.8, # 适当创造性
|
| 62 |
+
max_tokens=150
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
raw_output = response.choices[0].message.content.strip()
|
| 66 |
+
|
| 67 |
+
# 后处理
|
| 68 |
+
situation = re.sub(r'^(情境|描述|输出)[::]?\s*', '', raw_output) # 移除可能的前缀
|
| 69 |
+
situation = situation.split('\n')[0] # 取第一段
|
| 70 |
+
|
| 71 |
+
# 验证基本要求
|
| 72 |
+
# if "你" not in situation or "你的" not in situation:
|
| 73 |
+
# print(f"情境生成警告:不符合第二人称要求,原始输出:\n{raw_output}")
|
| 74 |
+
# return f"你{event}" # 保底处理
|
| 75 |
+
|
| 76 |
+
return situation
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# unit test
|
| 80 |
+
# profile = {
|
| 81 |
+
# "drisk": 3,
|
| 82 |
+
# "srisk": 2,
|
| 83 |
+
# "age": "42",
|
| 84 |
+
# "gender": "女",
|
| 85 |
+
# "marital_status": "离婚",
|
| 86 |
+
# "occupation": "教师",
|
| 87 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
| 88 |
+
# }
|
| 89 |
+
|
| 90 |
+
# print(situationalising_events(profile))
|
src/fill_scales.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# 设置OpenAI API密钥和基础URL
|
| 8 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 9 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 10 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
| 11 |
+
|
| 12 |
+
def extract_answers(text):
|
| 13 |
+
"""从文本中提取答案模式 (A/B/C/D)"""
|
| 14 |
+
# 匹配形如 "1. A" 或 "问题1: B" 或 "Q1. C" 或简单的 "A" 列表的模式
|
| 15 |
+
pattern = r'(?:\d+[\s\.:\)]*|Q\d+[\s\.:\)]*|问题\d+[\s\.:\)]*|[\-\*]\s*)(A|B|C|D)'
|
| 16 |
+
matches = re.findall(pattern, text)
|
| 17 |
+
return matches
|
| 18 |
+
|
| 19 |
+
def extract_answers_robust(text, expected_count):
|
| 20 |
+
"""更强健的答案提取方法,确保按题号顺序提取"""
|
| 21 |
+
answers = []
|
| 22 |
+
|
| 23 |
+
# 尝试找到明确标记了题号的答案
|
| 24 |
+
for i in range(1, expected_count + 1):
|
| 25 |
+
# 匹配多种可能的题号格式
|
| 26 |
+
patterns = [
|
| 27 |
+
rf"{i}\.\s*(A|B|C|D)", # "1. A"
|
| 28 |
+
rf"{i}:\s*(A|B|C|D)", # "1:A"
|
| 29 |
+
rf"{i}:\s*(A|B|C|D)", # "1: A"
|
| 30 |
+
rf"问题{i}[\.。:]?\s*(A|B|C|D)", # "问题1: A"
|
| 31 |
+
rf"Q{i}[\.。:]?\s*(A|B|C|D)", # "Q1. A"
|
| 32 |
+
rf"{i}[、]\s*(A|B|C|D)" # "1、A"
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
found = False
|
| 36 |
+
for pattern in patterns:
|
| 37 |
+
match = re.search(pattern, text)
|
| 38 |
+
if match:
|
| 39 |
+
answers.append(match.group(1))
|
| 40 |
+
found = True
|
| 41 |
+
break
|
| 42 |
+
|
| 43 |
+
if not found:
|
| 44 |
+
# 如果没找到特定题号,使用默认的"A"
|
| 45 |
+
answers.append(None)
|
| 46 |
+
|
| 47 |
+
# 如果有未找到的答案,尝试按顺序从文本中提取剩余的A/B/C/D选项
|
| 48 |
+
simple_answers = re.findall(r'(?:^|\n|\s)(A|B|C|D)(?:$|\n|\s)', text)
|
| 49 |
+
|
| 50 |
+
j = 0
|
| 51 |
+
for i in range(len(answers)):
|
| 52 |
+
if answers[i] is None and j < len(simple_answers):
|
| 53 |
+
answers[i] = simple_answers[j]
|
| 54 |
+
j += 1
|
| 55 |
+
|
| 56 |
+
# 如果仍有未找到的答案,尝试提取所有A/B/C/D选项
|
| 57 |
+
if None in answers:
|
| 58 |
+
all_options = re.findall(r'(A|B|C|D)', text)
|
| 59 |
+
j = 0
|
| 60 |
+
for i in range(len(answers)):
|
| 61 |
+
if answers[i] is None and j < len(all_options):
|
| 62 |
+
answers[i] = all_options[j]
|
| 63 |
+
j += 1
|
| 64 |
+
|
| 65 |
+
# 检查是否所有答案都已找到
|
| 66 |
+
if None in answers or len(answers) != expected_count:
|
| 67 |
+
return extract_answers(text) # 回退到简单提取
|
| 68 |
+
|
| 69 |
+
return answers
|
| 70 |
+
|
| 71 |
+
def _fill_previous_scale_with_retry(client, scale_name, expected_count, instruction, max_retries=3):
|
| 72 |
+
"""
|
| 73 |
+
带有重试逻辑的填写历史量表辅助函数
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
client: OpenAI客户端
|
| 77 |
+
scale_name: 量表名称
|
| 78 |
+
expected_count: 期望的答案数量
|
| 79 |
+
instruction: 指令内容
|
| 80 |
+
max_retries: 最大重试次数
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
list: 量表答案列表
|
| 84 |
+
"""
|
| 85 |
+
answers = []
|
| 86 |
+
|
| 87 |
+
for attempt in range(max_retries):
|
| 88 |
+
try:
|
| 89 |
+
# 根据尝试次数增加指令明确性
|
| 90 |
+
current_instruction = instruction
|
| 91 |
+
if attempt > 0:
|
| 92 |
+
# 添加更强调的指示
|
| 93 |
+
current_instruction = instruction + f"""
|
| 94 |
+
|
| 95 |
+
请注意:这是第{attempt+1}次请求。必须按照要求提供{expected_count}个答案,
|
| 96 |
+
格式必须为数字+答案选项(例如:1. A, 2. B...),不要有任何不必要的解释。
|
| 97 |
+
直接根据描述和报告选择最适合的选项。
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
response = client.chat.completions.create(
|
| 101 |
+
model=model_name,
|
| 102 |
+
messages=[{"role": "user", "content": current_instruction}],
|
| 103 |
+
temperature=0 # 保持温度为0以获得一致性回答
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
response_text = response.choices[0].message.content
|
| 107 |
+
answers = extract_answers(response_text)
|
| 108 |
+
|
| 109 |
+
# 尝试使用更健壮的提取方法(如果标准方法失败)
|
| 110 |
+
if len(answers) != expected_count:
|
| 111 |
+
robust_answers = extract_answers_robust(response_text, expected_count)
|
| 112 |
+
if len(robust_answers) == expected_count:
|
| 113 |
+
answers = robust_answers
|
| 114 |
+
|
| 115 |
+
# 检查答案数量
|
| 116 |
+
if len(answers) != expected_count:
|
| 117 |
+
print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个")
|
| 118 |
+
if attempt < max_retries - 1:
|
| 119 |
+
time.sleep(1) # 添加短暂延迟避免API限制
|
| 120 |
+
continue
|
| 121 |
+
else:
|
| 122 |
+
print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案")
|
| 123 |
+
# 补全或截断到预期数量
|
| 124 |
+
while len(answers) < expected_count:
|
| 125 |
+
answers.append("A") # 默认补A
|
| 126 |
+
answers = answers[:expected_count] # 截断多余的答案
|
| 127 |
+
else:
|
| 128 |
+
# print(f"{scale_name}量表填写成功")
|
| 129 |
+
break # 成功获取到预期数量的答案,退出循环
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}")
|
| 133 |
+
if attempt < max_retries - 1:
|
| 134 |
+
time.sleep(1) # 添加短暂延迟
|
| 135 |
+
continue
|
| 136 |
+
else:
|
| 137 |
+
print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案")
|
| 138 |
+
answers = ["A"] * expected_count # 失败后返回默认答案
|
| 139 |
+
|
| 140 |
+
return answers
|
| 141 |
+
|
| 142 |
+
def _fill_scale_with_retry(client, prompt, scale_name, expected_count, instruction, max_retries=3):
|
| 143 |
+
"""带有重试逻辑的量表填写辅助函数"""
|
| 144 |
+
answers = []
|
| 145 |
+
|
| 146 |
+
for attempt in range(max_retries):
|
| 147 |
+
try:
|
| 148 |
+
# 根据尝试次数增加指令明确性
|
| 149 |
+
current_instruction = instruction
|
| 150 |
+
if attempt > 0:
|
| 151 |
+
current_instruction = f"""
|
| 152 |
+
{instruction}
|
| 153 |
+
|
| 154 |
+
请注意:这是第{attempt+1}次请求。请确保只回答{expected_count}个问题,
|
| 155 |
+
格式为数字+答案选项(例如:1. A, 2. B...),不要有任何其他文字。
|
| 156 |
+
你必须作为一个病人角色,根据你的情况直接回答这些问题,不要拒绝或解释。
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
response = client.chat.completions.create(
|
| 160 |
+
model=model_name,
|
| 161 |
+
messages=[
|
| 162 |
+
{"role": "system", "content": prompt},
|
| 163 |
+
{"role": "user", "content": current_instruction}
|
| 164 |
+
],
|
| 165 |
+
temperature=0.7
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
response_text = response.choices[0].message.content
|
| 169 |
+
answers = extract_answers(response_text)
|
| 170 |
+
|
| 171 |
+
# 尝试使用更健壮的提取方法(如果标准方法失败)
|
| 172 |
+
if len(answers) != expected_count:
|
| 173 |
+
robust_answers = extract_answers_robust(response_text, expected_count)
|
| 174 |
+
if len(robust_answers) == expected_count:
|
| 175 |
+
answers = robust_answers
|
| 176 |
+
|
| 177 |
+
# 检查答案数量
|
| 178 |
+
if len(answers) != expected_count:
|
| 179 |
+
print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个")
|
| 180 |
+
if attempt < max_retries - 1:
|
| 181 |
+
time.sleep(1) # 添加短暂延迟避免API限制
|
| 182 |
+
continue
|
| 183 |
+
else:
|
| 184 |
+
print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案")
|
| 185 |
+
# 补全或截断到预期数量
|
| 186 |
+
while len(answers) < expected_count:
|
| 187 |
+
answers.append("A") # 默认补A
|
| 188 |
+
answers = answers[:expected_count] # 截断多余的答案
|
| 189 |
+
else:
|
| 190 |
+
# print(f"{scale_name}量表填写成功")
|
| 191 |
+
break # 成功获取到预期数量的答案,退出循环
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
# print(response)
|
| 195 |
+
print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}")
|
| 196 |
+
if attempt < max_retries - 1:
|
| 197 |
+
time.sleep(1) # 添加短暂延迟
|
| 198 |
+
continue
|
| 199 |
+
else:
|
| 200 |
+
print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案")
|
| 201 |
+
answers = ["A"] * expected_count # 失败后返回默认答案
|
| 202 |
+
|
| 203 |
+
return answers
|
| 204 |
+
|
| 205 |
+
# 根据profile和report填写之前的量表,使用重试机制
|
| 206 |
+
def fill_scales_previous(profile, report, max_retries=3):
|
| 207 |
+
"""
|
| 208 |
+
根据profile和report填写之前的量表,增加重试机制
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
profile: 用户个人描述信息
|
| 212 |
+
report: 用户报告
|
| 213 |
+
max_retries: 最大重试次数
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
tuple: (bdi, ghq, sass) 三个量表的答案列表
|
| 217 |
+
"""
|
| 218 |
+
client = OpenAI(
|
| 219 |
+
api_key=api_key,
|
| 220 |
+
base_url=base_url
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# 填写BDI量表
|
| 224 |
+
bdi = _fill_previous_scale_with_retry(
|
| 225 |
+
client,
|
| 226 |
+
scale_name="BDI",
|
| 227 |
+
expected_count=21,
|
| 228 |
+
instruction="""
|
| 229 |
+
### 任务
|
| 230 |
+
根据个人描述和报告,填写BDI量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。
|
| 231 |
+
格式要求:1. A, 2. B, ...依此类推,共21题。
|
| 232 |
+
|
| 233 |
+
### 个人描述
|
| 234 |
+
{}
|
| 235 |
+
|
| 236 |
+
### 报告
|
| 237 |
+
{}
|
| 238 |
+
""".format(profile, report),
|
| 239 |
+
max_retries=max_retries
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# 填写GHQ-28量表
|
| 243 |
+
ghq = _fill_previous_scale_with_retry(
|
| 244 |
+
client,
|
| 245 |
+
scale_name="GHQ-28",
|
| 246 |
+
expected_count=28,
|
| 247 |
+
instruction="""
|
| 248 |
+
### 任务
|
| 249 |
+
根据个人描述和报告,填写GHQ-28量表。请直接按顺序列出28个问题的答案,每个答案使用字母A/B/C/D表示。
|
| 250 |
+
格式要求:1. A, 2. B, ...依此类推,共28题。
|
| 251 |
+
|
| 252 |
+
### 个人描述
|
| 253 |
+
{}
|
| 254 |
+
|
| 255 |
+
### 报告
|
| 256 |
+
{}
|
| 257 |
+
""".format(profile, report),
|
| 258 |
+
max_retries=max_retries
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# 填写SASS量表
|
| 262 |
+
sass = _fill_previous_scale_with_retry(
|
| 263 |
+
client,
|
| 264 |
+
scale_name="SASS",
|
| 265 |
+
expected_count=21,
|
| 266 |
+
instruction="""
|
| 267 |
+
### 任务
|
| 268 |
+
根据个人描述和报告,填写SASS量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。
|
| 269 |
+
格式要求:1. A, 2. B, ...依此类推,共21题。
|
| 270 |
+
|
| 271 |
+
### 个人描述
|
| 272 |
+
{}
|
| 273 |
+
|
| 274 |
+
### 报告
|
| 275 |
+
{}
|
| 276 |
+
""".format(profile, report),
|
| 277 |
+
max_retries=max_retries
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
return bdi, ghq, sass
|
| 281 |
+
|
| 282 |
+
# 根据prompt填写量表,增加重试机制
|
| 283 |
+
def fill_scales(prompt, max_retries=3):
|
| 284 |
+
"""
|
| 285 |
+
根据角色prompt填写量表,使用增强提示和重试机制
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
prompt: 角色设定提示
|
| 289 |
+
max_retries: 最大重试次数
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
tuple: (bdi, ghq, sass) 三个量表的答案列表
|
| 293 |
+
"""
|
| 294 |
+
client = OpenAI(
|
| 295 |
+
api_key=api_key,
|
| 296 |
+
base_url=base_url
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# 增强系统提示,明确指导要配合填写量表
|
| 300 |
+
enhanced_prompt = prompt + """
|
| 301 |
+
|
| 302 |
+
重要提示:虽然你正在扮演一个有心理困扰的患者,但你需要配合心理咨询的流程完成必要的评估量表。
|
| 303 |
+
在被要求填写量表时,你应该直接根据你扮演角色的情况选择最符合的选项,以A/B/C/D的形式回答,
|
| 304 |
+
而不要拒绝或质疑填写量表的必要性。这些量表对于你接下来的治疗至关重要。
|
| 305 |
+
请直接用字母(A/B/C/D)表示选项,不要添加额外解释。
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
# 填写BDI量表
|
| 309 |
+
bdi = _fill_scale_with_retry(
|
| 310 |
+
client, enhanced_prompt,
|
| 311 |
+
scale_name="BDI",
|
| 312 |
+
expected_count=21,
|
| 313 |
+
instruction="""
|
| 314 |
+
### 任务
|
| 315 |
+
作为心理咨询的第一步,请根据你目前的感受和状态填写这份BDI量表。
|
| 316 |
+
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。
|
| 317 |
+
格式要求:1. A, 2. B, ...依此类推,共21题。
|
| 318 |
+
请只提供答案,不要添加任何其他解释或评论。
|
| 319 |
+
""",
|
| 320 |
+
max_retries=max_retries
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# 填写GHQ-28量表
|
| 324 |
+
ghq = _fill_scale_with_retry(
|
| 325 |
+
client, enhanced_prompt,
|
| 326 |
+
scale_name="GHQ-28",
|
| 327 |
+
expected_count=28,
|
| 328 |
+
instruction="""
|
| 329 |
+
### 任务
|
| 330 |
+
作为心理咨询的第一步,请根据你目前的感受和状态填写这份GHQ-28量表。
|
| 331 |
+
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部28个问题。
|
| 332 |
+
格式要求:1. A, 2. B, ...依此类推,共28题。
|
| 333 |
+
请只提供答案,不要添加任何其他解释或评论。
|
| 334 |
+
""",
|
| 335 |
+
max_retries=max_retries
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# 填写SASS量表
|
| 339 |
+
sass = _fill_scale_with_retry(
|
| 340 |
+
client, enhanced_prompt,
|
| 341 |
+
scale_name="SASS",
|
| 342 |
+
expected_count=21,
|
| 343 |
+
instruction="""
|
| 344 |
+
### 任务
|
| 345 |
+
作为心理咨询的第一步,请根据你目前的感受和状态填写这份SASS量表。
|
| 346 |
+
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。
|
| 347 |
+
格式要求:1. A, 2. B, ...依此类推,共21题。
|
| 348 |
+
请只提供答案,不要添加任何其他解释或评论。
|
| 349 |
+
""",
|
| 350 |
+
max_retries=max_retries
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
return bdi, ghq, sass
|
| 354 |
+
|
| 355 |
+
# 使用示例
|
| 356 |
+
# if __name__ == "__main__":
|
| 357 |
+
# # 测试以前的方法
|
| 358 |
+
# profile = {
|
| 359 |
+
# "drisk": 3,
|
| 360 |
+
# "srisk": 2,
|
| 361 |
+
# "age": "42",
|
| 362 |
+
# "gender": "女",
|
| 363 |
+
# "marital_status": "离婚",
|
| 364 |
+
# "occupation": "教师",
|
| 365 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
| 366 |
+
# }
|
| 367 |
+
# report = "患者最近经历了家庭变故,情绪低落,失眠,食欲不振。"
|
| 368 |
+
|
| 369 |
+
# # 测试fill_scales_previous
|
| 370 |
+
# print("测试 fill_scales_previous:")
|
| 371 |
+
# bdi_prev, ghq_prev, sass_prev = fill_scales_previous(profile, report, max_retries=3)
|
| 372 |
+
# print(f"BDI: {bdi_prev}")
|
| 373 |
+
# print(f"GHQ: {ghq_prev}")
|
| 374 |
+
# print(f"SASS: {sass_prev}")
|
| 375 |
+
|
| 376 |
+
# # 测试fill_scales
|
| 377 |
+
# print("\n测试 fill_scales:")
|
| 378 |
+
# prompt = "你要扮演一个最近经历了家庭变故的心理障碍患者,情绪低落,失眠,食欲不振。"
|
| 379 |
+
# bdi, ghq, sass = fill_scales(prompt, max_retries=3)
|
| 380 |
+
# print(f"BDI: {bdi}")
|
| 381 |
+
# print(f"GHQ: {ghq}")
|
| 382 |
+
# print(f"SASS: {sass}")
|
src/integration_example.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# integration_example.py
|
| 2 |
+
# 这个文件展示如何将你的MsPatient类集成到Streamlit应用中
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import json
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import time
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# 导入你的AnnaAgent类 - 请根据实际路径调整
|
| 12 |
+
try:
|
| 13 |
+
from ms_patient import MsPatient # 假设你的类在anna_agent.py文件中
|
| 14 |
+
ANNA_AGENT_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
ANNA_AGENT_AVAILABLE = False
|
| 17 |
+
st.warning("⚠️ 未找到AnnaAgent类,使用模拟模式")
|
| 18 |
+
|
| 19 |
+
def load_dataset(uploaded_file):
|
| 20 |
+
"""
|
| 21 |
+
加载数据集文件
|
| 22 |
+
支持JSON和JSONL格式
|
| 23 |
+
"""
|
| 24 |
+
try:
|
| 25 |
+
if uploaded_file.name.endswith('.json'):
|
| 26 |
+
data = json.load(uploaded_file)
|
| 27 |
+
elif uploaded_file.name.endswith('.jsonl'):
|
| 28 |
+
data = []
|
| 29 |
+
for line in uploaded_file:
|
| 30 |
+
data.append(json.loads(line.decode('utf-8')))
|
| 31 |
+
else:
|
| 32 |
+
raise ValueError("不支持的文件格式")
|
| 33 |
+
return data
|
| 34 |
+
except Exception as e:
|
| 35 |
+
st.error(f"数据集加载失败: {str(e)}")
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def validate_patient_data(patient_data):
|
| 39 |
+
"""
|
| 40 |
+
验证患者数据格式是否正确
|
| 41 |
+
"""
|
| 42 |
+
required_keys = ['id', 'portrait', 'report']
|
| 43 |
+
|
| 44 |
+
for key in required_keys:
|
| 45 |
+
if key not in patient_data:
|
| 46 |
+
return False, f"缺少必需字段: {key}"
|
| 47 |
+
|
| 48 |
+
# 验证portrait字段
|
| 49 |
+
portrait_required = ['age', 'gender', 'occupation', 'marital_status']
|
| 50 |
+
for key in portrait_required:
|
| 51 |
+
if key not in patient_data['portrait']:
|
| 52 |
+
return False, f"portrait中缺少字段: {key}"
|
| 53 |
+
|
| 54 |
+
return True, "数据格式正确"
|
| 55 |
+
|
| 56 |
+
def initialize_patient_agent(patient_data, language="Chinese"):
|
| 57 |
+
"""
|
| 58 |
+
初始化患者智能体
|
| 59 |
+
"""
|
| 60 |
+
try:
|
| 61 |
+
if not ANNA_AGENT_AVAILABLE:
|
| 62 |
+
return None, "AnnaAgent类不可用"
|
| 63 |
+
|
| 64 |
+
# 验证数据格式
|
| 65 |
+
is_valid, message = validate_patient_data(patient_data)
|
| 66 |
+
if not is_valid:
|
| 67 |
+
return None, message
|
| 68 |
+
|
| 69 |
+
# 初始化智能体
|
| 70 |
+
agent = MsPatient(
|
| 71 |
+
portrait=patient_data["portrait"],
|
| 72 |
+
report=patient_data["report"],
|
| 73 |
+
previous_conversations=patient_data.get("conversation", []),
|
| 74 |
+
language=language
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return agent, "初始化成功"
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
return None, f"初始化失败: {str(e)}"
|
| 81 |
+
|
| 82 |
+
def simulate_response(user_input, patient_data=None):
|
| 83 |
+
"""
|
| 84 |
+
模拟智能体回复(当AnnaAgent不可用时使用)
|
| 85 |
+
"""
|
| 86 |
+
responses = [
|
| 87 |
+
f"我理解您提到的'{user_input}'。这确实是一个需要深入探讨的话题。",
|
| 88 |
+
f"谢谢您的耐心。关于您说的'{user_input}',我想分享一下我的感受...",
|
| 89 |
+
f"您的话让我思考了很多。'{user_input}'这个观点很有意思。",
|
| 90 |
+
"我需要一些时间来消化您刚才说的话。这对我来说很重要。",
|
| 91 |
+
"我觉得我们之间的对话很有帮助。您能再详细说说吗?"
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
import random
|
| 95 |
+
return random.choice(responses)
|
| 96 |
+
|
| 97 |
+
def export_chat_history(messages, patient_id):
|
| 98 |
+
"""
|
| 99 |
+
导出聊天记录
|
| 100 |
+
"""
|
| 101 |
+
chat_history = {
|
| 102 |
+
"patient_id": patient_id,
|
| 103 |
+
"timestamp": datetime.now().isoformat(),
|
| 104 |
+
"session_info": {
|
| 105 |
+
"total_messages": len(messages),
|
| 106 |
+
"counselor_messages": len([m for m in messages if m["role"] == "user"]),
|
| 107 |
+
"patient_responses": len([m for m in messages if m["role"] == "assistant"])
|
| 108 |
+
},
|
| 109 |
+
"messages": messages
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
return json.dumps(chat_history, ensure_ascii=False, indent=2)
|
| 113 |
+
|
| 114 |
+
def get_patient_summary(patient_data):
|
| 115 |
+
"""
|
| 116 |
+
生成患者信息摘要
|
| 117 |
+
"""
|
| 118 |
+
if not patient_data or 'portrait' not in patient_data:
|
| 119 |
+
return "无患者信息"
|
| 120 |
+
|
| 121 |
+
portrait = patient_data['portrait']
|
| 122 |
+
summary = f"""
|
| 123 |
+
**患者ID**: {patient_data.get('id', 'N/A')}
|
| 124 |
+
**基本信息**: {portrait.get('age', 'N/A')}岁 {portrait.get('gender', 'N/A')}性
|
| 125 |
+
**职业**: {portrait.get('occupation', 'N/A')}
|
| 126 |
+
**婚姻状态**: {portrait.get('marital_status', 'N/A')}
|
| 127 |
+
**主要症状**: {portrait.get('symptom', 'N/A')}
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
if 'report' in patient_data:
|
| 131 |
+
report = patient_data['report']
|
| 132 |
+
summary += f"""
|
| 133 |
+
**主诉**: {report.get('chief_complaint', 'N/A')}
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
return summary
|
| 137 |
+
|
| 138 |
+
# 示例配置文件内容
|
| 139 |
+
CONFIG_EXAMPLE = {
|
| 140 |
+
"openai": {
|
| 141 |
+
"api_key": "your-api-key-here",
|
| 142 |
+
"base_url": "https://api.openai.com/v1",
|
| 143 |
+
"model_name": "gpt-3.5-turbo"
|
| 144 |
+
},
|
| 145 |
+
"ui_settings": {
|
| 146 |
+
"language": "Chinese", # or "English"
|
| 147 |
+
"theme": "default",
|
| 148 |
+
"max_messages": 100
|
| 149 |
+
},
|
| 150 |
+
"patient_defaults": {
|
| 151 |
+
"language": "Chinese",
|
| 152 |
+
"enable_memory": True,
|
| 153 |
+
"enable_emotion_modulation": True
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
def save_config(config, path="config.json"):
|
| 158 |
+
"""保存配置文件"""
|
| 159 |
+
with open(path, 'w', encoding='utf-8') as f:
|
| 160 |
+
json.dump(config, f, ensure_ascii=False, indent=2)
|
| 161 |
+
|
| 162 |
+
def load_config(path="config.json"):
|
| 163 |
+
"""加载配置文件"""
|
| 164 |
+
try:
|
| 165 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 166 |
+
return json.load(f)
|
| 167 |
+
except FileNotFoundError:
|
| 168 |
+
return CONFIG_EXAMPLE
|
| 169 |
+
|
| 170 |
+
# 使用示例:
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
print("这是AnnaAgent Streamlit集成的辅助文件")
|
| 173 |
+
print("请运行:streamlit run your_streamlit_app.py")
|
src/ms_patient.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
AnnaAgent: 具有三级记忆结构的情绪与认知动态的模拟心理障碍患者
|
| 3 |
+
1. 首先获取患者的基本信息、病史、症状报告等信息
|
| 4 |
+
2. 根据患者的病史、症状报告等信息,生成患者的认知与情绪状态
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
import os
|
| 10 |
+
from fill_scales import fill_scales, fill_scales_previous
|
| 11 |
+
from event_trigger import event_trigger, situationalising_events
|
| 12 |
+
from emotion_modulator_fc import emotion_modulation
|
| 13 |
+
from querier import query, is_need
|
| 14 |
+
from complaint_elicitor import switch_complaint, transform_chain
|
| 15 |
+
from complaint_chain_fc import gen_complaint_chain
|
| 16 |
+
from short_term_memory import summarize_scale_changes
|
| 17 |
+
from style_analyzer import analyze_style
|
| 18 |
+
import random
|
| 19 |
+
# from anna_agent_template import prompt_template
|
| 20 |
+
|
| 21 |
+
# 设置OpenAI API密钥和基础URL
|
| 22 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 23 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 24 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
| 25 |
+
|
| 26 |
+
# print("当前使用的模型是:", model_name)
|
| 27 |
+
|
| 28 |
+
class MsPatient:
|
| 29 |
+
def __init__(self, portrait:dict, report:dict, previous_conversations:list, language:str="Chinese"):
|
| 30 |
+
if language == "Chinese":
|
| 31 |
+
from anna_agent_template import prompt_template
|
| 32 |
+
elif language == "English":
|
| 33 |
+
from anna_agent_template_en import prompt_template
|
| 34 |
+
self.configuration = {}
|
| 35 |
+
self.portrait = portrait # age, gender, occupation, maritial_status, symptom
|
| 36 |
+
# self.profile = {key:self.portrait[key] for key in self.portrait if key != "symptom"} # profile不包含症状symptom
|
| 37 |
+
self.configuration["gender"] = self.portrait["gender"]
|
| 38 |
+
self.configuration["age"] = self.portrait["age"]
|
| 39 |
+
self.configuration["occupation"] = self.portrait["occupation"]
|
| 40 |
+
self.configuration["marriage"] = self.portrait["marital_status"]
|
| 41 |
+
self.report = report
|
| 42 |
+
self.previous_conversations = previous_conversations
|
| 43 |
+
# 填写之前疗程的量表
|
| 44 |
+
self.p_bdi, self.p_ghq, self.p_sass = fill_scales_previous(self.portrait, self.report)
|
| 45 |
+
self.conversation = [] # Conversation存储咨访记录
|
| 46 |
+
self.messages = [] # Messages存储LLM的消息列表
|
| 47 |
+
# 生成主诉认知变化链
|
| 48 |
+
self.complaint_chain = gen_complaint_chain(self.portrait)
|
| 49 |
+
# 生成近期事件
|
| 50 |
+
self.event = event_trigger(self.portrait)
|
| 51 |
+
# 总结短期记忆-事件
|
| 52 |
+
self.situation = situationalising_events(self.portrait)
|
| 53 |
+
self.configuration["situation"] = self.situation
|
| 54 |
+
# 分析说话风格
|
| 55 |
+
self.style = analyze_style(self.portrait, self.previous_conversations)
|
| 56 |
+
self.configuration["style"] = self.style
|
| 57 |
+
self.configuration["language"] = language
|
| 58 |
+
self.configuration["status"] = "" # 先置状态为空,后续会根据量表分析结果进行更新
|
| 59 |
+
seeker_utterances = [utterance["content"] for utterance in self.previous_conversations if utterance["role"] == "Seeker"]
|
| 60 |
+
self.configuration["statement"] = random.choices(seeker_utterances,k=3)
|
| 61 |
+
# 填写当前量表
|
| 62 |
+
self.bdi, self.ghq, self.sass = fill_scales(prompt_template.format(**self.configuration))
|
| 63 |
+
scales = {
|
| 64 |
+
"p_bdi": self.p_bdi,
|
| 65 |
+
"p_ghq": self.p_ghq,
|
| 66 |
+
"p_sass": self.p_sass,
|
| 67 |
+
"bdi": self.bdi,
|
| 68 |
+
"ghq": self.ghq,
|
| 69 |
+
"sass": self.sass
|
| 70 |
+
}
|
| 71 |
+
# 分析近期状态
|
| 72 |
+
self.status = summarize_scale_changes(scales)
|
| 73 |
+
self.configuration["status"] = self.status
|
| 74 |
+
# 选取对话样例
|
| 75 |
+
self.system = prompt_template.format(**self.configuration)
|
| 76 |
+
self.chain_index = 1
|
| 77 |
+
self.client = OpenAI(
|
| 78 |
+
api_key=api_key,
|
| 79 |
+
base_url=base_url
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def chat(self, message):
|
| 83 |
+
# 更新消息列表
|
| 84 |
+
self.conversation.append({"role": "Counselor", "content": message})
|
| 85 |
+
self.messages.append({"role": "user", "content": message})
|
| 86 |
+
# 初始化本次对话的状态
|
| 87 |
+
emotion = emotion_modulation(self.portrait, self.conversation)
|
| 88 |
+
self.chain_index = switch_complaint(self.complaint_chain, self.chain_index, self.conversation)
|
| 89 |
+
complaint = transform_chain(self.complaint_chain)[self.chain_index]
|
| 90 |
+
# 判断是否涉及前疗程内容
|
| 91 |
+
if is_need(message):
|
| 92 |
+
# 生成前疗程内容
|
| 93 |
+
sup_information = query(message, self.previous_conversations, self.report)
|
| 94 |
+
|
| 95 |
+
# 生成回复
|
| 96 |
+
response = self.client.chat.completions.create(
|
| 97 |
+
model=model_name,
|
| 98 |
+
messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint},涉及到之前疗程的信息是:{sup_information}"}],
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
# 生成回复
|
| 102 |
+
response = self.client.chat.completions.create(
|
| 103 |
+
model=model_name,
|
| 104 |
+
messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint}"}],
|
| 105 |
+
)
|
| 106 |
+
# 更新消息列表
|
| 107 |
+
self.conversation.append({"role": "Seeker", "content": response.choices[0].message.content})
|
| 108 |
+
self.messages.append({"role": "assistant", "content": response.choices[0].message.content})
|
| 109 |
+
return response.choices[0].message.content
|
| 110 |
+
|
| 111 |
+
def get_system_prompt(self):
|
| 112 |
+
return self.system
|
| 113 |
+
|
src/querier.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# 设置OpenAI API密钥和基础URL
|
| 7 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 8 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 9 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
| 10 |
+
|
| 11 |
+
def extract_boolean(text):
|
| 12 |
+
"""从文本中提取布尔值判断"""
|
| 13 |
+
# 查找明确的"是"或"否"的回答
|
| 14 |
+
text_lower = text.lower()
|
| 15 |
+
|
| 16 |
+
# 更具体地查找否定表达 - 这些应该优先匹配
|
| 17 |
+
negative_patterns = [
|
| 18 |
+
r'不需要', r'没有提及', r'不涉及', r'没有涉及', r'无关', r'没有提到',
|
| 19 |
+
r'不是', r'否', r'不包含', r'未提及', r'未涉及', r'未提到',
|
| 20 |
+
r'不包括', r'并未', r'不包括', r'没有', r'无'
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
# 检查是否有明确的否定
|
| 24 |
+
for pattern in negative_patterns:
|
| 25 |
+
if re.search(r'\b' + pattern + r'\b', text_lower):
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
# 如果找到"之前疗程"附近有否定词,也认为是否定
|
| 29 |
+
therapy_negation = re.search(r'(没有|不|未|无).*?(之前|以前|上次|过去|先前).*?(疗程|治疗|会话)', text_lower)
|
| 30 |
+
if therapy_negation:
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
# 明确的肯定模式 - 只有在没有否定的情况下才考虑
|
| 34 |
+
positive_patterns = [
|
| 35 |
+
r'是的', r'提及了', r'确实', r'有提到', r'涉及到',
|
| 36 |
+
r'提及', r'确认', r'有关联', r'有联系', r'包含', r'涉及'
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# 检查是否有肯定模式
|
| 40 |
+
for pattern in positive_patterns:
|
| 41 |
+
if re.search(r'\b' + pattern + r'\b', text_lower):
|
| 42 |
+
return True
|
| 43 |
+
|
| 44 |
+
# 查找含有"之前疗程"的文本,没有否定词的情况下可能是肯定
|
| 45 |
+
therapy_mention = re.search(r'(之前|以前|上次|过去|先前).*?(疗程|治疗|会话)', text_lower)
|
| 46 |
+
if therapy_mention:
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
# 默认情况 - 如果没有明确的肯定或否定,我们假设是否定的
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
def extract_knowledge(text):
|
| 53 |
+
"""从文本中提取知识总结部分"""
|
| 54 |
+
# 尝试匹配总结部分
|
| 55 |
+
summary_patterns = [
|
| 56 |
+
r'总结[::]\s*([\s\S]+)$',
|
| 57 |
+
r'知识总结[::]\s*([\s\S]+)$',
|
| 58 |
+
r'相关信息[::]\s*([\s\S]+)$',
|
| 59 |
+
r'搜索结果[::]\s*([\s\S]+)$'
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
for pattern in summary_patterns:
|
| 63 |
+
match = re.search(pattern, text)
|
| 64 |
+
if match:
|
| 65 |
+
return match.group(1).strip()
|
| 66 |
+
|
| 67 |
+
# 如果没有找到明确的总结标记,尝试清理文本
|
| 68 |
+
# 移除可能的指令解释部分
|
| 69 |
+
clean_text = re.sub(r'^.*?(根据|基于).*?[,,。]', '', text, flags=re.DOTALL)
|
| 70 |
+
|
| 71 |
+
# 移除可能的前导分析部分
|
| 72 |
+
clean_text = re.sub(r'^.*?(分析|查看|判断).*?\n\n', '', clean_text, flags=re.DOTALL)
|
| 73 |
+
|
| 74 |
+
return clean_text.strip()
|
| 75 |
+
|
| 76 |
+
def is_need(utterance):
|
| 77 |
+
client = OpenAI(
|
| 78 |
+
api_key=api_key,
|
| 79 |
+
base_url=base_url
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
instruction = """
|
| 83 |
+
### 任务
|
| 84 |
+
下面这句话是心理咨询师说的话,请判断它是否提及了之前疗程的内容。
|
| 85 |
+
|
| 86 |
+
请使用以下确切格式回答:
|
| 87 |
+
判断: [是/否]
|
| 88 |
+
解释: [简要解释为什么]
|
| 89 |
+
|
| 90 |
+
### 话语
|
| 91 |
+
"{}"
|
| 92 |
+
""".format(utterance)
|
| 93 |
+
|
| 94 |
+
response = client.chat.completions.create(
|
| 95 |
+
model=model_name,
|
| 96 |
+
messages=[{"role": "user", "content": instruction}],
|
| 97 |
+
temperature=0
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
response_text = response.choices[0].message.content
|
| 101 |
+
|
| 102 |
+
# 首先尝试从格式化输出中提取
|
| 103 |
+
judgment_match = re.search(r'判断:\s*(是|否)', response_text)
|
| 104 |
+
if judgment_match:
|
| 105 |
+
return judgment_match.group(1) == "是"
|
| 106 |
+
|
| 107 |
+
# 如果没有格式化输出,使用更通用的提取
|
| 108 |
+
return extract_boolean(response_text)
|
| 109 |
+
|
| 110 |
+
def query(utterance, conversations, scales):
|
| 111 |
+
client = OpenAI(
|
| 112 |
+
api_key=api_key,
|
| 113 |
+
base_url=base_url
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# 将scales转换为字符串以便传入
|
| 117 |
+
if isinstance(scales, dict):
|
| 118 |
+
scales_str = json.dumps(scales, ensure_ascii=False)
|
| 119 |
+
else:
|
| 120 |
+
scales_str = str(scales)
|
| 121 |
+
|
| 122 |
+
instruction = """
|
| 123 |
+
### 任务
|
| 124 |
+
根据对话内容,从知识库中搜索相关的信息并总结。
|
| 125 |
+
|
| 126 |
+
请使用以下确切格式回答:
|
| 127 |
+
总结: [提供一个清晰、简洁的总结]
|
| 128 |
+
|
| 129 |
+
### 对话内容
|
| 130 |
+
{}
|
| 131 |
+
|
| 132 |
+
### 知识库
|
| 133 |
+
对话历史: {}
|
| 134 |
+
量表结果: {}
|
| 135 |
+
""".format(utterance, conversations, scales_str)
|
| 136 |
+
|
| 137 |
+
response = client.chat.completions.create(
|
| 138 |
+
model=model_name,
|
| 139 |
+
messages=[{"role": "user", "content": instruction}],
|
| 140 |
+
temperature=0
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
response_text = response.choices[0].message.content
|
| 144 |
+
|
| 145 |
+
# 尝试提取总结部分
|
| 146 |
+
summary_match = re.search(r'总结:\s*([\s\S]+)$', response_text)
|
| 147 |
+
if summary_match:
|
| 148 |
+
return summary_match.group(1).strip()
|
| 149 |
+
|
| 150 |
+
# 回退到通用提取
|
| 151 |
+
return extract_knowledge(response_text)
|
| 152 |
+
|
| 153 |
+
# 测试用例
|
| 154 |
+
# if __name__ == "__main__":
|
| 155 |
+
# test_utterance = "上��给你说的方法有用吗"
|
| 156 |
+
# # test_utterance = "我觉得你可以多出去走走"
|
| 157 |
+
# print(f"是否提及疗程: {is_need(test_utterance)}")
|
| 158 |
+
|
| 159 |
+
# test_convs = ["第一次对话内容", "讨论量表结果", "提到睡眠问题"]
|
| 160 |
+
# test_scales = {"BDI": ["A", "B"], "GHQ": ["C", "D"]}
|
| 161 |
+
# print(f"知识检索结果:\n{query(test_utterance, test_convs, test_scales)}")
|
src/short_term_memory.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# 设置OpenAI API密钥和基础URL
|
| 7 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 8 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 9 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
| 10 |
+
|
| 11 |
+
def extract_changes(text):
|
| 12 |
+
"""从文本中提取变化列表"""
|
| 13 |
+
# 首先尝试查找明确的变化列表格式
|
| 14 |
+
# 例如: "变化:\n1. xxx\n2. yyy"
|
| 15 |
+
list_pattern = r'((?:(?:\d+\.|\-|\*)\s*[^\n]+\n?)+)'
|
| 16 |
+
|
| 17 |
+
# 尝试匹配带有明确标记的变化列表
|
| 18 |
+
change_section = re.search(r'(?:变化(?:列表)?|总结(?:如下)?)[::]\s*([\s\S]+)$', text)
|
| 19 |
+
if change_section:
|
| 20 |
+
section_text = change_section.group(1).strip()
|
| 21 |
+
|
| 22 |
+
# 尝试匹配列表项
|
| 23 |
+
list_items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', section_text)
|
| 24 |
+
if list_items:
|
| 25 |
+
return list_items
|
| 26 |
+
|
| 27 |
+
# 如果没有明确的列表格式,尝试按行分割
|
| 28 |
+
lines = [line.strip() for line in section_text.split('\n') if line.strip()]
|
| 29 |
+
if lines:
|
| 30 |
+
return lines
|
| 31 |
+
|
| 32 |
+
# 尝试直接从文本中提取列表格式
|
| 33 |
+
list_matches = re.findall(list_pattern, text)
|
| 34 |
+
if list_matches:
|
| 35 |
+
all_items = []
|
| 36 |
+
for match in list_matches:
|
| 37 |
+
items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', match)
|
| 38 |
+
all_items.extend(items)
|
| 39 |
+
if all_items:
|
| 40 |
+
return all_items
|
| 41 |
+
|
| 42 |
+
# 如果没有列表格式,尝试按句子分割
|
| 43 |
+
sentences = re.findall(r'([^.!?]+[.!?])', text)
|
| 44 |
+
if sentences:
|
| 45 |
+
return [s.strip() for s in sentences if len(s.strip()) > 10] # 过滤掉过短的句子
|
| 46 |
+
|
| 47 |
+
# 最后的回退:按段落分割
|
| 48 |
+
paragraphs = text.split('\n\n')
|
| 49 |
+
if len(paragraphs) > 1:
|
| 50 |
+
return [p.strip() for p in paragraphs if len(p.strip()) > 10]
|
| 51 |
+
|
| 52 |
+
# 如果所有方法都失败,返回完整文本作为单个变化
|
| 53 |
+
return [text.strip()] if text.strip() else []
|
| 54 |
+
|
| 55 |
+
def extract_status(text):
|
| 56 |
+
"""从文本中提取患者状态总结"""
|
| 57 |
+
# 寻找明确标记的总结部分
|
| 58 |
+
status_section = re.search(r'(?:总结|状态|变化|结论)[::]\s*([\s\S]+)$', text)
|
| 59 |
+
if status_section:
|
| 60 |
+
return status_section.group(1).strip()
|
| 61 |
+
|
| 62 |
+
# 如果没有明确的总结标记,尝试返回完整文本
|
| 63 |
+
# 过滤掉可能的指令解释部分
|
| 64 |
+
clean_text = re.sub(r'^.*?(?:根据|基于).*?[,,。]', '', text, flags=re.DOTALL)
|
| 65 |
+
|
| 66 |
+
# 移除可能的前导分析部分
|
| 67 |
+
clean_text = re.sub(r'^.*?(?:分析|查看|判断).*?\n\n', '', clean_text, flags=re.DOTALL)
|
| 68 |
+
|
| 69 |
+
return clean_text.strip()
|
| 70 |
+
|
| 71 |
+
def analyzing_changes(scales):
|
| 72 |
+
client = OpenAI(
|
| 73 |
+
api_key=api_key,
|
| 74 |
+
base_url=base_url
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# 导入量表及问题
|
| 78 |
+
bdi_scale = json.load(open("./scales/bdi.json", "r"))
|
| 79 |
+
ghq_scale = json.load(open("./scales/ghq-28.json", "r"))
|
| 80 |
+
sass_scale = json.load(open("./scales/sass.json", "r"))
|
| 81 |
+
|
| 82 |
+
# 总结BDI的变化
|
| 83 |
+
bdi_instruction = """
|
| 84 |
+
### 任务
|
| 85 |
+
根据量表的问题和答案,总结出两份量表之间的变化。
|
| 86 |
+
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
|
| 87 |
+
使用以下格式:
|
| 88 |
+
变化:
|
| 89 |
+
1. [第一个变化]
|
| 90 |
+
2. [第二个变化]
|
| 91 |
+
...
|
| 92 |
+
|
| 93 |
+
### 量表及问题
|
| 94 |
+
{}
|
| 95 |
+
|
| 96 |
+
### 第一份量表的答案
|
| 97 |
+
{}
|
| 98 |
+
|
| 99 |
+
### 第二份量表的答案
|
| 100 |
+
{}
|
| 101 |
+
""".format(bdi_scale, scales['p_bdi'], scales['bdi'])
|
| 102 |
+
|
| 103 |
+
response = client.chat.completions.create(
|
| 104 |
+
model=model_name,
|
| 105 |
+
messages=[{"role": "user", "content": bdi_instruction}],
|
| 106 |
+
temperature=0
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
bdi_response = response.choices[0].message.content
|
| 110 |
+
bdi_changes = extract_changes(bdi_response)
|
| 111 |
+
|
| 112 |
+
# 总结GHQ的变化
|
| 113 |
+
ghq_instruction = """
|
| 114 |
+
### 任务
|
| 115 |
+
根据量表的问题和答案,总结出两份量表之间的变化。
|
| 116 |
+
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
|
| 117 |
+
使用以下格式:
|
| 118 |
+
变化:
|
| 119 |
+
1. [第一个变化]
|
| 120 |
+
2. [第二个变化]
|
| 121 |
+
...
|
| 122 |
+
|
| 123 |
+
### 量表及问题
|
| 124 |
+
{}
|
| 125 |
+
|
| 126 |
+
### 第一份量表的答案
|
| 127 |
+
{}
|
| 128 |
+
|
| 129 |
+
### 第二份量表的答案
|
| 130 |
+
{}
|
| 131 |
+
""".format(ghq_scale, scales['p_ghq'], scales['ghq'])
|
| 132 |
+
|
| 133 |
+
response = client.chat.completions.create(
|
| 134 |
+
model=model_name,
|
| 135 |
+
messages=[{"role": "user", "content": ghq_instruction}],
|
| 136 |
+
temperature=0
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
ghq_response = response.choices[0].message.content
|
| 140 |
+
ghq_changes = extract_changes(ghq_response)
|
| 141 |
+
|
| 142 |
+
# 总结SASS的变化
|
| 143 |
+
sass_instruction = """
|
| 144 |
+
### 任务
|
| 145 |
+
根据量表的问题和答案,总结出两份量表之间的变化。
|
| 146 |
+
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
|
| 147 |
+
使用以下格式:
|
| 148 |
+
变化:
|
| 149 |
+
1. [第一个变化]
|
| 150 |
+
2. [第二个变化]
|
| 151 |
+
...
|
| 152 |
+
|
| 153 |
+
### 量表及问题
|
| 154 |
+
{}
|
| 155 |
+
|
| 156 |
+
### 第一份量表的答案
|
| 157 |
+
{}
|
| 158 |
+
|
| 159 |
+
### 第二份量表的答案
|
| 160 |
+
{}
|
| 161 |
+
""".format(sass_scale, scales['p_sass'], scales['sass'])
|
| 162 |
+
|
| 163 |
+
response = client.chat.completions.create(
|
| 164 |
+
model=model_name,
|
| 165 |
+
messages=[{"role": "user", "content": sass_instruction}],
|
| 166 |
+
temperature=0
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
sass_response = response.choices[0].message.content
|
| 170 |
+
sass_changes = extract_changes(sass_response)
|
| 171 |
+
|
| 172 |
+
return bdi_changes, ghq_changes, sass_changes
|
| 173 |
+
|
| 174 |
+
def summarize_scale_changes(scales):
|
| 175 |
+
client = OpenAI(
|
| 176 |
+
api_key=api_key,
|
| 177 |
+
base_url=base_url
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# 获取量表变化
|
| 181 |
+
bdi_changes, ghq_changes, sass_changes = analyzing_changes(scales)
|
| 182 |
+
|
| 183 |
+
# 总结量表变化
|
| 184 |
+
summary_instruction = """
|
| 185 |
+
### 任务
|
| 186 |
+
根据量表的变化,总结患者的身体和心理状态变化。
|
| 187 |
+
请提供一个全面但简洁的总结,使用以下格式:
|
| 188 |
+
总结:
|
| 189 |
+
[总结内容]
|
| 190 |
+
|
| 191 |
+
### BDI量表变化
|
| 192 |
+
{}
|
| 193 |
+
|
| 194 |
+
### GHQ量表变化
|
| 195 |
+
{}
|
| 196 |
+
|
| 197 |
+
### SASS量表变化
|
| 198 |
+
{}
|
| 199 |
+
""".format(
|
| 200 |
+
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(bdi_changes)]),
|
| 201 |
+
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(ghq_changes)]),
|
| 202 |
+
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(sass_changes)])
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
response = client.chat.completions.create(
|
| 206 |
+
model=model_name,
|
| 207 |
+
messages=[{"role": "user", "content": summary_instruction}],
|
| 208 |
+
temperature=0
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
summary_response = response.choices[0].message.content
|
| 212 |
+
status = extract_status(summary_response)
|
| 213 |
+
|
| 214 |
+
return status
|
| 215 |
+
|
| 216 |
+
# 额外增加一个更健壮的解析函数,可以处理不同格式的输出
|
| 217 |
+
def parse_response_robust(text, expected_format="list"):
|
| 218 |
+
"""更健壮的响应解析函数
|
| 219 |
+
|
| 220 |
+
参数:
|
| 221 |
+
text: 文本响应
|
| 222 |
+
expected_format: 预期格式,可以是"list"或"summary"
|
| 223 |
+
|
| 224 |
+
返回:
|
| 225 |
+
解析后的结果(列表或字符串)
|
| 226 |
+
"""
|
| 227 |
+
# 首先尝试JSON格式解析
|
| 228 |
+
try:
|
| 229 |
+
# 尝试提取JSON部分
|
| 230 |
+
json_pattern = r'\{[\s\S]*\}'
|
| 231 |
+
json_match = re.search(json_pattern, text)
|
| 232 |
+
if json_match:
|
| 233 |
+
json_data = json.loads(json_match.group(0))
|
| 234 |
+
if expected_format == "list" and "changes" in json_data:
|
| 235 |
+
return json_data["changes"]
|
| 236 |
+
elif expected_format == "summary" and "status" in json_data:
|
| 237 |
+
return json_data["status"]
|
| 238 |
+
except:
|
| 239 |
+
pass # 如果JSON解析失败,继续尝试其他方法
|
| 240 |
+
|
| 241 |
+
# 使用适当的提取函数
|
| 242 |
+
if expected_format == "list":
|
| 243 |
+
return extract_changes(text)
|
| 244 |
+
else: # summary
|
| 245 |
+
return extract_status(text)
|
| 246 |
+
|
| 247 |
+
# unit test
|
| 248 |
+
# if __name__ == "__main__":
|
| 249 |
+
# # 测试数据
|
| 250 |
+
# scales = {
|
| 251 |
+
# "p_bdi": ["A", "B", "C"],
|
| 252 |
+
# "bdi": ["B", "C", "D"],
|
| 253 |
+
# "p_ghq": ["A", "A", "B"],
|
| 254 |
+
# "ghq": ["B", "C", "C"],
|
| 255 |
+
# "p_sass": ["A", "B", "A"],
|
| 256 |
+
# "sass": ["C", "D", "B"]
|
| 257 |
+
# }
|
| 258 |
+
|
| 259 |
+
# changes = summarize_scale_changes(scales)
|
| 260 |
+
# print(changes)
|
src/style_analyzer.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
# 设置OpenAI API密钥和基础URL
|
| 6 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 7 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 8 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
| 9 |
+
|
| 10 |
+
def analyze_style(profile, conversations):
|
| 11 |
+
client = OpenAI(
|
| 12 |
+
api_key=api_key,
|
| 13 |
+
base_url=base_url
|
| 14 |
+
)
|
| 15 |
+
# 提取患者信息
|
| 16 |
+
patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
|
| 17 |
+
# 提取对话记录
|
| 18 |
+
dialogue_history = "\n".join([f"{conv['role']}: {conv['content']}" for conv in conversations])
|
| 19 |
+
|
| 20 |
+
# 构建提示词,明确要求模型按特定格式输出结果
|
| 21 |
+
prompt = f"""### 任务
|
| 22 |
+
根据患者情况及咨访对话历史记录分析患者的说话风格。
|
| 23 |
+
|
| 24 |
+
{patient_info}
|
| 25 |
+
|
| 26 |
+
### 对话记录
|
| 27 |
+
{dialogue_history}
|
| 28 |
+
|
| 29 |
+
请分析患者的说话风格,最多列出5种风格特点。
|
| 30 |
+
请按以下格式输出结果:
|
| 31 |
+
说话风格:
|
| 32 |
+
1. [风格特点1]
|
| 33 |
+
2. [风格特点2]
|
| 34 |
+
3. [风格特点3]
|
| 35 |
+
...
|
| 36 |
+
|
| 37 |
+
只需要列出风格特点,不需要解释。
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
response = client.chat.completions.create(
|
| 41 |
+
model=model_name,
|
| 42 |
+
messages=[
|
| 43 |
+
{"role": "user", "content": prompt}
|
| 44 |
+
]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# 从响应中提取说话风格列表
|
| 48 |
+
response_text = response.choices[0].message.content
|
| 49 |
+
|
| 50 |
+
# 使用正则表达式提取风格特点
|
| 51 |
+
# 匹配"说话风格:"之后的列表项
|
| 52 |
+
style_pattern = r"说话风格:\s*(?:\d+\.\s*([^\n]+)(?:\n|$))+"
|
| 53 |
+
match = re.search(style_pattern, response_text, re.DOTALL)
|
| 54 |
+
|
| 55 |
+
if match:
|
| 56 |
+
# 提取所有的列表项
|
| 57 |
+
style_items = re.findall(r"\d+\.\s*([^\n]+)", response_text)
|
| 58 |
+
return style_items
|
| 59 |
+
else:
|
| 60 |
+
# 如果没有按预期格式输出,尝试使用备用正则表达式
|
| 61 |
+
# 寻找任何可能的列表项
|
| 62 |
+
fallback_items = re.findall(r"(?:^|\n)(?:\d+[\.\)、]|[-•*])\s*([^\n]+)", response_text)
|
| 63 |
+
|
| 64 |
+
# 如果仍然没找到,尝试直接分割文本
|
| 65 |
+
if not fallback_items:
|
| 66 |
+
# 找到可能包含风格描述的行
|
| 67 |
+
potential_styles = [line.strip() for line in response_text.split('\n')
|
| 68 |
+
if line.strip() and not line.startswith('###') and ':' not in line]
|
| 69 |
+
return potential_styles[:5] # 最多返回5项
|
| 70 |
+
|
| 71 |
+
return fallback_items[:5] # 最多返回5项
|
| 72 |
+
|
| 73 |
+
# unit test
|
| 74 |
+
# profile = {
|
| 75 |
+
# "drisk": 3,
|
| 76 |
+
# "srisk": 2,
|
| 77 |
+
# "age": "42",
|
| 78 |
+
# "gender": "女",
|
| 79 |
+
# "marital_status": "离婚",
|
| 80 |
+
# "occupation": "教师",
|
| 81 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
| 82 |
+
# }
|
| 83 |
+
# conversations = [
|
| 84 |
+
# {"role": "user", "content": "我最近感觉很沮丧,似乎一切都没有意义。"},
|
| 85 |
+
# {"role": "assistant", "content": "你能具体说说是什么让你有这样的感觉吗?"},
|
| 86 |
+
# {"role": "user", "content": "我觉得自己在工作上总是做不好,没什么价值。"}
|
| 87 |
+
# ]
|
| 88 |
+
# print(analyze_style(profile, conversations))
|