#!/usr/bin/env python3 from __future__ import annotations import json import os from pathlib import Path import socket import sqlite3 from typing import Any import numpy as np LOCAL_CACHE_DIR = Path("models/hf") os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve())) os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve())) from sentence_transformers import SentenceTransformer DB_PATH = Path("cheat.db") DEFAULT_SOCKET_PATH = Path(os.environ.get("CHEAT_SOCKET_PATH", "/tmp/cheat.sock")) MAX_MESSAGE_BYTES = 1024 * 1024 def deserialize_embedding(blob: bytes, dim: int) -> np.ndarray: vec = np.frombuffer(blob, dtype=np.float32) if vec.shape[0] != dim: raise ValueError(f"Embedding length mismatch: expected {dim}, got {vec.shape[0]}") return vec def load_index(conn: sqlite3.Connection) -> tuple[list[dict[str, Any]], np.ndarray, str]: rows = conn.execute(""" SELECT c.id, c.command, c.explanation, c.intent_json, c.alternatives_json, c.requires_json, c.packages_json, c.tags_json, c.platform_json, c.shell_json, c.safety, e.model_name, e.embedding_blob, e.embedding_dim FROM cards c JOIN card_embeddings e ON c.id = e.card_id ORDER BY c.id """).fetchall() if not rows: raise RuntimeError("No indexed cards found. Run build_index.py first.") cards: list[dict[str, Any]] = [] vectors: list[np.ndarray] = [] model_name: str | None = None for row in rows: ( card_id, command, explanation, intent_json, alternatives_json, requires_json, packages_json, tags_json, platform_json, shell_json, safety, row_model_name, embedding_blob, embedding_dim, ) = row if model_name is None: model_name = row_model_name elif model_name != row_model_name: raise RuntimeError("Mixed embedding models found in the index.") cards.append({ "id": card_id, "command": command, "explanation": explanation, "intent": json.loads(intent_json), "alternatives": json.loads(alternatives_json), "requires": json.loads(requires_json), "packages": json.loads(packages_json), "tags": json.loads(tags_json), "platform": json.loads(platform_json), "shell": json.loads(shell_json), "safety": safety, }) vectors.append(deserialize_embedding(embedding_blob, embedding_dim)) matrix = np.vstack(vectors) return cards, matrix, model_name class SearchEngine: def __init__(self, db_path: Path = DB_PATH) -> None: self.db_path = db_path self.cards: list[dict[str, Any]] = [] self.matrix = np.empty((0, 0), dtype=np.float32) self.model_name = "" self._model: SentenceTransformer | None = None self.reload() def reload(self) -> None: conn = sqlite3.connect(self.db_path) try: cards, matrix, model_name = load_index(conn) finally: conn.close() model = self._model if model is None or model_name != self.model_name: model = SentenceTransformer( model_name, cache_folder=str(LOCAL_CACHE_DIR.resolve()), local_files_only=True, ) self.cards = cards self.matrix = matrix self.model_name = model_name self._model = model def search(self, query: str, top_k: int = 5) -> list[dict[str, Any]]: if top_k < 1: raise ValueError("top_k must be at least 1") if self._model is None: raise RuntimeError("Search engine is not initialized.") qvec = self._model.encode( [query], normalize_embeddings=True, convert_to_numpy=True, )[0] scores = self.matrix @ qvec top_indices = np.argsort(scores)[::-1][:top_k] results: list[dict[str, Any]] = [] for idx in top_indices: card = dict(self.cards[idx]) card["score"] = float(scores[idx]) results.append(card) return results def send_server_request( payload: dict[str, Any], socket_path: Path = DEFAULT_SOCKET_PATH, timeout: float = 5.0, ) -> dict[str, Any]: encoded = (json.dumps(payload) + "\n").encode("utf-8") with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client: client.settimeout(timeout) client.connect(str(socket_path)) client.sendall(encoded) client.shutdown(socket.SHUT_WR) chunks: list[bytes] = [] total_bytes = 0 while True: chunk = client.recv(65536) if not chunk: break chunks.append(chunk) total_bytes += len(chunk) if total_bytes > MAX_MESSAGE_BYTES: raise RuntimeError("Server response exceeded the maximum allowed size.") if b"\n" in chunk: break message = b"".join(chunks).decode("utf-8").strip() if not message: raise RuntimeError("Server closed the connection without responding.") return json.loads(message)