#!/usr/bin/env python3 from __future__ import annotations import argparse import json from pathlib import Path import signal import socket import sys from typing import Any from cheat_runtime import DEFAULT_SOCKET_PATH, MAX_MESSAGE_BYTES, SearchEngine, send_server_request def search(query: str, top_k: int = 5) -> list[dict[str, Any]]: engine = SearchEngine() return engine.search(query, top_k=top_k) def format_results(results: list[dict[str, Any]]) -> str: lines: list[str] = [] for i, result in enumerate(results, start=1): lines.append(f"[{i}] score={result['score']:.4f} id={result['id']}") lines.append(f" command: {result['command']}") lines.append(f" explanation: {result['explanation']}") if result["alternatives"]: lines.append(f" alternatives: {', '.join(result['alternatives'])}") lines.append(f" intent: {', '.join(result['intent'][:3])}") lines.append("") return "\n".join(lines).rstrip() def query_via_server(query: str, top_k: int, socket_path: Path) -> list[dict[str, Any]]: response = send_server_request( {"action": "query", "query": query, "top_k": top_k}, socket_path=socket_path, ) if not response.get("ok"): raise RuntimeError(response.get("error", "Query server returned an error.")) return response["results"] def read_request(conn: socket.socket) -> dict[str, Any]: chunks: list[bytes] = [] total_bytes = 0 while True: chunk = conn.recv(65536) if not chunk: break chunks.append(chunk) total_bytes += len(chunk) if total_bytes > MAX_MESSAGE_BYTES: raise RuntimeError("Request exceeded the maximum allowed size.") if b"\n" in chunk: break message = b"".join(chunks).decode("utf-8").strip() if not message: raise RuntimeError("Received an empty request.") return json.loads(message) def send_response(conn: socket.socket, response: dict[str, Any]) -> None: encoded = (json.dumps(response) + "\n").encode("utf-8") conn.sendall(encoded) def prepare_socket(socket_path: Path) -> socket.socket: socket_path.parent.mkdir(parents=True, exist_ok=True) if socket_path.exists(): try: status = send_server_request({"action": "status"}, socket_path=socket_path, timeout=0.5) except OSError: socket_path.unlink() else: if status.get("ok"): raise RuntimeError(f"Query server is already running at {socket_path}") socket_path.unlink() server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) server_sock.bind(str(socket_path)) server_sock.listen() server_sock.settimeout(1.0) return server_sock def serve(socket_path: Path) -> None: engine = SearchEngine() server_sock = prepare_socket(socket_path) stop_requested = False def request_stop(_signum: int, _frame: Any) -> None: nonlocal stop_requested stop_requested = True previous_sigint = signal.signal(signal.SIGINT, request_stop) previous_sigterm = signal.signal(signal.SIGTERM, request_stop) print(f"Serving cheat queries on {socket_path}") print("Use --stop to shut the server down.") try: while not stop_requested: try: conn, _ = server_sock.accept() except socket.timeout: continue with conn: should_stop = False try: request = read_request(conn) action = request.get("action") if action == "query": response = { "ok": True, "results": engine.search( str(request["query"]), top_k=int(request.get("top_k", 5)), ), } elif action == "reload": engine.reload() response = { "ok": True, "message": f"Reloaded {len(engine.cards)} cards.", } elif action == "status": response = { "ok": True, "card_count": len(engine.cards), "model_name": engine.model_name, "socket_path": str(socket_path), } elif action == "stop": response = {"ok": True, "message": "Stopping query server."} should_stop = True else: response = {"ok": False, "error": f"Unknown action: {action!r}"} except Exception as exc: response = {"ok": False, "error": str(exc)} send_response(conn, response) if should_stop: stop_requested = True finally: signal.signal(signal.SIGINT, previous_sigint) signal.signal(signal.SIGTERM, previous_sigterm) server_sock.close() try: socket_path.unlink() except FileNotFoundError: pass def control_server(action: str, socket_path: Path) -> int: try: response = send_server_request({"action": action}, socket_path=socket_path) except OSError as exc: print(f"Unable to reach query server at {socket_path}: {exc}", file=sys.stderr) return 1 if not response.get("ok"): print(response.get("error", "Query server request failed."), file=sys.stderr) return 1 if action == "status": print(f"socket: {response['socket_path']}") print(f"model: {response['model_name']}") print(f"cards: {response['card_count']}") else: print(response.get("message", action)) return 0 def main() -> None: parser = argparse.ArgumentParser(description="Query the local command card index.") parser.add_argument("query", nargs="?", type=str, help="Natural language query") parser.add_argument("--top-k", type=int, default=5, help="Number of results to return") parser.add_argument( "--socket", type=Path, default=DEFAULT_SOCKET_PATH, help=f"Unix socket path for the query server (default: {DEFAULT_SOCKET_PATH})", ) parser.add_argument("--serve", action="store_true", help="Run the long-lived query server") parser.add_argument("--reload", action="store_true", help="Reload a running query server") parser.add_argument("--stop", action="store_true", help="Stop a running query server") parser.add_argument("--status", action="store_true", help="Show query server status") args = parser.parse_args() actions = [args.serve, args.reload, args.stop, args.status] if sum(bool(action) for action in actions) > 1: parser.error("Choose only one of --serve, --reload, --stop, or --status.") if args.serve: if args.query is not None: parser.error("Query text cannot be used with --serve.") serve(args.socket) return if args.reload: if args.query is not None: parser.error("Query text cannot be used with --reload.") raise SystemExit(control_server("reload", args.socket)) if args.stop: if args.query is not None: parser.error("Query text cannot be used with --stop.") raise SystemExit(control_server("stop", args.socket)) if args.status: if args.query is not None: parser.error("Query text cannot be used with --status.") raise SystemExit(control_server("status", args.socket)) if args.query is None: parser.error("A query is required unless you use a server control flag.") try: results = query_via_server(args.query, args.top_k, args.socket) except OSError: print( "Query server not running; loading the model in-process for this query.", file=sys.stderr, ) results = search(args.query, top_k=args.top_k) output = format_results(results) if output: print(output) if __name__ == "__main__": main()