aboutsummaryrefslogtreecommitdiff
path: root/query_index.py
diff options
context:
space:
mode:
authortwells46 <173561638+twells46@users.noreply.github.com>2026-04-01 15:20:50 -0500
committertwells46 <173561638+twells46@users.noreply.github.com>2026-04-01 15:20:50 -0500
commit2f37974a4c84f7ffdd07e2c223eba2d8bd981b61 (patch)
tree1741f17884077e9d4e0dbfe5908305fc21661ced /query_index.py
Initial commit
Diffstat (limited to 'query_index.py')
-rw-r--r--query_index.py148
1 files changed, 148 insertions, 0 deletions
diff --git a/query_index.py b/query_index.py
new file mode 100644
index 0000000..7dee4d3
--- /dev/null
+++ b/query_index.py
@@ -0,0 +1,148 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import argparse
+import json
+import sqlite3
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+from sentence_transformers import SentenceTransformer
+import os
+from pathlib import Path
+
+DB_PATH = Path("cheat.db")
+
+LOCAL_CACHE_DIR = Path("models/hf")
+os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve()))
+os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve()))
+
+def deserialize_embedding(blob: bytes, dim: int) -> np.ndarray:
+ vec = np.frombuffer(blob, dtype=np.float32)
+ if vec.shape[0] != dim:
+ raise ValueError(f"Embedding length mismatch: expected {dim}, got {vec.shape[0]}")
+ return vec
+
+
+def load_index(conn: sqlite3.Connection) -> tuple[list[dict[str, Any]], np.ndarray, str]:
+ rows = conn.execute("""
+ SELECT
+ c.id,
+ c.command,
+ c.explanation,
+ c.intent_json,
+ c.alternatives_json,
+ c.requires_json,
+ c.packages_json,
+ c.tags_json,
+ c.platform_json,
+ c.shell_json,
+ c.safety,
+ e.model_name,
+ e.embedding_blob,
+ e.embedding_dim
+ FROM cards c
+ JOIN card_embeddings e ON c.id = e.card_id
+ ORDER BY c.id
+ """).fetchall()
+
+ if not rows:
+ raise RuntimeError("No indexed cards found. Run build_index.py first.")
+
+ cards: list[dict[str, Any]] = []
+ vectors: list[np.ndarray] = []
+ model_name: str | None = None
+
+ for row in rows:
+ (
+ card_id,
+ command,
+ explanation,
+ intent_json,
+ alternatives_json,
+ requires_json,
+ packages_json,
+ tags_json,
+ platform_json,
+ shell_json,
+ safety,
+ row_model_name,
+ embedding_blob,
+ embedding_dim,
+ ) = row
+
+ if model_name is None:
+ model_name = row_model_name
+ elif model_name != row_model_name:
+ raise RuntimeError("Mixed embedding models found in the index.")
+
+ cards.append({
+ "id": card_id,
+ "command": command,
+ "explanation": explanation,
+ "intent": json.loads(intent_json),
+ "alternatives": json.loads(alternatives_json),
+ "requires": json.loads(requires_json),
+ "packages": json.loads(packages_json),
+ "tags": json.loads(tags_json),
+ "platform": json.loads(platform_json),
+ "shell": json.loads(shell_json),
+ "safety": safety,
+ })
+ vectors.append(deserialize_embedding(embedding_blob, embedding_dim))
+
+ matrix = np.vstack(vectors)
+ return cards, matrix, model_name
+
+
+def search(
+ query: str,
+ top_k: int = 5,
+) -> list[dict[str, Any]]:
+ conn = sqlite3.connect(DB_PATH)
+ try:
+ cards, matrix, model_name = load_index(conn)
+ finally:
+ conn.close()
+
+ model = SentenceTransformer(
+ model_name,
+ cache_folder=str(LOCAL_CACHE_DIR.resolve()),
+ local_files_only=True,
+ )
+
+ qvec = model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0]
+
+ scores = matrix @ qvec
+ top_indices = np.argsort(scores)[::-1][:top_k]
+
+ results: list[dict[str, Any]] = []
+ for idx in top_indices:
+ card = dict(cards[idx])
+ card["score"] = float(scores[idx])
+ results.append(card)
+
+ return results
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Query the local command card index.")
+ parser.add_argument("query", type=str, help="Natural language query")
+ parser.add_argument("--top-k", type=int, default=5, help="Number of results to return")
+ args = parser.parse_args()
+
+ results = search(args.query, top_k=args.top_k)
+
+ for i, result in enumerate(results, start=1):
+ print(f"[{i}] score={result['score']:.4f} id={result['id']}")
+ print(f" command: {result['command']}")
+ print(f" explanation: {result['explanation']}")
+ if result["alternatives"]:
+ print(f" alternatives: {', '.join(result['alternatives'])}")
+ print(f" intent: {', '.join(result['intent'][:3])}")
+ print()
+
+
+if __name__ == "__main__":
+ main()