docs: update README/INTERNALS for import feature, harden .gitignore

This commit is contained in:
P0luz
2026-04-19 12:09:53 +08:00
parent a09fbfe13a
commit 821546d5de
27 changed files with 5365 additions and 479 deletions

758
import_memory.py Normal file
View File

@@ -0,0 +1,758 @@
# ============================================================
# Module: Memory Import Engine (import_memory.py)
# 模块:历史记忆导入引擎
#
# Imports conversation history from various platforms into OB.
# 将各平台对话历史导入 OB 记忆系统。
#
# Supports: Claude JSON, ChatGPT export, DeepSeek, Markdown, plain text
# 支持格式Claude JSON、ChatGPT 导出、DeepSeek、Markdown、纯文本
#
# Features:
# - Chunked processing with resume support
# - Progress persistence (import_state.json)
# - Raw preservation mode for special contexts
# - Post-import frequency pattern detection
# ============================================================
import os
import json
import hashlib
import logging
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Optional
from utils import count_tokens_approx, now_iso
logger = logging.getLogger("ombre_brain.import")
# ============================================================
# Format Parsers — normalize any format to conversation turns
# 格式解析器 — 将任意格式标准化为对话轮次
# ============================================================
def _parse_claude_json(data: dict | list) -> list[dict]:
"""Parse Claude.ai export JSON → [{role, content, timestamp}, ...]"""
turns = []
conversations = data if isinstance(data, list) else [data]
for conv in conversations:
messages = conv.get("chat_messages", conv.get("messages", []))
for msg in messages:
if not isinstance(msg, dict):
continue
content = msg.get("text", msg.get("content", ""))
if isinstance(content, list):
content = " ".join(
p.get("text", "") for p in content if isinstance(p, dict)
)
if not content or not content.strip():
continue
role = msg.get("sender", msg.get("role", "user"))
ts = msg.get("created_at", msg.get("timestamp", ""))
turns.append({"role": role, "content": content.strip(), "timestamp": ts})
return turns
def _parse_chatgpt_json(data: list | dict) -> list[dict]:
"""Parse ChatGPT export JSON → [{role, content, timestamp}, ...]"""
turns = []
conversations = data if isinstance(data, list) else [data]
for conv in conversations:
mapping = conv.get("mapping", {})
if mapping:
# ChatGPT uses a tree structure with mapping
sorted_nodes = sorted(
mapping.values(),
key=lambda n: n.get("message", {}).get("create_time", 0) or 0,
)
for node in sorted_nodes:
msg = node.get("message")
if not msg or not isinstance(msg, dict):
continue
content_parts = msg.get("content", {}).get("parts", [])
content = " ".join(str(p) for p in content_parts if p)
if not content.strip():
continue
role = msg.get("author", {}).get("role", "user")
ts = msg.get("create_time", "")
if isinstance(ts, (int, float)):
ts = datetime.fromtimestamp(ts).isoformat()
turns.append({"role": role, "content": content.strip(), "timestamp": str(ts)})
else:
# Simpler format: list of messages
messages = conv.get("messages", [])
for msg in messages:
if not isinstance(msg, dict):
continue
content = msg.get("content", msg.get("text", ""))
if isinstance(content, dict):
content = " ".join(str(p) for p in content.get("parts", []))
if not content or not content.strip():
continue
role = msg.get("role", msg.get("author", {}).get("role", "user"))
ts = msg.get("timestamp", msg.get("create_time", ""))
turns.append({"role": role, "content": content.strip(), "timestamp": str(ts)})
return turns
def _parse_markdown(text: str) -> list[dict]:
"""Parse Markdown/plain text → [{role, content, timestamp}, ...]"""
# Try to detect conversation patterns
lines = text.split("\n")
turns = []
current_role = "user"
current_content = []
for line in lines:
stripped = line.strip()
# Detect role switches
if stripped.lower().startswith(("human:", "user:", "你:", "我:")):
if current_content:
turns.append({"role": current_role, "content": "\n".join(current_content).strip(), "timestamp": ""})
current_role = "user"
content_after = stripped.split(":", 1)[1].strip() if ":" in stripped else ""
current_content = [content_after] if content_after else []
elif stripped.lower().startswith(("assistant:", "claude:", "ai:", "gpt:", "bot:", "deepseek:")):
if current_content:
turns.append({"role": current_role, "content": "\n".join(current_content).strip(), "timestamp": ""})
current_role = "assistant"
content_after = stripped.split(":", 1)[1].strip() if ":" in stripped else ""
current_content = [content_after] if content_after else []
else:
current_content.append(line)
if current_content:
content = "\n".join(current_content).strip()
if content:
turns.append({"role": current_role, "content": content, "timestamp": ""})
# If no role patterns detected, treat entire text as one big chunk
if not turns:
turns = [{"role": "user", "content": text.strip(), "timestamp": ""}]
return turns
def detect_and_parse(raw_content: str, filename: str = "") -> list[dict]:
"""
Auto-detect format and parse to normalized turns.
自动检测格式并解析为标准化的对话轮次。
"""
ext = Path(filename).suffix.lower() if filename else ""
# Try JSON first
if ext in (".json", "") or raw_content.strip().startswith(("{", "[")):
try:
data = json.loads(raw_content)
# Detect Claude vs ChatGPT format
if isinstance(data, list):
sample = data[0] if data else {}
else:
sample = data
if isinstance(sample, dict):
if "chat_messages" in sample:
return _parse_claude_json(data)
if "mapping" in sample:
return _parse_chatgpt_json(data)
if "messages" in sample:
# Could be either — try ChatGPT first, fall back to Claude
msgs = sample["messages"]
if msgs and isinstance(msgs[0], dict) and "content" in msgs[0]:
if isinstance(msgs[0]["content"], dict):
return _parse_chatgpt_json(data)
return _parse_claude_json(data)
# Single conversation object with role/content messages
if "role" in sample and "content" in sample:
return _parse_claude_json(data)
except (json.JSONDecodeError, KeyError, IndexError):
pass
# Fall back to markdown/text
return _parse_markdown(raw_content)
# ============================================================
# Chunking — split turns into ~10k token windows
# 分窗 — 按对话轮次边界切为 ~10k token 窗口
# ============================================================
def chunk_turns(turns: list[dict], target_tokens: int = 10000) -> list[dict]:
"""
Group conversation turns into chunks of ~target_tokens.
Returns list of {content, timestamp_start, timestamp_end, turn_count}.
按对话轮次边界将对话分为 ~target_tokens 大小的窗口。
"""
chunks = []
current_lines = []
current_tokens = 0
first_ts = ""
last_ts = ""
turn_count = 0
for turn in turns:
role_label = "用户" if turn["role"] in ("user", "human") else "AI"
line = f"[{role_label}] {turn['content']}"
line_tokens = count_tokens_approx(line)
# If single turn exceeds target, split it
if line_tokens > target_tokens * 1.5:
# Flush current
if current_lines:
chunks.append({
"content": "\n".join(current_lines),
"timestamp_start": first_ts,
"timestamp_end": last_ts,
"turn_count": turn_count,
})
current_lines = []
current_tokens = 0
turn_count = 0
first_ts = ""
# Add oversized turn as its own chunk
chunks.append({
"content": line,
"timestamp_start": turn.get("timestamp", ""),
"timestamp_end": turn.get("timestamp", ""),
"turn_count": 1,
})
continue
if current_tokens + line_tokens > target_tokens and current_lines:
chunks.append({
"content": "\n".join(current_lines),
"timestamp_start": first_ts,
"timestamp_end": last_ts,
"turn_count": turn_count,
})
current_lines = []
current_tokens = 0
turn_count = 0
first_ts = ""
if not first_ts:
first_ts = turn.get("timestamp", "")
last_ts = turn.get("timestamp", "")
current_lines.append(line)
current_tokens += line_tokens
turn_count += 1
if current_lines:
chunks.append({
"content": "\n".join(current_lines),
"timestamp_start": first_ts,
"timestamp_end": last_ts,
"turn_count": turn_count,
})
return chunks
# ============================================================
# Import State — persistent progress tracking
# 导入状态 — 持久化进度追踪
# ============================================================
class ImportState:
"""Manages import progress with file-based persistence."""
def __init__(self, state_dir: str):
self.state_file = os.path.join(state_dir, "import_state.json")
self.data = {
"source_file": "",
"source_hash": "",
"total_chunks": 0,
"processed": 0,
"api_calls": 0,
"memories_created": 0,
"memories_merged": 0,
"memories_raw": 0,
"errors": [],
"status": "idle", # idle | running | paused | completed | error
"started_at": "",
"updated_at": "",
}
def load(self) -> bool:
"""Load state from file. Returns True if state exists."""
if os.path.exists(self.state_file):
try:
with open(self.state_file, "r", encoding="utf-8") as f:
saved = json.load(f)
self.data.update(saved)
return True
except (json.JSONDecodeError, OSError):
return False
return False
def save(self):
"""Persist state to file."""
self.data["updated_at"] = now_iso()
os.makedirs(os.path.dirname(self.state_file), exist_ok=True)
tmp = self.state_file + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(self.data, f, ensure_ascii=False, indent=2)
os.replace(tmp, self.state_file)
def reset(self, source_file: str, source_hash: str, total_chunks: int):
"""Reset state for a new import."""
self.data = {
"source_file": source_file,
"source_hash": source_hash,
"total_chunks": total_chunks,
"processed": 0,
"api_calls": 0,
"memories_created": 0,
"memories_merged": 0,
"memories_raw": 0,
"errors": [],
"status": "running",
"started_at": now_iso(),
"updated_at": now_iso(),
}
@property
def can_resume(self) -> bool:
return self.data["status"] in ("paused", "running") and self.data["processed"] < self.data["total_chunks"]
def to_dict(self) -> dict:
return dict(self.data)
# ============================================================
# Import extraction prompt
# 导入提取提示词
# ============================================================
IMPORT_EXTRACT_PROMPT = """你是一个对话记忆提取专家。从以下对话片段中提取值得长期记住的信息。
提取规则:
1. 提取用户的事实、偏好、习惯、重要事件、情感时刻
2. 同一话题的零散信息整合为一条记忆
3. 过滤掉纯技术调试输出、代码块、重复问答、无意义寒暄
4. 如果对话中有特殊暗号、仪式性行为、关键承诺等,标记 preserve_raw=true
5. 如果内容是用户和AI之间的习惯性互动模式例如打招呼方式、告别习惯标记 is_pattern=true
6. 每条记忆不少于30字
7. 总条目数控制在 0~5 个(没有值得记的就返回空数组)
8. 在 content 中对人名、地名、专有名词用 [[双链]] 标记
输出格式(纯 JSON 数组,无其他内容):
[
{
"name": "条目标题10字以内",
"content": "整理后的内容",
"domain": ["主题域1"],
"valence": 0.7,
"arousal": 0.4,
"tags": ["核心词1", "核心词2", "扩展词1"],
"importance": 5,
"preserve_raw": false,
"is_pattern": false
}
]
主题域可选(选 1~2 个):
日常: ["饮食", "穿搭", "出行", "居家", "购物"]
人际: ["家庭", "恋爱", "友谊", "社交"]
成长: ["工作", "学习", "考试", "求职"]
身心: ["健康", "心理", "睡眠", "运动"]
兴趣: ["游戏", "影视", "音乐", "阅读", "创作", "手工"]
数字: ["编程", "AI", "硬件", "网络"]
事务: ["财务", "计划", "待办"]
内心: ["情绪", "回忆", "梦境", "自省"]
importance: 1-10
valence: 0~10=消极, 0.5=中性, 1=积极)
arousal: 0~10=平静, 0.5=普通, 1=激动)
preserve_raw: true = 特殊情境/暗号/仪式,保留原文不摘要
is_pattern: true = 反复出现的习惯性行为模式"""
# ============================================================
# Import Engine — core processing logic
# 导入引擎 — 核心处理逻辑
# ============================================================
class ImportEngine:
"""
Processes conversation history files into OB memory buckets.
将对话历史文件处理为 OB 记忆桶。
"""
def __init__(self, config: dict, bucket_mgr, dehydrator, embedding_engine=None):
self.config = config
self.bucket_mgr = bucket_mgr
self.dehydrator = dehydrator
self.embedding_engine = embedding_engine
self.state = ImportState(config["buckets_dir"])
self._paused = False
self._running = False
self._chunks: list[dict] = []
@property
def is_running(self) -> bool:
return self._running
def pause(self):
"""Request pause — will stop after current chunk finishes."""
self._paused = True
def get_status(self) -> dict:
"""Get current import status."""
return self.state.to_dict()
async def start(
self,
raw_content: str,
filename: str = "",
preserve_raw: bool = False,
resume: bool = False,
) -> dict:
"""
Start or resume an import.
开始或恢复导入。
"""
if self._running:
return {"error": "Import already running"}
self._running = True
self._paused = False
try:
source_hash = hashlib.sha256(raw_content.encode()).hexdigest()[:16]
# Check for resume
if resume and self.state.load() and self.state.can_resume:
if self.state.data["source_hash"] == source_hash:
logger.info(f"Resuming import from chunk {self.state.data['processed']}/{self.state.data['total_chunks']}")
# Re-parse and re-chunk to get the same chunks
turns = detect_and_parse(raw_content, filename)
self._chunks = chunk_turns(turns)
self.state.data["status"] = "running"
self.state.save()
return await self._process_chunks(preserve_raw)
else:
logger.warning("Source file changed, starting fresh import")
# Fresh import
turns = detect_and_parse(raw_content, filename)
if not turns:
self._running = False
return {"error": "No conversation turns found in file"}
self._chunks = chunk_turns(turns)
if not self._chunks:
self._running = False
return {"error": "No processable chunks after splitting"}
self.state.reset(filename, source_hash, len(self._chunks))
self.state.save()
logger.info(f"Starting import: {len(turns)} turns → {len(self._chunks)} chunks")
return await self._process_chunks(preserve_raw)
except Exception as e:
self.state.data["status"] = "error"
self.state.data["errors"].append(str(e))
self.state.save()
self._running = False
raise
async def _process_chunks(self, preserve_raw: bool) -> dict:
"""Process chunks from current position."""
start_idx = self.state.data["processed"]
for i in range(start_idx, len(self._chunks)):
if self._paused:
self.state.data["status"] = "paused"
self.state.save()
self._running = False
logger.info(f"Import paused at chunk {i}/{len(self._chunks)}")
return self.state.to_dict()
chunk = self._chunks[i]
try:
await self._process_single_chunk(chunk, preserve_raw)
except Exception as e:
err_msg = f"Chunk {i}: {str(e)[:200]}"
logger.warning(f"Import chunk error: {err_msg}")
if len(self.state.data["errors"]) < 100:
self.state.data["errors"].append(err_msg)
self.state.data["processed"] = i + 1
# Save progress every chunk
self.state.save()
self.state.data["status"] = "completed"
self.state.save()
self._running = False
logger.info(f"Import completed: {self.state.data['memories_created']} created, {self.state.data['memories_merged']} merged")
return self.state.to_dict()
async def _process_single_chunk(self, chunk: dict, preserve_raw: bool):
"""Extract memories from a single chunk and store them."""
content = chunk["content"]
if not content.strip():
return
# --- LLM extraction ---
try:
items = await self._extract_memories(content)
self.state.data["api_calls"] += 1
except Exception as e:
logger.warning(f"LLM extraction failed: {e}")
self.state.data["api_calls"] += 1
return
if not items:
return
# --- Store each extracted memory ---
for item in items:
try:
should_preserve = preserve_raw or item.get("preserve_raw", False)
if should_preserve:
# Raw mode: store original content without summarization
bucket_id = await self.bucket_mgr.create(
content=item["content"],
tags=item.get("tags", []),
importance=item.get("importance", 5),
domain=item.get("domain", ["未分类"]),
valence=item.get("valence", 0.5),
arousal=item.get("arousal", 0.3),
name=item.get("name"),
)
if self.embedding_engine:
try:
await self.embedding_engine.generate_and_store(bucket_id, item["content"])
except Exception:
pass
self.state.data["memories_raw"] += 1
self.state.data["memories_created"] += 1
else:
# Normal mode: go through merge-or-create pipeline
is_merged = await self._merge_or_create_item(item)
if is_merged:
self.state.data["memories_merged"] += 1
else:
self.state.data["memories_created"] += 1
# Patch timestamp if available
if chunk.get("timestamp_start"):
# We don't have update support for created, so skip
pass
except Exception as e:
logger.warning(f"Failed to store memory: {item.get('name', '?')}: {e}")
async def _extract_memories(self, chunk_content: str) -> list[dict]:
"""Use LLM to extract memories from a conversation chunk."""
if not self.dehydrator.api_available:
raise RuntimeError("API not available")
response = await self.dehydrator.client.chat.completions.create(
model=self.dehydrator.model,
messages=[
{"role": "system", "content": IMPORT_EXTRACT_PROMPT},
{"role": "user", "content": chunk_content[:12000]},
],
max_tokens=2048,
temperature=0.0,
)
if not response.choices:
return []
raw = response.choices[0].message.content or ""
if not raw.strip():
return []
return self._parse_extraction(raw)
@staticmethod
def _parse_extraction(raw: str) -> list[dict]:
"""Parse and validate LLM extraction result."""
try:
cleaned = raw.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[-1].rsplit("```", 1)[0]
items = json.loads(cleaned)
except (json.JSONDecodeError, IndexError, ValueError):
logger.warning(f"Import extraction JSON parse failed: {raw[:200]}")
return []
if not isinstance(items, list):
return []
validated = []
for item in items:
if not isinstance(item, dict) or not item.get("content"):
continue
try:
importance = max(1, min(10, int(item.get("importance", 5))))
except (ValueError, TypeError):
importance = 5
try:
valence = max(0.0, min(1.0, float(item.get("valence", 0.5))))
arousal = max(0.0, min(1.0, float(item.get("arousal", 0.3))))
except (ValueError, TypeError):
valence, arousal = 0.5, 0.3
validated.append({
"name": str(item.get("name", ""))[:20],
"content": str(item["content"]),
"domain": item.get("domain", ["未分类"])[:3],
"valence": valence,
"arousal": arousal,
"tags": [str(t) for t in item.get("tags", [])][:10],
"importance": importance,
"preserve_raw": bool(item.get("preserve_raw", False)),
"is_pattern": bool(item.get("is_pattern", False)),
})
return validated
async def _merge_or_create_item(self, item: dict) -> bool:
"""Try to merge with existing bucket, or create new. Returns is_merged."""
content = item["content"]
domain = item.get("domain", ["未分类"])
tags = item.get("tags", [])
importance = item.get("importance", 5)
valence = item.get("valence", 0.5)
arousal = item.get("arousal", 0.3)
name = item.get("name", "")
try:
existing = await self.bucket_mgr.search(content, limit=1, domain_filter=domain or None)
except Exception:
existing = []
merge_threshold = self.config.get("merge_threshold", 75)
if existing and existing[0].get("score", 0) > merge_threshold:
bucket = existing[0]
if not (bucket["metadata"].get("pinned") or bucket["metadata"].get("protected")):
try:
merged = await self.dehydrator.merge(bucket["content"], content)
self.state.data["api_calls"] += 1
old_v = bucket["metadata"].get("valence", 0.5)
old_a = bucket["metadata"].get("arousal", 0.3)
await self.bucket_mgr.update(
bucket["id"],
content=merged,
tags=list(set(bucket["metadata"].get("tags", []) + tags)),
importance=max(bucket["metadata"].get("importance", 5), importance),
domain=list(set(bucket["metadata"].get("domain", []) + domain)),
valence=round((old_v + valence) / 2, 2),
arousal=round((old_a + arousal) / 2, 2),
)
if self.embedding_engine:
try:
await self.embedding_engine.generate_and_store(bucket["id"], merged)
except Exception:
pass
return True
except Exception as e:
logger.warning(f"Merge failed during import: {e}")
self.state.data["api_calls"] += 1
# Create new
bucket_id = await self.bucket_mgr.create(
content=content,
tags=tags,
importance=importance,
domain=domain,
valence=valence,
arousal=arousal,
name=name or None,
)
if self.embedding_engine:
try:
await self.embedding_engine.generate_and_store(bucket_id, content)
except Exception:
pass
return False
async def detect_patterns(self) -> list[dict]:
"""
Post-import: detect high-frequency patterns via embedding clustering.
导入后:通过 embedding 聚类检测高频模式。
Returns list of {pattern_content, count, bucket_ids, suggested_action}.
"""
if not self.embedding_engine:
return []
all_buckets = await self.bucket_mgr.list_all(include_archive=False)
dynamic_buckets = [
b for b in all_buckets
if b["metadata"].get("type") == "dynamic"
and not b["metadata"].get("pinned")
and not b["metadata"].get("resolved")
]
if len(dynamic_buckets) < 5:
return []
# Get embeddings
embeddings = {}
for b in dynamic_buckets:
emb = await self.embedding_engine.get_embedding(b["id"])
if emb is not None:
embeddings[b["id"]] = emb
if len(embeddings) < 5:
return []
# Find clusters: group by pairwise similarity > 0.7
import numpy as np
ids = list(embeddings.keys())
clusters: dict[str, list[str]] = {}
visited = set()
for i, id_a in enumerate(ids):
if id_a in visited:
continue
cluster = [id_a]
visited.add(id_a)
emb_a = np.array(embeddings[id_a])
norm_a = np.linalg.norm(emb_a)
if norm_a == 0:
continue
for j in range(i + 1, len(ids)):
id_b = ids[j]
if id_b in visited:
continue
emb_b = np.array(embeddings[id_b])
norm_b = np.linalg.norm(emb_b)
if norm_b == 0:
continue
sim = float(np.dot(emb_a, emb_b) / (norm_a * norm_b))
if sim > 0.7:
cluster.append(id_b)
visited.add(id_b)
if len(cluster) >= 3:
clusters[id_a] = cluster
# Format results
patterns = []
for lead_id, cluster_ids in clusters.items():
lead_bucket = next((b for b in dynamic_buckets if b["id"] == lead_id), None)
if not lead_bucket:
continue
patterns.append({
"pattern_content": lead_bucket["content"][:200],
"pattern_name": lead_bucket["metadata"].get("name", lead_id),
"count": len(cluster_ids),
"bucket_ids": cluster_ids,
"suggested_action": "pin" if len(cluster_ids) >= 5 else "review",
})
patterns.sort(key=lambda p: p["count"], reverse=True)
return patterns[:20]