Spaces:
Sleeping
Sleeping
| import random | |
| # from collections import defaultdict | |
| # 计算总权重 | |
| def calculate_total_weight(current_state, states, category_distances, distance_weights): | |
| total_weight = 0 | |
| current_class = None | |
| for cls, state_list in states.items(): | |
| if current_state in state_list: | |
| current_class = cls | |
| break | |
| if current_class is None: | |
| raise ValueError("Current state not found in any class.") | |
| for cls, state_list in states.items(): | |
| distance = category_distances[current_class][cls] | |
| weight = distance_weights.get(distance, 0) | |
| total_weight += weight * len(state_list) | |
| return total_weight | |
| # 计算每个目标状态的概率 | |
| def calculate_probabilities(current_state, states, category_distances, distance_weights): | |
| probabilities = {} | |
| current_class = None | |
| for cls, state_list in states.items(): | |
| if current_state in state_list: | |
| current_class = cls | |
| break | |
| if current_class is None: | |
| raise ValueError("Current state not found in any class.") | |
| total_weight = calculate_total_weight(current_state, states, category_distances, distance_weights) | |
| for cls, state_list in states.items(): | |
| distance = category_distances[current_class][cls] | |
| weight = distance_weights.get(distance, 0) | |
| class_weight = weight * len(state_list) | |
| for state in state_list: | |
| if state != current_state: | |
| probabilities[state] = class_weight / total_weight | |
| return probabilities | |
| # 实现状态扰动 | |
| def perturb_state(current_state): | |
| # 定义状态和类别 | |
| states = { | |
| 'Positive': [ | |
| "admiration", | |
| "amusement", | |
| "approval", | |
| "caring", | |
| "curiosity", | |
| "desire", | |
| "excitement", | |
| "gratitude", | |
| "joy", | |
| "love", | |
| "optimism", | |
| "pride", | |
| "realization", | |
| "relief" | |
| ], | |
| 'Neutral': ['neutral'], | |
| 'Ambiguous': [ | |
| "confusion", | |
| "disappointment", | |
| "nervousness" | |
| ], | |
| 'Negative': [ | |
| "anger", | |
| "annoyance", | |
| "disapproval", | |
| "disgust", | |
| "embarrassment", | |
| "fear", | |
| "sadness", | |
| "remorse" | |
| ] | |
| } | |
| # 定义类别之间的距离 | |
| category_distances = { | |
| 'Positive': {'Positive': 0, 'Neutral': 1, 'Ambiguous': 2, 'Negative': 3}, | |
| 'Neutral': {'Positive': 1, 'Neutral': 0, 'Ambiguous': 1, 'Negative': 2}, | |
| 'Ambiguous': {'Positive': 2, 'Neutral': 1, 'Ambiguous': 0, 'Negative': 1}, | |
| 'Negative': {'Positive': 3, 'Neutral': 2, 'Ambiguous': 1, 'Negative': 0} | |
| } | |
| # 定义距离权重 | |
| distance_weights = { | |
| 0: 10, # 同类状态 | |
| 1: 5, # 相邻类别 | |
| 2: 2, # 相隔一个类别 | |
| 3: 1 # 相隔两个类别 | |
| } | |
| probabilities = calculate_probabilities(current_state, states, category_distances, distance_weights) | |
| next_state = random.choices(list(probabilities.keys()), weights=list(probabilities.values()), k=1)[0] | |
| return next_state | |
| # 示例运行 | |
| # current_state = 'confusion' | |
| # next_state = perturb_state(current_state) | |
| # print(f"Next state: {next_state}") | |
| # 验证概率分布 | |
| # state_counts = defaultdict(int) | |
| # for _ in range(1000): | |
| # next_state = perturb_state(current_state, states, category_distances, distance_weights) | |
| # state_counts[next_state] += 1 | |
| # print("\nProbability distribution:") | |
| # for state, count in state_counts.items(): | |
| # print(f"{state}: {count / 1000:.2f}") |