你說:
background_provider.py所提供的memo是否有包含歷史紀錄??
並檢查代碼是否有城市相關功能,請刪除
說明LLM 怎麼生成memo
######
# background_provider.py
import os
import threadpoolctl
import multiprocessing
num_cpus = multiprocessing.cpu_count()
threadpoolctl.threadpool_limits(limits=min(4, num_cpus), user_api='openmp')
import asyncio
import json
import logging
import configparser
import uuid
import traceback
from dotenv import load_dotenv
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from redis_manager import RedisManager
from langchain_groq import ChatGroq
from langchain.schema import HumanMessage
import spacy
import faiss
import struct
import hashlib # 添加哈希库
import aiofiles # 确保已安装 aiofiles: pip install aiofiles
from langchain.schema import SystemMessage, HumanMessage
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("background_provider")
# Load environment variables
load_dotenv()
# Load configuration file
config = configparser.ConfigParser()
config.read('config.ini')
REDIS_HOST = config.get('Redis', 'Host', fallback='localhost')
REDIS_PORT = config.getint('Redis', 'Port', fallback=6379)
REDIS_DB = config.getint('Redis', 'DB', fallback=0)
BACKGROUND_PROCESSING_STREAM = config.get('Streams', 'ProcessingStream', fallback='processing_stream')
BACKGROUND_MEMO_STREAM = config.get('Streams', 'BackgroundMemoStream', fallback='background_memo_stream')
BACKGROUND_MEMO_READY_CHANNEL = config.get('Chat', 'BackgroundMemoReadyChannel', fallback='background_memo_ready_channel')
GROQ_API_KEY = os.getenv('GROQ_API_KEY_2')
if not GROQ_API_KEY:
logger.error("GROQ_API_KEY is not set. Please set it in your environment variables.")
exit(1)
# Initialize ONNX Runtime and Tokenizer
ONNX_MODEL_PATH = "multilingual_e5_large.onnx" # Ensure this path is correct
if not os.path.exists(ONNX_MODEL_PATH):
logger.error(f"ONNX 模型文件未找到: {ONNX_MODEL_PATH}")
exit(1)
try:
ort_session = ort.InferenceSession(ONNX_MODEL_PATH)
logger.info("成功加载 ONNX 模型。")
except Exception as e:
logger.error(f"加载 ONNX 模型失败: {e}")
exit(1)
tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
MAX_LENGTH = 128 # 固定的序列长度,需与 background_datamaker_ONNX.py 保持一致
class BackgroundProvider:
def __init__(self):
self.processing_stream = BACKGROUND_PROCESSING_STREAM
self.memo_stream = BACKGROUND_MEMO_STREAM
self.memo_ready_channel = BACKGROUND_MEMO_READY_CHANNEL
self.redis_manager = RedisManager(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB
)
# Initialize ChatGroq client for memo generation
self.chat_groq = ChatGroq(
temperature=0.5,
model_name="llama-3.1-8b-instant",
groq_api_key=GROQ_API_KEY
)
# Initialize ONNX Runtime session and tokenizer
self.ort_session = ort_session
self.tokenizer = tokenizer
self.max_length = MAX_LENGTH
# Load spaCy model for entity recognition
try:
self.nlp = spacy.load('en_core_web_sm') # 或 'en_core_web_sm' 根據安裝情況
logger.info("成功加载 spaCy Transformer 模型。")
except Exception as e:
logger.error(f"加载 spaCy 模型失败: {e}")
self.nlp = None # 如果加载失败,设置为 None
# Consumer group related variables
self.consumer_group = 'background_provider_group'
self.consumer_name = f'background_provider_consumer_{uuid.uuid4()}'
self.max_retries = 3 # Maximum retry attempts
# FAISS相关初始化
self.faiss_index = None
self.faiss_index_path = "faiss.index" # FAISS索引文件路径
self.id_to_background_path = "id_to_background.json" # FAISS索引ID到背景信息的映射文件路径
self.id_to_background = {} # 映射 FAISS 索引 ID 到背景信息
# 初始化 Pub/Sub 相關
self.memo_ready_queue = asyncio.Queue()
def compute_hash(self, content: str) -> str:
"""計算內容的 SHA-256 哈希值。"""
return hashlib.sha256(content.encode('utf-8')).hexdigest()
async def listen_memo_ready(self):
"""持久化地聽取備忘錄就緒通知,並將消息放入異步隊列中。"""
try:
pubsub = self.redis_manager.redis_client.pubsub()
await pubsub.subscribe(self.memo_ready_channel.encode('utf-8'))
logger.info(f"Subscribed to {self.memo_ready_channel}")
while True:
try:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
if message:
data = json.loads(message['data'].decode('utf-8'))
await self.memo_ready_queue.put(data)
logger.debug(f"Received memo ready message: {data}")
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error in listen_memo_ready: {e}")
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Failed to subscribe to {self.memo_ready_channel}: {e}")
await asyncio.sleep(5)
asyncio.create_task(self.listen_memo_ready()) # 重試訂閱
async def wait_for_memo_notification(self, user_id, message_id, timeout=5):
"""等待特定的備忘錄就緒通知,並設置超時。"""
start_time = time.time()
while time.time() - start_time < timeout:
try:
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
data = await asyncio.wait_for(self.memo_ready_queue.get(), timeout=remaining)
if data['user_id'] == user_id and data['message_id'] == message_id:
logger.info(f"Received memo ready notification for user_id {user_id} and message_id {message_id}")
return True
except asyncio.TimeoutError:
break
except Exception as e:
logger.error(f"Error while waiting for memo notification: {e}")
logger.info(f"Timeout waiting for memo ready notification for user_id {user_id} and message_id {message_id}")
return False
async def connect(self):
await self.redis_manager.connect()
logger.info(f"Connected to Redis: {REDIS_HOST}:{REDIS_PORT}")
# Create consumer group if it doesn't exist
try:
await self.redis_manager.redis_client.xgroup_create(
name=self.processing_stream.encode('utf-8'),
groupname=self.consumer_group.encode('utf-8'),
id='$',
mkstream=True
)
logger.info(f"Consumer group '{self.consumer_group}' created for stream '{self.processing_stream}'.")
except Exception as e:
if "BUSYGROUP" in str(e):
logger.info(f"Consumer group '{self.consumer_group}' already exists for stream '{self.processing_stream}'.")
else:
logger.error(f"Error creating consumer group: {e}")
raise
# Initialize FAISS索引
self._load_or_create_faiss_index()
# 加载背景信息到 FAISS
await self.load_background_info()
def _load_or_create_faiss_index(self):
"""加载现有的 FAISS 索引或创建一个新的索引。"""
if os.path.exists(self.faiss_index_path) and os.path.exists(self.id_to_background_path):
logger.info(f"Loading FAISS index from {self.faiss_index_path}")
self.faiss_index = faiss.read_index(self.faiss_index_path)
with open(self.id_to_background_path, "r", encoding="utf-8") as f:
self.id_to_background = json.load(f)
logger.info("FAISS index loaded successfully.")
else:
logger.info("Creating new FAISS index.")
dimension = 1024 # 确保与模型的嵌入维度一致
# 使用内积索引以实现余弦相似度(向量归一化后)
self.faiss_index = faiss.IndexFlatIP(dimension)
self.id_to_background = {}
logger.info("FAISS index created with IndexFlatIP.")
async def load_background_info(self):
"""加载 background_info.txt 中的背景信息,并添加到 FAISS 和 Redis。"""
logger.info("Loading background information from background_info.txt")
background_file = "background_info.txt"
if not os.path.exists(background_file):
logger.error(f"{background_file} not found.")
return
with open(background_file, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
background_id, content = self._parse_background_line(line)
if not background_id or not content:
continue
# 计算当前内容的哈希值
current_hash = self.compute_hash(content)
# 获取存储在 Redis 中的哈希值
stored_hash = await self.redis_manager.get_hash(background_id)
if stored_hash == current_hash:
logger.info(f"Background ID {background_id} 未修改,跳过向量化。")
continue # 跳过向量化
# 向量化内容
vector = await self.vectorize_content(content)
if not vector:
logger.error(f"Failed to vectorize content for background_id: {background_id}")
continue
# 将向量转换为 NumPy 数组并归一化
vector_np = np.array(vector).astype('float32').reshape(1, -1)
faiss.normalize_L2(vector_np)
# 添加到 FAISS 索引
self.faiss_index.add(vector_np)
faiss_id = self.faiss_index.ntotal - 1
self.id_to_background[str(faiss_id)] = {
"background_id": background_id,
"content": content
}
# 插入到 Redis
await self.redis_manager.insert_background(background_id, content, vector)
# 更新哈希值
await self.redis_manager.set_hash(background_id, current_hash)
logger.info(f"Added/Updated background_id: {background_id} to FAISS and Redis.")
# 保存 FAISS 索引和映射
faiss.write_index(self.faiss_index, self.faiss_index_path)
with open(self.id_to_background_path, "w", encoding="utf-8") as f:
json.dump(self.id_to_background, f, ensure_ascii=False, indent=4)
logger.info("Background information loaded and FAISS index updated.")
def _parse_background_line(self, line: str):
"""
解析 background_info.txt 中的每一行。
假设格式为: background_id\tcontent
"""
parts = line.strip().split('\t')
if len(parts) != 2:
logger.warning(f"Invalid line format: {line}")
return None, None
return parts[0], parts[1]
async def vectorize_content(self, content: str) -> list:
"""使用 ONNX 模型将内容转换为向量。"""
try:
logger.info("使用 ONNX 模型开始向量化内容。")
# Tokenize the input text with fixed max_length
encoded = self.tokenizer(
content,
padding='max_length',
truncation=True,
max_length=self.max_length,
return_tensors="np"
)
input_ids = encoded['input_ids'].astype(np.int64)
attention_mask = encoded['attention_mask'].astype(np.int64)
# Log shapes for debugging
logger.debug(f"input_ids shape: {input_ids.shape}, attention_mask shape: {attention_mask.shape}")
# Run the model asynchronously
loop = asyncio.get_event_loop()
outputs = await loop.run_in_executor(
None,
lambda: self.ort_session.run(None, {
"input_ids": input_ids,
"attention_mask": attention_mask
})
)
last_hidden_state = outputs[0] # Assuming 'last_hidden_state' is the first output
# Generate sentence embeddings by averaging token embeddings
embedding = last_hidden_state[0][attention_mask[0].astype(bool)].mean(axis=0)
embedding = embedding.tolist()
logger.debug(f"Vector embedding: {embedding}")
# Dynamically check the embedding length based on the model's output
expected_embedding_length = last_hidden_state.shape[-1] # Get the last dimension of the model output
if len(embedding) != expected_embedding_length:
logger.error(f"Generated embedding length mismatch: Expected {expected_embedding_length}, got {len(embedding)}")
return []
logger.info(f"向量化成功,向量长度为 {expected_embedding_length}。")
return embedding
except Exception as e:
logger.error(f"向量化过程中出错: {e}")
logger.error(traceback.format_exc())
return []
def extract_entities(self, text):
if self.nlp:
doc = self.nlp(text)
entities = [ent.text for ent in doc.ents if ent.label_ in ['PERSON', 'ORG', 'GPE', 'LOC', 'EVENT', 'PRODUCT']]
return list(set(entities)) # 去除重复项
else:
return []
async def generate_memo(self, user_id: str, content: str, message_id: str) -> str:
"""生成备忘录的方法,调用 process_search_results 并传递必要参数。"""
try:
# Step 1: Vectorize user input content
logger.info(f"为 user_id: {user_id}, message_id: {message_id} 向量化内容")
vector = await self.vectorize_content(content)
if not vector:
logger.error("向量化失败。返回默认备忘录。")
return "无法生成备忘录。"
# Normalize the vector for cosine similarity
vector_np = np.array(vector).astype('float32').reshape(1, -1)
faiss.normalize_L2(vector_np)
# Step 2: Search relevant background data in FAISS using the generated vector
logger.info("使用生成的向量在 FAISS 中搜索相关背景数据。")
faiss_results = self.search_faiss(vector_np, top_k=5)
logger.info(f"FAISS 搜索完成。找到 {len(faiss_results)} 个结果。")
# Step 3: Search relevant background data in Redis using keywords/entities
logger.info("使用关键词/实体在 Redis 中搜索相关背景数据。")
keyword_results = await self.search_background_keywords(content, top_k=5)
logger.info(f"关键词搜索完成。找到 {len(keyword_results)} 个结果。")
# Step 4: Combine FAISS 和关键词搜索结果
combined_results = self.combine_results(faiss_results, keyword_results)
logger.info(f"合并搜索结果完成。共 {len(combined_results)} 个结果。")
if not combined_results:
logger.warning("未找到相关的背景信息。")
return "未找到相关的背景信息。"
# Step 5: Generate memo based on search results
logger.info(f"为 user_id: {user_id}, message_id: {message_id} 生成备忘录")
memo = await self.process_search_results(combined_results, content, user_id, message_id) # 传递所有必要参数
return memo
except Exception as e:
logger.error(f"在 generate_memo 中出错: {e}")
logger.error(traceback.format_exc())
return "无法生成备忘录。"
async def process_search_results(self, search_results: list, user_question: str, user_id: str, message_id: str) -> str:
if not search_results:
logger.warning("没有可处理的搜索结果。")
return "未找到相关的背景信息。"
# Compile search results into a readable format
search_context = "\n".join([f"- {item['content']}" for item in search_results])
prompt = (
"Using the following collected background information and the user's question, "
"provide a detailed and comprehensive answer that directly addresses the user's question. "
"The answer should be structured with clear headings and bullet points where appropriate. "
"Do not format the answer as a letter, memo, or include any letter-like elements such as 'To:', 'From:', 'Subject:', 'Date:', 'Dear', 'Best regards', or any signatures. "
"Do not mention any personal limitations or refer to yourself in the answer. "
"Answer from the perspective of the person described in the background information.\n\n"
"Example Format:\n"
"# Main Heading\n"
"- Point 1\n"
"- Point 2\n\n"
f"User's question: {user_question}\n"
"Collected background information:\n"
f"{search_context}\n\n"
"Answer:\n"
)
system_prompt = (
"You are to act as the person described in the background information. "
"Do not include any letter-like elements such as 'To:', 'From:', 'Subject:', 'Date:', 'Dear', 'Best regards', or any signatures."
)
logger.debug(f"最终发送给模型的提示(prompt):\n{prompt}")
try:
logger.info("向 ChatGroq 发送提示以生成备忘录。")
response = await asyncio.to_thread(
self.chat_groq.invoke,
[
SystemMessage(content=system_prompt),
HumanMessage(content=prompt)
]
)
memo_text = response.content.strip()
if not memo_text:
logger.error("ChatGroq 返回了一个空的备忘录。")
return "无法生成备忘录。"
logger.debug(f"生成的备忘录: {memo_text}")
logger.info("备忘录生成成功。")
return memo_text
except Exception as e:
logger.error(f"在生成备忘录过程中出错: {e}")
logger.error(traceback.format_exc())
return "无法生成备忘录。"
async def handle_message(self, message: dict):
user_id = message.get('user_id')
content = message.get('content')
message_id = message.get('message_id')
if not user_id or not content or not message_id:
logger.error(f"收到无效的消息格式: {message}")
return
logger.info(f"Processing message - user_id: {user_id}, message_id: {message_id}, content: {content}")
try:
# Generate memo
memo = await self.generate_memo(user_id, content, message_id)
# Construct memo message
memo_message = {
"user_id": user_id,
"message_id": message_id,
"memo": memo
}
# Store memo in Redis for main_chatbot to retrieve
user_memo_key = f"user_memo:{user_id}"
await self.redis_manager.redis_client.hset(
user_memo_key,
message_id,
memo
)
logger.info(f"Stored memo in Redis under key {user_memo_key} with message_id {message_id}.")
# Send memo to background_memo_stream
await self.redis_manager.send_to_stream(self.memo_stream, memo_message)
logger.info(f"已将备忘录发送到 {self.memo_stream}: {memo_message}")
# Publish memo ready notification to background_memo_ready_channel
notification = {
'user_id': user_id,
'message_id': message_id
}
await self.redis_manager.publish(self.memo_ready_channel, json.dumps(notification))
logger.info(f"已发布备忘录就绪通知,user_id: {user_id}, message_id: {message_id}")
except Exception as e:
logger.error(f"在处理消息时出错: {e}")
logger.error(traceback.format_exc())
async def listen_background_stream(self):
while True:
try:
# Read messages from Redis Stream
entries = await self.redis_manager.redis_client.xreadgroup(
groupname=self.consumer_group.encode('utf-8'),
consumername=self.consumer_name.encode('utf-8'),
streams={self.processing_stream.encode('utf-8'): '>'},
block=5000, # 5 seconds timeout
count=10
)
if entries:
for stream, messages in entries:
for message_id, message_data in messages:
logger.debug(f"Received message from {stream.decode('utf-8')}: ID={message_id.decode('utf-8')}, Data={message_data}")
# Decode message_data from bytes to string
message_data_decoded = {
k.decode('utf-8'): v.decode('utf-8') for k, v in message_data.items()
}
logger.debug(f"Decoded message data: {message_data_decoded}")
# Process the message
await self.handle_message(message_data_decoded)
# Acknowledge the message as processed
await self.redis_manager.redis_client.xack(
self.processing_stream.encode('utf-8'),
self.consumer_group.encode('utf-8'),
message_id
)
logger.debug(f"Acknowledged message {message_id.decode('utf-8')} on stream {self.processing_stream}")
except Exception as e:
logger.error(f"Error while listening to {self.processing_stream}: {e}")
logger.error(traceback.format_exc())
await asyncio.sleep(1) # Wait before retrying
def search_faiss(self, vector_np: np.ndarray, top_k: int = 10) -> list:
"""使用 FAISS 进行向量搜索。"""
try:
if self.faiss_index is None:
logger.error("FAISS index is not initialized.")
return []
distances, indices = self.faiss_index.search(vector_np, top_k)
results = []
for distance, idx in zip(distances[0], indices[0]):
if idx == -1:
continue
background_info = self.id_to_background.get(str(idx), {})
if background_info:
results.append({
"background_id": background_info["background_id"],
"content": background_info["content"],
"distance": distance # 内积值越大,相似度越高
})
logger.info(f"FAISS search retrieved {len(results)} results.")
return results
except Exception as e:
logger.error(f"在 FAISS 搜索过程中出错: {e}")
logger.error(traceback.format_exc())
return []
def extract_keywords(self, text):
"""从文本中提取关键词。"""
if self.nlp:
doc = self.nlp(text)
keywords = [token.text for token in doc if token.pos_ in ['NOUN', 'PROPN']]
return list(set(keywords))
else:
return text.split() # 如果 NLP 模型不可用,使用简单的分词
async def search_background_keywords(self, query: str, top_k: int = 10) -> list:
"""使用关键词/实体在 Redis 中搜索相关背景数据。"""
try:
entities = self.extract_entities(query)
logger.info(f"从用户问题中提取的实体: {entities}")
keyword_results = []
if entities:
for entity in entities:
results = await self.redis_manager.search_keywords(entity, top_k=top_k)
keyword_results.extend(results)
else:
# 如果没有提取到实体,使用关键词提取
keywords = self.extract_keywords(query)
logger.info(f"未提取到实体。提取的关键词: {keywords}")
for keyword in keywords:
results = await self.redis_manager.search_keywords(keyword, top_k=top_k)
keyword_results.extend(results)
logger.info(f"关键词搜索完成。找到 {len(keyword_results)} 个结果。")
return keyword_results
except Exception as e:
logger.error(f"在关键词搜索过程中出错: {e}")
logger.error(traceback.format_exc())
return []
def combine_results(self, faiss_results: list, keyword_results: list) -> list:
"""合并 FAISS 搜索结果和关键词搜索结果,去重并根据相关性排序。"""
all_results = {}
for item in faiss_results + keyword_results:
content = item['content']
if content not in all_results:
all_results[content] = item
else:
# 如果内容相同,选择距离更近(相关性更高)的那个
if item.get('distance', 0) > all_results[content].get('distance', 0):
all_results[content] = item
combined = list(all_results.values())
# 根据 'distance'(内积值)排序,值越大表示越相关
combined.sort(key=lambda x: x.get('distance', 0), reverse=True)
logger.info(f"合并后共有 {len(combined)} 个唯一结果。")
return combined
async def run(self):
try:
await self.connect()
logger.info(f"Connected to Redis: {REDIS_HOST}:{REDIS_PORT}")
# 在事件循環運行後創建異步任務
asyncio.create_task(self.listen_memo_ready())
# Start listening to background_processing_stream
await self.listen_background_stream()
finally:
await self.redis_manager.close()
async def add_background(self, background_id: str, content: str):
"""
添加新的背景信息到 FAISS 和 Redis
"""
try:
# 向量化内容
vector = await self.vectorize_content(content)
if not vector:
logger.error(f"Failed to vectorize content for background_id: {background_id}")
return
# 将向量转换为 NumPy 数组并归一化
vector_np = np.array(vector).astype('float32').reshape(1, -1)
faiss.normalize_L2(vector_np)
# 添加到 FAISS 索引
self.faiss_index.add(vector_np)
faiss_id = self.faiss_index.ntotal - 1
self.id_to_background[str(faiss_id)] = {
"background_id": background_id,
"content": content
}
# 插入到 Redis
await self.redis_manager.insert_background(background_id, content, vector)
# 保存 FAISS 索引和映射
faiss.write_index(self.faiss_index, self.faiss_index_path)
async with aiofiles.open(self.id_to_background_path, "w", encoding="utf-8") as f:
await f.write(json.dumps(self.id_to_background, ensure_ascii=False, indent=4))
logger.info(f"Added new background with ID {background_id}")
except Exception as e:
logger.error(f"在添加新的背景信息时出错: {e}")
logger.error(traceback.format_exc())
async def close(self):
await self.redis_manager.close()
if __name__ == "__main__":
background_provider = BackgroundProvider()
try:
asyncio.run(background_provider.run())
except KeyboardInterrupt:
logger.info("BackgroundProvider 已被手動終止。")