diff options
| -rw-r--r-- | .codex | 0 | ||||
| -rw-r--r-- | .gitignore | 3 | ||||
| -rw-r--r-- | README.md | 36 | ||||
| -rw-r--r-- | build_index.py | 37 | ||||
| -rw-r--r-- | cheat.db | bin | 118784 -> 118784 bytes | |||
| -rw-r--r-- | cheat_runtime.py | 183 | ||||
| -rw-r--r-- | query_index.py | 329 |
7 files changed, 450 insertions, 138 deletions
@@ -1,2 +1,3 @@ venv -models
\ No newline at end of file +models +__pycache__ @@ -7,22 +7,48 @@ Local command-helper retrieval system using JSONL, SQLite, and sentence-transfor ```sh export HF_HOME="$PWD/models/hf" export SENTENCE_TRANSFORMERS_HOME="$PWD/models/hf" -python -m venv venv +python -m venv .venv source .venv/bin/activate pip install -U pip pip install -r requirements.txt -python scripts/init_db.py -python scripts/build_index.py +python init_db.py +python build_index.py ``` Then run a query like this: ```sh -python scripts/query_index.py "get free disk space" +python query_index.py "get free disk space" ``` To add commands, add to `./cards.jsonl` and rebuild the index: ```sh -python scripts/build_index.py +python build_index.py +``` + +For fast repeated queries, start the Unix-socket server once in another terminal: + +```sh +python query_index.py --serve +``` + +Then keep using the same query command: + +```sh +python query_index.py "get free disk space" +``` + +To run the server in the background: + +```sh +nohup python query_index.py --serve >/tmp/cheat.log 2>&1 & +``` + +Useful controls: + +```sh +python query_index.py --status +python query_index.py --reload +python query_index.py --stop ``` diff --git a/build_index.py b/build_index.py index 8597c68..d25f446 100644 --- a/build_index.py +++ b/build_index.py @@ -2,24 +2,24 @@ from __future__ import annotations import json -import sqlite3 +import os from pathlib import Path +import sqlite3 +import sys from typing import Any +LOCAL_CACHE_DIR = Path("models/hf") +os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve())) +os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve())) + import numpy as np from sentence_transformers import SentenceTransformer -import os -from pathlib import Path +from cheat_runtime import DEFAULT_SOCKET_PATH, send_server_request DB_PATH = Path("cheat.db") CARDS_PATH = Path("./cards.jsonl") MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" -LOCAL_CACHE_DIR = Path("models/hf") - -LOCAL_CACHE_DIR = Path("models/hf") -os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve())) -os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve())) REQUIRED_FIELDS = { "id", @@ -154,9 +154,7 @@ def upsert_embedding( def main() -> None: if not DB_PATH.exists(): - raise FileNotFoundError( - f"Database not found at {DB_PATH}. Run scripts/init_db.py first." - ) + raise FileNotFoundError(f"Database not found at {DB_PATH}. Run init_db.py first.") if not CARDS_PATH.exists(): raise FileNotFoundError(f"Cards file not found at {CARDS_PATH}") @@ -187,6 +185,23 @@ def main() -> None: finally: conn.close() + try: + response = send_server_request( + {"action": "reload"}, + socket_path=DEFAULT_SOCKET_PATH, + timeout=1.0, + ) + except OSError: + return + + if response.get("ok"): + print(f"Reloaded query server at {DEFAULT_SOCKET_PATH}") + else: + print( + f"Warning: query server reload failed: {response.get('error', 'unknown error')}", + file=sys.stderr, + ) + if __name__ == "__main__": main() Binary files differdiff --git a/cheat_runtime.py b/cheat_runtime.py new file mode 100644 index 0000000..c82ec70 --- /dev/null +++ b/cheat_runtime.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import os +from pathlib import Path +import socket +import sqlite3 +from typing import Any + +import numpy as np + +LOCAL_CACHE_DIR = Path("models/hf") +os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve())) +os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve())) + +from sentence_transformers import SentenceTransformer + +DB_PATH = Path("cheat.db") +DEFAULT_SOCKET_PATH = Path(os.environ.get("CHEAT_SOCKET_PATH", "/tmp/cheat.sock")) +MAX_MESSAGE_BYTES = 1024 * 1024 + + +def deserialize_embedding(blob: bytes, dim: int) -> np.ndarray: + vec = np.frombuffer(blob, dtype=np.float32) + if vec.shape[0] != dim: + raise ValueError(f"Embedding length mismatch: expected {dim}, got {vec.shape[0]}") + return vec + + +def load_index(conn: sqlite3.Connection) -> tuple[list[dict[str, Any]], np.ndarray, str]: + rows = conn.execute(""" + SELECT + c.id, + c.command, + c.explanation, + c.intent_json, + c.alternatives_json, + c.requires_json, + c.packages_json, + c.tags_json, + c.platform_json, + c.shell_json, + c.safety, + e.model_name, + e.embedding_blob, + e.embedding_dim + FROM cards c + JOIN card_embeddings e ON c.id = e.card_id + ORDER BY c.id + """).fetchall() + + if not rows: + raise RuntimeError("No indexed cards found. Run build_index.py first.") + + cards: list[dict[str, Any]] = [] + vectors: list[np.ndarray] = [] + model_name: str | None = None + + for row in rows: + ( + card_id, + command, + explanation, + intent_json, + alternatives_json, + requires_json, + packages_json, + tags_json, + platform_json, + shell_json, + safety, + row_model_name, + embedding_blob, + embedding_dim, + ) = row + + if model_name is None: + model_name = row_model_name + elif model_name != row_model_name: + raise RuntimeError("Mixed embedding models found in the index.") + + cards.append({ + "id": card_id, + "command": command, + "explanation": explanation, + "intent": json.loads(intent_json), + "alternatives": json.loads(alternatives_json), + "requires": json.loads(requires_json), + "packages": json.loads(packages_json), + "tags": json.loads(tags_json), + "platform": json.loads(platform_json), + "shell": json.loads(shell_json), + "safety": safety, + }) + vectors.append(deserialize_embedding(embedding_blob, embedding_dim)) + + matrix = np.vstack(vectors) + return cards, matrix, model_name + + +class SearchEngine: + def __init__(self, db_path: Path = DB_PATH) -> None: + self.db_path = db_path + self.cards: list[dict[str, Any]] = [] + self.matrix = np.empty((0, 0), dtype=np.float32) + self.model_name = "" + self._model: SentenceTransformer | None = None + self.reload() + + def reload(self) -> None: + conn = sqlite3.connect(self.db_path) + try: + cards, matrix, model_name = load_index(conn) + finally: + conn.close() + + model = self._model + if model is None or model_name != self.model_name: + model = SentenceTransformer( + model_name, + cache_folder=str(LOCAL_CACHE_DIR.resolve()), + local_files_only=True, + ) + + self.cards = cards + self.matrix = matrix + self.model_name = model_name + self._model = model + + def search(self, query: str, top_k: int = 5) -> list[dict[str, Any]]: + if top_k < 1: + raise ValueError("top_k must be at least 1") + if self._model is None: + raise RuntimeError("Search engine is not initialized.") + + qvec = self._model.encode( + [query], + normalize_embeddings=True, + convert_to_numpy=True, + )[0] + + scores = self.matrix @ qvec + top_indices = np.argsort(scores)[::-1][:top_k] + + results: list[dict[str, Any]] = [] + for idx in top_indices: + card = dict(self.cards[idx]) + card["score"] = float(scores[idx]) + results.append(card) + + return results + + +def send_server_request( + payload: dict[str, Any], + socket_path: Path = DEFAULT_SOCKET_PATH, + timeout: float = 5.0, +) -> dict[str, Any]: + encoded = (json.dumps(payload) + "\n").encode("utf-8") + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client: + client.settimeout(timeout) + client.connect(str(socket_path)) + client.sendall(encoded) + client.shutdown(socket.SHUT_WR) + + chunks: list[bytes] = [] + total_bytes = 0 + while True: + chunk = client.recv(65536) + if not chunk: + break + chunks.append(chunk) + total_bytes += len(chunk) + if total_bytes > MAX_MESSAGE_BYTES: + raise RuntimeError("Server response exceeded the maximum allowed size.") + if b"\n" in chunk: + break + + message = b"".join(chunks).decode("utf-8").strip() + if not message: + raise RuntimeError("Server closed the connection without responding.") + return json.loads(message) diff --git a/query_index.py b/query_index.py index 7dee4d3..a6794dc 100644 --- a/query_index.py +++ b/query_index.py @@ -3,145 +3,232 @@ from __future__ import annotations import argparse import json -import sqlite3 from pathlib import Path +import signal +import socket +import sys from typing import Any -import numpy as np -from sentence_transformers import SentenceTransformer -import os -from pathlib import Path +from cheat_runtime import DEFAULT_SOCKET_PATH, MAX_MESSAGE_BYTES, SearchEngine, send_server_request -DB_PATH = Path("cheat.db") - -LOCAL_CACHE_DIR = Path("models/hf") -os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve())) -os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve())) - -def deserialize_embedding(blob: bytes, dim: int) -> np.ndarray: - vec = np.frombuffer(blob, dtype=np.float32) - if vec.shape[0] != dim: - raise ValueError(f"Embedding length mismatch: expected {dim}, got {vec.shape[0]}") - return vec - - -def load_index(conn: sqlite3.Connection) -> tuple[list[dict[str, Any]], np.ndarray, str]: - rows = conn.execute(""" - SELECT - c.id, - c.command, - c.explanation, - c.intent_json, - c.alternatives_json, - c.requires_json, - c.packages_json, - c.tags_json, - c.platform_json, - c.shell_json, - c.safety, - e.model_name, - e.embedding_blob, - e.embedding_dim - FROM cards c - JOIN card_embeddings e ON c.id = e.card_id - ORDER BY c.id - """).fetchall() - - if not rows: - raise RuntimeError("No indexed cards found. Run build_index.py first.") - - cards: list[dict[str, Any]] = [] - vectors: list[np.ndarray] = [] - model_name: str | None = None - - for row in rows: - ( - card_id, - command, - explanation, - intent_json, - alternatives_json, - requires_json, - packages_json, - tags_json, - platform_json, - shell_json, - safety, - row_model_name, - embedding_blob, - embedding_dim, - ) = row - - if model_name is None: - model_name = row_model_name - elif model_name != row_model_name: - raise RuntimeError("Mixed embedding models found in the index.") - - cards.append({ - "id": card_id, - "command": command, - "explanation": explanation, - "intent": json.loads(intent_json), - "alternatives": json.loads(alternatives_json), - "requires": json.loads(requires_json), - "packages": json.loads(packages_json), - "tags": json.loads(tags_json), - "platform": json.loads(platform_json), - "shell": json.loads(shell_json), - "safety": safety, - }) - vectors.append(deserialize_embedding(embedding_blob, embedding_dim)) - - matrix = np.vstack(vectors) - return cards, matrix, model_name - - -def search( - query: str, - top_k: int = 5, -) -> list[dict[str, Any]]: - conn = sqlite3.connect(DB_PATH) - try: - cards, matrix, model_name = load_index(conn) - finally: - conn.close() - model = SentenceTransformer( - model_name, - cache_folder=str(LOCAL_CACHE_DIR.resolve()), - local_files_only=True, +def search(query: str, top_k: int = 5) -> list[dict[str, Any]]: + engine = SearchEngine() + return engine.search(query, top_k=top_k) + + +def format_results(results: list[dict[str, Any]]) -> str: + lines: list[str] = [] + for i, result in enumerate(results, start=1): + lines.append(f"[{i}] score={result['score']:.4f} id={result['id']}") + lines.append(f" command: {result['command']}") + lines.append(f" explanation: {result['explanation']}") + if result["alternatives"]: + lines.append(f" alternatives: {', '.join(result['alternatives'])}") + lines.append(f" intent: {', '.join(result['intent'][:3])}") + lines.append("") + return "\n".join(lines).rstrip() + + +def query_via_server(query: str, top_k: int, socket_path: Path) -> list[dict[str, Any]]: + response = send_server_request( + {"action": "query", "query": query, "top_k": top_k}, + socket_path=socket_path, ) + if not response.get("ok"): + raise RuntimeError(response.get("error", "Query server returned an error.")) + return response["results"] + + +def read_request(conn: socket.socket) -> dict[str, Any]: + chunks: list[bytes] = [] + total_bytes = 0 + while True: + chunk = conn.recv(65536) + if not chunk: + break + chunks.append(chunk) + total_bytes += len(chunk) + if total_bytes > MAX_MESSAGE_BYTES: + raise RuntimeError("Request exceeded the maximum allowed size.") + if b"\n" in chunk: + break + + message = b"".join(chunks).decode("utf-8").strip() + if not message: + raise RuntimeError("Received an empty request.") + return json.loads(message) + + +def send_response(conn: socket.socket, response: dict[str, Any]) -> None: + encoded = (json.dumps(response) + "\n").encode("utf-8") + conn.sendall(encoded) + + +def prepare_socket(socket_path: Path) -> socket.socket: + socket_path.parent.mkdir(parents=True, exist_ok=True) + if socket_path.exists(): + try: + status = send_server_request({"action": "status"}, socket_path=socket_path, timeout=0.5) + except OSError: + socket_path.unlink() + else: + if status.get("ok"): + raise RuntimeError(f"Query server is already running at {socket_path}") + socket_path.unlink() + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(str(socket_path)) + server_sock.listen() + server_sock.settimeout(1.0) + return server_sock + + +def serve(socket_path: Path) -> None: + engine = SearchEngine() + server_sock = prepare_socket(socket_path) + stop_requested = False + + def request_stop(_signum: int, _frame: Any) -> None: + nonlocal stop_requested + stop_requested = True + + previous_sigint = signal.signal(signal.SIGINT, request_stop) + previous_sigterm = signal.signal(signal.SIGTERM, request_stop) + + print(f"Serving cheat queries on {socket_path}") + print("Use --stop to shut the server down.") + + try: + while not stop_requested: + try: + conn, _ = server_sock.accept() + except socket.timeout: + continue + + with conn: + should_stop = False + try: + request = read_request(conn) + action = request.get("action") + if action == "query": + response = { + "ok": True, + "results": engine.search( + str(request["query"]), + top_k=int(request.get("top_k", 5)), + ), + } + elif action == "reload": + engine.reload() + response = { + "ok": True, + "message": f"Reloaded {len(engine.cards)} cards.", + } + elif action == "status": + response = { + "ok": True, + "card_count": len(engine.cards), + "model_name": engine.model_name, + "socket_path": str(socket_path), + } + elif action == "stop": + response = {"ok": True, "message": "Stopping query server."} + should_stop = True + else: + response = {"ok": False, "error": f"Unknown action: {action!r}"} + except Exception as exc: + response = {"ok": False, "error": str(exc)} + + send_response(conn, response) + if should_stop: + stop_requested = True + finally: + signal.signal(signal.SIGINT, previous_sigint) + signal.signal(signal.SIGTERM, previous_sigterm) + server_sock.close() + try: + socket_path.unlink() + except FileNotFoundError: + pass - qvec = model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0] - scores = matrix @ qvec - top_indices = np.argsort(scores)[::-1][:top_k] +def control_server(action: str, socket_path: Path) -> int: + try: + response = send_server_request({"action": action}, socket_path=socket_path) + except OSError as exc: + print(f"Unable to reach query server at {socket_path}: {exc}", file=sys.stderr) + return 1 - results: list[dict[str, Any]] = [] - for idx in top_indices: - card = dict(cards[idx]) - card["score"] = float(scores[idx]) - results.append(card) + if not response.get("ok"): + print(response.get("error", "Query server request failed."), file=sys.stderr) + return 1 - return results + if action == "status": + print(f"socket: {response['socket_path']}") + print(f"model: {response['model_name']}") + print(f"cards: {response['card_count']}") + else: + print(response.get("message", action)) + return 0 def main() -> None: parser = argparse.ArgumentParser(description="Query the local command card index.") - parser.add_argument("query", type=str, help="Natural language query") + parser.add_argument("query", nargs="?", type=str, help="Natural language query") parser.add_argument("--top-k", type=int, default=5, help="Number of results to return") + parser.add_argument( + "--socket", + type=Path, + default=DEFAULT_SOCKET_PATH, + help=f"Unix socket path for the query server (default: {DEFAULT_SOCKET_PATH})", + ) + parser.add_argument("--serve", action="store_true", help="Run the long-lived query server") + parser.add_argument("--reload", action="store_true", help="Reload a running query server") + parser.add_argument("--stop", action="store_true", help="Stop a running query server") + parser.add_argument("--status", action="store_true", help="Show query server status") args = parser.parse_args() - results = search(args.query, top_k=args.top_k) + actions = [args.serve, args.reload, args.stop, args.status] + if sum(bool(action) for action in actions) > 1: + parser.error("Choose only one of --serve, --reload, --stop, or --status.") - for i, result in enumerate(results, start=1): - print(f"[{i}] score={result['score']:.4f} id={result['id']}") - print(f" command: {result['command']}") - print(f" explanation: {result['explanation']}") - if result["alternatives"]: - print(f" alternatives: {', '.join(result['alternatives'])}") - print(f" intent: {', '.join(result['intent'][:3])}") - print() + if args.serve: + if args.query is not None: + parser.error("Query text cannot be used with --serve.") + serve(args.socket) + return + + if args.reload: + if args.query is not None: + parser.error("Query text cannot be used with --reload.") + raise SystemExit(control_server("reload", args.socket)) + + if args.stop: + if args.query is not None: + parser.error("Query text cannot be used with --stop.") + raise SystemExit(control_server("stop", args.socket)) + + if args.status: + if args.query is not None: + parser.error("Query text cannot be used with --status.") + raise SystemExit(control_server("status", args.socket)) + + if args.query is None: + parser.error("A query is required unless you use a server control flag.") + + try: + results = query_via_server(args.query, args.top_k, args.socket) + except OSError: + print( + "Query server not running; loading the model in-process for this query.", + file=sys.stderr, + ) + results = search(args.query, top_k=args.top_k) + + output = format_results(results) + if output: + print(output) if __name__ == "__main__": |