# ============================================================ # Module: Memory Import Engine (import_memory.py) # 模块:历史记忆导入引擎 # # Imports conversation history from various platforms into OB. # 将各平台对话历史导入 OB 记忆系统。 # # Supports: Claude JSON, ChatGPT export, DeepSeek, Markdown, plain text # 支持格式:Claude JSON、ChatGPT 导出、DeepSeek、Markdown、纯文本 # # Features: # - Chunked processing with resume support # - Progress persistence (import_state.json) # - Raw preservation mode for special contexts # - Post-import frequency pattern detection # ============================================================ import os import json import hashlib import logging import asyncio from datetime import datetime from pathlib import Path from typing import Optional from utils import count_tokens_approx, now_iso logger = logging.getLogger("ombre_brain.import") # ============================================================ # Format Parsers — normalize any format to conversation turns # 格式解析器 — 将任意格式标准化为对话轮次 # ============================================================ def _parse_claude_json(data: dict | list) -> list[dict]: """Parse Claude.ai export JSON → [{role, content, timestamp}, ...]""" turns = [] conversations = data if isinstance(data, list) else [data] for conv in conversations: messages = conv.get("chat_messages", conv.get("messages", [])) for msg in messages: if not isinstance(msg, dict): continue content = msg.get("text", msg.get("content", "")) if isinstance(content, list): content = " ".join( p.get("text", "") for p in content if isinstance(p, dict) ) if not content or not content.strip(): continue role = msg.get("sender", msg.get("role", "user")) ts = msg.get("created_at", msg.get("timestamp", "")) turns.append({"role": role, "content": content.strip(), "timestamp": ts}) return turns def _parse_chatgpt_json(data: list | dict) -> list[dict]: """Parse ChatGPT export JSON → [{role, content, timestamp}, ...]""" turns = [] conversations = data if isinstance(data, list) else [data] for conv in conversations: mapping = conv.get("mapping", {}) if mapping: # ChatGPT uses a tree structure with mapping sorted_nodes = sorted( mapping.values(), key=lambda n: n.get("message", {}).get("create_time", 0) or 0, ) for node in sorted_nodes: msg = node.get("message") if not msg or not isinstance(msg, dict): continue content_parts = msg.get("content", {}).get("parts", []) content = " ".join(str(p) for p in content_parts if p) if not content.strip(): continue role = msg.get("author", {}).get("role", "user") ts = msg.get("create_time", "") if isinstance(ts, (int, float)): ts = datetime.fromtimestamp(ts).isoformat() turns.append({"role": role, "content": content.strip(), "timestamp": str(ts)}) else: # Simpler format: list of messages messages = conv.get("messages", []) for msg in messages: if not isinstance(msg, dict): continue content = msg.get("content", msg.get("text", "")) if isinstance(content, dict): content = " ".join(str(p) for p in content.get("parts", [])) if not content or not content.strip(): continue role = msg.get("role", msg.get("author", {}).get("role", "user")) ts = msg.get("timestamp", msg.get("create_time", "")) turns.append({"role": role, "content": content.strip(), "timestamp": str(ts)}) return turns def _parse_markdown(text: str) -> list[dict]: """Parse Markdown/plain text → [{role, content, timestamp}, ...]""" # Try to detect conversation patterns lines = text.split("\n") turns = [] current_role = "user" current_content = [] for line in lines: stripped = line.strip() # Detect role switches if stripped.lower().startswith(("human:", "user:", "你:", "我:")): if current_content: turns.append({"role": current_role, "content": "\n".join(current_content).strip(), "timestamp": ""}) current_role = "user" content_after = stripped.split(":", 1)[1].strip() if ":" in stripped else "" current_content = [content_after] if content_after else [] elif stripped.lower().startswith(("assistant:", "claude:", "ai:", "gpt:", "bot:", "deepseek:")): if current_content: turns.append({"role": current_role, "content": "\n".join(current_content).strip(), "timestamp": ""}) current_role = "assistant" content_after = stripped.split(":", 1)[1].strip() if ":" in stripped else "" current_content = [content_after] if content_after else [] else: current_content.append(line) if current_content: content = "\n".join(current_content).strip() if content: turns.append({"role": current_role, "content": content, "timestamp": ""}) # If no role patterns detected, treat entire text as one big chunk if not turns: turns = [{"role": "user", "content": text.strip(), "timestamp": ""}] return turns def detect_and_parse(raw_content: str, filename: str = "") -> list[dict]: """ Auto-detect format and parse to normalized turns. 自动检测格式并解析为标准化的对话轮次。 """ ext = Path(filename).suffix.lower() if filename else "" # Try JSON first if ext in (".json", "") or raw_content.strip().startswith(("{", "[")): try: data = json.loads(raw_content) # Detect Claude vs ChatGPT format if isinstance(data, list): sample = data[0] if data else {} else: sample = data if isinstance(sample, dict): if "chat_messages" in sample: return _parse_claude_json(data) if "mapping" in sample: return _parse_chatgpt_json(data) if "messages" in sample: # Could be either — try ChatGPT first, fall back to Claude msgs = sample["messages"] if msgs and isinstance(msgs[0], dict) and "content" in msgs[0]: if isinstance(msgs[0]["content"], dict): return _parse_chatgpt_json(data) return _parse_claude_json(data) # Single conversation object with role/content messages if "role" in sample and "content" in sample: return _parse_claude_json(data) except (json.JSONDecodeError, KeyError, IndexError): pass # Fall back to markdown/text return _parse_markdown(raw_content) # ============================================================ # Chunking — split turns into ~10k token windows # 分窗 — 按对话轮次边界切为 ~10k token 窗口 # ============================================================ def chunk_turns(turns: list[dict], target_tokens: int = 10000) -> list[dict]: """ Group conversation turns into chunks of ~target_tokens. Returns list of {content, timestamp_start, timestamp_end, turn_count}. 按对话轮次边界将对话分为 ~target_tokens 大小的窗口。 """ chunks = [] current_lines = [] current_tokens = 0 first_ts = "" last_ts = "" turn_count = 0 for turn in turns: role_label = "用户" if turn["role"] in ("user", "human") else "AI" line = f"[{role_label}] {turn['content']}" line_tokens = count_tokens_approx(line) # If single turn exceeds target, split it if line_tokens > target_tokens * 1.5: # Flush current if current_lines: chunks.append({ "content": "\n".join(current_lines), "timestamp_start": first_ts, "timestamp_end": last_ts, "turn_count": turn_count, }) current_lines = [] current_tokens = 0 turn_count = 0 first_ts = "" # Add oversized turn as its own chunk chunks.append({ "content": line, "timestamp_start": turn.get("timestamp", ""), "timestamp_end": turn.get("timestamp", ""), "turn_count": 1, }) continue if current_tokens + line_tokens > target_tokens and current_lines: chunks.append({ "content": "\n".join(current_lines), "timestamp_start": first_ts, "timestamp_end": last_ts, "turn_count": turn_count, }) current_lines = [] current_tokens = 0 turn_count = 0 first_ts = "" if not first_ts: first_ts = turn.get("timestamp", "") last_ts = turn.get("timestamp", "") current_lines.append(line) current_tokens += line_tokens turn_count += 1 if current_lines: chunks.append({ "content": "\n".join(current_lines), "timestamp_start": first_ts, "timestamp_end": last_ts, "turn_count": turn_count, }) return chunks # ============================================================ # Import State — persistent progress tracking # 导入状态 — 持久化进度追踪 # ============================================================ class ImportState: """Manages import progress with file-based persistence.""" def __init__(self, state_dir: str): self.state_file = os.path.join(state_dir, "import_state.json") self.data = { "source_file": "", "source_hash": "", "total_chunks": 0, "processed": 0, "api_calls": 0, "memories_created": 0, "memories_merged": 0, "memories_raw": 0, "errors": [], "status": "idle", # idle | running | paused | completed | error "started_at": "", "updated_at": "", } def load(self) -> bool: """Load state from file. Returns True if state exists.""" if os.path.exists(self.state_file): try: with open(self.state_file, "r", encoding="utf-8") as f: saved = json.load(f) self.data.update(saved) return True except (json.JSONDecodeError, OSError): return False return False def save(self): """Persist state to file.""" self.data["updated_at"] = now_iso() os.makedirs(os.path.dirname(self.state_file), exist_ok=True) tmp = self.state_file + ".tmp" with open(tmp, "w", encoding="utf-8") as f: json.dump(self.data, f, ensure_ascii=False, indent=2) os.replace(tmp, self.state_file) def reset(self, source_file: str, source_hash: str, total_chunks: int): """Reset state for a new import.""" self.data = { "source_file": source_file, "source_hash": source_hash, "total_chunks": total_chunks, "processed": 0, "api_calls": 0, "memories_created": 0, "memories_merged": 0, "memories_raw": 0, "errors": [], "status": "running", "started_at": now_iso(), "updated_at": now_iso(), } @property def can_resume(self) -> bool: return self.data["status"] in ("paused", "running") and self.data["processed"] < self.data["total_chunks"] def to_dict(self) -> dict: return dict(self.data) # ============================================================ # Import extraction prompt # 导入提取提示词 # ============================================================ IMPORT_EXTRACT_PROMPT = """你是一个对话记忆提取专家。从以下对话片段中提取值得长期记住的信息。 提取规则: 1. 提取用户的事实、偏好、习惯、重要事件、情感时刻 2. 同一话题的零散信息整合为一条记忆 3. 过滤掉纯技术调试输出、代码块、重复问答、无意义寒暄 4. 如果对话中有特殊暗号、仪式性行为、关键承诺等,标记 preserve_raw=true 5. 如果内容是用户和AI之间的习惯性互动模式(例如打招呼方式、告别习惯),标记 is_pattern=true 6. 每条记忆不少于30字 7. 总条目数控制在 0~5 个(没有值得记的就返回空数组) 8. 在 content 中对人名、地名、专有名词用 [[双链]] 标记 输出格式(纯 JSON 数组,无其他内容): [ { "name": "条目标题(10字以内)", "content": "整理后的内容", "domain": ["主题域1"], "valence": 0.7, "arousal": 0.4, "tags": ["核心词1", "核心词2", "扩展词1"], "importance": 5, "preserve_raw": false, "is_pattern": false } ] 主题域可选(选 1~2 个): 日常: ["饮食", "穿搭", "出行", "居家", "购物"] 人际: ["家庭", "恋爱", "友谊", "社交"] 成长: ["工作", "学习", "考试", "求职"] 身心: ["健康", "心理", "睡眠", "运动"] 兴趣: ["游戏", "影视", "音乐", "阅读", "创作", "手工"] 数字: ["编程", "AI", "硬件", "网络"] 事务: ["财务", "计划", "待办"] 内心: ["情绪", "回忆", "梦境", "自省"] importance: 1-10 valence: 0~1(0=消极, 0.5=中性, 1=积极) arousal: 0~1(0=平静, 0.5=普通, 1=激动) preserve_raw: true = 特殊情境/暗号/仪式,保留原文不摘要 is_pattern: true = 反复出现的习惯性行为模式""" # ============================================================ # Import Engine — core processing logic # 导入引擎 — 核心处理逻辑 # ============================================================ class ImportEngine: """ Processes conversation history files into OB memory buckets. 将对话历史文件处理为 OB 记忆桶。 """ def __init__(self, config: dict, bucket_mgr, dehydrator, embedding_engine=None): self.config = config self.bucket_mgr = bucket_mgr self.dehydrator = dehydrator self.embedding_engine = embedding_engine self.state = ImportState(config["buckets_dir"]) self._paused = False self._running = False self._chunks: list[dict] = [] @property def is_running(self) -> bool: return self._running def pause(self): """Request pause — will stop after current chunk finishes.""" self._paused = True def get_status(self) -> dict: """Get current import status.""" return self.state.to_dict() async def start( self, raw_content: str, filename: str = "", preserve_raw: bool = False, resume: bool = False, ) -> dict: """ Start or resume an import. 开始或恢复导入。 """ if self._running: return {"error": "Import already running"} self._running = True self._paused = False try: source_hash = hashlib.sha256(raw_content.encode()).hexdigest()[:16] # Check for resume if resume and self.state.load() and self.state.can_resume: if self.state.data["source_hash"] == source_hash: logger.info(f"Resuming import from chunk {self.state.data['processed']}/{self.state.data['total_chunks']}") # Re-parse and re-chunk to get the same chunks turns = detect_and_parse(raw_content, filename) self._chunks = chunk_turns(turns) self.state.data["status"] = "running" self.state.save() return await self._process_chunks(preserve_raw) else: logger.warning("Source file changed, starting fresh import") # Fresh import turns = detect_and_parse(raw_content, filename) if not turns: self._running = False return {"error": "No conversation turns found in file"} self._chunks = chunk_turns(turns) if not self._chunks: self._running = False return {"error": "No processable chunks after splitting"} self.state.reset(filename, source_hash, len(self._chunks)) self.state.save() logger.info(f"Starting import: {len(turns)} turns → {len(self._chunks)} chunks") return await self._process_chunks(preserve_raw) except Exception as e: self.state.data["status"] = "error" self.state.data["errors"].append(str(e)) self.state.save() self._running = False raise async def _process_chunks(self, preserve_raw: bool) -> dict: """Process chunks from current position.""" start_idx = self.state.data["processed"] for i in range(start_idx, len(self._chunks)): if self._paused: self.state.data["status"] = "paused" self.state.save() self._running = False logger.info(f"Import paused at chunk {i}/{len(self._chunks)}") return self.state.to_dict() chunk = self._chunks[i] try: await self._process_single_chunk(chunk, preserve_raw) except Exception as e: err_msg = f"Chunk {i}: {str(e)[:200]}" logger.warning(f"Import chunk error: {err_msg}") if len(self.state.data["errors"]) < 100: self.state.data["errors"].append(err_msg) self.state.data["processed"] = i + 1 # Save progress every chunk self.state.save() self.state.data["status"] = "completed" self.state.save() self._running = False logger.info(f"Import completed: {self.state.data['memories_created']} created, {self.state.data['memories_merged']} merged") return self.state.to_dict() async def _process_single_chunk(self, chunk: dict, preserve_raw: bool): """Extract memories from a single chunk and store them.""" content = chunk["content"] if not content.strip(): return # --- LLM extraction --- try: items = await self._extract_memories(content) self.state.data["api_calls"] += 1 except Exception as e: logger.warning(f"LLM extraction failed: {e}") self.state.data["api_calls"] += 1 return if not items: return # --- Store each extracted memory --- for item in items: try: should_preserve = preserve_raw or item.get("preserve_raw", False) if should_preserve: # Raw mode: store original content without summarization bucket_id = await self.bucket_mgr.create( content=item["content"], tags=item.get("tags", []), importance=item.get("importance", 5), domain=item.get("domain", ["未分类"]), valence=item.get("valence", 0.5), arousal=item.get("arousal", 0.3), name=item.get("name"), ) if self.embedding_engine: try: await self.embedding_engine.generate_and_store(bucket_id, item["content"]) except Exception: pass self.state.data["memories_raw"] += 1 self.state.data["memories_created"] += 1 else: # Normal mode: go through merge-or-create pipeline is_merged = await self._merge_or_create_item(item) if is_merged: self.state.data["memories_merged"] += 1 else: self.state.data["memories_created"] += 1 # Patch timestamp if available if chunk.get("timestamp_start"): # We don't have update support for created, so skip pass except Exception as e: logger.warning(f"Failed to store memory: {item.get('name', '?')}: {e}") async def _extract_memories(self, chunk_content: str) -> list[dict]: """Use LLM to extract memories from a conversation chunk.""" if not self.dehydrator.api_available: raise RuntimeError("API not available") response = await self.dehydrator.client.chat.completions.create( model=self.dehydrator.model, messages=[ {"role": "system", "content": IMPORT_EXTRACT_PROMPT}, {"role": "user", "content": chunk_content[:12000]}, ], max_tokens=2048, temperature=0.0, ) if not response.choices: return [] raw = response.choices[0].message.content or "" if not raw.strip(): return [] return self._parse_extraction(raw) @staticmethod def _parse_extraction(raw: str) -> list[dict]: """Parse and validate LLM extraction result.""" try: cleaned = raw.strip() if cleaned.startswith("```"): cleaned = cleaned.split("\n", 1)[-1].rsplit("```", 1)[0] items = json.loads(cleaned) except (json.JSONDecodeError, IndexError, ValueError): logger.warning(f"Import extraction JSON parse failed: {raw[:200]}") return [] if not isinstance(items, list): return [] validated = [] for item in items: if not isinstance(item, dict) or not item.get("content"): continue try: importance = max(1, min(10, int(item.get("importance", 5)))) except (ValueError, TypeError): importance = 5 try: valence = max(0.0, min(1.0, float(item.get("valence", 0.5)))) arousal = max(0.0, min(1.0, float(item.get("arousal", 0.3)))) except (ValueError, TypeError): valence, arousal = 0.5, 0.3 validated.append({ "name": str(item.get("name", ""))[:20], "content": str(item["content"]), "domain": item.get("domain", ["未分类"])[:3], "valence": valence, "arousal": arousal, "tags": [str(t) for t in item.get("tags", [])][:10], "importance": importance, "preserve_raw": bool(item.get("preserve_raw", False)), "is_pattern": bool(item.get("is_pattern", False)), }) return validated async def _merge_or_create_item(self, item: dict) -> bool: """Try to merge with existing bucket, or create new. Returns is_merged.""" content = item["content"] domain = item.get("domain", ["未分类"]) tags = item.get("tags", []) importance = item.get("importance", 5) valence = item.get("valence", 0.5) arousal = item.get("arousal", 0.3) name = item.get("name", "") try: existing = await self.bucket_mgr.search(content, limit=1, domain_filter=domain or None) except Exception: existing = [] merge_threshold = self.config.get("merge_threshold", 75) if existing and existing[0].get("score", 0) > merge_threshold: bucket = existing[0] if not (bucket["metadata"].get("pinned") or bucket["metadata"].get("protected")): try: merged = await self.dehydrator.merge(bucket["content"], content) self.state.data["api_calls"] += 1 old_v = bucket["metadata"].get("valence", 0.5) old_a = bucket["metadata"].get("arousal", 0.3) await self.bucket_mgr.update( bucket["id"], content=merged, tags=list(set(bucket["metadata"].get("tags", []) + tags)), importance=max(bucket["metadata"].get("importance", 5), importance), domain=list(set(bucket["metadata"].get("domain", []) + domain)), valence=round((old_v + valence) / 2, 2), arousal=round((old_a + arousal) / 2, 2), ) if self.embedding_engine: try: await self.embedding_engine.generate_and_store(bucket["id"], merged) except Exception: pass return True except Exception as e: logger.warning(f"Merge failed during import: {e}") self.state.data["api_calls"] += 1 # Create new bucket_id = await self.bucket_mgr.create( content=content, tags=tags, importance=importance, domain=domain, valence=valence, arousal=arousal, name=name or None, ) if self.embedding_engine: try: await self.embedding_engine.generate_and_store(bucket_id, content) except Exception: pass return False async def detect_patterns(self) -> list[dict]: """ Post-import: detect high-frequency patterns via embedding clustering. 导入后:通过 embedding 聚类检测高频模式。 Returns list of {pattern_content, count, bucket_ids, suggested_action}. """ if not self.embedding_engine: return [] all_buckets = await self.bucket_mgr.list_all(include_archive=False) dynamic_buckets = [ b for b in all_buckets if b["metadata"].get("type") == "dynamic" and not b["metadata"].get("pinned") and not b["metadata"].get("resolved") ] if len(dynamic_buckets) < 5: return [] # Get embeddings embeddings = {} for b in dynamic_buckets: emb = await self.embedding_engine.get_embedding(b["id"]) if emb is not None: embeddings[b["id"]] = emb if len(embeddings) < 5: return [] # Find clusters: group by pairwise similarity > 0.7 import numpy as np ids = list(embeddings.keys()) clusters: dict[str, list[str]] = {} visited = set() for i, id_a in enumerate(ids): if id_a in visited: continue cluster = [id_a] visited.add(id_a) emb_a = np.array(embeddings[id_a]) norm_a = np.linalg.norm(emb_a) if norm_a == 0: continue for j in range(i + 1, len(ids)): id_b = ids[j] if id_b in visited: continue emb_b = np.array(embeddings[id_b]) norm_b = np.linalg.norm(emb_b) if norm_b == 0: continue sim = float(np.dot(emb_a, emb_b) / (norm_a * norm_b)) if sim > 0.7: cluster.append(id_b) visited.add(id_b) if len(cluster) >= 3: clusters[id_a] = cluster # Format results patterns = [] for lead_id, cluster_ids in clusters.items(): lead_bucket = next((b for b in dynamic_buckets if b["id"] == lead_id), None) if not lead_bucket: continue patterns.append({ "pattern_content": lead_bucket["content"][:200], "pattern_name": lead_bucket["metadata"].get("name", lead_id), "count": len(cluster_ids), "bucket_ids": cluster_ids, "suggested_action": "pin" if len(cluster_ids) >= 5 else "review", }) patterns.sort(key=lambda p: p["count"], reverse=True) return patterns[:20]