aboutsummaryrefslogtreecommitdiff
path: root/cheat_runtime.py
diff options
context:
space:
mode:
authortwells46 <173561638+twells46@users.noreply.github.com>2026-04-01 15:40:57 -0500
committertwells46 <173561638+twells46@users.noreply.github.com>2026-04-01 15:40:57 -0500
commit1e72f44f28b97ef3f28627421bedca379d136a76 (patch)
tree3aa3ebf7ab7ace97744eca4e1784700f68135463 /cheat_runtime.py
parent2f37974a4c84f7ffdd07e2c223eba2d8bd981b61 (diff)
Long-running
Diffstat (limited to 'cheat_runtime.py')
-rw-r--r--cheat_runtime.py183
1 files changed, 183 insertions, 0 deletions
diff --git a/cheat_runtime.py b/cheat_runtime.py
new file mode 100644
index 0000000..c82ec70
--- /dev/null
+++ b/cheat_runtime.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import json
+import os
+from pathlib import Path
+import socket
+import sqlite3
+from typing import Any
+
+import numpy as np
+
+LOCAL_CACHE_DIR = Path("models/hf")
+os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve()))
+os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve()))
+
+from sentence_transformers import SentenceTransformer
+
+DB_PATH = Path("cheat.db")
+DEFAULT_SOCKET_PATH = Path(os.environ.get("CHEAT_SOCKET_PATH", "/tmp/cheat.sock"))
+MAX_MESSAGE_BYTES = 1024 * 1024
+
+
+def deserialize_embedding(blob: bytes, dim: int) -> np.ndarray:
+ vec = np.frombuffer(blob, dtype=np.float32)
+ if vec.shape[0] != dim:
+ raise ValueError(f"Embedding length mismatch: expected {dim}, got {vec.shape[0]}")
+ return vec
+
+
+def load_index(conn: sqlite3.Connection) -> tuple[list[dict[str, Any]], np.ndarray, str]:
+ rows = conn.execute("""
+ SELECT
+ c.id,
+ c.command,
+ c.explanation,
+ c.intent_json,
+ c.alternatives_json,
+ c.requires_json,
+ c.packages_json,
+ c.tags_json,
+ c.platform_json,
+ c.shell_json,
+ c.safety,
+ e.model_name,
+ e.embedding_blob,
+ e.embedding_dim
+ FROM cards c
+ JOIN card_embeddings e ON c.id = e.card_id
+ ORDER BY c.id
+ """).fetchall()
+
+ if not rows:
+ raise RuntimeError("No indexed cards found. Run build_index.py first.")
+
+ cards: list[dict[str, Any]] = []
+ vectors: list[np.ndarray] = []
+ model_name: str | None = None
+
+ for row in rows:
+ (
+ card_id,
+ command,
+ explanation,
+ intent_json,
+ alternatives_json,
+ requires_json,
+ packages_json,
+ tags_json,
+ platform_json,
+ shell_json,
+ safety,
+ row_model_name,
+ embedding_blob,
+ embedding_dim,
+ ) = row
+
+ if model_name is None:
+ model_name = row_model_name
+ elif model_name != row_model_name:
+ raise RuntimeError("Mixed embedding models found in the index.")
+
+ cards.append({
+ "id": card_id,
+ "command": command,
+ "explanation": explanation,
+ "intent": json.loads(intent_json),
+ "alternatives": json.loads(alternatives_json),
+ "requires": json.loads(requires_json),
+ "packages": json.loads(packages_json),
+ "tags": json.loads(tags_json),
+ "platform": json.loads(platform_json),
+ "shell": json.loads(shell_json),
+ "safety": safety,
+ })
+ vectors.append(deserialize_embedding(embedding_blob, embedding_dim))
+
+ matrix = np.vstack(vectors)
+ return cards, matrix, model_name
+
+
+class SearchEngine:
+ def __init__(self, db_path: Path = DB_PATH) -> None:
+ self.db_path = db_path
+ self.cards: list[dict[str, Any]] = []
+ self.matrix = np.empty((0, 0), dtype=np.float32)
+ self.model_name = ""
+ self._model: SentenceTransformer | None = None
+ self.reload()
+
+ def reload(self) -> None:
+ conn = sqlite3.connect(self.db_path)
+ try:
+ cards, matrix, model_name = load_index(conn)
+ finally:
+ conn.close()
+
+ model = self._model
+ if model is None or model_name != self.model_name:
+ model = SentenceTransformer(
+ model_name,
+ cache_folder=str(LOCAL_CACHE_DIR.resolve()),
+ local_files_only=True,
+ )
+
+ self.cards = cards
+ self.matrix = matrix
+ self.model_name = model_name
+ self._model = model
+
+ def search(self, query: str, top_k: int = 5) -> list[dict[str, Any]]:
+ if top_k < 1:
+ raise ValueError("top_k must be at least 1")
+ if self._model is None:
+ raise RuntimeError("Search engine is not initialized.")
+
+ qvec = self._model.encode(
+ [query],
+ normalize_embeddings=True,
+ convert_to_numpy=True,
+ )[0]
+
+ scores = self.matrix @ qvec
+ top_indices = np.argsort(scores)[::-1][:top_k]
+
+ results: list[dict[str, Any]] = []
+ for idx in top_indices:
+ card = dict(self.cards[idx])
+ card["score"] = float(scores[idx])
+ results.append(card)
+
+ return results
+
+
+def send_server_request(
+ payload: dict[str, Any],
+ socket_path: Path = DEFAULT_SOCKET_PATH,
+ timeout: float = 5.0,
+) -> dict[str, Any]:
+ encoded = (json.dumps(payload) + "\n").encode("utf-8")
+ with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client:
+ client.settimeout(timeout)
+ client.connect(str(socket_path))
+ client.sendall(encoded)
+ client.shutdown(socket.SHUT_WR)
+
+ chunks: list[bytes] = []
+ total_bytes = 0
+ while True:
+ chunk = client.recv(65536)
+ if not chunk:
+ break
+ chunks.append(chunk)
+ total_bytes += len(chunk)
+ if total_bytes > MAX_MESSAGE_BYTES:
+ raise RuntimeError("Server response exceeded the maximum allowed size.")
+ if b"\n" in chunk:
+ break
+
+ message = b"".join(chunks).decode("utf-8").strip()
+ if not message:
+ raise RuntimeError("Server closed the connection without responding.")
+ return json.loads(message)