Files
TgMessageHook/app/storage.py
Andre Beging 1204f5dcde feat: Enhance hook management and session handling
- Update hook model to include last_triggered_at field.
- Modify API endpoints to support updating hooks with new fields.
- Implement session management UI improvements with toggle functionality.
- Add new JavaScript functions for better session detail visibility.
- Refactor hook storage logic to handle last triggered timestamps.
- Introduce new favicon and logo for branding.
- Update styles for improved layout and user experience.
- Enhance tests to cover new functionality and ensure reliability.
2025-10-07 13:39:07 +02:00

210 lines
7.6 KiB
Python

import json
import re
import secrets
import string
import threading
from datetime import UTC, datetime
from pathlib import Path
from typing import List, Optional, Set
from fastapi.concurrency import run_in_threadpool
from .config import get_settings
from .models import HookCreate, HookRead
HOOK_ID_PATTERN = re.compile(r"^[A-Za-z0-9_-]{3,64}$")
class HookStore:
def __init__(self, storage_path: Path, hook_id_length: int) -> None:
self.storage_path = storage_path
self.hook_id_length = hook_id_length
self._lock = threading.Lock()
self._initialize()
def _deserialize_hook(self, item: dict) -> HookRead:
last_triggered = item.get("last_triggered_at")
return HookRead(
hook_id=item["hook_id"],
chat_id=item["chat_id"],
message=item["message"],
created_at=datetime.fromisoformat(item["created_at"]),
last_triggered_at=datetime.fromisoformat(last_triggered) if last_triggered else None,
)
def _initialize(self) -> None:
if not self.storage_path.exists():
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
self.storage_path.write_text("[]\n", encoding="utf-8")
else:
try:
self._load_raw()
except (json.JSONDecodeError, UnicodeDecodeError):
# Corrupted or non-UTF file; back it up and start fresh
backup = self.storage_path.with_suffix(self.storage_path.suffix + ".bak")
self.storage_path.replace(backup)
self.storage_path.write_text("[]\n", encoding="utf-8")
def _load_raw(self) -> List[dict]:
text = self.storage_path.read_text(encoding="utf-8")
data = json.loads(text or "[]")
if not isinstance(data, list):
raise json.JSONDecodeError("Hook store must contain a list", text, 0)
return data
def _save_raw(self, data: List[dict]) -> None:
payload = json.dumps(data, indent=2, ensure_ascii=False)
tmp_path = self.storage_path.with_suffix(self.storage_path.suffix + ".tmp")
tmp_path.write_text(payload + "\n", encoding="utf-8")
tmp_path.replace(self.storage_path)
def _generate_hook_id(self, existing_ids: Set[str]) -> str:
alphabet = string.ascii_lowercase + string.digits
while True:
candidate = "".join(secrets.choice(alphabet) for _ in range(self.hook_id_length))
if candidate not in existing_ids:
return candidate
def list_hooks(self) -> List[HookRead]:
raw_hooks = self._load_raw()
hooks = [self._deserialize_hook(item) for item in raw_hooks]
hooks.sort(key=lambda h: h.created_at, reverse=True)
return hooks
def create_hook(self, payload: HookCreate) -> HookRead:
created_at = datetime.now(UTC).replace(microsecond=0).isoformat()
with self._lock:
raw_hooks = self._load_raw()
existing_ids = {item["hook_id"] for item in raw_hooks}
hook_id = self._generate_hook_id(existing_ids)
raw_hooks.append(
{
"hook_id": hook_id,
"chat_id": payload.chat_id,
"message": payload.message,
"created_at": created_at,
"last_triggered_at": None,
}
)
self._save_raw(raw_hooks)
return HookRead(
hook_id=hook_id,
chat_id=payload.chat_id,
message=payload.message,
created_at=datetime.fromisoformat(created_at),
last_triggered_at=None,
)
def get_hook(self, hook_id: str) -> Optional[HookRead]:
raw_hooks = self._load_raw()
for item in raw_hooks:
if item.get("hook_id") == hook_id:
return self._deserialize_hook(item)
return None
def delete_hook(self, hook_id: str) -> bool:
with self._lock:
raw_hooks = self._load_raw()
new_hooks = [item for item in raw_hooks if item.get("hook_id") != hook_id]
if len(new_hooks) == len(raw_hooks):
return False
self._save_raw(new_hooks)
return True
def update_hook(
self,
current_id: str,
*,
new_hook_id: Optional[str] = None,
chat_id: Optional[str] = None,
message: Optional[str] = None,
) -> HookRead:
if new_hook_id is None and chat_id is None and message is None:
raise ValueError("No updates provided")
normalized_id = new_hook_id.strip() if new_hook_id is not None else None
normalized_chat = chat_id.strip() if chat_id is not None else None
normalized_message = message.strip() if message is not None else None
if normalized_id is not None:
if not normalized_id:
raise ValueError("Hook ID cannot be empty")
if not HOOK_ID_PATTERN.fullmatch(normalized_id):
raise ValueError("Hook ID must be 3-64 characters of letters, numbers, '_' or '-' only")
if normalized_chat is not None and not normalized_chat:
raise ValueError("Chat ID cannot be empty")
if normalized_message is not None and not normalized_message:
raise ValueError("Message cannot be empty")
with self._lock:
raw_hooks = self._load_raw()
exists = next((item for item in raw_hooks if item.get("hook_id") == current_id), None)
if not exists:
raise KeyError("Hook not found")
if normalized_id is not None and normalized_id != current_id:
if any(item.get("hook_id") == normalized_id for item in raw_hooks):
raise ValueError("Hook ID already exists")
exists["hook_id"] = normalized_id
if normalized_chat is not None:
exists["chat_id"] = normalized_chat
if normalized_message is not None:
exists["message"] = normalized_message
self._save_raw(raw_hooks)
return self._deserialize_hook(exists)
def mark_hook_triggered(self, hook_id: str) -> HookRead:
timestamp = datetime.now(UTC).replace(microsecond=0).isoformat()
with self._lock:
raw_hooks = self._load_raw()
exists = next((item for item in raw_hooks if item.get("hook_id") == hook_id), None)
if not exists:
raise KeyError("Hook not found")
exists["last_triggered_at"] = timestamp
self._save_raw(raw_hooks)
return self._deserialize_hook(exists)
settings = get_settings()
store = HookStore(settings.database_path, settings.hook_id_length)
async def list_hooks_async() -> List[HookRead]:
return await run_in_threadpool(store.list_hooks)
async def create_hook_async(payload: HookCreate) -> HookRead:
return await run_in_threadpool(store.create_hook, payload)
async def get_hook_async(hook_id: str) -> Optional[HookRead]:
return await run_in_threadpool(store.get_hook, hook_id)
async def delete_hook_async(hook_id: str) -> bool:
return await run_in_threadpool(store.delete_hook, hook_id)
async def update_hook_async(
current_hook_id: str,
*,
new_hook_id: Optional[str] = None,
chat_id: Optional[str] = None,
message: Optional[str] = None,
) -> HookRead:
return await run_in_threadpool(
store.update_hook,
current_hook_id,
new_hook_id=new_hook_id,
chat_id=chat_id,
message=message,
)
async def record_hook_trigger_async(hook_id: str) -> HookRead:
return await run_in_threadpool(store.mark_hook_triggered, hook_id)