import json import re import secrets import string import threading from datetime import UTC, datetime from pathlib import Path from typing import List, Optional, Set from fastapi.concurrency import run_in_threadpool from .config import get_settings from .models import HookCreate, HookRead HOOK_ID_PATTERN = re.compile(r"^[A-Za-z0-9_-]{3,64}$") class HookStore: def __init__(self, storage_path: Path, hook_id_length: int) -> None: self.storage_path = storage_path self.hook_id_length = hook_id_length self._lock = threading.Lock() self._initialize() def _deserialize_hook(self, item: dict) -> HookRead: last_triggered = item.get("last_triggered_at") return HookRead( hook_id=item["hook_id"], chat_id=item["chat_id"], message=item["message"], created_at=datetime.fromisoformat(item["created_at"]), last_triggered_at=datetime.fromisoformat(last_triggered) if last_triggered else None, ) def _initialize(self) -> None: if not self.storage_path.exists(): self.storage_path.parent.mkdir(parents=True, exist_ok=True) self.storage_path.write_text("[]\n", encoding="utf-8") else: try: self._load_raw() except (json.JSONDecodeError, UnicodeDecodeError): # Corrupted or non-UTF file; back it up and start fresh backup = self.storage_path.with_suffix(self.storage_path.suffix + ".bak") self.storage_path.replace(backup) self.storage_path.write_text("[]\n", encoding="utf-8") def _load_raw(self) -> List[dict]: text = self.storage_path.read_text(encoding="utf-8") data = json.loads(text or "[]") if not isinstance(data, list): raise json.JSONDecodeError("Hook store must contain a list", text, 0) return data def _save_raw(self, data: List[dict]) -> None: payload = json.dumps(data, indent=2, ensure_ascii=False) tmp_path = self.storage_path.with_suffix(self.storage_path.suffix + ".tmp") tmp_path.write_text(payload + "\n", encoding="utf-8") tmp_path.replace(self.storage_path) def _generate_hook_id(self, existing_ids: Set[str]) -> str: alphabet = string.ascii_lowercase + string.digits while True: candidate = "".join(secrets.choice(alphabet) for _ in range(self.hook_id_length)) if candidate not in existing_ids: return candidate def list_hooks(self) -> List[HookRead]: raw_hooks = self._load_raw() hooks = [self._deserialize_hook(item) for item in raw_hooks] hooks.sort(key=lambda h: h.created_at, reverse=True) return hooks def create_hook(self, payload: HookCreate) -> HookRead: created_at = datetime.now(UTC).replace(microsecond=0).isoformat() with self._lock: raw_hooks = self._load_raw() existing_ids = {item["hook_id"] for item in raw_hooks} hook_id = self._generate_hook_id(existing_ids) raw_hooks.append( { "hook_id": hook_id, "chat_id": payload.chat_id, "message": payload.message, "created_at": created_at, "last_triggered_at": None, } ) self._save_raw(raw_hooks) return HookRead( hook_id=hook_id, chat_id=payload.chat_id, message=payload.message, created_at=datetime.fromisoformat(created_at), last_triggered_at=None, ) def get_hook(self, hook_id: str) -> Optional[HookRead]: raw_hooks = self._load_raw() for item in raw_hooks: if item.get("hook_id") == hook_id: return self._deserialize_hook(item) return None def delete_hook(self, hook_id: str) -> bool: with self._lock: raw_hooks = self._load_raw() new_hooks = [item for item in raw_hooks if item.get("hook_id") != hook_id] if len(new_hooks) == len(raw_hooks): return False self._save_raw(new_hooks) return True def update_hook( self, current_id: str, *, new_hook_id: Optional[str] = None, chat_id: Optional[str] = None, message: Optional[str] = None, ) -> HookRead: if new_hook_id is None and chat_id is None and message is None: raise ValueError("No updates provided") normalized_id = new_hook_id.strip() if new_hook_id is not None else None normalized_chat = chat_id.strip() if chat_id is not None else None normalized_message = message.strip() if message is not None else None if normalized_id is not None: if not normalized_id: raise ValueError("Hook ID cannot be empty") if not HOOK_ID_PATTERN.fullmatch(normalized_id): raise ValueError("Hook ID must be 3-64 characters of letters, numbers, '_' or '-' only") if normalized_chat is not None and not normalized_chat: raise ValueError("Chat ID cannot be empty") if normalized_message is not None and not normalized_message: raise ValueError("Message cannot be empty") with self._lock: raw_hooks = self._load_raw() exists = next((item for item in raw_hooks if item.get("hook_id") == current_id), None) if not exists: raise KeyError("Hook not found") if normalized_id is not None and normalized_id != current_id: if any(item.get("hook_id") == normalized_id for item in raw_hooks): raise ValueError("Hook ID already exists") exists["hook_id"] = normalized_id if normalized_chat is not None: exists["chat_id"] = normalized_chat if normalized_message is not None: exists["message"] = normalized_message self._save_raw(raw_hooks) return self._deserialize_hook(exists) def mark_hook_triggered(self, hook_id: str) -> HookRead: timestamp = datetime.now(UTC).replace(microsecond=0).isoformat() with self._lock: raw_hooks = self._load_raw() exists = next((item for item in raw_hooks if item.get("hook_id") == hook_id), None) if not exists: raise KeyError("Hook not found") exists["last_triggered_at"] = timestamp self._save_raw(raw_hooks) return self._deserialize_hook(exists) settings = get_settings() store = HookStore(settings.database_path, settings.hook_id_length) async def list_hooks_async() -> List[HookRead]: return await run_in_threadpool(store.list_hooks) async def create_hook_async(payload: HookCreate) -> HookRead: return await run_in_threadpool(store.create_hook, payload) async def get_hook_async(hook_id: str) -> Optional[HookRead]: return await run_in_threadpool(store.get_hook, hook_id) async def delete_hook_async(hook_id: str) -> bool: return await run_in_threadpool(store.delete_hook, hook_id) async def update_hook_async( current_hook_id: str, *, new_hook_id: Optional[str] = None, chat_id: Optional[str] = None, message: Optional[str] = None, ) -> HookRead: return await run_in_threadpool( store.update_hook, current_hook_id, new_hook_id=new_hook_id, chat_id=chat_id, message=message, ) async def record_hook_trigger_async(hook_id: str) -> HookRead: return await run_in_threadpool(store.mark_hook_triggered, hook_id)