docs: update README/INTERNALS for import feature, harden .gitignore

This commit is contained in:
P0luz
2026-04-19 12:09:53 +08:00
parent a09fbfe13a
commit 821546d5de
27 changed files with 5365 additions and 479 deletions

188
embedding_engine.py Normal file
View File

@@ -0,0 +1,188 @@
# ============================================================
# Module: Embedding Engine (embedding_engine.py)
# 模块:向量化引擎
#
# Generates embeddings via Gemini API (OpenAI-compatible),
# stores them in SQLite, and provides cosine similarity search.
# 通过 Gemini APIOpenAI 兼容)生成 embedding
# 存储在 SQLite 中,提供余弦相似度搜索。
#
# Depended on by: server.py, bucket_manager.py
# 被谁依赖server.py, bucket_manager.py
# ============================================================
import os
import json
import math
import sqlite3
import logging
import asyncio
from pathlib import Path
from openai import AsyncOpenAI
logger = logging.getLogger("ombre_brain.embedding")
class EmbeddingEngine:
"""
Embedding generation + SQLite vector storage + cosine search.
向量生成 + SQLite 向量存储 + 余弦搜索。
"""
def __init__(self, config: dict):
dehy_cfg = config.get("dehydration", {})
embed_cfg = config.get("embedding", {})
self.api_key = dehy_cfg.get("api_key", "")
self.base_url = dehy_cfg.get("base_url", "https://generativelanguage.googleapis.com/v1beta/openai/")
self.model = embed_cfg.get("model", "gemini-embedding-001")
self.enabled = bool(self.api_key) and embed_cfg.get("enabled", True)
# --- SQLite path: buckets_dir/embeddings.db ---
db_path = os.path.join(config["buckets_dir"], "embeddings.db")
self.db_path = db_path
# --- Initialize client ---
if self.enabled:
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=30.0,
)
else:
self.client = None
# --- Initialize SQLite ---
self._init_db()
def _init_db(self):
"""Create embeddings table if not exists."""
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
conn = sqlite3.connect(self.db_path)
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
bucket_id TEXT PRIMARY KEY,
embedding TEXT NOT NULL,
updated_at TEXT NOT NULL
)
""")
conn.commit()
conn.close()
async def generate_and_store(self, bucket_id: str, content: str) -> bool:
"""
Generate embedding for content and store in SQLite.
为内容生成 embedding 并存入 SQLite。
Returns True on success, False on failure.
"""
if not self.enabled or not content or not content.strip():
return False
try:
embedding = await self._generate_embedding(content)
if not embedding:
return False
self._store_embedding(bucket_id, embedding)
return True
except Exception as e:
logger.warning(f"Embedding generation failed for {bucket_id}: {e}")
return False
async def _generate_embedding(self, text: str) -> list[float]:
"""Call API to generate embedding vector."""
# Truncate to avoid token limits
truncated = text[:2000]
try:
response = await self.client.embeddings.create(
model=self.model,
input=truncated,
)
if response.data and len(response.data) > 0:
return response.data[0].embedding
return []
except Exception as e:
logger.warning(f"Embedding API call failed: {e}")
return []
def _store_embedding(self, bucket_id: str, embedding: list[float]):
"""Store embedding in SQLite."""
from utils import now_iso
conn = sqlite3.connect(self.db_path)
conn.execute(
"INSERT OR REPLACE INTO embeddings (bucket_id, embedding, updated_at) VALUES (?, ?, ?)",
(bucket_id, json.dumps(embedding), now_iso()),
)
conn.commit()
conn.close()
def delete_embedding(self, bucket_id: str):
"""Remove embedding when bucket is deleted."""
conn = sqlite3.connect(self.db_path)
conn.execute("DELETE FROM embeddings WHERE bucket_id = ?", (bucket_id,))
conn.commit()
conn.close()
async def get_embedding(self, bucket_id: str) -> list[float] | None:
"""Retrieve stored embedding for a bucket. Returns None if not found."""
conn = sqlite3.connect(self.db_path)
row = conn.execute(
"SELECT embedding FROM embeddings WHERE bucket_id = ?", (bucket_id,)
).fetchone()
conn.close()
if row:
try:
return json.loads(row[0])
except json.JSONDecodeError:
return None
return None
async def search_similar(self, query: str, top_k: int = 10) -> list[tuple[str, float]]:
"""
Search for buckets similar to query text.
Returns list of (bucket_id, similarity_score) sorted by score desc.
搜索与查询文本相似的桶。返回 (bucket_id, 相似度分数) 列表。
"""
if not self.enabled:
return []
try:
query_embedding = await self._generate_embedding(query)
if not query_embedding:
return []
except Exception as e:
logger.warning(f"Query embedding failed: {e}")
return []
# Load all embeddings from SQLite
conn = sqlite3.connect(self.db_path)
rows = conn.execute("SELECT bucket_id, embedding FROM embeddings").fetchall()
conn.close()
if not rows:
return []
# Calculate cosine similarity
results = []
for bucket_id, emb_json in rows:
try:
stored_embedding = json.loads(emb_json)
sim = self._cosine_similarity(query_embedding, stored_embedding)
results.append((bucket_id, sim))
except (json.JSONDecodeError, Exception):
continue
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
@staticmethod
def _cosine_similarity(a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
if len(a) != len(b) or not a:
return 0.0
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)