Files
Ombre_Brain/embedding_engine.py

189 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ============================================================
# Module: Embedding Engine (embedding_engine.py)
# 模块:向量化引擎
#
# Generates embeddings via Gemini API (OpenAI-compatible),
# stores them in SQLite, and provides cosine similarity search.
# 通过 Gemini APIOpenAI 兼容)生成 embedding
# 存储在 SQLite 中,提供余弦相似度搜索。
#
# Depended on by: server.py, bucket_manager.py
# 被谁依赖server.py, bucket_manager.py
# ============================================================
import os
import json
import math
import sqlite3
import logging
import asyncio
from pathlib import Path
from openai import AsyncOpenAI
logger = logging.getLogger("ombre_brain.embedding")
class EmbeddingEngine:
"""
Embedding generation + SQLite vector storage + cosine search.
向量生成 + SQLite 向量存储 + 余弦搜索。
"""
def __init__(self, config: dict):
dehy_cfg = config.get("dehydration", {})
embed_cfg = config.get("embedding", {})
self.api_key = dehy_cfg.get("api_key", "")
self.base_url = dehy_cfg.get("base_url", "https://generativelanguage.googleapis.com/v1beta/openai/")
self.model = embed_cfg.get("model", "gemini-embedding-001")
self.enabled = bool(self.api_key) and embed_cfg.get("enabled", True)
# --- SQLite path: buckets_dir/embeddings.db ---
db_path = os.path.join(config["buckets_dir"], "embeddings.db")
self.db_path = db_path
# --- Initialize client ---
if self.enabled:
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=30.0,
)
else:
self.client = None
# --- Initialize SQLite ---
self._init_db()
def _init_db(self):
"""Create embeddings table if not exists."""
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
conn = sqlite3.connect(self.db_path)
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
bucket_id TEXT PRIMARY KEY,
embedding TEXT NOT NULL,
updated_at TEXT NOT NULL
)
""")
conn.commit()
conn.close()
async def generate_and_store(self, bucket_id: str, content: str) -> bool:
"""
Generate embedding for content and store in SQLite.
为内容生成 embedding 并存入 SQLite。
Returns True on success, False on failure.
"""
if not self.enabled or not content or not content.strip():
return False
try:
embedding = await self._generate_embedding(content)
if not embedding:
return False
self._store_embedding(bucket_id, embedding)
return True
except Exception as e:
logger.warning(f"Embedding generation failed for {bucket_id}: {e}")
return False
async def _generate_embedding(self, text: str) -> list[float]:
"""Call API to generate embedding vector."""
# Truncate to avoid token limits
truncated = text[:2000]
try:
response = await self.client.embeddings.create(
model=self.model,
input=truncated,
)
if response.data and len(response.data) > 0:
return response.data[0].embedding
return []
except Exception as e:
logger.warning(f"Embedding API call failed: {e}")
return []
def _store_embedding(self, bucket_id: str, embedding: list[float]):
"""Store embedding in SQLite."""
from utils import now_iso
conn = sqlite3.connect(self.db_path)
conn.execute(
"INSERT OR REPLACE INTO embeddings (bucket_id, embedding, updated_at) VALUES (?, ?, ?)",
(bucket_id, json.dumps(embedding), now_iso()),
)
conn.commit()
conn.close()
def delete_embedding(self, bucket_id: str):
"""Remove embedding when bucket is deleted."""
conn = sqlite3.connect(self.db_path)
conn.execute("DELETE FROM embeddings WHERE bucket_id = ?", (bucket_id,))
conn.commit()
conn.close()
async def get_embedding(self, bucket_id: str) -> list[float] | None:
"""Retrieve stored embedding for a bucket. Returns None if not found."""
conn = sqlite3.connect(self.db_path)
row = conn.execute(
"SELECT embedding FROM embeddings WHERE bucket_id = ?", (bucket_id,)
).fetchone()
conn.close()
if row:
try:
return json.loads(row[0])
except json.JSONDecodeError:
return None
return None
async def search_similar(self, query: str, top_k: int = 10) -> list[tuple[str, float]]:
"""
Search for buckets similar to query text.
Returns list of (bucket_id, similarity_score) sorted by score desc.
搜索与查询文本相似的桶。返回 (bucket_id, 相似度分数) 列表。
"""
if not self.enabled:
return []
try:
query_embedding = await self._generate_embedding(query)
if not query_embedding:
return []
except Exception as e:
logger.warning(f"Query embedding failed: {e}")
return []
# Load all embeddings from SQLite
conn = sqlite3.connect(self.db_path)
rows = conn.execute("SELECT bucket_id, embedding FROM embeddings").fetchall()
conn.close()
if not rows:
return []
# Calculate cosine similarity
results = []
for bucket_id, emb_json in rows:
try:
stored_embedding = json.loads(emb_json)
sim = self._cosine_similarity(query_embedding, stored_embedding)
results.append((bucket_id, sim))
except (json.JSONDecodeError, Exception):
continue
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
@staticmethod
def _cosine_similarity(a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
if len(a) != len(b) or not a:
return 0.0
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)