189 lines
6.5 KiB
Python
189 lines
6.5 KiB
Python
# ============================================================
|
||
# Module: Embedding Engine (embedding_engine.py)
|
||
# 模块:向量化引擎
|
||
#
|
||
# Generates embeddings via Gemini API (OpenAI-compatible),
|
||
# stores them in SQLite, and provides cosine similarity search.
|
||
# 通过 Gemini API(OpenAI 兼容)生成 embedding,
|
||
# 存储在 SQLite 中,提供余弦相似度搜索。
|
||
#
|
||
# Depended on by: server.py, bucket_manager.py
|
||
# 被谁依赖:server.py, bucket_manager.py
|
||
# ============================================================
|
||
|
||
import os
|
||
import json
|
||
import math
|
||
import sqlite3
|
||
import logging
|
||
import asyncio
|
||
from pathlib import Path
|
||
|
||
from openai import AsyncOpenAI
|
||
|
||
logger = logging.getLogger("ombre_brain.embedding")
|
||
|
||
|
||
class EmbeddingEngine:
|
||
"""
|
||
Embedding generation + SQLite vector storage + cosine search.
|
||
向量生成 + SQLite 向量存储 + 余弦搜索。
|
||
"""
|
||
|
||
def __init__(self, config: dict):
|
||
dehy_cfg = config.get("dehydration", {})
|
||
embed_cfg = config.get("embedding", {})
|
||
|
||
self.api_key = dehy_cfg.get("api_key", "")
|
||
self.base_url = dehy_cfg.get("base_url", "https://generativelanguage.googleapis.com/v1beta/openai/")
|
||
self.model = embed_cfg.get("model", "gemini-embedding-001")
|
||
self.enabled = bool(self.api_key) and embed_cfg.get("enabled", True)
|
||
|
||
# --- SQLite path: buckets_dir/embeddings.db ---
|
||
db_path = os.path.join(config["buckets_dir"], "embeddings.db")
|
||
self.db_path = db_path
|
||
|
||
# --- Initialize client ---
|
||
if self.enabled:
|
||
self.client = AsyncOpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url,
|
||
timeout=30.0,
|
||
)
|
||
else:
|
||
self.client = None
|
||
|
||
# --- Initialize SQLite ---
|
||
self._init_db()
|
||
|
||
def _init_db(self):
|
||
"""Create embeddings table if not exists."""
|
||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS embeddings (
|
||
bucket_id TEXT PRIMARY KEY,
|
||
embedding TEXT NOT NULL,
|
||
updated_at TEXT NOT NULL
|
||
)
|
||
""")
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
async def generate_and_store(self, bucket_id: str, content: str) -> bool:
|
||
"""
|
||
Generate embedding for content and store in SQLite.
|
||
为内容生成 embedding 并存入 SQLite。
|
||
Returns True on success, False on failure.
|
||
"""
|
||
if not self.enabled or not content or not content.strip():
|
||
return False
|
||
|
||
try:
|
||
embedding = await self._generate_embedding(content)
|
||
if not embedding:
|
||
return False
|
||
self._store_embedding(bucket_id, embedding)
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"Embedding generation failed for {bucket_id}: {e}")
|
||
return False
|
||
|
||
async def _generate_embedding(self, text: str) -> list[float]:
|
||
"""Call API to generate embedding vector."""
|
||
# Truncate to avoid token limits
|
||
truncated = text[:2000]
|
||
try:
|
||
response = await self.client.embeddings.create(
|
||
model=self.model,
|
||
input=truncated,
|
||
)
|
||
if response.data and len(response.data) > 0:
|
||
return response.data[0].embedding
|
||
return []
|
||
except Exception as e:
|
||
logger.warning(f"Embedding API call failed: {e}")
|
||
return []
|
||
|
||
def _store_embedding(self, bucket_id: str, embedding: list[float]):
|
||
"""Store embedding in SQLite."""
|
||
from utils import now_iso
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.execute(
|
||
"INSERT OR REPLACE INTO embeddings (bucket_id, embedding, updated_at) VALUES (?, ?, ?)",
|
||
(bucket_id, json.dumps(embedding), now_iso()),
|
||
)
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def delete_embedding(self, bucket_id: str):
|
||
"""Remove embedding when bucket is deleted."""
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.execute("DELETE FROM embeddings WHERE bucket_id = ?", (bucket_id,))
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
async def get_embedding(self, bucket_id: str) -> list[float] | None:
|
||
"""Retrieve stored embedding for a bucket. Returns None if not found."""
|
||
conn = sqlite3.connect(self.db_path)
|
||
row = conn.execute(
|
||
"SELECT embedding FROM embeddings WHERE bucket_id = ?", (bucket_id,)
|
||
).fetchone()
|
||
conn.close()
|
||
if row:
|
||
try:
|
||
return json.loads(row[0])
|
||
except json.JSONDecodeError:
|
||
return None
|
||
return None
|
||
|
||
async def search_similar(self, query: str, top_k: int = 10) -> list[tuple[str, float]]:
|
||
"""
|
||
Search for buckets similar to query text.
|
||
Returns list of (bucket_id, similarity_score) sorted by score desc.
|
||
搜索与查询文本相似的桶。返回 (bucket_id, 相似度分数) 列表。
|
||
"""
|
||
if not self.enabled:
|
||
return []
|
||
|
||
try:
|
||
query_embedding = await self._generate_embedding(query)
|
||
if not query_embedding:
|
||
return []
|
||
except Exception as e:
|
||
logger.warning(f"Query embedding failed: {e}")
|
||
return []
|
||
|
||
# Load all embeddings from SQLite
|
||
conn = sqlite3.connect(self.db_path)
|
||
rows = conn.execute("SELECT bucket_id, embedding FROM embeddings").fetchall()
|
||
conn.close()
|
||
|
||
if not rows:
|
||
return []
|
||
|
||
# Calculate cosine similarity
|
||
results = []
|
||
for bucket_id, emb_json in rows:
|
||
try:
|
||
stored_embedding = json.loads(emb_json)
|
||
sim = self._cosine_similarity(query_embedding, stored_embedding)
|
||
results.append((bucket_id, sim))
|
||
except (json.JSONDecodeError, Exception):
|
||
continue
|
||
|
||
results.sort(key=lambda x: x[1], reverse=True)
|
||
return results[:top_k]
|
||
|
||
@staticmethod
|
||
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||
"""Calculate cosine similarity between two vectors."""
|
||
if len(a) != len(b) or not a:
|
||
return 0.0
|
||
dot = sum(x * y for x, y in zip(a, b))
|
||
norm_a = math.sqrt(sum(x * x for x in a))
|
||
norm_b = math.sqrt(sum(x * x for x in b))
|
||
if norm_a == 0 or norm_b == 0:
|
||
return 0.0
|
||
return dot / (norm_a * norm_b)
|