docs: update README/INTERNALS for import feature, harden .gitignore
This commit is contained in:
188
embedding_engine.py
Normal file
188
embedding_engine.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# ============================================================
|
||||
# Module: Embedding Engine (embedding_engine.py)
|
||||
# 模块:向量化引擎
|
||||
#
|
||||
# Generates embeddings via Gemini API (OpenAI-compatible),
|
||||
# stores them in SQLite, and provides cosine similarity search.
|
||||
# 通过 Gemini API(OpenAI 兼容)生成 embedding,
|
||||
# 存储在 SQLite 中,提供余弦相似度搜索。
|
||||
#
|
||||
# Depended on by: server.py, bucket_manager.py
|
||||
# 被谁依赖:server.py, bucket_manager.py
|
||||
# ============================================================
|
||||
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import sqlite3
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger("ombre_brain.embedding")
|
||||
|
||||
|
||||
class EmbeddingEngine:
|
||||
"""
|
||||
Embedding generation + SQLite vector storage + cosine search.
|
||||
向量生成 + SQLite 向量存储 + 余弦搜索。
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
dehy_cfg = config.get("dehydration", {})
|
||||
embed_cfg = config.get("embedding", {})
|
||||
|
||||
self.api_key = dehy_cfg.get("api_key", "")
|
||||
self.base_url = dehy_cfg.get("base_url", "https://generativelanguage.googleapis.com/v1beta/openai/")
|
||||
self.model = embed_cfg.get("model", "gemini-embedding-001")
|
||||
self.enabled = bool(self.api_key) and embed_cfg.get("enabled", True)
|
||||
|
||||
# --- SQLite path: buckets_dir/embeddings.db ---
|
||||
db_path = os.path.join(config["buckets_dir"], "embeddings.db")
|
||||
self.db_path = db_path
|
||||
|
||||
# --- Initialize client ---
|
||||
if self.enabled:
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
timeout=30.0,
|
||||
)
|
||||
else:
|
||||
self.client = None
|
||||
|
||||
# --- Initialize SQLite ---
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""Create embeddings table if not exists."""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS embeddings (
|
||||
bucket_id TEXT PRIMARY KEY,
|
||||
embedding TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
async def generate_and_store(self, bucket_id: str, content: str) -> bool:
|
||||
"""
|
||||
Generate embedding for content and store in SQLite.
|
||||
为内容生成 embedding 并存入 SQLite。
|
||||
Returns True on success, False on failure.
|
||||
"""
|
||||
if not self.enabled or not content or not content.strip():
|
||||
return False
|
||||
|
||||
try:
|
||||
embedding = await self._generate_embedding(content)
|
||||
if not embedding:
|
||||
return False
|
||||
self._store_embedding(bucket_id, embedding)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Embedding generation failed for {bucket_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _generate_embedding(self, text: str) -> list[float]:
|
||||
"""Call API to generate embedding vector."""
|
||||
# Truncate to avoid token limits
|
||||
truncated = text[:2000]
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=self.model,
|
||||
input=truncated,
|
||||
)
|
||||
if response.data and len(response.data) > 0:
|
||||
return response.data[0].embedding
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Embedding API call failed: {e}")
|
||||
return []
|
||||
|
||||
def _store_embedding(self, bucket_id: str, embedding: list[float]):
|
||||
"""Store embedding in SQLite."""
|
||||
from utils import now_iso
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO embeddings (bucket_id, embedding, updated_at) VALUES (?, ?, ?)",
|
||||
(bucket_id, json.dumps(embedding), now_iso()),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_embedding(self, bucket_id: str):
|
||||
"""Remove embedding when bucket is deleted."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.execute("DELETE FROM embeddings WHERE bucket_id = ?", (bucket_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
async def get_embedding(self, bucket_id: str) -> list[float] | None:
|
||||
"""Retrieve stored embedding for a bucket. Returns None if not found."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
row = conn.execute(
|
||||
"SELECT embedding FROM embeddings WHERE bucket_id = ?", (bucket_id,)
|
||||
).fetchone()
|
||||
conn.close()
|
||||
if row:
|
||||
try:
|
||||
return json.loads(row[0])
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def search_similar(self, query: str, top_k: int = 10) -> list[tuple[str, float]]:
|
||||
"""
|
||||
Search for buckets similar to query text.
|
||||
Returns list of (bucket_id, similarity_score) sorted by score desc.
|
||||
搜索与查询文本相似的桶。返回 (bucket_id, 相似度分数) 列表。
|
||||
"""
|
||||
if not self.enabled:
|
||||
return []
|
||||
|
||||
try:
|
||||
query_embedding = await self._generate_embedding(query)
|
||||
if not query_embedding:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Query embedding failed: {e}")
|
||||
return []
|
||||
|
||||
# Load all embeddings from SQLite
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
rows = conn.execute("SELECT bucket_id, embedding FROM embeddings").fetchall()
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Calculate cosine similarity
|
||||
results = []
|
||||
for bucket_id, emb_json in rows:
|
||||
try:
|
||||
stored_embedding = json.loads(emb_json)
|
||||
sim = self._cosine_similarity(query_embedding, stored_embedding)
|
||||
results.append((bucket_id, sim))
|
||||
except (json.JSONDecodeError, Exception):
|
||||
continue
|
||||
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
if len(a) != len(b) or not a:
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
Reference in New Issue
Block a user