#!/usr/bin/env python3 from __future__ import annotations import json import os from pathlib import Path import sqlite3 import sys from typing import Any LOCAL_CACHE_DIR = Path("models/hf") os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve())) os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve())) import numpy as np from sentence_transformers import SentenceTransformer from cheat_runtime import DEFAULT_SOCKET_PATH, send_server_request DB_PATH = Path("cheat.db") CARDS_PATH = Path("./cards.jsonl") MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" REQUIRED_FIELDS = { "id", "intent", "command", "alternatives", "explanation", "requires", "packages", "tags", "platform", "shell", "safety", } def load_cards(path: Path) -> list[dict[str, Any]]: cards: list[dict[str, Any]] = [] with path.open("r", encoding="utf-8") as f: for line_no, line in enumerate(f, start=1): line = line.strip() if not line: continue try: card = json.loads(line) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON on line {line_no}: {e}") from e missing = REQUIRED_FIELDS - set(card.keys()) if missing: raise ValueError(f"Missing required fields on line {line_no}: {sorted(missing)}") cards.append(card) return cards def build_search_text(card: dict[str, Any]) -> str: """ Build a compact semantic representation for embedding. This is what the retriever will search over. """ parts: list[str] = [] intents = card.get("intent", []) tags = card.get("tags", []) command = card.get("command", "") explanation = card.get("explanation", "") alternatives = card.get("alternatives", []) requires = card.get("requires", []) platform = card.get("platform", []) if intents: parts.append("Intents: " + " | ".join(intents)) if tags: parts.append("Tags: " + ", ".join(tags)) if command: parts.append("Command: " + command) if alternatives: parts.append("Alternatives: " + " | ".join(alternatives)) if explanation: parts.append("Explanation: " + explanation) if requires: parts.append("Requires: " + ", ".join(requires)) if platform: parts.append("Platform: " + ", ".join(platform)) return "\n".join(parts) def serialize_embedding(vec: np.ndarray) -> bytes: return vec.astype(np.float32).tobytes() def upsert_card(conn: sqlite3.Connection, card: dict[str, Any], search_text: str) -> None: conn.execute(""" INSERT INTO cards ( id, command, explanation, intent_json, alternatives_json, requires_json, packages_json, tags_json, platform_json, shell_json, safety, search_text, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) ON CONFLICT(id) DO UPDATE SET command=excluded.command, explanation=excluded.explanation, intent_json=excluded.intent_json, alternatives_json=excluded.alternatives_json, requires_json=excluded.requires_json, packages_json=excluded.packages_json, tags_json=excluded.tags_json, platform_json=excluded.platform_json, shell_json=excluded.shell_json, safety=excluded.safety, search_text=excluded.search_text, updated_at=CURRENT_TIMESTAMP """, ( card["id"], card["command"], card["explanation"], json.dumps(card["intent"], ensure_ascii=False), json.dumps(card["alternatives"], ensure_ascii=False), json.dumps(card["requires"], ensure_ascii=False), json.dumps(card["packages"], ensure_ascii=False), json.dumps(card["tags"], ensure_ascii=False), json.dumps(card["platform"], ensure_ascii=False), json.dumps(card["shell"], ensure_ascii=False), card["safety"], search_text, )) def upsert_embedding( conn: sqlite3.Connection, card_id: str, model_name: str, vec: np.ndarray, ) -> None: conn.execute(""" INSERT INTO card_embeddings ( card_id, model_name, embedding_blob, embedding_dim ) VALUES (?, ?, ?, ?) ON CONFLICT(card_id) DO UPDATE SET model_name=excluded.model_name, embedding_blob=excluded.embedding_blob, embedding_dim=excluded.embedding_dim """, ( card_id, model_name, serialize_embedding(vec), int(vec.shape[0]), )) def main() -> None: if not DB_PATH.exists(): raise FileNotFoundError(f"Database not found at {DB_PATH}. Run init_db.py first.") if not CARDS_PATH.exists(): raise FileNotFoundError(f"Cards file not found at {CARDS_PATH}") cards = load_cards(CARDS_PATH) model = SentenceTransformer( MODEL_NAME, cache_folder=str(LOCAL_CACHE_DIR.resolve()), local_files_only=True, ) search_texts = [build_search_text(card) for card in cards] embeddings = model.encode( search_texts, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True, ) conn = sqlite3.connect(DB_PATH) try: conn.execute("PRAGMA foreign_keys=ON;") for card, vec, search_text in zip(cards, embeddings, search_texts): upsert_card(conn, card, search_text) upsert_embedding(conn, card["id"], MODEL_NAME, vec) conn.commit() print(f"Indexed {len(cards)} cards into {DB_PATH}") finally: conn.close() try: response = send_server_request( {"action": "reload"}, socket_path=DEFAULT_SOCKET_PATH, timeout=1.0, ) except OSError: return if response.get("ok"): print(f"Reloaded query server at {DEFAULT_SOCKET_PATH}") else: print( f"Warning: query server reload failed: {response.get('error', 'unknown error')}", file=sys.stderr, ) if __name__ == "__main__": main()