aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.codex0
-rw-r--r--.gitignore3
-rw-r--r--README.md36
-rw-r--r--build_index.py37
-rw-r--r--cheat.dbbin118784 -> 118784 bytes
-rw-r--r--cheat_runtime.py183
-rw-r--r--query_index.py329
7 files changed, 450 insertions, 138 deletions
diff --git a/.codex b/.codex
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/.codex
diff --git a/.gitignore b/.gitignore
index 5b7a10c..7a49393 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
venv
-models \ No newline at end of file
+models
+__pycache__
diff --git a/README.md b/README.md
index 534001e..65f5a36 100644
--- a/README.md
+++ b/README.md
@@ -7,22 +7,48 @@ Local command-helper retrieval system using JSONL, SQLite, and sentence-transfor
```sh
export HF_HOME="$PWD/models/hf"
export SENTENCE_TRANSFORMERS_HOME="$PWD/models/hf"
-python -m venv venv
+python -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install -r requirements.txt
-python scripts/init_db.py
-python scripts/build_index.py
+python init_db.py
+python build_index.py
```
Then run a query like this:
```sh
-python scripts/query_index.py "get free disk space"
+python query_index.py "get free disk space"
```
To add commands, add to `./cards.jsonl` and rebuild the index:
```sh
-python scripts/build_index.py
+python build_index.py
+```
+
+For fast repeated queries, start the Unix-socket server once in another terminal:
+
+```sh
+python query_index.py --serve
+```
+
+Then keep using the same query command:
+
+```sh
+python query_index.py "get free disk space"
+```
+
+To run the server in the background:
+
+```sh
+nohup python query_index.py --serve >/tmp/cheat.log 2>&1 &
+```
+
+Useful controls:
+
+```sh
+python query_index.py --status
+python query_index.py --reload
+python query_index.py --stop
```
diff --git a/build_index.py b/build_index.py
index 8597c68..d25f446 100644
--- a/build_index.py
+++ b/build_index.py
@@ -2,24 +2,24 @@
from __future__ import annotations
import json
-import sqlite3
+import os
from pathlib import Path
+import sqlite3
+import sys
from typing import Any
+LOCAL_CACHE_DIR = Path("models/hf")
+os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve()))
+os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve()))
+
import numpy as np
from sentence_transformers import SentenceTransformer
-import os
-from pathlib import Path
+from cheat_runtime import DEFAULT_SOCKET_PATH, send_server_request
DB_PATH = Path("cheat.db")
CARDS_PATH = Path("./cards.jsonl")
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
-LOCAL_CACHE_DIR = Path("models/hf")
-
-LOCAL_CACHE_DIR = Path("models/hf")
-os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve()))
-os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve()))
REQUIRED_FIELDS = {
"id",
@@ -154,9 +154,7 @@ def upsert_embedding(
def main() -> None:
if not DB_PATH.exists():
- raise FileNotFoundError(
- f"Database not found at {DB_PATH}. Run scripts/init_db.py first."
- )
+ raise FileNotFoundError(f"Database not found at {DB_PATH}. Run init_db.py first.")
if not CARDS_PATH.exists():
raise FileNotFoundError(f"Cards file not found at {CARDS_PATH}")
@@ -187,6 +185,23 @@ def main() -> None:
finally:
conn.close()
+ try:
+ response = send_server_request(
+ {"action": "reload"},
+ socket_path=DEFAULT_SOCKET_PATH,
+ timeout=1.0,
+ )
+ except OSError:
+ return
+
+ if response.get("ok"):
+ print(f"Reloaded query server at {DEFAULT_SOCKET_PATH}")
+ else:
+ print(
+ f"Warning: query server reload failed: {response.get('error', 'unknown error')}",
+ file=sys.stderr,
+ )
+
if __name__ == "__main__":
main()
diff --git a/cheat.db b/cheat.db
index ce146cf..f3d4cb5 100644
--- a/cheat.db
+++ b/cheat.db
Binary files differ
diff --git a/cheat_runtime.py b/cheat_runtime.py
new file mode 100644
index 0000000..c82ec70
--- /dev/null
+++ b/cheat_runtime.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import json
+import os
+from pathlib import Path
+import socket
+import sqlite3
+from typing import Any
+
+import numpy as np
+
+LOCAL_CACHE_DIR = Path("models/hf")
+os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve()))
+os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve()))
+
+from sentence_transformers import SentenceTransformer
+
+DB_PATH = Path("cheat.db")
+DEFAULT_SOCKET_PATH = Path(os.environ.get("CHEAT_SOCKET_PATH", "/tmp/cheat.sock"))
+MAX_MESSAGE_BYTES = 1024 * 1024
+
+
+def deserialize_embedding(blob: bytes, dim: int) -> np.ndarray:
+ vec = np.frombuffer(blob, dtype=np.float32)
+ if vec.shape[0] != dim:
+ raise ValueError(f"Embedding length mismatch: expected {dim}, got {vec.shape[0]}")
+ return vec
+
+
+def load_index(conn: sqlite3.Connection) -> tuple[list[dict[str, Any]], np.ndarray, str]:
+ rows = conn.execute("""
+ SELECT
+ c.id,
+ c.command,
+ c.explanation,
+ c.intent_json,
+ c.alternatives_json,
+ c.requires_json,
+ c.packages_json,
+ c.tags_json,
+ c.platform_json,
+ c.shell_json,
+ c.safety,
+ e.model_name,
+ e.embedding_blob,
+ e.embedding_dim
+ FROM cards c
+ JOIN card_embeddings e ON c.id = e.card_id
+ ORDER BY c.id
+ """).fetchall()
+
+ if not rows:
+ raise RuntimeError("No indexed cards found. Run build_index.py first.")
+
+ cards: list[dict[str, Any]] = []
+ vectors: list[np.ndarray] = []
+ model_name: str | None = None
+
+ for row in rows:
+ (
+ card_id,
+ command,
+ explanation,
+ intent_json,
+ alternatives_json,
+ requires_json,
+ packages_json,
+ tags_json,
+ platform_json,
+ shell_json,
+ safety,
+ row_model_name,
+ embedding_blob,
+ embedding_dim,
+ ) = row
+
+ if model_name is None:
+ model_name = row_model_name
+ elif model_name != row_model_name:
+ raise RuntimeError("Mixed embedding models found in the index.")
+
+ cards.append({
+ "id": card_id,
+ "command": command,
+ "explanation": explanation,
+ "intent": json.loads(intent_json),
+ "alternatives": json.loads(alternatives_json),
+ "requires": json.loads(requires_json),
+ "packages": json.loads(packages_json),
+ "tags": json.loads(tags_json),
+ "platform": json.loads(platform_json),
+ "shell": json.loads(shell_json),
+ "safety": safety,
+ })
+ vectors.append(deserialize_embedding(embedding_blob, embedding_dim))
+
+ matrix = np.vstack(vectors)
+ return cards, matrix, model_name
+
+
+class SearchEngine:
+ def __init__(self, db_path: Path = DB_PATH) -> None:
+ self.db_path = db_path
+ self.cards: list[dict[str, Any]] = []
+ self.matrix = np.empty((0, 0), dtype=np.float32)
+ self.model_name = ""
+ self._model: SentenceTransformer | None = None
+ self.reload()
+
+ def reload(self) -> None:
+ conn = sqlite3.connect(self.db_path)
+ try:
+ cards, matrix, model_name = load_index(conn)
+ finally:
+ conn.close()
+
+ model = self._model
+ if model is None or model_name != self.model_name:
+ model = SentenceTransformer(
+ model_name,
+ cache_folder=str(LOCAL_CACHE_DIR.resolve()),
+ local_files_only=True,
+ )
+
+ self.cards = cards
+ self.matrix = matrix
+ self.model_name = model_name
+ self._model = model
+
+ def search(self, query: str, top_k: int = 5) -> list[dict[str, Any]]:
+ if top_k < 1:
+ raise ValueError("top_k must be at least 1")
+ if self._model is None:
+ raise RuntimeError("Search engine is not initialized.")
+
+ qvec = self._model.encode(
+ [query],
+ normalize_embeddings=True,
+ convert_to_numpy=True,
+ )[0]
+
+ scores = self.matrix @ qvec
+ top_indices = np.argsort(scores)[::-1][:top_k]
+
+ results: list[dict[str, Any]] = []
+ for idx in top_indices:
+ card = dict(self.cards[idx])
+ card["score"] = float(scores[idx])
+ results.append(card)
+
+ return results
+
+
+def send_server_request(
+ payload: dict[str, Any],
+ socket_path: Path = DEFAULT_SOCKET_PATH,
+ timeout: float = 5.0,
+) -> dict[str, Any]:
+ encoded = (json.dumps(payload) + "\n").encode("utf-8")
+ with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client:
+ client.settimeout(timeout)
+ client.connect(str(socket_path))
+ client.sendall(encoded)
+ client.shutdown(socket.SHUT_WR)
+
+ chunks: list[bytes] = []
+ total_bytes = 0
+ while True:
+ chunk = client.recv(65536)
+ if not chunk:
+ break
+ chunks.append(chunk)
+ total_bytes += len(chunk)
+ if total_bytes > MAX_MESSAGE_BYTES:
+ raise RuntimeError("Server response exceeded the maximum allowed size.")
+ if b"\n" in chunk:
+ break
+
+ message = b"".join(chunks).decode("utf-8").strip()
+ if not message:
+ raise RuntimeError("Server closed the connection without responding.")
+ return json.loads(message)
diff --git a/query_index.py b/query_index.py
index 7dee4d3..a6794dc 100644
--- a/query_index.py
+++ b/query_index.py
@@ -3,145 +3,232 @@ from __future__ import annotations
import argparse
import json
-import sqlite3
from pathlib import Path
+import signal
+import socket
+import sys
from typing import Any
-import numpy as np
-from sentence_transformers import SentenceTransformer
-import os
-from pathlib import Path
+from cheat_runtime import DEFAULT_SOCKET_PATH, MAX_MESSAGE_BYTES, SearchEngine, send_server_request
-DB_PATH = Path("cheat.db")
-
-LOCAL_CACHE_DIR = Path("models/hf")
-os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve()))
-os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve()))
-
-def deserialize_embedding(blob: bytes, dim: int) -> np.ndarray:
- vec = np.frombuffer(blob, dtype=np.float32)
- if vec.shape[0] != dim:
- raise ValueError(f"Embedding length mismatch: expected {dim}, got {vec.shape[0]}")
- return vec
-
-
-def load_index(conn: sqlite3.Connection) -> tuple[list[dict[str, Any]], np.ndarray, str]:
- rows = conn.execute("""
- SELECT
- c.id,
- c.command,
- c.explanation,
- c.intent_json,
- c.alternatives_json,
- c.requires_json,
- c.packages_json,
- c.tags_json,
- c.platform_json,
- c.shell_json,
- c.safety,
- e.model_name,
- e.embedding_blob,
- e.embedding_dim
- FROM cards c
- JOIN card_embeddings e ON c.id = e.card_id
- ORDER BY c.id
- """).fetchall()
-
- if not rows:
- raise RuntimeError("No indexed cards found. Run build_index.py first.")
-
- cards: list[dict[str, Any]] = []
- vectors: list[np.ndarray] = []
- model_name: str | None = None
-
- for row in rows:
- (
- card_id,
- command,
- explanation,
- intent_json,
- alternatives_json,
- requires_json,
- packages_json,
- tags_json,
- platform_json,
- shell_json,
- safety,
- row_model_name,
- embedding_blob,
- embedding_dim,
- ) = row
-
- if model_name is None:
- model_name = row_model_name
- elif model_name != row_model_name:
- raise RuntimeError("Mixed embedding models found in the index.")
-
- cards.append({
- "id": card_id,
- "command": command,
- "explanation": explanation,
- "intent": json.loads(intent_json),
- "alternatives": json.loads(alternatives_json),
- "requires": json.loads(requires_json),
- "packages": json.loads(packages_json),
- "tags": json.loads(tags_json),
- "platform": json.loads(platform_json),
- "shell": json.loads(shell_json),
- "safety": safety,
- })
- vectors.append(deserialize_embedding(embedding_blob, embedding_dim))
-
- matrix = np.vstack(vectors)
- return cards, matrix, model_name
-
-
-def search(
- query: str,
- top_k: int = 5,
-) -> list[dict[str, Any]]:
- conn = sqlite3.connect(DB_PATH)
- try:
- cards, matrix, model_name = load_index(conn)
- finally:
- conn.close()
- model = SentenceTransformer(
- model_name,
- cache_folder=str(LOCAL_CACHE_DIR.resolve()),
- local_files_only=True,
+def search(query: str, top_k: int = 5) -> list[dict[str, Any]]:
+ engine = SearchEngine()
+ return engine.search(query, top_k=top_k)
+
+
+def format_results(results: list[dict[str, Any]]) -> str:
+ lines: list[str] = []
+ for i, result in enumerate(results, start=1):
+ lines.append(f"[{i}] score={result['score']:.4f} id={result['id']}")
+ lines.append(f" command: {result['command']}")
+ lines.append(f" explanation: {result['explanation']}")
+ if result["alternatives"]:
+ lines.append(f" alternatives: {', '.join(result['alternatives'])}")
+ lines.append(f" intent: {', '.join(result['intent'][:3])}")
+ lines.append("")
+ return "\n".join(lines).rstrip()
+
+
+def query_via_server(query: str, top_k: int, socket_path: Path) -> list[dict[str, Any]]:
+ response = send_server_request(
+ {"action": "query", "query": query, "top_k": top_k},
+ socket_path=socket_path,
)
+ if not response.get("ok"):
+ raise RuntimeError(response.get("error", "Query server returned an error."))
+ return response["results"]
+
+
+def read_request(conn: socket.socket) -> dict[str, Any]:
+ chunks: list[bytes] = []
+ total_bytes = 0
+ while True:
+ chunk = conn.recv(65536)
+ if not chunk:
+ break
+ chunks.append(chunk)
+ total_bytes += len(chunk)
+ if total_bytes > MAX_MESSAGE_BYTES:
+ raise RuntimeError("Request exceeded the maximum allowed size.")
+ if b"\n" in chunk:
+ break
+
+ message = b"".join(chunks).decode("utf-8").strip()
+ if not message:
+ raise RuntimeError("Received an empty request.")
+ return json.loads(message)
+
+
+def send_response(conn: socket.socket, response: dict[str, Any]) -> None:
+ encoded = (json.dumps(response) + "\n").encode("utf-8")
+ conn.sendall(encoded)
+
+
+def prepare_socket(socket_path: Path) -> socket.socket:
+ socket_path.parent.mkdir(parents=True, exist_ok=True)
+ if socket_path.exists():
+ try:
+ status = send_server_request({"action": "status"}, socket_path=socket_path, timeout=0.5)
+ except OSError:
+ socket_path.unlink()
+ else:
+ if status.get("ok"):
+ raise RuntimeError(f"Query server is already running at {socket_path}")
+ socket_path.unlink()
+
+ server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ server_sock.bind(str(socket_path))
+ server_sock.listen()
+ server_sock.settimeout(1.0)
+ return server_sock
+
+
+def serve(socket_path: Path) -> None:
+ engine = SearchEngine()
+ server_sock = prepare_socket(socket_path)
+ stop_requested = False
+
+ def request_stop(_signum: int, _frame: Any) -> None:
+ nonlocal stop_requested
+ stop_requested = True
+
+ previous_sigint = signal.signal(signal.SIGINT, request_stop)
+ previous_sigterm = signal.signal(signal.SIGTERM, request_stop)
+
+ print(f"Serving cheat queries on {socket_path}")
+ print("Use --stop to shut the server down.")
+
+ try:
+ while not stop_requested:
+ try:
+ conn, _ = server_sock.accept()
+ except socket.timeout:
+ continue
+
+ with conn:
+ should_stop = False
+ try:
+ request = read_request(conn)
+ action = request.get("action")
+ if action == "query":
+ response = {
+ "ok": True,
+ "results": engine.search(
+ str(request["query"]),
+ top_k=int(request.get("top_k", 5)),
+ ),
+ }
+ elif action == "reload":
+ engine.reload()
+ response = {
+ "ok": True,
+ "message": f"Reloaded {len(engine.cards)} cards.",
+ }
+ elif action == "status":
+ response = {
+ "ok": True,
+ "card_count": len(engine.cards),
+ "model_name": engine.model_name,
+ "socket_path": str(socket_path),
+ }
+ elif action == "stop":
+ response = {"ok": True, "message": "Stopping query server."}
+ should_stop = True
+ else:
+ response = {"ok": False, "error": f"Unknown action: {action!r}"}
+ except Exception as exc:
+ response = {"ok": False, "error": str(exc)}
+
+ send_response(conn, response)
+ if should_stop:
+ stop_requested = True
+ finally:
+ signal.signal(signal.SIGINT, previous_sigint)
+ signal.signal(signal.SIGTERM, previous_sigterm)
+ server_sock.close()
+ try:
+ socket_path.unlink()
+ except FileNotFoundError:
+ pass
- qvec = model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0]
- scores = matrix @ qvec
- top_indices = np.argsort(scores)[::-1][:top_k]
+def control_server(action: str, socket_path: Path) -> int:
+ try:
+ response = send_server_request({"action": action}, socket_path=socket_path)
+ except OSError as exc:
+ print(f"Unable to reach query server at {socket_path}: {exc}", file=sys.stderr)
+ return 1
- results: list[dict[str, Any]] = []
- for idx in top_indices:
- card = dict(cards[idx])
- card["score"] = float(scores[idx])
- results.append(card)
+ if not response.get("ok"):
+ print(response.get("error", "Query server request failed."), file=sys.stderr)
+ return 1
- return results
+ if action == "status":
+ print(f"socket: {response['socket_path']}")
+ print(f"model: {response['model_name']}")
+ print(f"cards: {response['card_count']}")
+ else:
+ print(response.get("message", action))
+ return 0
def main() -> None:
parser = argparse.ArgumentParser(description="Query the local command card index.")
- parser.add_argument("query", type=str, help="Natural language query")
+ parser.add_argument("query", nargs="?", type=str, help="Natural language query")
parser.add_argument("--top-k", type=int, default=5, help="Number of results to return")
+ parser.add_argument(
+ "--socket",
+ type=Path,
+ default=DEFAULT_SOCKET_PATH,
+ help=f"Unix socket path for the query server (default: {DEFAULT_SOCKET_PATH})",
+ )
+ parser.add_argument("--serve", action="store_true", help="Run the long-lived query server")
+ parser.add_argument("--reload", action="store_true", help="Reload a running query server")
+ parser.add_argument("--stop", action="store_true", help="Stop a running query server")
+ parser.add_argument("--status", action="store_true", help="Show query server status")
args = parser.parse_args()
- results = search(args.query, top_k=args.top_k)
+ actions = [args.serve, args.reload, args.stop, args.status]
+ if sum(bool(action) for action in actions) > 1:
+ parser.error("Choose only one of --serve, --reload, --stop, or --status.")
- for i, result in enumerate(results, start=1):
- print(f"[{i}] score={result['score']:.4f} id={result['id']}")
- print(f" command: {result['command']}")
- print(f" explanation: {result['explanation']}")
- if result["alternatives"]:
- print(f" alternatives: {', '.join(result['alternatives'])}")
- print(f" intent: {', '.join(result['intent'][:3])}")
- print()
+ if args.serve:
+ if args.query is not None:
+ parser.error("Query text cannot be used with --serve.")
+ serve(args.socket)
+ return
+
+ if args.reload:
+ if args.query is not None:
+ parser.error("Query text cannot be used with --reload.")
+ raise SystemExit(control_server("reload", args.socket))
+
+ if args.stop:
+ if args.query is not None:
+ parser.error("Query text cannot be used with --stop.")
+ raise SystemExit(control_server("stop", args.socket))
+
+ if args.status:
+ if args.query is not None:
+ parser.error("Query text cannot be used with --status.")
+ raise SystemExit(control_server("status", args.socket))
+
+ if args.query is None:
+ parser.error("A query is required unless you use a server control flag.")
+
+ try:
+ results = query_via_server(args.query, args.top_k, args.socket)
+ except OSError:
+ print(
+ "Query server not running; loading the model in-process for this query.",
+ file=sys.stderr,
+ )
+ results = search(args.query, top_k=args.top_k)
+
+ output = format_results(results)
+ if output:
+ print(output)
if __name__ == "__main__":