diff --git a/ergpt/kb/async_client.py b/ergpt/kb/async_client.py new file mode 100644 index 0000000..1a47f21 --- /dev/null +++ b/ergpt/kb/async_client.py @@ -0,0 +1,204 @@ +from typing import BinaryIO + +import httpx + +from .exceptions import ( + AuthenticationError, + ErGPTError, + NotFoundError, + PermissionDeniedError, + RateLimitError, + ServerError, + ValidationError, +) +from .models import ( + ChunkPagination, + CreateKnowledgeBaseRequest, + DocumentPagePagination, + DocumentWithContentUrl, + HTTPValidationError, + KnowledgeBase, + KnowledgeBaseAggregate, + PagePagination, + SearchRequest, + SearchResponse, + UpdateKnowledgeBaseRequest, +) + +DEFAULT_BASE_URL = "https://api.er-gpt.ru" +DEFAULT_TIMEOUT = 30.0 + + +def _handle_response_error(response: httpx.Response) -> None: + if response.status_code == 401: + raise AuthenticationError("Ошибка аутентификации. Проверьте API токен.") + elif response.status_code == 403: + raise PermissionDeniedError("Доступ запрещен.") + elif response.status_code == 404: + raise NotFoundError(f"Ресурс не найден: {response.url}") + elif response.status_code == 422: + try: + error_data = response.json() + error = HTTPValidationError.model_validate(error_data) + raise ValidationError( + "Ошибка валидации", details=[e.model_dump() for e in error.detail] + ) + except Exception: + raise ValidationError("Ошибка валидации", details=[{"msg": response.text}]) from None + elif response.status_code == 429: + raise RateLimitError("Превышен лимит запросов. Попробуйте позже.") + elif response.status_code >= 500: + raise ServerError(f"Ошибка сервера: {response.status_code}") + else: + raise ErGPTError(f"HTTP {response.status_code}: {response.text}") + + +class AsyncKnowledgeBaseClient: + def __init__( + self, + api_token: str, + base_url: str = DEFAULT_BASE_URL, + timeout: float = DEFAULT_TIMEOUT, + ): + self.api_token = api_token + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self._client: httpx.AsyncClient | None = None + + def _get_client(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient( + base_url=self.base_url, + timeout=self.timeout, + headers={ + "Authorization": f"Bearer {self.api_token}", + "Accept": "application/json", + }, + ) + return self._client + + async def close(self) -> None: + if self._client is not None and not self._client.is_closed: + await self._client.aclose() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + return False + + async def _request(self, method: str, path: str, **kwargs) -> httpx.Response: + client = self._get_client() + response = await client.request(method, path, **kwargs) + if response.status_code >= 400: + _handle_response_error(response) + return response + + async def create_knowledge_base( + self, name: str, description: str, chunk_size: int, chunk_overlap: int + ) -> KnowledgeBase: + request = CreateKnowledgeBaseRequest( + name=name, + description=description, + config={"chunk_size": chunk_size, "chunk_overlap": chunk_overlap}, + ) + response = await self._request("POST", "/api/knowledge", json=request.model_dump()) + return KnowledgeBase.model_validate(response.json()) + + async def list_knowledge_bases( + self, search: str | None = None, current: int = 0, page_size: int = 10 + ) -> PagePagination: + params: dict = {"current": current, "page_size": page_size} + if search: + params["search"] = search + response = await self._request("GET", "/api/knowledge", params=params) + data = response.json() + return PagePagination( + result=[KnowledgeBaseAggregate.model_validate(item) for item in data["result"]], + total=data["total"], + ) + + async def get_knowledge_base(self, kb_id: str) -> KnowledgeBaseAggregate: + response = await self._request("GET", f"/api/knowledge/{kb_id}") + return KnowledgeBaseAggregate.model_validate(response.json()) + + async def update_knowledge_base( + self, kb_id: str, name: str, description: str + ) -> KnowledgeBaseAggregate: + request = UpdateKnowledgeBaseRequest(name=name, description=description) + response = await self._request( + "PATCH", f"/api/knowledge/{kb_id}", json=request.model_dump() + ) + return KnowledgeBaseAggregate.model_validate(response.json()) + + async def delete_knowledge_base(self, kb_id: str) -> None: + await self._request("DELETE", f"/api/knowledge/{kb_id}") + + async def upload_document( + self, kb_id: str, file: BinaryIO, filename: str | None = None + ) -> DocumentWithContentUrl: + files = {"file": (filename or file.name, file)} + response = await self._request("POST", f"/api/knowledge/upload/{kb_id}", files=files) + return DocumentWithContentUrl.model_validate(response.json()) + + async def list_documents( + self, + kb_id: str, + statuses: list[str] | None = None, + current: int = 0, + page_size: int = 10, + ) -> DocumentPagePagination: + params: dict = {"current": current, "page_size": page_size} + if statuses: + params["statuses"] = statuses + response = await self._request("GET", f"/api/knowledge/{kb_id}/documents", params=params) + data = response.json() + return DocumentPagePagination( + result=[DocumentWithContentUrl.model_validate(item) for item in data["result"]], + total=data["total"], + ) + + async def delete_document(self, document_id: str) -> None: + await self._request("DELETE", f"/api/knowledge/document/{document_id}") + + async def delete_documents_bulk(self, document_ids: list[str]) -> None: + params = [("document_ids", doc_id) for doc_id in document_ids] + await self._request("DELETE", "/api/knowledge/document/bulk", params=params) + + async def retry_document(self, document_id: str) -> None: + await self._request("POST", f"/api/knowledge/{document_id}/retry") + + async def get_document_chunks( + self, + document_id: str, + search: str | None = None, + cursor: str | None = None, + limit: int = 10, + ) -> ChunkPagination: + params: dict = {"limit": limit} + if search: + params["search"] = search + if cursor: + params["cursor"] = cursor + response = await self._request( + "GET", f"/api/knowledge/documents/{document_id}/chunks", params=params + ) + return ChunkPagination.model_validate(response.json()) + + async def search( + self, + kb_id: str, + query: str, + limit: int | None = None, + score_threshold: float = 0.2, + ) -> SearchResponse: + request = SearchRequest( + query=query, kb_id=kb_id, limit=limit, score_threshold=score_threshold + ) + response = await self._request( + "POST", + f"/api/knowledge/{kb_id}/search", + json=request.model_dump(exclude_none=True), + ) + return SearchResponse.model_validate(response.json()) diff --git a/ergpt/kb/client.py b/ergpt/kb/client.py new file mode 100644 index 0000000..8293d7f --- /dev/null +++ b/ergpt/kb/client.py @@ -0,0 +1,203 @@ +from typing import BinaryIO + +import httpx + +from .exceptions import ( + AuthenticationError, + ErGPTError, + NotFoundError, + PermissionDeniedError, + RateLimitError, + ServerError, + ValidationError, +) +from .models import ( + CreateKnowledgeBaseRequest, + DocumentPagePagination, + DocumentWithContentUrl, + HTTPValidationError, + KnowledgeBase, + KnowledgeBaseAggregate, + PagePagination, + SearchRequest, + SearchResponse, + UpdateKnowledgeBaseRequest, +) + +DEFAULT_BASE_URL = "https://api.er-gpt.ru" +DEFAULT_TIMEOUT = 30.0 + + +def _handle_response_error(response: httpx.Response) -> None: + if response.status_code == 401: + raise AuthenticationError("Ошибка аутентификации. Проверьте API токен.") + elif response.status_code == 403: + raise PermissionDeniedError("Доступ запрещен.") + elif response.status_code == 404: + raise NotFoundError(f"Ресурс не найден: {response.url}") + elif response.status_code == 422: + try: + error_data = response.json() + error = HTTPValidationError.model_validate(error_data) + raise ValidationError( + "Ошибка валидации", details=[e.model_dump() for e in error.detail] + ) + except Exception: + raise ValidationError("Ошибка валидации", details=[{"msg": response.text}]) from None + elif response.status_code == 429: + raise RateLimitError("Превышен лимит запросов. Попробуйте позже.") + elif response.status_code >= 500: + raise ServerError(f"Ошибка сервера: {response.status_code}") + else: + raise ErGPTError(f"HTTP {response.status_code}: {response.text}") + + +class KnowledgeBaseClient: + def __init__( + self, + api_token: str, + base_url: str = DEFAULT_BASE_URL, + timeout: float = DEFAULT_TIMEOUT, + ): + self.api_token = api_token + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self._client: httpx.Client | None = None + + def _get_client(self) -> httpx.Client: + if self._client is None or self._client.is_closed: + self._client = httpx.Client( + base_url=self.base_url, + timeout=self.timeout, + headers={ + "Authorization": f"Bearer {self.api_token}", + "Accept": "application/json", + }, + ) + return self._client + + def close(self) -> None: + if self._client is not None and not self._client.is_closed: + self._client.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def _request(self, method: str, path: str, **kwargs) -> httpx.Response: + client = self._get_client() + response = client.request(method, path, **kwargs) + if response.status_code >= 400: + _handle_response_error(response) + return response + + def create_knowledge_base( + self, name: str, description: str, chunk_size: int, chunk_overlap: int + ) -> KnowledgeBase: + request = CreateKnowledgeBaseRequest( + name=name, + description=description, + config={"chunk_size": chunk_size, "chunk_overlap": chunk_overlap}, + ) + response = self._request("POST", "/api/knowledge", json=request.model_dump()) + return KnowledgeBase.model_validate(response.json()) + + def list_knowledge_bases( + self, search: str | None = None, current: int = 0, page_size: int = 10 + ) -> PagePagination: + params: dict = {"current": current, "page_size": page_size} + if search: + params["search"] = search + response = self._request("GET", "/api/knowledge", params=params) + data = response.json() + return PagePagination( + result=[KnowledgeBaseAggregate.model_validate(item) for item in data["result"]], + total=data["total"], + ) + + def get_knowledge_base(self, kb_id: str) -> KnowledgeBaseAggregate: + response = self._request("GET", f"/api/knowledge/{kb_id}") + return KnowledgeBaseAggregate.model_validate(response.json()) + + def update_knowledge_base( + self, kb_id: str, name: str, description: str + ) -> KnowledgeBaseAggregate: + request = UpdateKnowledgeBaseRequest(name=name, description=description) + response = self._request("PATCH", f"/api/knowledge/{kb_id}", json=request.model_dump()) + return KnowledgeBaseAggregate.model_validate(response.json()) + + def delete_knowledge_base(self, kb_id: str) -> None: + self._request("DELETE", f"/api/knowledge/{kb_id}") + + def upload_document( + self, kb_id: str, file: BinaryIO, filename: str | None = None + ) -> DocumentWithContentUrl: + files = {"file": (filename or file.name, file)} + response = self._request("POST", f"/api/knowledge/upload/{kb_id}", files=files) + return DocumentWithContentUrl.model_validate(response.json()) + + def list_documents( + self, + kb_id: str, + statuses: list[str] | None = None, + current: int = 0, + page_size: int = 10, + ) -> DocumentPagePagination: + params: dict = {"current": current, "page_size": page_size} + if statuses: + params["statuses"] = statuses + response = self._request("GET", f"/api/knowledge/{kb_id}/documents", params=params) + data = response.json() + return DocumentPagePagination( + result=[DocumentWithContentUrl.model_validate(item) for item in data["result"]], + total=data["total"], + ) + + def delete_document(self, document_id: str) -> None: + self._request("DELETE", f"/api/knowledge/document/{document_id}") + + def delete_documents_bulk(self, document_ids: list[str]) -> None: + params = [("document_ids", doc_id) for doc_id in document_ids] + self._request("DELETE", "/api/knowledge/document/bulk", params=params) + + def retry_document(self, document_id: str) -> None: + self._request("POST", f"/api/knowledge/{document_id}/retry") + + def get_document_chunks( + self, + document_id: str, + search: str | None = None, + cursor: str | None = None, + limit: int = 10, + ): + from .models import ChunkPagination + + params: dict = {"limit": limit} + if search: + params["search"] = search + if cursor: + params["cursor"] = cursor + response = self._request( + "GET", f"/api/knowledge/documents/{document_id}/chunks", params=params + ) + return ChunkPagination.model_validate(response.json()) + + def search( + self, + kb_id: str, + query: str, + limit: int | None = None, + score_threshold: float = 0.2, + ) -> SearchResponse: + request = SearchRequest( + query=query, kb_id=kb_id, limit=limit, score_threshold=score_threshold + ) + response = self._request( + "POST", + f"/api/knowledge/{kb_id}/search", + json=request.model_dump(exclude_none=True), + ) + return SearchResponse.model_validate(response.json())