# ============================================================ # Module: Common Utilities (utils.py) # 模块:通用工具函数 # # Provides config loading, logging init, path safety, ID generation, etc. # 提供配置加载、日志初始化、路径安全校验、ID 生成等基础能力 # # Depended on by: server.py, bucket_manager.py, dehydrator.py, decay_engine.py # 被谁依赖:server.py, bucket_manager.py, dehydrator.py, decay_engine.py # ============================================================ import os import re import uuid import yaml import logging from pathlib import Path from datetime import datetime def load_config(config_path: str = None) -> dict: """ Load configuration file. 加载配置文件。 Priority: environment variables > config.yaml > built-in defaults. 优先级:环境变量 > config.yaml > 内置默认值。 """ # --- Built-in defaults (fallback so it runs even without config.yaml) --- # --- 内置默认配置(兜底,保证即使没有 config.yaml 也能跑)--- defaults = { "transport": "stdio", "log_level": "INFO", "buckets_dir": os.path.join(os.path.dirname(os.path.abspath(__file__)), "buckets"), "merge_threshold": 75, "dehydration": { "model": "deepseek-chat", "base_url": "https://api.deepseek.com/v1", "api_key": "", "max_tokens": 1024, "temperature": 0.1, }, "decay": { "lambda": 0.05, "threshold": 0.3, "check_interval_hours": 24, "emotion_weights": { "base": 1.0, "arousal_boost": 0.8, }, }, "matching": { "fuzzy_threshold": 50, "max_results": 5, }, } # --- Load user config from YAML file --- # --- 从 YAML 文件加载用户自定义配置 --- if config_path is None: config_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "config.yaml" ) config = defaults.copy() if os.path.exists(config_path): try: with open(config_path, "r", encoding="utf-8") as f: file_config = yaml.safe_load(f) or {} if isinstance(file_config, dict): config = _deep_merge(defaults, file_config) else: logging.warning( f"Config file is not a valid YAML dict, using defaults / " f"配置文件不是有效的 YAML 字典,使用默认配置: {config_path}" ) except yaml.YAMLError as e: logging.warning( f"Failed to parse config file, using defaults / " f"配置文件解析失败,使用默认配置: {e}" ) # --- Environment variable overrides (highest priority) --- # --- 环境变量覆盖敏感/运行时配置(优先级最高)--- env_api_key = os.environ.get("OMBRE_API_KEY", "") if env_api_key: config.setdefault("dehydration", {})["api_key"] = env_api_key env_base_url = os.environ.get("OMBRE_BASE_URL", "") if env_base_url: config.setdefault("dehydration", {})["base_url"] = env_base_url env_transport = os.environ.get("OMBRE_TRANSPORT", "") if env_transport: config["transport"] = env_transport env_buckets_dir = os.environ.get("OMBRE_BUCKETS_DIR", "") if env_buckets_dir: config["buckets_dir"] = env_buckets_dir # --- Ensure bucket storage directories exist --- # --- 确保记忆桶存储目录存在 --- buckets_dir = config["buckets_dir"] for subdir in ["permanent", "dynamic", "archive"]: os.makedirs(os.path.join(buckets_dir, subdir), exist_ok=True) return config def _deep_merge(base: dict, override: dict) -> dict: """ Deep-merge two dicts; override values take precedence. 深度合并两个字典,override 的值覆盖 base。 """ result = base.copy() for key, value in override.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): result[key] = _deep_merge(result[key], value) else: result[key] = value return result def setup_logging(level: str = "INFO") -> None: """ Initialize logging system. 初始化日志系统。 Note: In MCP stdio mode, stdout is occupied by the protocol; logs must go to stderr. 注意:MCP stdio 模式下 stdout 被协议占用,日志只能走 stderr。 """ log_level = getattr(logging, level.upper(), None) if not isinstance(log_level, int): log_level = logging.INFO logging.basicConfig( level=log_level, format="[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", handlers=[logging.StreamHandler()], # StreamHandler defaults to stderr ) def generate_bucket_id() -> str: """ Generate a unique bucket ID (12-char short UUID for readability). 生成唯一的记忆桶 ID(12 位短 UUID,方便人类阅读)。 """ return uuid.uuid4().hex[:12] def sanitize_name(name: str) -> str: """ Sanitize bucket name, keeping only safe characters. Prevents path traversal attacks (e.g. ../../etc/passwd). 清洗桶名称,只保留安全字符。防止路径遍历攻击。 """ if not isinstance(name, str): return "unnamed" cleaned = re.sub(r"[^\w\s\u4e00-\u9fff-]", "", name, flags=re.UNICODE) cleaned = cleaned.strip()[:80] return cleaned if cleaned else "unnamed" def safe_path(base_dir: str, filename: str) -> Path: """ Construct a safe file path, ensuring it stays within base_dir. Prevents directory traversal. 构造安全的文件路径,确保最终路径始终在 base_dir 内部。 """ base = Path(base_dir).resolve() target = (base / filename).resolve() if not str(target).startswith(str(base)): raise ValueError( f"Path safety check failed / 路径安全检查失败: " f"{target} is not inside / 不在 {base} 内" ) return target def count_tokens_approx(text: str) -> int: """ Rough token count estimate. 粗略估算 token 数。 Chinese ≈ 1 char = 1.5 tokens, English ≈ 1 word = 1.3 tokens. Used to decide whether dehydration is needed; precision not required. 中文 ≈ 1字=1.5token,英文 ≈ 1词=1.3token。 用于判断是否需要脱水压缩,不追求精确。 """ if not text: return 0 chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) english_words = len(re.findall(r"[a-zA-Z]+", text)) return int(chinese_chars * 1.5 + english_words * 1.3 + len(text) * 0.05) def now_iso() -> str: """ Return current time as ISO format string. 返回当前时间的 ISO 格式字符串。 """ return datetime.now().isoformat(timespec="seconds")