Source code for arize.experiments.types

"""Type definitions and data models for experiments."""

from __future__ import annotations

import json
import textwrap
from collections.abc import Awaitable, Callable, Iterable, Mapping
from copy import copy, deepcopy
from dataclasses import dataclass, field
from datetime import datetime, timezone
from importlib.metadata import version
from random import getrandbits
from typing import (
    NoReturn,
    cast,
)

import pandas as pd
from wrapt import ObjectProxy

from arize.experiments.evaluators.types import (
    EvaluationResult,
    JSONSerializable,
)

ExperimentId = str
ExampleId = str
RepetitionNumber = int
ExperimentRunId = str
TraceId = str


[docs] @dataclass(frozen=True) class Example: """Represents an example in an experiment dataset. Args: id: The unique identifier for the example. updated_at: The timestamp when the example was last updated. input: The input data for the example. output: The output data for the example. metadata: Additional metadata for the example. dataset_row: The original dataset row containing the example data. """ id: ExampleId = field(default_factory=str) updated_at: datetime = field(default_factory=datetime.now) input: Mapping[str, JSONSerializable] = field(default_factory=dict) output: Mapping[str, JSONSerializable] = field(default_factory=dict) metadata: Mapping[str, JSONSerializable] = field(default_factory=dict) dataset_row: Mapping[str, JSONSerializable] = field(default_factory=dict) def __post_init__(self) -> None: """Initialize example fields from dataset_row if provided.""" if self.dataset_row is not None: object.__setattr__( self, "dataset_row", _make_read_only(self.dataset_row) ) if "attributes.input.value" in self.dataset_row: object.__setattr__( self, "input", _make_read_only(self.dataset_row["attributes.input.value"]), ) if "attributes.output.value" in self.dataset_row: object.__setattr__( self, "output", _make_read_only( self.dataset_row["attributes.output.value"] ), ) if "attributes.metadata" in self.dataset_row: object.__setattr__( self, "metadata", _make_read_only(self.dataset_row["attributes.metadata"]), ) if "id" in self.dataset_row: object.__setattr__(self, "id", self.dataset_row["id"]) if "updated_at" in self.dataset_row: object.__setattr__( self, "updated_at", self.dataset_row["updated_at"] ) else: object.__setattr__(self, "input", self.input) object.__setattr__(self, "output", self.output) object.__setattr__(self, "metadata", self.metadata)
[docs] @classmethod def from_dict(cls, obj: Mapping[str, object]) -> Example: """Create an Example instance from a dictionary.""" return cls( id=cast("str", obj["id"]), input=cast("Mapping[str, JSONSerializable]", obj["input"]), output=cast("Mapping[str, JSONSerializable]", obj["output"]), metadata=cast( "Mapping[str, JSONSerializable]", obj.get("metadata") or {} ), updated_at=cast("datetime", obj["updated_at"]), )
def __repr__(self) -> str: """Return a formatted string representation of the example.""" spaces = " " * 4 name = self.__class__.__name__ identifiers = [f'{spaces}id="{self.id}",'] contents = [] for key in ("input", "output", "metadata", "dataset_row"): value = getattr(self, key, None) if value: contents.append( spaces + f"{_blue(key)}=" + json.dumps( _shorten(value), ensure_ascii=False, sort_keys=True, indent=len(spaces), ) .replace("\n", f"\n{spaces}") .replace(' "..."\n', " ...\n") + "," ) return "\n".join([f"{name}(", *identifiers, *contents, ")"])
def _shorten( obj: dict[str, object] | list[object] | str | object, width: int = 50 ) -> dict[str, object] | list[object] | str | object: if isinstance(obj, str): return textwrap.shorten(obj, width=width, placeholder="...") if isinstance(obj, dict): return {k: _shorten(v) for k, v in obj.items()} if isinstance(obj, list): if len(obj) > 2: return [_shorten(v) for v in obj[:2]] + ["..."] return [_shorten(v) for v in obj] return obj def _make_read_only( obj: dict[str, object] | list[object] | str | object, ) -> dict[str, object] | list[object] | str | object: if isinstance(obj, dict): return _ReadOnly({k: _make_read_only(v) for k, v in obj.items()}) if isinstance(obj, str): return obj if isinstance(obj, list): return _ReadOnly(list(map(_make_read_only, obj))) return obj class _ReadOnly(ObjectProxy): def __setitem__(self, *args: object, **kwargs: object) -> object: raise NotImplementedError def __delitem__(self, *args: object, **kwargs: object) -> object: raise NotImplementedError def __iadd__(self, *args: object, **kwargs: object) -> object: raise NotImplementedError def pop(self, *args: object, **kwargs: object) -> object: raise NotImplementedError def append(self, *args: object, **kwargs: object) -> object: raise NotImplementedError def __copy__(self, *args: object, **kwargs: object) -> object: return copy(self.__wrapped__) def __deepcopy__(self, *args: object, **kwargs: object) -> object: return deepcopy(self.__wrapped__) def __repr__(self) -> str: return repr(self.__wrapped__) def __str__(self) -> str: return str(self.__wrapped__) def _blue(text: str) -> str: return f"\033[1m\033[94m{text}\033[0m" @dataclass(frozen=True) class TestCase: """Container for an experiment test case with example data and repetition number.""" example: Example repetition_number: RepetitionNumber EXP_ID: ExperimentId = "EXP_ID" def _exp_id() -> str: suffix = getrandbits(24).to_bytes(3, "big").hex() return f"{EXP_ID}_{suffix}"
[docs] @dataclass(frozen=True) class ExperimentRun: """Represents a single run of an experiment. Args: start_time: The start time of the experiment run. end_time: The end time of the experiment run. experiment_id: The unique identifier for the experiment. dataset_example_id: The unique identifier for the dataset example. repetition_number: The repetition number of the experiment run. output: The output of the experiment run. error: The error message if the experiment run failed. id: The unique identifier for the experiment run. trace_id: The trace identifier for the experiment run. """ start_time: datetime end_time: datetime experiment_id: ExperimentId dataset_example_id: ExampleId repetition_number: RepetitionNumber output: JSONSerializable error: str | None = None id: ExperimentRunId = field(default_factory=_exp_id) trace_id: TraceId | None = None
[docs] @classmethod def from_dict(cls, obj: Mapping[str, object]) -> ExperimentRun: """Create an ExperimentRun instance from a dictionary.""" return cls( start_time=cast("datetime", obj["start_time"]), end_time=cast("datetime", obj["end_time"]), experiment_id=cast("str", obj["experiment_id"]), dataset_example_id=cast("str", obj["dataset_example_id"]), repetition_number=cast("int", obj.get("repetition_number") or 1), output=cast("JSONSerializable", _make_read_only(obj.get("output"))), error=cast("str | None", obj.get("error")), id=cast("str", obj["id"]), trace_id=cast("str | None", obj.get("trace_id")), )
def __post_init__(self) -> None: """Validate that exactly one of output or error is specified. Raises: ValueError: If both or neither output and error are specified. """ if (self.output is None) == (self.error is None): raise ValueError( "Must specify exactly one of experiment_run_output or error" )
[docs] @dataclass(frozen=True) class ExperimentEvaluationRun: """Represents a single evaluation run of an experiment. Args: experiment_run_id: The unique identifier for the experiment run. start_time: The start time of the evaluation run. end_time: The end time of the evaluation run. name: The name of the evaluation run. annotator_kind: The kind of annotator used in the evaluation run. error: The error message if the evaluation run failed. result (EvaluationResult | :obj:`None`): The result of the evaluation run. id (str): The unique identifier for the evaluation run. trace_id (TraceId | :obj:`None`): The trace identifier for the evaluation run. """ experiment_run_id: ExperimentRunId start_time: datetime end_time: datetime name: str annotator_kind: str error: str | None = None result: EvaluationResult | None = None id: str = field(default_factory=_exp_id) trace_id: TraceId | None = None
[docs] @classmethod def from_dict(cls, obj: Mapping[str, object]) -> ExperimentEvaluationRun: """Create an ExperimentEvaluationRun instance from a dictionary.""" return cls( experiment_run_id=cast("str", obj["experiment_run_id"]), start_time=cast("datetime", obj["start_time"]), end_time=cast("datetime", obj["end_time"]), name=cast("str", obj["name"]), annotator_kind=cast("str", obj["annotator_kind"]), error=cast("str | None", obj.get("error")), result=EvaluationResult.from_dict( cast("Mapping[str, object] | None", obj.get("result")) ), id=cast("str", obj["id"]), trace_id=cast("str | None", obj.get("trace_id")), )
def __post_init__(self) -> None: """Validate that exactly one of result or error is specified. Raises: ValueError: If both or neither result and error are specified. """ if bool(self.result) == bool(self.error): raise ValueError("Must specify either result or error")
_LOCAL_TIMEZONE = datetime.now(timezone.utc).astimezone().tzinfo def local_now() -> datetime: """Get the current datetime in the local timezone. Returns: A datetime object representing the current time in local timezone. """ return datetime.now(timezone.utc).astimezone(tz=_LOCAL_TIMEZONE) @dataclass(frozen=True) class _HasStats: _title: str = field(repr=False, default="") _timestamp: datetime = field(repr=False, default_factory=local_now) stats: pd.DataFrame = field(repr=False, default_factory=pd.DataFrame) @property def title(self) -> str: return f"{self._title} ({self._timestamp:%x %I:%M %p %z})" def __str__(self) -> str: try: pandas_major = int(version("pandas").split(".")[0]) if pandas_major < 1: raise ImportError("Pandas version < 1.0") # noqa: TRY301 # `tabulate` is used by pandas >= 1.0 in DataFrame.to_markdown() import tabulate # noqa: F401 except ImportError: text = self.stats.__str__() else: text = self.stats.to_markdown(index=False) return f"{self.title}\n{'-' * len(self.title)}\n" + text @dataclass(frozen=True) class _TaskSummary(_HasStats): """Summary statistics of experiment task executions. **Users should not instantiate this object directly.** """ _title: str = "Tasks Summary" @classmethod def from_task_runs( cls, n_examples: int, task_runs: Iterable[ExperimentRun | None] ) -> _TaskSummary: df = pd.DataFrame.from_records( [ { "example_id": run.dataset_example_id, "error": run.error, } for run in task_runs if run is not None ] ) n_runs = len(df) n_errors = 0 if df.empty else df.loc[:, "error"].astype(bool).sum() record = { "n_examples": n_examples, "n_runs": n_runs, "n_errors": n_errors, **( {"top_error": _top_string(df.loc[:, "error"])} if n_errors else {} ), } stats = pd.DataFrame.from_records([record]) summary: _TaskSummary = object.__new__(cls) summary.__init__(stats=stats) # type: ignore[misc] return summary @classmethod def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # Direct instantiation by users is discouraged. raise NotImplementedError @classmethod def __init_subclass__(cls, **kwargs: object) -> None: # Direct sub-classing by users is discouraged. raise NotImplementedError def _top_string(s: pd.Series, length: int = 100) -> str | None: if (cnt := s.dropna().str.slice(0, length).value_counts()).empty: return None return cast("str", cnt.sort_values(ascending=False).index[0])
[docs] @dataclass class ExperimentTaskFieldNames: """Column names for mapping experiment task results in a :class:`pandas.DataFrame`. Args: example_id: Name of column containing example IDs. The ID values must match the id of the dataset rows. output: Name of column containing task results """ example_id: str output: str
TaskOutput = JSONSerializable ExampleOutput = Mapping[str, JSONSerializable] ExampleMetadata = Mapping[str, JSONSerializable] ExampleInput = Mapping[str, JSONSerializable] ExperimentTask = ( Callable[[Example], TaskOutput] | Callable[[Example], Awaitable[TaskOutput]] ) # Public re-exports of generated API response types for this subdomain. from arize._generated.api_client.models.annotate_record_input import ( # noqa: E402 AnnotateRecordInput, ) from arize._generated.api_client.models.annotation_input import ( # noqa: E402 AnnotationInput, ) from arize._generated.api_client.models.experiment import ( # noqa: E402 Experiment, ) from arize._generated.api_client.models.experiments_list200_response import ( # noqa: E402 ExperimentsList200Response, ) from arize._generated.api_client.models.experiments_runs_list200_response import ( # noqa: E402 ExperimentsRunsList200Response, ) __all__ = [ "AnnotateRecordInput", "AnnotationInput", "Example", "ExampleId", "ExampleInput", "ExampleMetadata", "ExampleOutput", "Experiment", "ExperimentEvaluationRun", "ExperimentId", "ExperimentRun", "ExperimentRunId", "ExperimentTask", "ExperimentTaskFieldNames", "ExperimentsList200Response", "ExperimentsRunsList200Response", "RepetitionNumber", "TaskOutput", "TestCase", "TraceId", ]