diff options
| author | twells46 <173561638+twells46@users.noreply.github.com> | 2026-04-01 15:20:50 -0500 |
|---|---|---|
| committer | twells46 <173561638+twells46@users.noreply.github.com> | 2026-04-01 15:20:50 -0500 |
| commit | 2f37974a4c84f7ffdd07e2c223eba2d8bd981b61 (patch) | |
| tree | 1741f17884077e9d4e0dbfe5908305fc21661ced /build_index.py | |
Initial commit
Diffstat (limited to 'build_index.py')
| -rw-r--r-- | build_index.py | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/build_index.py b/build_index.py new file mode 100644 index 0000000..8597c68 --- /dev/null +++ b/build_index.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path +from typing import Any + +import numpy as np +from sentence_transformers import SentenceTransformer + +import os +from pathlib import Path + +DB_PATH = Path("cheat.db") +CARDS_PATH = Path("./cards.jsonl") +MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" +LOCAL_CACHE_DIR = Path("models/hf") + +LOCAL_CACHE_DIR = Path("models/hf") +os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve())) +os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve())) + +REQUIRED_FIELDS = { + "id", + "intent", + "command", + "alternatives", + "explanation", + "requires", + "packages", + "tags", + "platform", + "shell", + "safety", +} + + +def load_cards(path: Path) -> list[dict[str, Any]]: + cards: list[dict[str, Any]] = [] + with path.open("r", encoding="utf-8") as f: + for line_no, line in enumerate(f, start=1): + line = line.strip() + if not line: + continue + try: + card = json.loads(line) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_no}: {e}") from e + + missing = REQUIRED_FIELDS - set(card.keys()) + if missing: + raise ValueError(f"Missing required fields on line {line_no}: {sorted(missing)}") + + cards.append(card) + return cards + + +def build_search_text(card: dict[str, Any]) -> str: + """ + Build a compact semantic representation for embedding. + This is what the retriever will search over. + """ + parts: list[str] = [] + + intents = card.get("intent", []) + tags = card.get("tags", []) + command = card.get("command", "") + explanation = card.get("explanation", "") + alternatives = card.get("alternatives", []) + requires = card.get("requires", []) + platform = card.get("platform", []) + + if intents: + parts.append("Intents: " + " | ".join(intents)) + if tags: + parts.append("Tags: " + ", ".join(tags)) + if command: + parts.append("Command: " + command) + if alternatives: + parts.append("Alternatives: " + " | ".join(alternatives)) + if explanation: + parts.append("Explanation: " + explanation) + if requires: + parts.append("Requires: " + ", ".join(requires)) + if platform: + parts.append("Platform: " + ", ".join(platform)) + + return "\n".join(parts) + + +def serialize_embedding(vec: np.ndarray) -> bytes: + return vec.astype(np.float32).tobytes() + + +def upsert_card(conn: sqlite3.Connection, card: dict[str, Any], search_text: str) -> None: + conn.execute(""" + INSERT INTO cards ( + id, command, explanation, intent_json, alternatives_json, requires_json, + packages_json, tags_json, platform_json, shell_json, safety, search_text, updated_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + ON CONFLICT(id) DO UPDATE SET + command=excluded.command, + explanation=excluded.explanation, + intent_json=excluded.intent_json, + alternatives_json=excluded.alternatives_json, + requires_json=excluded.requires_json, + packages_json=excluded.packages_json, + tags_json=excluded.tags_json, + platform_json=excluded.platform_json, + shell_json=excluded.shell_json, + safety=excluded.safety, + search_text=excluded.search_text, + updated_at=CURRENT_TIMESTAMP + """, ( + card["id"], + card["command"], + card["explanation"], + json.dumps(card["intent"], ensure_ascii=False), + json.dumps(card["alternatives"], ensure_ascii=False), + json.dumps(card["requires"], ensure_ascii=False), + json.dumps(card["packages"], ensure_ascii=False), + json.dumps(card["tags"], ensure_ascii=False), + json.dumps(card["platform"], ensure_ascii=False), + json.dumps(card["shell"], ensure_ascii=False), + card["safety"], + search_text, + )) + + +def upsert_embedding( + conn: sqlite3.Connection, + card_id: str, + model_name: str, + vec: np.ndarray, +) -> None: + conn.execute(""" + INSERT INTO card_embeddings ( + card_id, model_name, embedding_blob, embedding_dim + ) + VALUES (?, ?, ?, ?) + ON CONFLICT(card_id) DO UPDATE SET + model_name=excluded.model_name, + embedding_blob=excluded.embedding_blob, + embedding_dim=excluded.embedding_dim + """, ( + card_id, + model_name, + serialize_embedding(vec), + int(vec.shape[0]), + )) + + +def main() -> None: + if not DB_PATH.exists(): + raise FileNotFoundError( + f"Database not found at {DB_PATH}. Run scripts/init_db.py first." + ) + if not CARDS_PATH.exists(): + raise FileNotFoundError(f"Cards file not found at {CARDS_PATH}") + + cards = load_cards(CARDS_PATH) + + model = SentenceTransformer( + MODEL_NAME, + cache_folder=str(LOCAL_CACHE_DIR.resolve()), + local_files_only=True, + ) + + search_texts = [build_search_text(card) for card in cards] + embeddings = model.encode( + search_texts, + normalize_embeddings=True, + convert_to_numpy=True, + show_progress_bar=True, + ) + + conn = sqlite3.connect(DB_PATH) + try: + conn.execute("PRAGMA foreign_keys=ON;") + for card, vec, search_text in zip(cards, embeddings, search_texts): + upsert_card(conn, card, search_text) + upsert_embedding(conn, card["id"], MODEL_NAME, vec) + conn.commit() + print(f"Indexed {len(cards)} cards into {DB_PATH}") + finally: + conn.close() + + +if __name__ == "__main__": + main() |