Source code for tenso.client

"""
High-level HTTP client for Tenso-powered FastAPI endpoints.

Provides TensoFastAPIClient that natively streams and unpacks
TensoResponse chunks, handling all protocol details under the hood.

Requires ``httpx`` (``pip install httpx``).
"""

from typing import Any, Dict, Optional, Union

import numpy as np

from .core import dumps, loads


[docs] class TensoFastAPIClient: """ Client for communicating with FastAPI endpoints that use TensoResponse and get_tenso_data. Example:: client = TensoFastAPIClient("http://localhost:8000") result = client.predict("/infer", np.random.randn(1, 224, 224, 3).astype(np.float32)) print(result.shape) # Async usage result = await client.apredict("/infer", tensor) """ def __init__( self, base_url: str, timeout: float = 30.0, check_integrity: bool = False, headers: Optional[Dict[str, str]] = None, ): self._base_url = base_url.rstrip("/") self._timeout = timeout self._check_integrity = check_integrity self._extra_headers = headers or {} self._sync_client = None self._async_client = None def _get_sync_client(self): if self._sync_client is None: import httpx self._sync_client = httpx.Client( base_url=self._base_url, timeout=self._timeout, headers=self._extra_headers, ) return self._sync_client def _get_async_client(self): if self._async_client is None: import httpx self._async_client = httpx.AsyncClient( base_url=self._base_url, timeout=self._timeout, headers=self._extra_headers, ) return self._async_client def _make_headers(self) -> dict: return {"content-type": "application/octet-stream"}
[docs] def predict( self, endpoint: str, tensor: Union[np.ndarray, dict], compress: bool = False, ) -> Any: """ Send a tensor to a Tenso-powered endpoint and return the deserialized response. Parameters ---------- endpoint : str The API path (e.g. "/infer"). tensor : np.ndarray or dict The input tensor or bundle. compress : bool Whether to LZ4-compress the request body. Returns ------- Any The deserialized response (np.ndarray, dict, or sparse matrix). """ packet = bytes(dumps( tensor, check_integrity=self._check_integrity, compress=compress, )) client = self._get_sync_client() response = client.post( endpoint, content=packet, headers=self._make_headers(), ) response.raise_for_status() return loads(response.content)
[docs] async def apredict( self, endpoint: str, tensor: Union[np.ndarray, dict], compress: bool = False, ) -> Any: """Async version of predict().""" packet = bytes(dumps( tensor, check_integrity=self._check_integrity, compress=compress, )) client = self._get_async_client() response = await client.post( endpoint, content=packet, headers=self._make_headers(), ) response.raise_for_status() return loads(response.content)
[docs] def stream_predict( self, endpoint: str, tensor: Union[np.ndarray, dict], ) -> Any: """ Send a tensor and stream the response, reassembling chunks into the final deserialized object. """ packet = bytes(dumps(tensor, check_integrity=self._check_integrity)) client = self._get_sync_client() chunks = [] with client.stream( "POST", endpoint, content=packet, headers=self._make_headers() ) as response: response.raise_for_status() for chunk in response.iter_bytes(): chunks.append(chunk) return loads(b"".join(chunks))
[docs] def close(self): """Close sync client. For async client, use ``async with`` or ``await client.aclose()``.""" if self._sync_client is not None: self._sync_client.close() self._sync_client = None
[docs] async def aclose(self): """Close both sync and async clients.""" if self._async_client is not None: await self._async_client.aclose() self._async_client = None self.close()
def __enter__(self): return self def __exit__(self, *args): self.close() async def __aenter__(self): return self async def __aexit__(self, *args): await self.aclose()