Files
Ombre_Brain/utils.py

205 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ============================================================
# Module: Common Utilities (utils.py)
# 模块:通用工具函数
#
# Provides config loading, logging init, path safety, ID generation, etc.
# 提供配置加载、日志初始化、路径安全校验、ID 生成等基础能力
#
# Depended on by: server.py, bucket_manager.py, dehydrator.py, decay_engine.py
# 被谁依赖server.py, bucket_manager.py, dehydrator.py, decay_engine.py
# ============================================================
import os
import re
import uuid
import yaml
import logging
from pathlib import Path
from datetime import datetime
def load_config(config_path: str = None) -> dict:
"""
Load configuration file.
加载配置文件。
Priority: environment variables > config.yaml > built-in defaults.
优先级:环境变量 > config.yaml > 内置默认值。
"""
# --- Built-in defaults (fallback so it runs even without config.yaml) ---
# --- 内置默认配置(兜底,保证即使没有 config.yaml 也能跑)---
defaults = {
"transport": "stdio",
"log_level": "INFO",
"buckets_dir": os.path.join(os.path.dirname(os.path.abspath(__file__)), "buckets"),
"merge_threshold": 75,
"dehydration": {
"model": "deepseek-chat",
"base_url": "https://api.deepseek.com/v1",
"api_key": "",
"max_tokens": 1024,
"temperature": 0.1,
},
"decay": {
"lambda": 0.05,
"threshold": 0.3,
"check_interval_hours": 24,
"emotion_weights": {
"base": 1.0,
"arousal_boost": 0.8,
},
},
"matching": {
"fuzzy_threshold": 50,
"max_results": 5,
},
}
# --- Load user config from YAML file ---
# --- 从 YAML 文件加载用户自定义配置 ---
if config_path is None:
config_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config.yaml"
)
config = defaults.copy()
if os.path.exists(config_path):
try:
with open(config_path, "r", encoding="utf-8") as f:
file_config = yaml.safe_load(f) or {}
if isinstance(file_config, dict):
config = _deep_merge(defaults, file_config)
else:
logging.warning(
f"Config file is not a valid YAML dict, using defaults / "
f"配置文件不是有效的 YAML 字典,使用默认配置: {config_path}"
)
except yaml.YAMLError as e:
logging.warning(
f"Failed to parse config file, using defaults / "
f"配置文件解析失败,使用默认配置: {e}"
)
# --- Environment variable overrides (highest priority) ---
# --- 环境变量覆盖敏感/运行时配置(优先级最高)---
env_api_key = os.environ.get("OMBRE_API_KEY", "")
if env_api_key:
config.setdefault("dehydration", {})["api_key"] = env_api_key
env_base_url = os.environ.get("OMBRE_BASE_URL", "")
if env_base_url:
config.setdefault("dehydration", {})["base_url"] = env_base_url
env_transport = os.environ.get("OMBRE_TRANSPORT", "")
if env_transport:
config["transport"] = env_transport
env_buckets_dir = os.environ.get("OMBRE_BUCKETS_DIR", "")
if env_buckets_dir:
config["buckets_dir"] = env_buckets_dir
# --- Ensure bucket storage directories exist ---
# --- 确保记忆桶存储目录存在 ---
buckets_dir = config["buckets_dir"]
for subdir in ["permanent", "dynamic", "archive"]:
os.makedirs(os.path.join(buckets_dir, subdir), exist_ok=True)
return config
def _deep_merge(base: dict, override: dict) -> dict:
"""
Deep-merge two dicts; override values take precedence.
深度合并两个字典override 的值覆盖 base。
"""
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
def setup_logging(level: str = "INFO") -> None:
"""
Initialize logging system.
初始化日志系统。
Note: In MCP stdio mode, stdout is occupied by the protocol;
logs must go to stderr.
注意MCP stdio 模式下 stdout 被协议占用,日志只能走 stderr。
"""
log_level = getattr(logging, level.upper(), None)
if not isinstance(log_level, int):
log_level = logging.INFO
logging.basicConfig(
level=log_level,
format="[%(asctime)s] %(name)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler()], # StreamHandler defaults to stderr
)
def generate_bucket_id() -> str:
"""
Generate a unique bucket ID (12-char short UUID for readability).
生成唯一的记忆桶 ID12 位短 UUID方便人类阅读
"""
return uuid.uuid4().hex[:12]
def sanitize_name(name: str) -> str:
"""
Sanitize bucket name, keeping only safe characters.
Prevents path traversal attacks (e.g. ../../etc/passwd).
清洗桶名称,只保留安全字符。防止路径遍历攻击。
"""
if not isinstance(name, str):
return "unnamed"
cleaned = re.sub(r"[^\w\s\u4e00-\u9fff-]", "", name, flags=re.UNICODE)
cleaned = cleaned.strip()[:80]
return cleaned if cleaned else "unnamed"
def safe_path(base_dir: str, filename: str) -> Path:
"""
Construct a safe file path, ensuring it stays within base_dir.
Prevents directory traversal.
构造安全的文件路径,确保最终路径始终在 base_dir 内部。
"""
base = Path(base_dir).resolve()
target = (base / filename).resolve()
if not str(target).startswith(str(base)):
raise ValueError(
f"Path safety check failed / 路径安全检查失败: "
f"{target} is not inside / 不在 {base}"
)
return target
def count_tokens_approx(text: str) -> int:
"""
Rough token count estimate.
粗略估算 token 数。
Chinese ≈ 1 char = 1.5 tokens, English ≈ 1 word = 1.3 tokens.
Used to decide whether dehydration is needed; precision not required.
中文 ≈ 1字=1.5token,英文 ≈ 1词=1.3token。
用于判断是否需要脱水压缩,不追求精确。
"""
if not text:
return 0
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
english_words = len(re.findall(r"[a-zA-Z]+", text))
return int(chinese_chars * 1.5 + english_words * 1.3 + len(text) * 0.05)
def now_iso() -> str:
"""
Return current time as ISO format string.
返回当前时间的 ISO 格式字符串。
"""
return datetime.now().isoformat(timespec="seconds")