Files
Ombre_Brain/import_memory.py

759 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ============================================================
# Module: Memory Import Engine (import_memory.py)
# 模块:历史记忆导入引擎
#
# Imports conversation history from various platforms into OB.
# 将各平台对话历史导入 OB 记忆系统。
#
# Supports: Claude JSON, ChatGPT export, DeepSeek, Markdown, plain text
# 支持格式Claude JSON、ChatGPT 导出、DeepSeek、Markdown、纯文本
#
# Features:
# - Chunked processing with resume support
# - Progress persistence (import_state.json)
# - Raw preservation mode for special contexts
# - Post-import frequency pattern detection
# ============================================================
import os
import json
import hashlib
import logging
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Optional
from utils import count_tokens_approx, now_iso
logger = logging.getLogger("ombre_brain.import")
# ============================================================
# Format Parsers — normalize any format to conversation turns
# 格式解析器 — 将任意格式标准化为对话轮次
# ============================================================
def _parse_claude_json(data: dict | list) -> list[dict]:
"""Parse Claude.ai export JSON → [{role, content, timestamp}, ...]"""
turns = []
conversations = data if isinstance(data, list) else [data]
for conv in conversations:
messages = conv.get("chat_messages", conv.get("messages", []))
for msg in messages:
if not isinstance(msg, dict):
continue
content = msg.get("text", msg.get("content", ""))
if isinstance(content, list):
content = " ".join(
p.get("text", "") for p in content if isinstance(p, dict)
)
if not content or not content.strip():
continue
role = msg.get("sender", msg.get("role", "user"))
ts = msg.get("created_at", msg.get("timestamp", ""))
turns.append({"role": role, "content": content.strip(), "timestamp": ts})
return turns
def _parse_chatgpt_json(data: list | dict) -> list[dict]:
"""Parse ChatGPT export JSON → [{role, content, timestamp}, ...]"""
turns = []
conversations = data if isinstance(data, list) else [data]
for conv in conversations:
mapping = conv.get("mapping", {})
if mapping:
# ChatGPT uses a tree structure with mapping
sorted_nodes = sorted(
mapping.values(),
key=lambda n: n.get("message", {}).get("create_time", 0) or 0,
)
for node in sorted_nodes:
msg = node.get("message")
if not msg or not isinstance(msg, dict):
continue
content_parts = msg.get("content", {}).get("parts", [])
content = " ".join(str(p) for p in content_parts if p)
if not content.strip():
continue
role = msg.get("author", {}).get("role", "user")
ts = msg.get("create_time", "")
if isinstance(ts, (int, float)):
ts = datetime.fromtimestamp(ts).isoformat()
turns.append({"role": role, "content": content.strip(), "timestamp": str(ts)})
else:
# Simpler format: list of messages
messages = conv.get("messages", [])
for msg in messages:
if not isinstance(msg, dict):
continue
content = msg.get("content", msg.get("text", ""))
if isinstance(content, dict):
content = " ".join(str(p) for p in content.get("parts", []))
if not content or not content.strip():
continue
role = msg.get("role", msg.get("author", {}).get("role", "user"))
ts = msg.get("timestamp", msg.get("create_time", ""))
turns.append({"role": role, "content": content.strip(), "timestamp": str(ts)})
return turns
def _parse_markdown(text: str) -> list[dict]:
"""Parse Markdown/plain text → [{role, content, timestamp}, ...]"""
# Try to detect conversation patterns
lines = text.split("\n")
turns = []
current_role = "user"
current_content = []
for line in lines:
stripped = line.strip()
# Detect role switches
if stripped.lower().startswith(("human:", "user:", "你:", "我:")):
if current_content:
turns.append({"role": current_role, "content": "\n".join(current_content).strip(), "timestamp": ""})
current_role = "user"
content_after = stripped.split(":", 1)[1].strip() if ":" in stripped else ""
current_content = [content_after] if content_after else []
elif stripped.lower().startswith(("assistant:", "claude:", "ai:", "gpt:", "bot:", "deepseek:")):
if current_content:
turns.append({"role": current_role, "content": "\n".join(current_content).strip(), "timestamp": ""})
current_role = "assistant"
content_after = stripped.split(":", 1)[1].strip() if ":" in stripped else ""
current_content = [content_after] if content_after else []
else:
current_content.append(line)
if current_content:
content = "\n".join(current_content).strip()
if content:
turns.append({"role": current_role, "content": content, "timestamp": ""})
# If no role patterns detected, treat entire text as one big chunk
if not turns:
turns = [{"role": "user", "content": text.strip(), "timestamp": ""}]
return turns
def detect_and_parse(raw_content: str, filename: str = "") -> list[dict]:
"""
Auto-detect format and parse to normalized turns.
自动检测格式并解析为标准化的对话轮次。
"""
ext = Path(filename).suffix.lower() if filename else ""
# Try JSON first
if ext in (".json", "") or raw_content.strip().startswith(("{", "[")):
try:
data = json.loads(raw_content)
# Detect Claude vs ChatGPT format
if isinstance(data, list):
sample = data[0] if data else {}
else:
sample = data
if isinstance(sample, dict):
if "chat_messages" in sample:
return _parse_claude_json(data)
if "mapping" in sample:
return _parse_chatgpt_json(data)
if "messages" in sample:
# Could be either — try ChatGPT first, fall back to Claude
msgs = sample["messages"]
if msgs and isinstance(msgs[0], dict) and "content" in msgs[0]:
if isinstance(msgs[0]["content"], dict):
return _parse_chatgpt_json(data)
return _parse_claude_json(data)
# Single conversation object with role/content messages
if "role" in sample and "content" in sample:
return _parse_claude_json(data)
except (json.JSONDecodeError, KeyError, IndexError):
pass
# Fall back to markdown/text
return _parse_markdown(raw_content)
# ============================================================
# Chunking — split turns into ~10k token windows
# 分窗 — 按对话轮次边界切为 ~10k token 窗口
# ============================================================
def chunk_turns(turns: list[dict], target_tokens: int = 10000) -> list[dict]:
"""
Group conversation turns into chunks of ~target_tokens.
Returns list of {content, timestamp_start, timestamp_end, turn_count}.
按对话轮次边界将对话分为 ~target_tokens 大小的窗口。
"""
chunks = []
current_lines = []
current_tokens = 0
first_ts = ""
last_ts = ""
turn_count = 0
for turn in turns:
role_label = "用户" if turn["role"] in ("user", "human") else "AI"
line = f"[{role_label}] {turn['content']}"
line_tokens = count_tokens_approx(line)
# If single turn exceeds target, split it
if line_tokens > target_tokens * 1.5:
# Flush current
if current_lines:
chunks.append({
"content": "\n".join(current_lines),
"timestamp_start": first_ts,
"timestamp_end": last_ts,
"turn_count": turn_count,
})
current_lines = []
current_tokens = 0
turn_count = 0
first_ts = ""
# Add oversized turn as its own chunk
chunks.append({
"content": line,
"timestamp_start": turn.get("timestamp", ""),
"timestamp_end": turn.get("timestamp", ""),
"turn_count": 1,
})
continue
if current_tokens + line_tokens > target_tokens and current_lines:
chunks.append({
"content": "\n".join(current_lines),
"timestamp_start": first_ts,
"timestamp_end": last_ts,
"turn_count": turn_count,
})
current_lines = []
current_tokens = 0
turn_count = 0
first_ts = ""
if not first_ts:
first_ts = turn.get("timestamp", "")
last_ts = turn.get("timestamp", "")
current_lines.append(line)
current_tokens += line_tokens
turn_count += 1
if current_lines:
chunks.append({
"content": "\n".join(current_lines),
"timestamp_start": first_ts,
"timestamp_end": last_ts,
"turn_count": turn_count,
})
return chunks
# ============================================================
# Import State — persistent progress tracking
# 导入状态 — 持久化进度追踪
# ============================================================
class ImportState:
"""Manages import progress with file-based persistence."""
def __init__(self, state_dir: str):
self.state_file = os.path.join(state_dir, "import_state.json")
self.data = {
"source_file": "",
"source_hash": "",
"total_chunks": 0,
"processed": 0,
"api_calls": 0,
"memories_created": 0,
"memories_merged": 0,
"memories_raw": 0,
"errors": [],
"status": "idle", # idle | running | paused | completed | error
"started_at": "",
"updated_at": "",
}
def load(self) -> bool:
"""Load state from file. Returns True if state exists."""
if os.path.exists(self.state_file):
try:
with open(self.state_file, "r", encoding="utf-8") as f:
saved = json.load(f)
self.data.update(saved)
return True
except (json.JSONDecodeError, OSError):
return False
return False
def save(self):
"""Persist state to file."""
self.data["updated_at"] = now_iso()
os.makedirs(os.path.dirname(self.state_file), exist_ok=True)
tmp = self.state_file + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(self.data, f, ensure_ascii=False, indent=2)
os.replace(tmp, self.state_file)
def reset(self, source_file: str, source_hash: str, total_chunks: int):
"""Reset state for a new import."""
self.data = {
"source_file": source_file,
"source_hash": source_hash,
"total_chunks": total_chunks,
"processed": 0,
"api_calls": 0,
"memories_created": 0,
"memories_merged": 0,
"memories_raw": 0,
"errors": [],
"status": "running",
"started_at": now_iso(),
"updated_at": now_iso(),
}
@property
def can_resume(self) -> bool:
return self.data["status"] in ("paused", "running") and self.data["processed"] < self.data["total_chunks"]
def to_dict(self) -> dict:
return dict(self.data)
# ============================================================
# Import extraction prompt
# 导入提取提示词
# ============================================================
IMPORT_EXTRACT_PROMPT = """你是一个对话记忆提取专家。从以下对话片段中提取值得长期记住的信息。
提取规则:
1. 提取用户的事实、偏好、习惯、重要事件、情感时刻
2. 同一话题的零散信息整合为一条记忆
3. 过滤掉纯技术调试输出、代码块、重复问答、无意义寒暄
4. 如果对话中有特殊暗号、仪式性行为、关键承诺等,标记 preserve_raw=true
5. 如果内容是用户和AI之间的习惯性互动模式例如打招呼方式、告别习惯标记 is_pattern=true
6. 每条记忆不少于30字
7. 总条目数控制在 0~5 个(没有值得记的就返回空数组)
8. 在 content 中对人名、地名、专有名词用 [[双链]] 标记
输出格式(纯 JSON 数组,无其他内容):
[
{
"name": "条目标题10字以内",
"content": "整理后的内容",
"domain": ["主题域1"],
"valence": 0.7,
"arousal": 0.4,
"tags": ["核心词1", "核心词2", "扩展词1"],
"importance": 5,
"preserve_raw": false,
"is_pattern": false
}
]
主题域可选(选 1~2 个):
日常: ["饮食", "穿搭", "出行", "居家", "购物"]
人际: ["家庭", "恋爱", "友谊", "社交"]
成长: ["工作", "学习", "考试", "求职"]
身心: ["健康", "心理", "睡眠", "运动"]
兴趣: ["游戏", "影视", "音乐", "阅读", "创作", "手工"]
数字: ["编程", "AI", "硬件", "网络"]
事务: ["财务", "计划", "待办"]
内心: ["情绪", "回忆", "梦境", "自省"]
importance: 1-10
valence: 0~10=消极, 0.5=中性, 1=积极)
arousal: 0~10=平静, 0.5=普通, 1=激动)
preserve_raw: true = 特殊情境/暗号/仪式,保留原文不摘要
is_pattern: true = 反复出现的习惯性行为模式"""
# ============================================================
# Import Engine — core processing logic
# 导入引擎 — 核心处理逻辑
# ============================================================
class ImportEngine:
"""
Processes conversation history files into OB memory buckets.
将对话历史文件处理为 OB 记忆桶。
"""
def __init__(self, config: dict, bucket_mgr, dehydrator, embedding_engine=None):
self.config = config
self.bucket_mgr = bucket_mgr
self.dehydrator = dehydrator
self.embedding_engine = embedding_engine
self.state = ImportState(config["buckets_dir"])
self._paused = False
self._running = False
self._chunks: list[dict] = []
@property
def is_running(self) -> bool:
return self._running
def pause(self):
"""Request pause — will stop after current chunk finishes."""
self._paused = True
def get_status(self) -> dict:
"""Get current import status."""
return self.state.to_dict()
async def start(
self,
raw_content: str,
filename: str = "",
preserve_raw: bool = False,
resume: bool = False,
) -> dict:
"""
Start or resume an import.
开始或恢复导入。
"""
if self._running:
return {"error": "Import already running"}
self._running = True
self._paused = False
try:
source_hash = hashlib.sha256(raw_content.encode()).hexdigest()[:16]
# Check for resume
if resume and self.state.load() and self.state.can_resume:
if self.state.data["source_hash"] == source_hash:
logger.info(f"Resuming import from chunk {self.state.data['processed']}/{self.state.data['total_chunks']}")
# Re-parse and re-chunk to get the same chunks
turns = detect_and_parse(raw_content, filename)
self._chunks = chunk_turns(turns)
self.state.data["status"] = "running"
self.state.save()
return await self._process_chunks(preserve_raw)
else:
logger.warning("Source file changed, starting fresh import")
# Fresh import
turns = detect_and_parse(raw_content, filename)
if not turns:
self._running = False
return {"error": "No conversation turns found in file"}
self._chunks = chunk_turns(turns)
if not self._chunks:
self._running = False
return {"error": "No processable chunks after splitting"}
self.state.reset(filename, source_hash, len(self._chunks))
self.state.save()
logger.info(f"Starting import: {len(turns)} turns → {len(self._chunks)} chunks")
return await self._process_chunks(preserve_raw)
except Exception as e:
self.state.data["status"] = "error"
self.state.data["errors"].append(str(e))
self.state.save()
self._running = False
raise
async def _process_chunks(self, preserve_raw: bool) -> dict:
"""Process chunks from current position."""
start_idx = self.state.data["processed"]
for i in range(start_idx, len(self._chunks)):
if self._paused:
self.state.data["status"] = "paused"
self.state.save()
self._running = False
logger.info(f"Import paused at chunk {i}/{len(self._chunks)}")
return self.state.to_dict()
chunk = self._chunks[i]
try:
await self._process_single_chunk(chunk, preserve_raw)
except Exception as e:
err_msg = f"Chunk {i}: {str(e)[:200]}"
logger.warning(f"Import chunk error: {err_msg}")
if len(self.state.data["errors"]) < 100:
self.state.data["errors"].append(err_msg)
self.state.data["processed"] = i + 1
# Save progress every chunk
self.state.save()
self.state.data["status"] = "completed"
self.state.save()
self._running = False
logger.info(f"Import completed: {self.state.data['memories_created']} created, {self.state.data['memories_merged']} merged")
return self.state.to_dict()
async def _process_single_chunk(self, chunk: dict, preserve_raw: bool):
"""Extract memories from a single chunk and store them."""
content = chunk["content"]
if not content.strip():
return
# --- LLM extraction ---
try:
items = await self._extract_memories(content)
self.state.data["api_calls"] += 1
except Exception as e:
logger.warning(f"LLM extraction failed: {e}")
self.state.data["api_calls"] += 1
return
if not items:
return
# --- Store each extracted memory ---
for item in items:
try:
should_preserve = preserve_raw or item.get("preserve_raw", False)
if should_preserve:
# Raw mode: store original content without summarization
bucket_id = await self.bucket_mgr.create(
content=item["content"],
tags=item.get("tags", []),
importance=item.get("importance", 5),
domain=item.get("domain", ["未分类"]),
valence=item.get("valence", 0.5),
arousal=item.get("arousal", 0.3),
name=item.get("name"),
)
if self.embedding_engine:
try:
await self.embedding_engine.generate_and_store(bucket_id, item["content"])
except Exception:
pass
self.state.data["memories_raw"] += 1
self.state.data["memories_created"] += 1
else:
# Normal mode: go through merge-or-create pipeline
is_merged = await self._merge_or_create_item(item)
if is_merged:
self.state.data["memories_merged"] += 1
else:
self.state.data["memories_created"] += 1
# Patch timestamp if available
if chunk.get("timestamp_start"):
# We don't have update support for created, so skip
pass
except Exception as e:
logger.warning(f"Failed to store memory: {item.get('name', '?')}: {e}")
async def _extract_memories(self, chunk_content: str) -> list[dict]:
"""Use LLM to extract memories from a conversation chunk."""
if not self.dehydrator.api_available:
raise RuntimeError("API not available")
response = await self.dehydrator.client.chat.completions.create(
model=self.dehydrator.model,
messages=[
{"role": "system", "content": IMPORT_EXTRACT_PROMPT},
{"role": "user", "content": chunk_content[:12000]},
],
max_tokens=2048,
temperature=0.0,
)
if not response.choices:
return []
raw = response.choices[0].message.content or ""
if not raw.strip():
return []
return self._parse_extraction(raw)
@staticmethod
def _parse_extraction(raw: str) -> list[dict]:
"""Parse and validate LLM extraction result."""
try:
cleaned = raw.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[-1].rsplit("```", 1)[0]
items = json.loads(cleaned)
except (json.JSONDecodeError, IndexError, ValueError):
logger.warning(f"Import extraction JSON parse failed: {raw[:200]}")
return []
if not isinstance(items, list):
return []
validated = []
for item in items:
if not isinstance(item, dict) or not item.get("content"):
continue
try:
importance = max(1, min(10, int(item.get("importance", 5))))
except (ValueError, TypeError):
importance = 5
try:
valence = max(0.0, min(1.0, float(item.get("valence", 0.5))))
arousal = max(0.0, min(1.0, float(item.get("arousal", 0.3))))
except (ValueError, TypeError):
valence, arousal = 0.5, 0.3
validated.append({
"name": str(item.get("name", ""))[:20],
"content": str(item["content"]),
"domain": item.get("domain", ["未分类"])[:3],
"valence": valence,
"arousal": arousal,
"tags": [str(t) for t in item.get("tags", [])][:10],
"importance": importance,
"preserve_raw": bool(item.get("preserve_raw", False)),
"is_pattern": bool(item.get("is_pattern", False)),
})
return validated
async def _merge_or_create_item(self, item: dict) -> bool:
"""Try to merge with existing bucket, or create new. Returns is_merged."""
content = item["content"]
domain = item.get("domain", ["未分类"])
tags = item.get("tags", [])
importance = item.get("importance", 5)
valence = item.get("valence", 0.5)
arousal = item.get("arousal", 0.3)
name = item.get("name", "")
try:
existing = await self.bucket_mgr.search(content, limit=1, domain_filter=domain or None)
except Exception:
existing = []
merge_threshold = self.config.get("merge_threshold", 75)
if existing and existing[0].get("score", 0) > merge_threshold:
bucket = existing[0]
if not (bucket["metadata"].get("pinned") or bucket["metadata"].get("protected")):
try:
merged = await self.dehydrator.merge(bucket["content"], content)
self.state.data["api_calls"] += 1
old_v = bucket["metadata"].get("valence", 0.5)
old_a = bucket["metadata"].get("arousal", 0.3)
await self.bucket_mgr.update(
bucket["id"],
content=merged,
tags=list(set(bucket["metadata"].get("tags", []) + tags)),
importance=max(bucket["metadata"].get("importance", 5), importance),
domain=list(set(bucket["metadata"].get("domain", []) + domain)),
valence=round((old_v + valence) / 2, 2),
arousal=round((old_a + arousal) / 2, 2),
)
if self.embedding_engine:
try:
await self.embedding_engine.generate_and_store(bucket["id"], merged)
except Exception:
pass
return True
except Exception as e:
logger.warning(f"Merge failed during import: {e}")
self.state.data["api_calls"] += 1
# Create new
bucket_id = await self.bucket_mgr.create(
content=content,
tags=tags,
importance=importance,
domain=domain,
valence=valence,
arousal=arousal,
name=name or None,
)
if self.embedding_engine:
try:
await self.embedding_engine.generate_and_store(bucket_id, content)
except Exception:
pass
return False
async def detect_patterns(self) -> list[dict]:
"""
Post-import: detect high-frequency patterns via embedding clustering.
导入后:通过 embedding 聚类检测高频模式。
Returns list of {pattern_content, count, bucket_ids, suggested_action}.
"""
if not self.embedding_engine:
return []
all_buckets = await self.bucket_mgr.list_all(include_archive=False)
dynamic_buckets = [
b for b in all_buckets
if b["metadata"].get("type") == "dynamic"
and not b["metadata"].get("pinned")
and not b["metadata"].get("resolved")
]
if len(dynamic_buckets) < 5:
return []
# Get embeddings
embeddings = {}
for b in dynamic_buckets:
emb = await self.embedding_engine.get_embedding(b["id"])
if emb is not None:
embeddings[b["id"]] = emb
if len(embeddings) < 5:
return []
# Find clusters: group by pairwise similarity > 0.7
import numpy as np
ids = list(embeddings.keys())
clusters: dict[str, list[str]] = {}
visited = set()
for i, id_a in enumerate(ids):
if id_a in visited:
continue
cluster = [id_a]
visited.add(id_a)
emb_a = np.array(embeddings[id_a])
norm_a = np.linalg.norm(emb_a)
if norm_a == 0:
continue
for j in range(i + 1, len(ids)):
id_b = ids[j]
if id_b in visited:
continue
emb_b = np.array(embeddings[id_b])
norm_b = np.linalg.norm(emb_b)
if norm_b == 0:
continue
sim = float(np.dot(emb_a, emb_b) / (norm_a * norm_b))
if sim > 0.7:
cluster.append(id_b)
visited.add(id_b)
if len(cluster) >= 3:
clusters[id_a] = cluster
# Format results
patterns = []
for lead_id, cluster_ids in clusters.items():
lead_bucket = next((b for b in dynamic_buckets if b["id"] == lead_id), None)
if not lead_bucket:
continue
patterns.append({
"pattern_content": lead_bucket["content"][:200],
"pattern_name": lead_bucket["metadata"].get("name", lead_id),
"count": len(cluster_ids),
"bucket_ids": cluster_ids,
"suggested_action": "pin" if len(cluster_ids) >= 5 else "review",
})
patterns.sort(key=lambda p: p["count"], reverse=True)
return patterns[:20]