diff options
Diffstat (limited to 'cheat_runtime.py')
| -rw-r--r-- | cheat_runtime.py | 183 |
1 files changed, 183 insertions, 0 deletions
diff --git a/cheat_runtime.py b/cheat_runtime.py new file mode 100644 index 0000000..c82ec70 --- /dev/null +++ b/cheat_runtime.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import os +from pathlib import Path +import socket +import sqlite3 +from typing import Any + +import numpy as np + +LOCAL_CACHE_DIR = Path("models/hf") +os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve())) +os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve())) + +from sentence_transformers import SentenceTransformer + +DB_PATH = Path("cheat.db") +DEFAULT_SOCKET_PATH = Path(os.environ.get("CHEAT_SOCKET_PATH", "/tmp/cheat.sock")) +MAX_MESSAGE_BYTES = 1024 * 1024 + + +def deserialize_embedding(blob: bytes, dim: int) -> np.ndarray: + vec = np.frombuffer(blob, dtype=np.float32) + if vec.shape[0] != dim: + raise ValueError(f"Embedding length mismatch: expected {dim}, got {vec.shape[0]}") + return vec + + +def load_index(conn: sqlite3.Connection) -> tuple[list[dict[str, Any]], np.ndarray, str]: + rows = conn.execute(""" + SELECT + c.id, + c.command, + c.explanation, + c.intent_json, + c.alternatives_json, + c.requires_json, + c.packages_json, + c.tags_json, + c.platform_json, + c.shell_json, + c.safety, + e.model_name, + e.embedding_blob, + e.embedding_dim + FROM cards c + JOIN card_embeddings e ON c.id = e.card_id + ORDER BY c.id + """).fetchall() + + if not rows: + raise RuntimeError("No indexed cards found. Run build_index.py first.") + + cards: list[dict[str, Any]] = [] + vectors: list[np.ndarray] = [] + model_name: str | None = None + + for row in rows: + ( + card_id, + command, + explanation, + intent_json, + alternatives_json, + requires_json, + packages_json, + tags_json, + platform_json, + shell_json, + safety, + row_model_name, + embedding_blob, + embedding_dim, + ) = row + + if model_name is None: + model_name = row_model_name + elif model_name != row_model_name: + raise RuntimeError("Mixed embedding models found in the index.") + + cards.append({ + "id": card_id, + "command": command, + "explanation": explanation, + "intent": json.loads(intent_json), + "alternatives": json.loads(alternatives_json), + "requires": json.loads(requires_json), + "packages": json.loads(packages_json), + "tags": json.loads(tags_json), + "platform": json.loads(platform_json), + "shell": json.loads(shell_json), + "safety": safety, + }) + vectors.append(deserialize_embedding(embedding_blob, embedding_dim)) + + matrix = np.vstack(vectors) + return cards, matrix, model_name + + +class SearchEngine: + def __init__(self, db_path: Path = DB_PATH) -> None: + self.db_path = db_path + self.cards: list[dict[str, Any]] = [] + self.matrix = np.empty((0, 0), dtype=np.float32) + self.model_name = "" + self._model: SentenceTransformer | None = None + self.reload() + + def reload(self) -> None: + conn = sqlite3.connect(self.db_path) + try: + cards, matrix, model_name = load_index(conn) + finally: + conn.close() + + model = self._model + if model is None or model_name != self.model_name: + model = SentenceTransformer( + model_name, + cache_folder=str(LOCAL_CACHE_DIR.resolve()), + local_files_only=True, + ) + + self.cards = cards + self.matrix = matrix + self.model_name = model_name + self._model = model + + def search(self, query: str, top_k: int = 5) -> list[dict[str, Any]]: + if top_k < 1: + raise ValueError("top_k must be at least 1") + if self._model is None: + raise RuntimeError("Search engine is not initialized.") + + qvec = self._model.encode( + [query], + normalize_embeddings=True, + convert_to_numpy=True, + )[0] + + scores = self.matrix @ qvec + top_indices = np.argsort(scores)[::-1][:top_k] + + results: list[dict[str, Any]] = [] + for idx in top_indices: + card = dict(self.cards[idx]) + card["score"] = float(scores[idx]) + results.append(card) + + return results + + +def send_server_request( + payload: dict[str, Any], + socket_path: Path = DEFAULT_SOCKET_PATH, + timeout: float = 5.0, +) -> dict[str, Any]: + encoded = (json.dumps(payload) + "\n").encode("utf-8") + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client: + client.settimeout(timeout) + client.connect(str(socket_path)) + client.sendall(encoded) + client.shutdown(socket.SHUT_WR) + + chunks: list[bytes] = [] + total_bytes = 0 + while True: + chunk = client.recv(65536) + if not chunk: + break + chunks.append(chunk) + total_bytes += len(chunk) + if total_bytes > MAX_MESSAGE_BYTES: + raise RuntimeError("Server response exceeded the maximum allowed size.") + if b"\n" in chunk: + break + + message = b"".join(chunks).decode("utf-8").strip() + if not message: + raise RuntimeError("Server closed the connection without responding.") + return json.loads(message) |