aboutsummaryrefslogtreecommitdiff
path: root/query_index.py
diff options
context:
space:
mode:
Diffstat (limited to 'query_index.py')
-rw-r--r--query_index.py329
1 files changed, 208 insertions, 121 deletions
diff --git a/query_index.py b/query_index.py
index 7dee4d3..a6794dc 100644
--- a/query_index.py
+++ b/query_index.py
@@ -3,145 +3,232 @@ from __future__ import annotations
import argparse
import json
-import sqlite3
from pathlib import Path
+import signal
+import socket
+import sys
from typing import Any
-import numpy as np
-from sentence_transformers import SentenceTransformer
-import os
-from pathlib import Path
+from cheat_runtime import DEFAULT_SOCKET_PATH, MAX_MESSAGE_BYTES, SearchEngine, send_server_request
-DB_PATH = Path("cheat.db")
-
-LOCAL_CACHE_DIR = Path("models/hf")
-os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve()))
-os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve()))
-
-def deserialize_embedding(blob: bytes, dim: int) -> np.ndarray:
- vec = np.frombuffer(blob, dtype=np.float32)
- if vec.shape[0] != dim:
- raise ValueError(f"Embedding length mismatch: expected {dim}, got {vec.shape[0]}")
- return vec
-
-
-def load_index(conn: sqlite3.Connection) -> tuple[list[dict[str, Any]], np.ndarray, str]:
- rows = conn.execute("""
- SELECT
- c.id,
- c.command,
- c.explanation,
- c.intent_json,
- c.alternatives_json,
- c.requires_json,
- c.packages_json,
- c.tags_json,
- c.platform_json,
- c.shell_json,
- c.safety,
- e.model_name,
- e.embedding_blob,
- e.embedding_dim
- FROM cards c
- JOIN card_embeddings e ON c.id = e.card_id
- ORDER BY c.id
- """).fetchall()
-
- if not rows:
- raise RuntimeError("No indexed cards found. Run build_index.py first.")
-
- cards: list[dict[str, Any]] = []
- vectors: list[np.ndarray] = []
- model_name: str | None = None
-
- for row in rows:
- (
- card_id,
- command,
- explanation,
- intent_json,
- alternatives_json,
- requires_json,
- packages_json,
- tags_json,
- platform_json,
- shell_json,
- safety,
- row_model_name,
- embedding_blob,
- embedding_dim,
- ) = row
-
- if model_name is None:
- model_name = row_model_name
- elif model_name != row_model_name:
- raise RuntimeError("Mixed embedding models found in the index.")
-
- cards.append({
- "id": card_id,
- "command": command,
- "explanation": explanation,
- "intent": json.loads(intent_json),
- "alternatives": json.loads(alternatives_json),
- "requires": json.loads(requires_json),
- "packages": json.loads(packages_json),
- "tags": json.loads(tags_json),
- "platform": json.loads(platform_json),
- "shell": json.loads(shell_json),
- "safety": safety,
- })
- vectors.append(deserialize_embedding(embedding_blob, embedding_dim))
-
- matrix = np.vstack(vectors)
- return cards, matrix, model_name
-
-
-def search(
- query: str,
- top_k: int = 5,
-) -> list[dict[str, Any]]:
- conn = sqlite3.connect(DB_PATH)
- try:
- cards, matrix, model_name = load_index(conn)
- finally:
- conn.close()
- model = SentenceTransformer(
- model_name,
- cache_folder=str(LOCAL_CACHE_DIR.resolve()),
- local_files_only=True,
+def search(query: str, top_k: int = 5) -> list[dict[str, Any]]:
+ engine = SearchEngine()
+ return engine.search(query, top_k=top_k)
+
+
+def format_results(results: list[dict[str, Any]]) -> str:
+ lines: list[str] = []
+ for i, result in enumerate(results, start=1):
+ lines.append(f"[{i}] score={result['score']:.4f} id={result['id']}")
+ lines.append(f" command: {result['command']}")
+ lines.append(f" explanation: {result['explanation']}")
+ if result["alternatives"]:
+ lines.append(f" alternatives: {', '.join(result['alternatives'])}")
+ lines.append(f" intent: {', '.join(result['intent'][:3])}")
+ lines.append("")
+ return "\n".join(lines).rstrip()
+
+
+def query_via_server(query: str, top_k: int, socket_path: Path) -> list[dict[str, Any]]:
+ response = send_server_request(
+ {"action": "query", "query": query, "top_k": top_k},
+ socket_path=socket_path,
)
+ if not response.get("ok"):
+ raise RuntimeError(response.get("error", "Query server returned an error."))
+ return response["results"]
+
+
+def read_request(conn: socket.socket) -> dict[str, Any]:
+ chunks: list[bytes] = []
+ total_bytes = 0
+ while True:
+ chunk = conn.recv(65536)
+ if not chunk:
+ break
+ chunks.append(chunk)
+ total_bytes += len(chunk)
+ if total_bytes > MAX_MESSAGE_BYTES:
+ raise RuntimeError("Request exceeded the maximum allowed size.")
+ if b"\n" in chunk:
+ break
+
+ message = b"".join(chunks).decode("utf-8").strip()
+ if not message:
+ raise RuntimeError("Received an empty request.")
+ return json.loads(message)
+
+
+def send_response(conn: socket.socket, response: dict[str, Any]) -> None:
+ encoded = (json.dumps(response) + "\n").encode("utf-8")
+ conn.sendall(encoded)
+
+
+def prepare_socket(socket_path: Path) -> socket.socket:
+ socket_path.parent.mkdir(parents=True, exist_ok=True)
+ if socket_path.exists():
+ try:
+ status = send_server_request({"action": "status"}, socket_path=socket_path, timeout=0.5)
+ except OSError:
+ socket_path.unlink()
+ else:
+ if status.get("ok"):
+ raise RuntimeError(f"Query server is already running at {socket_path}")
+ socket_path.unlink()
+
+ server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ server_sock.bind(str(socket_path))
+ server_sock.listen()
+ server_sock.settimeout(1.0)
+ return server_sock
+
+
+def serve(socket_path: Path) -> None:
+ engine = SearchEngine()
+ server_sock = prepare_socket(socket_path)
+ stop_requested = False
+
+ def request_stop(_signum: int, _frame: Any) -> None:
+ nonlocal stop_requested
+ stop_requested = True
+
+ previous_sigint = signal.signal(signal.SIGINT, request_stop)
+ previous_sigterm = signal.signal(signal.SIGTERM, request_stop)
+
+ print(f"Serving cheat queries on {socket_path}")
+ print("Use --stop to shut the server down.")
+
+ try:
+ while not stop_requested:
+ try:
+ conn, _ = server_sock.accept()
+ except socket.timeout:
+ continue
+
+ with conn:
+ should_stop = False
+ try:
+ request = read_request(conn)
+ action = request.get("action")
+ if action == "query":
+ response = {
+ "ok": True,
+ "results": engine.search(
+ str(request["query"]),
+ top_k=int(request.get("top_k", 5)),
+ ),
+ }
+ elif action == "reload":
+ engine.reload()
+ response = {
+ "ok": True,
+ "message": f"Reloaded {len(engine.cards)} cards.",
+ }
+ elif action == "status":
+ response = {
+ "ok": True,
+ "card_count": len(engine.cards),
+ "model_name": engine.model_name,
+ "socket_path": str(socket_path),
+ }
+ elif action == "stop":
+ response = {"ok": True, "message": "Stopping query server."}
+ should_stop = True
+ else:
+ response = {"ok": False, "error": f"Unknown action: {action!r}"}
+ except Exception as exc:
+ response = {"ok": False, "error": str(exc)}
+
+ send_response(conn, response)
+ if should_stop:
+ stop_requested = True
+ finally:
+ signal.signal(signal.SIGINT, previous_sigint)
+ signal.signal(signal.SIGTERM, previous_sigterm)
+ server_sock.close()
+ try:
+ socket_path.unlink()
+ except FileNotFoundError:
+ pass
- qvec = model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0]
- scores = matrix @ qvec
- top_indices = np.argsort(scores)[::-1][:top_k]
+def control_server(action: str, socket_path: Path) -> int:
+ try:
+ response = send_server_request({"action": action}, socket_path=socket_path)
+ except OSError as exc:
+ print(f"Unable to reach query server at {socket_path}: {exc}", file=sys.stderr)
+ return 1
- results: list[dict[str, Any]] = []
- for idx in top_indices:
- card = dict(cards[idx])
- card["score"] = float(scores[idx])
- results.append(card)
+ if not response.get("ok"):
+ print(response.get("error", "Query server request failed."), file=sys.stderr)
+ return 1
- return results
+ if action == "status":
+ print(f"socket: {response['socket_path']}")
+ print(f"model: {response['model_name']}")
+ print(f"cards: {response['card_count']}")
+ else:
+ print(response.get("message", action))
+ return 0
def main() -> None:
parser = argparse.ArgumentParser(description="Query the local command card index.")
- parser.add_argument("query", type=str, help="Natural language query")
+ parser.add_argument("query", nargs="?", type=str, help="Natural language query")
parser.add_argument("--top-k", type=int, default=5, help="Number of results to return")
+ parser.add_argument(
+ "--socket",
+ type=Path,
+ default=DEFAULT_SOCKET_PATH,
+ help=f"Unix socket path for the query server (default: {DEFAULT_SOCKET_PATH})",
+ )
+ parser.add_argument("--serve", action="store_true", help="Run the long-lived query server")
+ parser.add_argument("--reload", action="store_true", help="Reload a running query server")
+ parser.add_argument("--stop", action="store_true", help="Stop a running query server")
+ parser.add_argument("--status", action="store_true", help="Show query server status")
args = parser.parse_args()
- results = search(args.query, top_k=args.top_k)
+ actions = [args.serve, args.reload, args.stop, args.status]
+ if sum(bool(action) for action in actions) > 1:
+ parser.error("Choose only one of --serve, --reload, --stop, or --status.")
- for i, result in enumerate(results, start=1):
- print(f"[{i}] score={result['score']:.4f} id={result['id']}")
- print(f" command: {result['command']}")
- print(f" explanation: {result['explanation']}")
- if result["alternatives"]:
- print(f" alternatives: {', '.join(result['alternatives'])}")
- print(f" intent: {', '.join(result['intent'][:3])}")
- print()
+ if args.serve:
+ if args.query is not None:
+ parser.error("Query text cannot be used with --serve.")
+ serve(args.socket)
+ return
+
+ if args.reload:
+ if args.query is not None:
+ parser.error("Query text cannot be used with --reload.")
+ raise SystemExit(control_server("reload", args.socket))
+
+ if args.stop:
+ if args.query is not None:
+ parser.error("Query text cannot be used with --stop.")
+ raise SystemExit(control_server("stop", args.socket))
+
+ if args.status:
+ if args.query is not None:
+ parser.error("Query text cannot be used with --status.")
+ raise SystemExit(control_server("status", args.socket))
+
+ if args.query is None:
+ parser.error("A query is required unless you use a server control flag.")
+
+ try:
+ results = query_via_server(args.query, args.top_k, args.socket)
+ except OSError:
+ print(
+ "Query server not running; loading the model in-process for this query.",
+ file=sys.stderr,
+ )
+ results = search(args.query, top_k=args.top_k)
+
+ output = format_results(results)
+ if output:
+ print(output)
if __name__ == "__main__":