"""Client implementation for managing datasets in the Arize platform."""
from __future__ import annotations
import logging
import time
import uuid
from typing import TYPE_CHECKING, Any, cast
import pandas as pd
import pyarrow as pa
from arize._flight.client import ArizeFlightClient
from arize._generated.api_client import models
from arize.datasets.validation import validate_dataset_df
from arize.exceptions.base import INVALID_ARROW_CONVERSION_MSG
from arize.pre_releases import ReleaseStage, prerelease_endpoint
from arize.utils.cache import cache_resource, load_cached_resource
from arize.utils.openinference_conversion import (
convert_boolean_columns_to_str,
convert_datetime_columns_to_int,
convert_default_columns_to_json_str,
)
from arize.utils.resolve import (
_find_dataset_id,
_find_space_id,
_resolve_resource,
)
from arize.utils.size import get_payload_size_mb
if TYPE_CHECKING:
# builtins is needed to use builtins.list in type annotations because
# the class has a list() method that shadows the built-in list type
import builtins
from arize._generated.api_client.api_client import ApiClient
from arize.config import SDKConfiguration
from arize.datasets.types import (
Dataset,
DatasetExampleListResponse,
DatasetListResponse,
)
logger = logging.getLogger(__name__)
[docs]
class DatasetsClient:
"""Client for managing datasets including creation, retrieval, and example management.
This class is primarily intended for internal use within the SDK. Users are
highly encouraged to access resource-specific functionality via
:class:`arize.ArizeClient`.
The datasets client is a thin wrapper around the generated REST API client,
using the shared generated API client owned by
:class:`arize.config.SDKConfiguration`.
"""
def __init__(
self, *, sdk_config: SDKConfiguration, generated_client: ApiClient
) -> None:
"""
Args:
sdk_config: Resolved SDK configuration.
generated_client: Shared generated API client instance.
""" # noqa: D205, D212
self._sdk_config = sdk_config
# Import at runtime so it's still lazy and extras-gated by the parent
from arize._generated import api_client as gen
# Use the provided client directly
self._api = gen.DatasetsApi(generated_client)
self._spaces_api = gen.SpacesApi(generated_client)
[docs]
@prerelease_endpoint(key="datasets.list", stage=ReleaseStage.BETA)
def list(
self,
*,
name: str | None = None,
space: str | None = None,
limit: int = 100,
cursor: str | None = None,
) -> DatasetListResponse:
"""List datasets the user has access to.
Datasets are returned in descending creation order (most recently created
first). Dataset versions are not included in this response; use `get()` to
retrieve a dataset along with its versions.
Args:
name: Optional case-insensitive substring filter on the dataset name.
space: Optional space filter. If the value is a base64-encoded resource ID it is
treated as a space ID; otherwise it is used as a case-insensitive
substring filter on the space name.
limit: Maximum number of datasets to return. The server enforces an
upper bound.
cursor: Opaque pagination cursor returned from a previous response.
Returns:
A response object with the datasets and pagination information.
Raises:
ApiException: If the REST API
returns an error response (e.g. 401/403/429).
"""
resolved_space = _resolve_resource(space)
return self._api.datasets_list(
space_id=resolved_space.id,
space_name=resolved_space.name,
name=name,
limit=limit,
cursor=cursor,
)
[docs]
@prerelease_endpoint(key="datasets.create", stage=ReleaseStage.BETA)
def create(
self,
*,
name: str,
space: str,
examples: builtins.list[dict[str, object]] | pd.DataFrame,
force_http: bool = False,
) -> Dataset:
"""Create a dataset with JSON examples.
Empty datasets are not allowed.
Payload notes (server-enforced):
- `name` must be unique within the given space.
- Each example may contain arbitrary user-defined fields.
- Do not include system-managed fields on create: `id`, `created_at`,
`updated_at` (requests containing these fields will be rejected).
- Each example must contain at least one property (i.e. `{}` is invalid).
Transport selection:
- If the payload is below the configured REST payload threshold (or
`force_http=True`), this method uploads via REST.
- Otherwise, it attempts a more efficient upload path via gRPC + Flight.
Args:
name: Dataset name (must be unique within the target space).
space: Space ID or name to create the dataset in.
examples: Dataset examples either as:
- a list of JSON-like dicts, or
- a :class:`pandas.DataFrame` (will be converted to records for REST).
force_http: If True, force REST upload even if the payload exceeds the
configured REST payload threshold.
Returns:
The created dataset object as returned by the API.
Raises:
TypeError: If `examples` is not a list of dicts or a :class:`pandas.DataFrame`.
RuntimeError: If the Flight upload path is selected and the Flight request
fails.
ApiException: If the REST API
returns an error response (e.g. 400/401/403/409/429).
"""
space_id = _find_space_id(self._spaces_api, space)
if len(examples) == 0:
raise ValueError("Cannot create an empty dataset")
below_threshold = (
get_payload_size_mb(examples)
<= self._sdk_config.max_http_payload_size_mb
)
if below_threshold or force_http:
from arize._generated import api_client as gen
data = (
examples.to_dict(orient="records")
if isinstance(examples, pd.DataFrame)
else examples
)
body = gen.DatasetsCreateRequest(
name=name,
space_id=space_id,
# Cast: pandas to_dict returns dict[Hashable, Any] but API requires dict[str, Any]
examples=cast("list[dict[str, Any]]", data),
)
return self._api.datasets_create(datasets_create_request=body)
# If we have too many examples, try to convert to a dataframe
# and log via gRPC + flight
logger.info(
f"Uploading {len(examples)} examples via REST may be slow. "
"Trying to convert to DataFrame for more efficient upload via "
"gRPC + Flight."
)
if not isinstance(examples, pd.DataFrame):
examples = pd.DataFrame(examples)
return self._create_dataset_via_flight(
name=name,
space_id=space_id,
examples=examples,
)
[docs]
@prerelease_endpoint(key="datasets.get", stage=ReleaseStage.BETA)
def get(
self,
*,
dataset: str,
space: str | None = None,
) -> Dataset:
"""Get a dataset by ID or name.
The returned dataset includes its dataset versions (sorted by creation time,
most recent first). Dataset examples are not included; use `list_examples()`
to retrieve examples.
Args:
dataset: Dataset ID or name.
space: Space ID or name. Required when *dataset* is a name.
Returns:
The dataset object.
Raises:
ApiException: If the REST API
returns an error response (e.g. 401/403/404/429).
"""
dataset_id = _find_dataset_id(
api=self._api,
dataset=dataset,
space=space,
)
return self._api.datasets_get(dataset_id=dataset_id)
[docs]
@prerelease_endpoint(key="datasets.delete", stage=ReleaseStage.BETA)
def delete(
self,
*,
dataset: str,
space: str | None = None,
) -> None:
"""Delete a dataset by ID or name.
This operation is irreversible.
Args:
dataset: Dataset ID or name.
space: Space ID or name. Required when *dataset* is a name.
Returns:
This method returns None on success (common empty 204 response).
Raises:
ApiException: If the REST API
returns an error response (e.g. 401/403/404/429).
"""
dataset_id = _find_dataset_id(
api=self._api,
dataset=dataset,
space=space,
)
return self._api.datasets_delete(dataset_id=dataset_id)
[docs]
@prerelease_endpoint(key="datasets.update", stage=ReleaseStage.ALPHA)
def update(
self,
*,
dataset: str,
space: str | None = None,
name: str,
) -> Dataset:
"""Rename a dataset.
Args:
dataset: Dataset ID or name.
space: Space ID or name. Required when *dataset* is a name.
name: New name for the dataset. Must be unique within the space.
Returns:
The updated dataset object.
Raises:
ApiException: If the REST API returns an error response
(e.g. 400/401/403/404/409/429).
"""
from arize._generated import api_client as gen
dataset_id = _find_dataset_id(
api=self._api, dataset=dataset, space=space
)
body = gen.DatasetsUpdateRequest(name=name)
return self._api.datasets_update(
dataset_id=dataset_id, datasets_update_request=body
)
[docs]
@prerelease_endpoint(key="datasets.list_examples", stage=ReleaseStage.BETA)
def list_examples(
self,
*,
dataset: str,
space: str | None = None,
dataset_version_id: str | None = None,
limit: int = 100,
all: bool = False,
) -> DatasetExampleListResponse:
"""List examples for a dataset (optionally for a specific version).
If `dataset_version_id` is not provided (empty string), the server selects
the latest dataset version.
Pagination notes:
- The response includes `pagination` for forward compatibility.
- Cursor pagination may not be fully implemented by the server yet.
- If `all=True`, this method retrieves all examples via the Flight path,
and returns them in a single response with `has_more=False`.
Args:
dataset: Dataset ID or name.
space: Space ID or name. Required when *dataset* is a name.
dataset_version_id: Dataset version ID. If empty, the latest version is
selected.
limit: Maximum number of examples to return when `all=False`. The server
enforces an upper bound.
all: If True, fetch all examples (ignores `limit`) via Flight and return a
single response.
Returns:
A response object containing `examples` and `pagination` metadata.
Raises:
RuntimeError: If the Flight request fails or returns no response when
`all=True`.
ApiException: If the REST API
returns an error response when `all=False` (e.g. 401/403/404/429).
"""
dataset_id = _find_dataset_id(
api=self._api,
dataset=dataset,
space=space,
)
if not all:
return self._api.datasets_examples_list(
dataset_id=dataset_id,
dataset_version_id=dataset_version_id,
limit=limit,
)
dataset_obj = self.get(dataset=dataset_id)
dataset_updated_at = getattr(dataset_obj, "updated_at", None)
# TODO(Kiko): Space ID should not be needed,
# should work on server tech debt to remove this
space_id = dataset_obj.space_id
dataset_df = None
# try to load dataset from cache
if self._sdk_config.enable_caching:
dataset_df = load_cached_resource(
cache_dir=self._sdk_config.cache_dir,
resource="dataset",
resource_id=dataset_id,
resource_updated_at=dataset_updated_at,
)
if dataset_df is not None:
examples = [
obj
for example in dataset_df.to_dict(orient="records")
if (
obj := models.DatasetExample.from_dict(
cast("dict[str, Any]", example)
)
)
is not None
]
return models.DatasetExampleListResponse(
examples=examples,
pagination=models.PaginationMetadata(
has_more=False, # Note that all=True
),
)
with ArizeFlightClient(
api_key=self._sdk_config.api_key,
host=self._sdk_config.flight_host,
port=self._sdk_config.flight_port,
scheme=self._sdk_config.flight_scheme,
request_verify=self._sdk_config.request_verify,
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
) as flight_client:
try:
dataset_df = flight_client.get_dataset_examples(
space_id=space_id,
dataset_id=dataset_id,
dataset_version_id=dataset_version_id,
)
except Exception as e:
msg = f"Error during request: {e!s}"
logger.exception(msg)
raise RuntimeError(msg) from e
if dataset_df is None:
# This should not happen with proper Flight client implementation,
# but we handle it defensively
msg = "No response received from flight server during request"
logger.error(msg)
raise RuntimeError(msg)
# cache dataset for future use
cache_resource(
cache_dir=self._sdk_config.cache_dir,
resource="dataset",
resource_id=dataset_id,
resource_updated_at=dataset_updated_at,
resource_data=dataset_df,
)
examples = [
obj
for example in dataset_df.to_dict(orient="records")
if (
obj := models.DatasetExample.from_dict(
cast("dict[str, Any]", example)
)
)
is not None
]
return models.DatasetExampleListResponse(
examples=examples,
pagination=models.PaginationMetadata(
has_more=False, # Note that all=True
),
)
# TODO(Kiko): Needs flightserver support
[docs]
@prerelease_endpoint(
key="datasets.append_examples", stage=ReleaseStage.BETA
)
def append_examples(
self,
*,
dataset: str,
space: str | None = None,
dataset_version_id: str = "",
examples: builtins.list[dict[str, object]] | pd.DataFrame,
) -> models.DatasetVersionWithExampleIds:
"""Append new examples to an existing dataset.
This method adds examples to an existing dataset version. If
`dataset_version_id` is not provided (empty string), the server appends
the examples to the latest dataset version.
The inserted examples are assigned system-generated IDs by the server.
The response includes those IDs in `example_ids` and the version they
were written to in `dataset_version_id`.
Payload requirements (server-enforced):
- Each example may contain arbitrary user-defined fields.
- Do not include system-managed fields on input: `id`, `created_at`,
`updated_at` (requests containing these fields will be rejected).
- Each example must contain at least one property (i.e. empty
examples are not invalid).
Args:
dataset: Dataset ID or name.
space: Space ID or name. Required when *dataset* is a name.
dataset_version_id: Optional dataset version ID to append examples to. If empty,
the latest dataset version is selected.
examples: Examples to append, provided as either:
- a list of JSON-like dicts, or
- a :class:`pandas.DataFrame` (converted to records before upload).
Returns:
A :class:`DatasetVersionWithExampleIds` containing the dataset attributes,
the version the examples were written to (``dataset_version_id``),
and the IDs of the inserted examples (``example_ids``).
Raises:
AssertionError: If `examples` is not a list of dicts or a :class:`pandas.DataFrame`.
ApiException: If the REST API
returns an error response (e.g. 400/401/403/404/429).
"""
dataset_id = _find_dataset_id(
api=self._api,
dataset=dataset,
space=space,
)
from arize._generated import api_client as gen
data = (
examples.to_dict(orient="records")
if isinstance(examples, pd.DataFrame)
else examples
)
# Cast: pandas to_dict returns dict[Hashable, Any] but API requires dict[str, Any]
body = gen.DatasetsExamplesInsertRequest(
examples=cast("list[dict[str, Any]]", data)
)
return self._api.datasets_examples_insert(
dataset_id=dataset_id,
dataset_version_id=dataset_version_id,
datasets_examples_insert_request=body,
)
[docs]
@prerelease_endpoint(
key="datasets.annotate_examples", stage=ReleaseStage.ALPHA
)
def annotate_examples(
self,
*,
dataset: str,
space: str | None = None,
annotations: builtins.list[models.AnnotateRecordInput],
) -> None:
"""Write human annotations to a batch of examples in a dataset.
Annotations are upserted by annotation config name for each example.
Submitting the same annotation config name for the same example
overwrites the previous value. Retrying on network failure will
not create duplicates.
Up to 1000 examples may be annotated per request.
The write completes synchronously before the function returns. Visibility
in read queries may lag by a short interval (HTTP 202 Accepted).
Args:
dataset: Dataset ID or name.
space: Space ID or name. Required when *dataset* is a name.
annotations: A list of :class:`AnnotateRecordInput` items. Each item
must include a ``record_id`` (the dataset example ID) and ``values``
(a list of :class:`AnnotationInput` items with ``name``, and
optionally ``score``, ``label``, or ``text``).
Raises:
ApiException: If the REST API returns an error response
(e.g. 400/401/403/404/429).
"""
dataset_id = _find_dataset_id(
api=self._api,
dataset=dataset,
space=space,
)
from arize._generated import api_client as gen
body = gen.AnnotateDatasetExamplesRequestBody(annotations=annotations)
return self._api.datasets_examples_annotate(
dataset_id=dataset_id,
annotate_dataset_examples_request_body=body,
)
def _create_dataset_via_flight(
self,
name: str,
space_id: str,
examples: pd.DataFrame,
) -> Dataset:
"""Internal method to create a dataset using Flight protocol for large example sets."""
data = examples.copy()
# Convert datetime columns to int64 (ms since epoch)
data = convert_datetime_columns_to_int(data)
data = convert_boolean_columns_to_str(data)
data = _set_default_columns_for_dataset(data)
data = convert_default_columns_to_json_str(data)
validation_errors = validate_dataset_df(data)
if validation_errors:
raise RuntimeError([e.error_message() for e in validation_errors])
# Convert to Arrow table
try:
logger.debug("Converting data to Arrow format")
pa_table = pa.Table.from_pandas(data, preserve_index=False)
except pa.ArrowInvalid as e:
logger.exception(INVALID_ARROW_CONVERSION_MSG)
raise pa.ArrowInvalid(
f"Error converting to Arrow format: {e!s}"
) from e
except Exception:
logger.exception("Unexpected error creating Arrow table")
raise
response = None
with ArizeFlightClient(
api_key=self._sdk_config.api_key,
host=self._sdk_config.flight_host,
port=self._sdk_config.flight_port,
scheme=self._sdk_config.flight_scheme,
request_verify=self._sdk_config.request_verify,
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
) as flight_client:
try:
response = flight_client.create_dataset(
space_id=space_id,
dataset_name=name,
pa_table=pa_table,
)
except Exception as e:
msg = f"Error during create request: {e!s}"
logger.exception(msg)
raise RuntimeError(msg) from e
if response is None:
# This should not happen with proper Flight client implementation,
# but we handle it defensively
msg = "No response received from flight server during update"
logger.error(msg)
raise RuntimeError(msg)
# The response from flightserver is the dataset ID. To return the dataset
# object we make a GET query
return self.get(dataset=response)
def _set_default_columns_for_dataset(df: pd.DataFrame) -> pd.DataFrame:
"""Set default values for created_at and updated_at columns if missing or null."""
current_time = int(time.time() * 1000)
if "created_at" in df.columns:
if df["created_at"].isnull().any():
df["created_at"].fillna(current_time, inplace=True)
else:
df["created_at"] = current_time
if "updated_at" in df.columns:
if df["updated_at"].isnull().any():
df["updated_at"].fillna(current_time, inplace=True)
else:
df["updated_at"] = current_time
if "id" in df.columns:
if df["id"].isnull().any():
df["id"] = df["id"].apply(
lambda x: str(uuid.uuid4()) if pd.isnull(x) else x
)
else:
df["id"] = [str(uuid.uuid4()) for _ in range(len(df))]
return df