sdk/ergpt/kb/async_client.py
mrmamongo 5a4e9a6f12 feat: add sync and async httpx clients
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-20 12:08:26 +03:00

204 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import BinaryIO
import httpx
from .exceptions import (
AuthenticationError,
ErGPTError,
NotFoundError,
PermissionDeniedError,
RateLimitError,
ServerError,
ValidationError,
)
from .models import (
ChunkPagination,
CreateKnowledgeBaseRequest,
DocumentPagePagination,
DocumentWithContentUrl,
HTTPValidationError,
KnowledgeBase,
KnowledgeBaseAggregate,
PagePagination,
SearchRequest,
SearchResponse,
UpdateKnowledgeBaseRequest,
)
DEFAULT_BASE_URL = "https://api.er-gpt.ru"
DEFAULT_TIMEOUT = 30.0
def _handle_response_error(response: httpx.Response) -> None:
if response.status_code == 401:
raise AuthenticationError("Ошибка аутентификации. Проверьте API токен.")
elif response.status_code == 403:
raise PermissionDeniedError("Доступ запрещен.")
elif response.status_code == 404:
raise NotFoundError(f"Ресурс не найден: {response.url}")
elif response.status_code == 422:
try:
error_data = response.json()
error = HTTPValidationError.model_validate(error_data)
raise ValidationError(
"Ошибка валидации", details=[e.model_dump() for e in error.detail]
)
except Exception:
raise ValidationError("Ошибка валидации", details=[{"msg": response.text}]) from None
elif response.status_code == 429:
raise RateLimitError("Превышен лимит запросов. Попробуйте позже.")
elif response.status_code >= 500:
raise ServerError(f"Ошибка сервера: {response.status_code}")
else:
raise ErGPTError(f"HTTP {response.status_code}: {response.text}")
class AsyncKnowledgeBaseClient:
def __init__(
self,
api_token: str,
base_url: str = DEFAULT_BASE_URL,
timeout: float = DEFAULT_TIMEOUT,
):
self.api_token = api_token
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._client: httpx.AsyncClient | None = None
def _get_client(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self.base_url,
timeout=self.timeout,
headers={
"Authorization": f"Bearer {self.api_token}",
"Accept": "application/json",
},
)
return self._client
async def close(self) -> None:
if self._client is not None and not self._client.is_closed:
await self._client.aclose()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
return False
async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
client = self._get_client()
response = await client.request(method, path, **kwargs)
if response.status_code >= 400:
_handle_response_error(response)
return response
async def create_knowledge_base(
self, name: str, description: str, chunk_size: int, chunk_overlap: int
) -> KnowledgeBase:
request = CreateKnowledgeBaseRequest(
name=name,
description=description,
config={"chunk_size": chunk_size, "chunk_overlap": chunk_overlap},
)
response = await self._request("POST", "/api/knowledge", json=request.model_dump())
return KnowledgeBase.model_validate(response.json())
async def list_knowledge_bases(
self, search: str | None = None, current: int = 0, page_size: int = 10
) -> PagePagination:
params: dict = {"current": current, "page_size": page_size}
if search:
params["search"] = search
response = await self._request("GET", "/api/knowledge", params=params)
data = response.json()
return PagePagination(
result=[KnowledgeBaseAggregate.model_validate(item) for item in data["result"]],
total=data["total"],
)
async def get_knowledge_base(self, kb_id: str) -> KnowledgeBaseAggregate:
response = await self._request("GET", f"/api/knowledge/{kb_id}")
return KnowledgeBaseAggregate.model_validate(response.json())
async def update_knowledge_base(
self, kb_id: str, name: str, description: str
) -> KnowledgeBaseAggregate:
request = UpdateKnowledgeBaseRequest(name=name, description=description)
response = await self._request(
"PATCH", f"/api/knowledge/{kb_id}", json=request.model_dump()
)
return KnowledgeBaseAggregate.model_validate(response.json())
async def delete_knowledge_base(self, kb_id: str) -> None:
await self._request("DELETE", f"/api/knowledge/{kb_id}")
async def upload_document(
self, kb_id: str, file: BinaryIO, filename: str | None = None
) -> DocumentWithContentUrl:
files = {"file": (filename or file.name, file)}
response = await self._request("POST", f"/api/knowledge/upload/{kb_id}", files=files)
return DocumentWithContentUrl.model_validate(response.json())
async def list_documents(
self,
kb_id: str,
statuses: list[str] | None = None,
current: int = 0,
page_size: int = 10,
) -> DocumentPagePagination:
params: dict = {"current": current, "page_size": page_size}
if statuses:
params["statuses"] = statuses
response = await self._request("GET", f"/api/knowledge/{kb_id}/documents", params=params)
data = response.json()
return DocumentPagePagination(
result=[DocumentWithContentUrl.model_validate(item) for item in data["result"]],
total=data["total"],
)
async def delete_document(self, document_id: str) -> None:
await self._request("DELETE", f"/api/knowledge/document/{document_id}")
async def delete_documents_bulk(self, document_ids: list[str]) -> None:
params = [("document_ids", doc_id) for doc_id in document_ids]
await self._request("DELETE", "/api/knowledge/document/bulk", params=params)
async def retry_document(self, document_id: str) -> None:
await self._request("POST", f"/api/knowledge/{document_id}/retry")
async def get_document_chunks(
self,
document_id: str,
search: str | None = None,
cursor: str | None = None,
limit: int = 10,
) -> ChunkPagination:
params: dict = {"limit": limit}
if search:
params["search"] = search
if cursor:
params["cursor"] = cursor
response = await self._request(
"GET", f"/api/knowledge/documents/{document_id}/chunks", params=params
)
return ChunkPagination.model_validate(response.json())
async def search(
self,
kb_id: str,
query: str,
limit: int | None = None,
score_threshold: float = 0.2,
) -> SearchResponse:
request = SearchRequest(
query=query, kb_id=kb_id, limit=limit, score_threshold=score_threshold
)
response = await self._request(
"POST",
f"/api/knowledge/{kb_id}/search",
json=request.model_dump(exclude_none=True),
)
return SearchResponse.model_validate(response.json())