Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
204 lines
7.4 KiB
Python
204 lines
7.4 KiB
Python
from typing import BinaryIO
|
||
|
||
import httpx
|
||
|
||
from .exceptions import (
|
||
AuthenticationError,
|
||
ErGPTError,
|
||
NotFoundError,
|
||
PermissionDeniedError,
|
||
RateLimitError,
|
||
ServerError,
|
||
ValidationError,
|
||
)
|
||
from .models import (
|
||
ChunkPagination,
|
||
CreateKnowledgeBaseRequest,
|
||
DocumentPagePagination,
|
||
DocumentWithContentUrl,
|
||
HTTPValidationError,
|
||
KnowledgeBase,
|
||
KnowledgeBaseAggregate,
|
||
PagePagination,
|
||
SearchRequest,
|
||
SearchResponse,
|
||
UpdateKnowledgeBaseRequest,
|
||
)
|
||
|
||
DEFAULT_BASE_URL = "https://api.er-gpt.ru"
|
||
DEFAULT_TIMEOUT = 30.0
|
||
|
||
|
||
def _handle_response_error(response: httpx.Response) -> None:
|
||
if response.status_code == 401:
|
||
raise AuthenticationError("Ошибка аутентификации. Проверьте API токен.")
|
||
elif response.status_code == 403:
|
||
raise PermissionDeniedError("Доступ запрещен.")
|
||
elif response.status_code == 404:
|
||
raise NotFoundError(f"Ресурс не найден: {response.url}")
|
||
elif response.status_code == 422:
|
||
try:
|
||
error_data = response.json()
|
||
error = HTTPValidationError.model_validate(error_data)
|
||
raise ValidationError(
|
||
"Ошибка валидации", details=[e.model_dump() for e in error.detail]
|
||
)
|
||
except Exception:
|
||
raise ValidationError("Ошибка валидации", details=[{"msg": response.text}]) from None
|
||
elif response.status_code == 429:
|
||
raise RateLimitError("Превышен лимит запросов. Попробуйте позже.")
|
||
elif response.status_code >= 500:
|
||
raise ServerError(f"Ошибка сервера: {response.status_code}")
|
||
else:
|
||
raise ErGPTError(f"HTTP {response.status_code}: {response.text}")
|
||
|
||
|
||
class AsyncKnowledgeBaseClient:
|
||
def __init__(
|
||
self,
|
||
api_token: str,
|
||
base_url: str = DEFAULT_BASE_URL,
|
||
timeout: float = DEFAULT_TIMEOUT,
|
||
):
|
||
self.api_token = api_token
|
||
self.base_url = base_url.rstrip("/")
|
||
self.timeout = timeout
|
||
self._client: httpx.AsyncClient | None = None
|
||
|
||
def _get_client(self) -> httpx.AsyncClient:
|
||
if self._client is None or self._client.is_closed:
|
||
self._client = httpx.AsyncClient(
|
||
base_url=self.base_url,
|
||
timeout=self.timeout,
|
||
headers={
|
||
"Authorization": f"Bearer {self.api_token}",
|
||
"Accept": "application/json",
|
||
},
|
||
)
|
||
return self._client
|
||
|
||
async def close(self) -> None:
|
||
if self._client is not None and not self._client.is_closed:
|
||
await self._client.aclose()
|
||
|
||
async def __aenter__(self):
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
await self.close()
|
||
return False
|
||
|
||
async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
|
||
client = self._get_client()
|
||
response = await client.request(method, path, **kwargs)
|
||
if response.status_code >= 400:
|
||
_handle_response_error(response)
|
||
return response
|
||
|
||
async def create_knowledge_base(
|
||
self, name: str, description: str, chunk_size: int, chunk_overlap: int
|
||
) -> KnowledgeBase:
|
||
request = CreateKnowledgeBaseRequest(
|
||
name=name,
|
||
description=description,
|
||
config={"chunk_size": chunk_size, "chunk_overlap": chunk_overlap},
|
||
)
|
||
response = await self._request("POST", "/api/knowledge", json=request.model_dump())
|
||
return KnowledgeBase.model_validate(response.json())
|
||
|
||
async def list_knowledge_bases(
|
||
self, search: str | None = None, current: int = 0, page_size: int = 10
|
||
) -> PagePagination:
|
||
params: dict = {"current": current, "page_size": page_size}
|
||
if search:
|
||
params["search"] = search
|
||
response = await self._request("GET", "/api/knowledge", params=params)
|
||
data = response.json()
|
||
return PagePagination(
|
||
result=[KnowledgeBaseAggregate.model_validate(item) for item in data["result"]],
|
||
total=data["total"],
|
||
)
|
||
|
||
async def get_knowledge_base(self, kb_id: str) -> KnowledgeBaseAggregate:
|
||
response = await self._request("GET", f"/api/knowledge/{kb_id}")
|
||
return KnowledgeBaseAggregate.model_validate(response.json())
|
||
|
||
async def update_knowledge_base(
|
||
self, kb_id: str, name: str, description: str
|
||
) -> KnowledgeBaseAggregate:
|
||
request = UpdateKnowledgeBaseRequest(name=name, description=description)
|
||
response = await self._request(
|
||
"PATCH", f"/api/knowledge/{kb_id}", json=request.model_dump()
|
||
)
|
||
return KnowledgeBaseAggregate.model_validate(response.json())
|
||
|
||
async def delete_knowledge_base(self, kb_id: str) -> None:
|
||
await self._request("DELETE", f"/api/knowledge/{kb_id}")
|
||
|
||
async def upload_document(
|
||
self, kb_id: str, file: BinaryIO, filename: str | None = None
|
||
) -> DocumentWithContentUrl:
|
||
files = {"file": (filename or file.name, file)}
|
||
response = await self._request("POST", f"/api/knowledge/upload/{kb_id}", files=files)
|
||
return DocumentWithContentUrl.model_validate(response.json())
|
||
|
||
async def list_documents(
|
||
self,
|
||
kb_id: str,
|
||
statuses: list[str] | None = None,
|
||
current: int = 0,
|
||
page_size: int = 10,
|
||
) -> DocumentPagePagination:
|
||
params: dict = {"current": current, "page_size": page_size}
|
||
if statuses:
|
||
params["statuses"] = statuses
|
||
response = await self._request("GET", f"/api/knowledge/{kb_id}/documents", params=params)
|
||
data = response.json()
|
||
return DocumentPagePagination(
|
||
result=[DocumentWithContentUrl.model_validate(item) for item in data["result"]],
|
||
total=data["total"],
|
||
)
|
||
|
||
async def delete_document(self, document_id: str) -> None:
|
||
await self._request("DELETE", f"/api/knowledge/document/{document_id}")
|
||
|
||
async def delete_documents_bulk(self, document_ids: list[str]) -> None:
|
||
params = [("document_ids", doc_id) for doc_id in document_ids]
|
||
await self._request("DELETE", "/api/knowledge/document/bulk", params=params)
|
||
|
||
async def retry_document(self, document_id: str) -> None:
|
||
await self._request("POST", f"/api/knowledge/{document_id}/retry")
|
||
|
||
async def get_document_chunks(
|
||
self,
|
||
document_id: str,
|
||
search: str | None = None,
|
||
cursor: str | None = None,
|
||
limit: int = 10,
|
||
) -> ChunkPagination:
|
||
params: dict = {"limit": limit}
|
||
if search:
|
||
params["search"] = search
|
||
if cursor:
|
||
params["cursor"] = cursor
|
||
response = await self._request(
|
||
"GET", f"/api/knowledge/documents/{document_id}/chunks", params=params
|
||
)
|
||
return ChunkPagination.model_validate(response.json())
|
||
|
||
async def search(
|
||
self,
|
||
kb_id: str,
|
||
query: str,
|
||
limit: int | None = None,
|
||
score_threshold: float = 0.2,
|
||
) -> SearchResponse:
|
||
request = SearchRequest(
|
||
query=query, kb_id=kb_id, limit=limit, score_threshold=score_threshold
|
||
)
|
||
response = await self._request(
|
||
"POST",
|
||
f"/api/knowledge/{kb_id}/search",
|
||
json=request.model_dump(exclude_none=True),
|
||
)
|
||
return SearchResponse.model_validate(response.json())
|