Source code for arize.ml.client

"""Client implementation for managing ML models in the Arize platform."""

from __future__ import annotations

import copy
import logging
import time
from typing import TYPE_CHECKING, Any, cast

from arize._generated.protocol.rec import public_pb2 as pb2
from arize._lazy import require
from arize.constants.ml import (
    LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME,
    LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME,
    LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME,
    LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME,
    MAX_FUTURE_YEARS_FROM_CURRENT_TIME,
    MAX_NUMBER_OF_EMBEDDINGS,
    MAX_PAST_YEARS_FROM_CURRENT_TIME,
    MAX_TAG_LENGTH,
    MAX_TAG_LENGTH_TRUNCATION,
    RESERVED_TAG_COLS,
)
from arize.exceptions.base import (
    INVALID_ARROW_CONVERSION_MSG,
    ValidationFailure,
)
from arize.exceptions.models import MissingModelNameError
from arize.exceptions.parameters import (
    InvalidNumberOfEmbeddings,
    InvalidValueType,
)
from arize.exceptions.spaces import MissingSpaceIDError
from arize.logging import get_truncation_warning_message
from arize.ml.bounded_executor import BoundedExecutor
from arize.ml.casting import cast_dictionary, cast_typed_columns
from arize.ml.stream_validation import (
    validate_and_convert_prediction_id,
    validate_label,
)
from arize.ml.types import (
    CATEGORICAL_MODEL_TYPES,
    NUMERIC_MODEL_TYPES,
    ActualLabelTypes,
    BaseSchema,
    CorpusSchema,
    Embedding,
    EmbeddingColumnNames,
    Environments,
    LLMRunMetadata,
    Metrics,
    ModelTypes,
    PredictionIDType,
    PredictionLabelTypes,
    Schema,
    SimilaritySearchParams,
    TypedValue,
    convert_element,
)
from arize.utils.types import is_list_of

if TYPE_CHECKING:
    import concurrent.futures as cf
    from datetime import datetime

    import pandas as pd
    import requests
    from requests_futures.sessions import FuturesSession

    from arize.config import SDKConfiguration


logger = logging.getLogger(__name__)

_STREAM_DEPS = (
    "requests_futures",
    "google.protobuf",
)
_STREAM_EXTRA = "ml-stream"

_BATCH_DEPS = (
    "pandas",
    "google.protobuf",
    "pyarrow",
    "requests",
    "tqdm",
)
_BATCH_EXTRA = "ml-batch"
_MIMIC_DEPS = (
    "interpret_community.mimic",
    "sklearn.preprocessing",
)
_MIMIC_EXTRA = "mimic-explainer"


[docs] class MLModelsClient: """Client for logging ML model predictions and actuals to Arize. This class is primarily intended for internal use within the SDK. Users are highly encouraged to access resource-specific functionality via :class:`arize.ArizeClient`. """ def __init__(self, *, sdk_config: SDKConfiguration) -> None: """ Args: sdk_config: Resolved SDK configuration. """ # noqa: D205, D212 self._sdk_config = sdk_config # internal cache for the futures session self._session: FuturesSession | None = None
[docs] def log_stream( self, *, space_id: str, model_name: str, model_type: ModelTypes, environment: Environments, model_version: str | None = None, prediction_id: PredictionIDType | None = None, prediction_timestamp: int | None = None, prediction_label: PredictionLabelTypes | None = None, actual_label: ActualLabelTypes | None = None, features: dict[str, str | bool | float | int | list[str] | TypedValue] | None = None, embedding_features: dict[str, Embedding] | None = None, shap_values: dict[str, float] | None = None, tags: dict[str, str | bool | float | int | TypedValue] | None = None, batch_id: str | None = None, prompt: str | Embedding | None = None, response: str | Embedding | None = None, prompt_template: str | None = None, prompt_template_version: str | None = None, llm_model_name: str | None = None, llm_params: dict[str, str | bool | float | int] | None = None, llm_run_metadata: LLMRunMetadata | None = None, timeout: float | None = None, ) -> cf.Future: """Log a single model prediction or actual to Arize asynchronously. This method sends a single prediction, actual, or both to Arize for ML monitoring. The request is made asynchronously and returns a Future that can be used to check the status or retrieve the response. Args: space_id: The space ID where the model resides. model_name: A unique name to identify your model in the Arize platform. model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION, RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the spans module instead. environment: The environment this data belongs to (PRODUCTION, TRAINING, or VALIDATION). model_version: Optional version identifier for the model. prediction_id: Unique identifier for this prediction. If not provided, one will be auto-generated for PRODUCTION environment. prediction_timestamp: Unix timestamp (seconds) for when the prediction was made. If not provided, the current time is used. Must be within 1 year in the future and 2 years in the past from the current time. prediction_label: The prediction output from your model. Type depends on model_type (e.g., string for categorical, float for numeric). actual_label: The ground truth label. Type depends on model_type. features: Dictionary of feature name to feature value. Values can be str, bool, float, int, list[str], or TypedValue. embedding_features: Dictionary of embedding feature name to Embedding object. Maximum 50 embeddings per record. Object detection models support only 1. shap_values: Dictionary of feature name to SHAP value (float) for feature importance/explainability. tags: Dictionary of metadata tags. Tag names cannot end with "_shap" or be reserved names. Values must be under 1000 characters (warning at 100). batch_id: Required for VALIDATION environment; identifies the validation batch. prompt: For generative models, the prompt text or embedding sent to the model. response: For generative models, the response text or embedding from the model. prompt_template: Template used to generate the prompt. prompt_template_version: Version identifier for the prompt template. llm_model_name: Name of the LLM model used (e.g., "gpt-4"). llm_params: Dictionary of LLM configuration parameters (e.g., temperature, max_tokens). llm_run_metadata: Metadata about the LLM run including token counts and latency. timeout: Maximum time (in seconds) to wait for the request to complete. Returns: A concurrent.futures.Future object representing the async request. Call .result() to block and retrieve the Response object, or check .done() for completion status. Raises: ValueError: If model_type is GENERATIVE_LLM, or if validation environment is missing batch_id, or if training/validation environment is missing prediction or actual, or if timestamp is out of range, or if no data is provided (must have prediction_label, actual_label, tags, or shap_values), or if tag names end with "_shap" or exceed length limits. MissingSpaceIDError: If space_id is not provided or empty. MissingModelNameError: If model_name is not provided or empty. InvalidValueType: If features, tags, or other parameters have incorrect types. InvalidNumberOfEmbeddings: If more than 50 embedding features are provided. KeyError: If tag names include reserved names. Notes: - Timestamps must be within 1 year future and 2 years past from current time - Tag values are truncated at 1000 characters, with warnings at 100 characters - For GENERATIVE_LLM models, use the spans module or OTEL tracing instead - The Future returned can be monitored for request status asynchronously """ require(_STREAM_EXTRA, _STREAM_DEPS) from arize._generated.protocol.rec import public_pb2 as pb2 from arize.ml.proto import ( get_pb_dictionary, get_pb_label, get_pb_timestamp, ) if model_type == ModelTypes.GENERATIVE_LLM: raise ValueError( "Wrong model type found: GENERATIVE_LLM. To send LLM data to Arize, " "use the spans module `arize_client.spans` or OTEL tracing" ) # This method requires a space_id and project_name if not space_id: raise MissingSpaceIDError() if not model_name: raise MissingModelNameError() # Validate batch_id if environment == Environments.VALIDATION and ( batch_id is None or not isinstance(batch_id, str) or len(batch_id.strip()) == 0 ): raise ValueError( "Batch ID must be a nonempty string if logging to validation environment." ) # Convert & Validate prediction_id prediction_id = validate_and_convert_prediction_id( prediction_id, environment, prediction_label, actual_label, shap_values, ) # Cast feature & tag values if features: features = cast_dictionary(features) # Defensive check if not isinstance(features, dict): raise InvalidValueType("features", features, "dict") for feat_name, feat_value in features.items(): _validate_mapping_key(feat_name, "features") if is_list_of(feat_value, str): continue val = convert_element(feat_value) if val is not None and not isinstance( val, (str, bool, float, int) ): raise InvalidValueType( f"feature '{feat_name}'", feat_value, "one of: bool, int, float, str", ) # Validate embedding_features type if embedding_features: if not isinstance(embedding_features, dict): raise InvalidValueType( "embedding_features", embedding_features, "dict" ) if len(embedding_features) > MAX_NUMBER_OF_EMBEDDINGS: raise InvalidNumberOfEmbeddings(len(embedding_features)) if ( model_type == ModelTypes.OBJECT_DETECTION and len(embedding_features.keys()) > 1 ): # Check that there is only 1 embedding feature for OD model types raise ValueError( "Object Detection models only support one embedding feature" ) for emb_name, emb_obj in embedding_features.items(): _validate_mapping_key(emb_name, "embedding features") # Must verify embedding type if not isinstance(emb_obj, Embedding): raise InvalidValueType( f"embedding feature '{emb_name}'", emb_obj, "Embedding" ) emb_obj.validate(emb_name) if tags: tags = cast_dictionary(tags) # Defensive check if not isinstance(tags, dict): raise InvalidValueType("tags", tags, "dict") wrong_tags = [ tag_name for tag_name in tags if tag_name in RESERVED_TAG_COLS ] if wrong_tags: raise KeyError( f"The following tag names are not allowed as they are reserved: {wrong_tags}" ) for tag_name, tag_value in tags.items(): _validate_mapping_key(tag_name, "tags") val = convert_element(tag_value) if val is not None and not isinstance( val, (str, bool, float, int) ): raise InvalidValueType( f"tag '{tag_name}'", tag_value, "one of: bool, int, float, str", ) if isinstance(tag_name, str) and tag_name.endswith("_shap"): raise ValueError( f"tag {tag_name} must not be named with a `_shap` suffix" ) if len(str(val)) > MAX_TAG_LENGTH: raise ValueError( f"The number of characters for each tag must be less than or equal to " f"{MAX_TAG_LENGTH}. The tag {tag_name} with value {tag_value} has " f"{len(str(val))} characters." ) if len(str(val)) > MAX_TAG_LENGTH_TRUNCATION: logger.warning( get_truncation_warning_message( "tags", MAX_TAG_LENGTH_TRUNCATION ) ) # Check the timestamp present on the event if prediction_timestamp is not None: if not isinstance(prediction_timestamp, int): raise InvalidValueType( "prediction_timestamp", prediction_timestamp, "int" ) # Send warning if prediction is sent with future timestamp now = int(time.time()) if prediction_timestamp > now: logger.warning( "Caution when sending a prediction with future timestamp." "Arize only stores 2 years worth of data. For example, if you sent a prediction " "to Arize from 1.5 years ago, and now send a prediction with timestamp of a year in " "the future, the oldest 0.5 years will be dropped to maintain the 2 years worth of data " "requirement." ) if not _is_timestamp_in_range(now, prediction_timestamp): raise ValueError( f"prediction_timestamp: {prediction_timestamp} is out of range." f"Prediction timestamps must be within {MAX_FUTURE_YEARS_FROM_CURRENT_TIME} year in the " f"future and {MAX_PAST_YEARS_FROM_CURRENT_TIME} years in the past from the current time." ) # Construct the prediction p = None if prediction_label is not None: if model_version is not None and not isinstance(model_version, str): raise InvalidValueType("model_version", model_version, "str") validate_label( prediction_or_actual="prediction", model_type=model_type, label=convert_element(prediction_label), embedding_features=embedding_features, ) p = pb2.Prediction( prediction_label=get_pb_label( prediction_or_actual="prediction", value=prediction_label, model_type=model_type, ), model_version=model_version, ) if features is not None: converted_feats = get_pb_dictionary(features) feats = pb2.Prediction(features=converted_feats) p.MergeFrom(feats) if embedding_features or prompt or response: # NOTE: Deep copy is necessary to avoid side effects on the original input dictionary combined_embedding_features: dict[str, str | Embedding] = ( cast( "dict[str, str | Embedding]", embedding_features.copy() ) if embedding_features else {} ) # Map prompt as embedding features for generative models if prompt is not None: combined_embedding_features.update({"prompt": prompt}) # Map response as embedding features for generative models if response is not None: combined_embedding_features.update({"response": response}) converted_embedding_feats = get_pb_dictionary( combined_embedding_features ) embedding_feats = pb2.Prediction( features=converted_embedding_feats ) p.MergeFrom(embedding_feats) if tags or llm_run_metadata: joined_tags = copy.deepcopy(tags) if tags is not None else {} if llm_run_metadata: if llm_run_metadata.total_token_count is not None: joined_tags[ LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME ] = llm_run_metadata.total_token_count if llm_run_metadata.prompt_token_count is not None: joined_tags[ LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME ] = llm_run_metadata.prompt_token_count if llm_run_metadata.response_token_count is not None: joined_tags[ LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME ] = llm_run_metadata.response_token_count if llm_run_metadata.response_latency_ms is not None: joined_tags[ LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME ] = llm_run_metadata.response_latency_ms converted_tags = get_pb_dictionary(joined_tags) tgs = pb2.Prediction(tags=converted_tags) p.MergeFrom(tgs) if ( prompt_template or prompt_template_version or llm_model_name or llm_params ): llm_fields = pb2.LLMFields( prompt_template=prompt_template or "", prompt_template_name=prompt_template_version or "", llm_model_name=llm_model_name or "", llm_params=get_pb_dictionary(llm_params), ) p.MergeFrom(pb2.Prediction(llm_fields=llm_fields)) if prediction_timestamp is not None: p.timestamp.MergeFrom(get_pb_timestamp(prediction_timestamp)) # Validate and construct the optional actual is_latent_tags = prediction_label is None and tags is not None a = None if actual_label or is_latent_tags: a = pb2.Actual() if actual_label is not None: validate_label( prediction_or_actual="actual", model_type=model_type, label=convert_element(actual_label), embedding_features=embedding_features, ) a.MergeFrom( pb2.Actual( actual_label=get_pb_label( prediction_or_actual="actual", value=actual_label, model_type=model_type, ) ) ) # Added to support delayed tags on actuals. if tags is not None: converted_tags = get_pb_dictionary(tags) a.MergeFrom(pb2.Actual(tags=converted_tags)) # Validate and construct the optional feature importances fi = None if shap_values is not None and bool(shap_values): for k, v in shap_values.items(): if not isinstance(convert_element(v), float): raise InvalidValueType(f"feature '{k}'", v, "float") if isinstance(k, str) and k.endswith("_shap"): raise ValueError( f"feature {k} must not be named with a `_shap` suffix" ) fi = pb2.FeatureImportances(feature_importances=shap_values) if p is None and a is None and fi is None: raise ValueError( "must provide at least one of prediction_label, actual_label, tags, or shap_values" ) env_params = None if environment == Environments.TRAINING: if p is None or a is None: raise ValueError( "Training records must have both Prediction and Actual" ) env_params = pb2.Record.EnvironmentParams( training=pb2.Record.EnvironmentParams.Training() ) elif environment == Environments.VALIDATION: if p is None or a is None: raise ValueError( "Validation records must have both Prediction and Actual" ) env_params = pb2.Record.EnvironmentParams( validation=pb2.Record.EnvironmentParams.Validation( batch_id=batch_id ) ) elif environment == Environments.PRODUCTION: env_params = pb2.Record.EnvironmentParams( production=pb2.Record.EnvironmentParams.Production() ) rec = pb2.Record( # We don't pass the deprecated space key # as part of the public record, we pass the space ID in the header model_id=model_name, prediction_id=prediction_id, prediction=p, actual=a, feature_importances=fi, environment_params=env_params, ) headers = self._sdk_config.headers_grpc headers.update( { "Grpc-Metadata-arize-space-id": space_id, "Grpc-Metadata-arize-interface": "stream", } ) return self._post( record=rec, headers=headers, timeout=timeout, indexes=None, # type: ignore[arg-type] )
[docs] def log( self, *, space_id: str, model_name: str, model_type: ModelTypes, dataframe: pd.DataFrame, schema: BaseSchema, environment: Environments, model_version: str = "", batch_id: str = "", validate: bool = True, metrics_validation: list[Metrics] | None = None, surrogate_explainability: bool = False, timeout: float | None = None, tmp_dir: str = "", ) -> requests.Response: """Log a batch of model predictions and actuals to Arize from a :class:`pandas.DataFrame`. This method uploads multiple records to Arize in a single batch operation using Apache Arrow format for efficient transfer. The dataframe structure is defined by the provided schema which maps dataframe columns to Arize data fields. Args: space_id: The space ID where the model resides. model_name: A unique name to identify your model in the Arize platform. model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION, RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the spans module instead. dataframe (:class:`pandas.DataFrame`): Pandas DataFrame containing the data to upload. Columns should correspond to the schema field mappings. schema: Schema object (Schema or CorpusSchema) that defines the mapping between dataframe columns and Arize data fields (e.g., prediction_label_column_name, feature_column_names, etc.). environment: The environment this data belongs to (PRODUCTION, TRAINING, VALIDATION, or CORPUS). model_version: Optional version identifier for the model. batch_id: Required for VALIDATION environment; identifies the validation batch. validate: When True, performs comprehensive validation before sending data. Includes checks for required fields, data types, and value constraints. metrics_validation: Optional list of metric families to validate against. surrogate_explainability: When True, automatically generates SHAP values using MIMIC surrogate explainer. Requires the 'mimic-explainer' extra. Has no effect if shap_values_column_names is already specified in schema. timeout: Maximum time (in seconds) to wait for the request to complete. tmp_dir: Optional temporary directory to store serialized Arrow data before upload. Returns: A requests.Response object from the upload request (only returned on HTTP 2xx). Non-2xx responses raise exceptions rather than being returned. Raises: MissingSpaceIDError: If space_id is not provided or empty. MissingModelNameError: If model_name is not provided or empty. ValueError: If model_type is GENERATIVE_LLM, or if environment is CORPUS with non-CorpusSchema, or if training/validation records are incomplete. ValidationFailure: If validate=True and validation checks fail. Contains list of validation error messages. pa.ArrowInvalid: If the dataframe cannot be converted to Arrow format, typically due to mixed types in columns not specified in the schema. AuthenticationError: If the server returns HTTP 401 or 403 (invalid API key or space ID). Raised immediately to prevent further uploads with bad credentials. APIError: If the server returns any other non-2xx response (e.g. 400, 422, 429, 5xx). Raised immediately to prevent further uploads when the server signals an error. Notes: - Categorical dtype columns are automatically converted to string - Extraneous columns not in the schema are removed before upload - Surrogate explainability requires 'mimic-explainer' extra - For GENERATIVE_LLM models, use the spans module or OTEL tracing instead - If logging actuals without predictions, ensure predictions were logged first - Data is sent via Apache Arrow for efficient large batch transfers """ require(_BATCH_EXTRA, _BATCH_DEPS) import pandas.api.types as ptypes import pyarrow as pa from arize.ml.batch_validation.validator import Validator from arize.utils.arrow import post_arrow_table from arize.utils.dataframe import remove_extraneous_columns # This method requires a space_id and project_name if not space_id: raise MissingSpaceIDError() if not model_name: raise MissingModelNameError() # Deep copy the schema since we might modify it to add certain columns and don't # want to cause side effects schema = copy.deepcopy(schema) if model_type == ModelTypes.GENERATIVE_LLM: raise ValueError( "Wrong model type found: GENERATIVE_LLM. To send LLM data to Arize, " "use the spans module `arize_client.spans` or OTEL tracing" ) # If typed columns are specified in the schema, # apply casting and return new copies of the dataframe + schema. # All downstream validations are kept the same. # note: we don't do any casting for Corpus schemas. if isinstance(schema, Schema) and schema.has_typed_columns(): # The pandas nullable string column type (StringDType) is still considered experimental # and is unavailable before pandas 1.0.0. # Thus we can only offer this functionality with pandas>=1.0.0. try: dataframe, schema = cast_typed_columns(dataframe, schema) except Exception: logger.exception("Error casting typed columns") raise logger.debug("Performing required validation.") errors = Validator.validate_required_checks( dataframe=dataframe, model_id=model_name, environment=environment, schema=schema, model_version=model_version, batch_id=batch_id, ) if errors: for e in errors: logger.error(e) raise ValidationFailure(errors) if validate: logger.debug("Performing parameters validation.") errors = Validator.validate_params( dataframe=dataframe, model_id=model_name, model_type=model_type, environment=environment, schema=schema, metric_families=metrics_validation, model_version=model_version, batch_id=batch_id, ) if errors: for e in errors: logger.error(e) raise ValidationFailure(errors) logger.debug("Removing unnecessary columns.") dataframe = remove_extraneous_columns(df=dataframe, schema=schema) # always validate pd.Category is not present, if yes, convert to string # Type ignore: pandas.api.types.is_categorical_dtype exists but stubs may be incomplete has_cat_col = any( ptypes.is_categorical_dtype(x) # type: ignore[attr-defined] for x in dataframe.dtypes ) if has_cat_col: cat_cols = [ col_name for col_name, col_cat in dataframe.dtypes.items() if col_cat.name == "category" ] cat_str_map = dict( zip( cat_cols, ["str"] * len(cat_cols), strict=True, ) ) dataframe = dataframe.astype(cat_str_map) if surrogate_explainability: require(_MIMIC_EXTRA, _MIMIC_DEPS) from arize.ml.surrogate_explainer.mimic import Mimic logger.debug("Running surrogate_explainability.") # Type ignore: schema typed as BaseSchema but runtime is Schema with these attrs if schema.shap_values_column_names: # type: ignore[attr-defined] logger.info( "surrogate_explainability=True has no effect " "because shap_values_column_names is already specified in schema." ) elif schema.feature_column_names is None or ( # type: ignore[attr-defined] hasattr(schema.feature_column_names, "__len__") # type: ignore[attr-defined] and len(schema.feature_column_names) == 0 # type: ignore[attr-defined] ): logger.info( "surrogate_explainability=True has no effect " "because feature_column_names is empty or not specified in schema." ) else: dataframe, schema = Mimic.augment( df=dataframe, schema=schema, # type: ignore[arg-type] model_type=model_type, ) # Convert to Arrow table try: logger.debug("Converting data to Arrow format") # pyarrow will err if a mixed type column exist in the dataset even if # the column is not specified in schema. Caveat: There may be other # error conditions that we're currently not aware of. pa_table = pa.Table.from_pandas(dataframe, preserve_index=False) except pa.ArrowInvalid as e: logger.exception(INVALID_ARROW_CONVERSION_MSG) raise pa.ArrowInvalid( f"Error converting to Arrow format: {e!s}" ) from e except Exception: logger.exception("Unexpected error creating Arrow table") raise if validate: logger.debug("Performing types validation.") errors = Validator.validate_types( model_type=model_type, schema=schema, pyarrow_schema=pa_table.schema, ) if errors: for error in errors: logger.error(error) raise ValidationFailure(errors) if validate: logger.debug("Performing values validation.") errors = Validator.validate_values( dataframe=dataframe, environment=environment, schema=schema, model_type=model_type, max_past_years=self._sdk_config.max_past_years, ) if errors: for error in errors: logger.error(error) raise ValidationFailure(errors) if isinstance(schema, Schema) and not schema.has_prediction_columns(): logger.warning( "Logging actuals without any predictions may result in " "unexpected behavior if corresponding predictions have not been logged prior. " "Please see the docs at https://docs.arize.com/arize/sending-data/sending-data-faq" "#what-happens-after-i-send-in-actual-data" ) if environment == Environments.CORPUS: proto_schema = _get_pb_schema_corpus( schema=schema, # type: ignore[arg-type] model_id=model_name, ) else: proto_schema = _get_pb_schema( schema=schema, # type: ignore[arg-type] model_id=model_name, model_version=model_version, model_type=model_type, environment=environment, batch_id=batch_id, ) # Create headers copy for the spans client # Safe to mutate, returns a deep copy headers = self._sdk_config.headers # Send the number of rows in the dataframe as a header # This helps the Arize server to return appropriate feedback, specially for async logging headers.update( { "arize-space-id": space_id, "arize-interface": "batch", "number-of-rows": str(len(dataframe)), } ) return post_arrow_table( files_url=self._sdk_config.files_url, pa_table=pa_table, proto_schema=proto_schema, headers=headers, timeout=timeout, verify=self._sdk_config.request_verify, max_chunksize=self._sdk_config.pyarrow_max_chunksize, tmp_dir=tmp_dir, )
[docs] def export_to_df( self, *, space_id: str, model_name: str, environment: Environments, start_time: datetime, end_time: datetime, include_actuals: bool = False, model_version: str = "", batch_id: str = "", where: str = "", columns: list | None = None, similarity_search_params: SimilaritySearchParams | None = None, stream_chunk_size: int | None = None, ) -> pd.DataFrame: """Export model data from Arize to a :class:`pandas.DataFrame`. Retrieves prediction and optional actual data for a model within a specified time range and returns it as a :class:`pandas.DataFrame` for analysis. Args: space_id: The space ID where the model resides. model_name: The name of the model to export data from. environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION). start_time: Start of the time range (inclusive) as a datetime object. end_time: End of the time range (inclusive) as a datetime object. include_actuals: When True, includes actual labels in the export. When False, only predictions are returned. model_version: Optional model version to filter by. Empty string returns all versions. batch_id: Optional batch ID to filter by (for VALIDATION environment). where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5"). columns: Optional list of column names to include. If None, all columns are returned. similarity_search_params: Optional parameters for embedding similarity search filtering. stream_chunk_size: Optional chunk size for streaming large result sets. Returns: :class:`pandas.DataFrame`: A pandas DataFrame containing the exported data with columns for predictions, actuals (if requested), features, tags, timestamps, and other model metadata. Raises: RuntimeError: If the Flight client request fails or returns no response. Notes: - Uses Apache Arrow Flight for efficient data transfer - Large exports may benefit from specifying stream_chunk_size - The where clause supports SQL-like filtering syntax """ require(_BATCH_EXTRA, _BATCH_DEPS) from arize._exporter.client import ArizeExportClient from arize._flight.client import ArizeFlightClient with ArizeFlightClient( api_key=self._sdk_config.api_key, host=self._sdk_config.flight_host, port=self._sdk_config.flight_port, scheme=self._sdk_config.flight_scheme, request_verify=self._sdk_config.request_verify, max_chunksize=self._sdk_config.pyarrow_max_chunksize, ) as flight_client: exporter = ArizeExportClient( flight_client=flight_client, ) return exporter.export_to_df( space_id=space_id, model_id=model_name, environment=environment, start_time=start_time, end_time=end_time, where=where, columns=columns, similarity_search_params=similarity_search_params, stream_chunk_size=stream_chunk_size, include_actuals=include_actuals, model_version=model_version, batch_id=batch_id, )
[docs] def export_to_parquet( self, *, path: str, space_id: str, model_name: str, environment: Environments, start_time: datetime, end_time: datetime, include_actuals: bool = False, model_version: str = "", batch_id: str = "", where: str = "", columns: list | None = None, similarity_search_params: SimilaritySearchParams | None = None, stream_chunk_size: int | None = None, ) -> None: """Export model data from Arize to a Parquet file. Retrieves prediction and optional actual data for a model within a specified time range and writes it directly to a Parquet file at the specified path. Args: path: The file path where the Parquet file will be written. space_id: The space ID where the model resides. model_name: The name of the model to export data from. environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION). start_time: Start of the time range (inclusive) as a datetime object. end_time: End of the time range (inclusive) as a datetime object. include_actuals: When True, includes actual labels in the export. When False, only predictions are returned. model_version: Optional model version to filter by. Empty string returns all versions. batch_id: Optional batch ID to filter by (for VALIDATION environment). where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5"). columns: Optional list of column names to include. If None, all columns are returned. similarity_search_params: Optional parameters for embedding similarity search filtering. stream_chunk_size: Optional chunk size for streaming large result sets. Raises: RuntimeError: If the Flight client request fails or returns no response. Notes: - Uses Apache Arrow Flight for efficient data transfer - Data is written directly to the specified path as a Parquet file - Large exports may benefit from specifying stream_chunk_size """ require(_BATCH_EXTRA, _BATCH_DEPS) from arize._exporter.client import ArizeExportClient from arize._flight.client import ArizeFlightClient with ArizeFlightClient( api_key=self._sdk_config.api_key, host=self._sdk_config.flight_host, port=self._sdk_config.flight_port, scheme=self._sdk_config.flight_scheme, request_verify=self._sdk_config.request_verify, max_chunksize=self._sdk_config.pyarrow_max_chunksize, ) as flight_client: exporter = ArizeExportClient( flight_client=flight_client, ) exporter.export_to_parquet( path=path, space_id=space_id, model_id=model_name, environment=environment, start_time=start_time, end_time=end_time, where=where, columns=columns, similarity_search_params=similarity_search_params, stream_chunk_size=stream_chunk_size, include_actuals=include_actuals, model_version=model_version, batch_id=batch_id, )
def _ensure_session(self) -> FuturesSession: """Lazily initialize and return the FuturesSession for async streaming requests.""" from requests_futures.sessions import FuturesSession session = object.__getattribute__(self, "_session") if session is not None: return session # disable TLS verification for local dev on localhost, or if user opts out new_session = FuturesSession( executor=BoundedExecutor( self._sdk_config.stream_max_queue_bound, self._sdk_config.stream_max_workers, ) ) object.__setattr__(self, "_session", new_session) return new_session def _post( self, record: pb2.Record, headers: dict[str, str], timeout: float | None, indexes: tuple, ) -> cf.Future[Any]: """Post a record to Arize via async HTTP request with protobuf JSON serialization.""" from google.protobuf.json_format import MessageToDict session = self._ensure_session() resp = session.post( self._sdk_config.records_url, headers=headers, timeout=timeout, json=MessageToDict( message=record, preserving_proto_field_name=True, ), verify=self._sdk_config.request_verify, ) if indexes is not None and len(indexes) == 2: resp.starting_index = indexes[0] resp.ending_index = indexes[1] return resp
def _validate_mapping_key(key_name: str, name: str) -> None: """Validate that a mapping key (feature/tag name) is a string and doesn't end with '_shap'.""" if not isinstance(key_name, str): raise TypeError( f"{name} dictionary key {key_name} must be named with string, type used: {type(key_name)}" ) if key_name.endswith("_shap"): raise ValueError( f"{name} dictionary key {key_name} must not be named with a `_shap` suffix" ) return def _is_timestamp_in_range(now: int, ts: int) -> bool: """Check if a timestamp is within the acceptable range (1 year future, 2 years past).""" max_time = now + (MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60) min_time = now - (MAX_PAST_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60) return min_time <= ts <= max_time def _get_pb_schema( schema: Schema, model_id: str, model_version: str | None, model_type: ModelTypes, environment: Environments, batch_id: str, ) -> object: """Construct a protocol buffer Schema from the user's Schema for batch logging.""" s = pb2.Schema() s.constants.model_id = model_id if model_version is not None: s.constants.model_version = model_version if environment == Environments.PRODUCTION: s.constants.environment = pb2.Schema.Environment.PRODUCTION elif environment == Environments.VALIDATION: s.constants.environment = pb2.Schema.Environment.VALIDATION elif environment == Environments.TRAINING: s.constants.environment = pb2.Schema.Environment.TRAINING else: raise ValueError(f"unexpected environment: {environment}") # Map user-friendly external model types -> internal model types when sending to Arize if model_type in NUMERIC_MODEL_TYPES: s.constants.model_type = pb2.Schema.ModelType.NUMERIC elif model_type in CATEGORICAL_MODEL_TYPES: s.constants.model_type = pb2.Schema.ModelType.SCORE_CATEGORICAL elif model_type == ModelTypes.RANKING: s.constants.model_type = pb2.Schema.ModelType.RANKING elif model_type == ModelTypes.OBJECT_DETECTION: s.constants.model_type = pb2.Schema.ModelType.OBJECT_DETECTION elif model_type == ModelTypes.GENERATIVE_LLM: s.constants.model_type = pb2.Schema.ModelType.GENERATIVE_LLM elif model_type == ModelTypes.MULTI_CLASS: s.constants.model_type = pb2.Schema.ModelType.MULTI_CLASS if batch_id is not None: s.constants.batch_id = batch_id if schema.prediction_id_column_name is not None: s.arrow_schema.prediction_id_column_name = ( schema.prediction_id_column_name ) if schema.timestamp_column_name is not None: s.arrow_schema.timestamp_column_name = schema.timestamp_column_name if schema.prediction_label_column_name is not None: s.arrow_schema.prediction_label_column_name = ( schema.prediction_label_column_name ) if model_type == ModelTypes.OBJECT_DETECTION: if schema.object_detection_prediction_column_names is not None: obj_det_pred = schema.object_detection_prediction_column_names pred_labels = ( s.arrow_schema.prediction_object_detection_label_column_names ) pred_labels.bboxes_coordinates_column_name = ( obj_det_pred.bounding_boxes_coordinates_column_name ) pred_labels.bboxes_categories_column_name = ( obj_det_pred.categories_column_name ) if obj_det_pred.scores_column_name is not None: pred_labels.bboxes_scores_column_name = ( obj_det_pred.scores_column_name ) if schema.semantic_segmentation_prediction_column_names is not None: seg_pred_cols = schema.semantic_segmentation_prediction_column_names pred_seg_labels = s.arrow_schema.prediction_semantic_segmentation_label_column_names pred_seg_labels.polygons_coordinates_column_name = ( seg_pred_cols.polygon_coordinates_column_name ) pred_seg_labels.polygons_categories_column_name = ( seg_pred_cols.categories_column_name ) if schema.instance_segmentation_prediction_column_names is not None: inst_seg_pred_cols = ( schema.instance_segmentation_prediction_column_names ) pred_inst_seg_labels = s.arrow_schema.prediction_instance_segmentation_label_column_names pred_inst_seg_labels.polygons_coordinates_column_name = ( inst_seg_pred_cols.polygon_coordinates_column_name ) pred_inst_seg_labels.polygons_categories_column_name = ( inst_seg_pred_cols.categories_column_name ) if inst_seg_pred_cols.scores_column_name is not None: pred_inst_seg_labels.polygons_scores_column_name = ( inst_seg_pred_cols.scores_column_name ) if ( inst_seg_pred_cols.bounding_boxes_coordinates_column_name is not None ): pred_inst_seg_labels.bboxes_coordinates_column_name = ( inst_seg_pred_cols.bounding_boxes_coordinates_column_name ) if schema.prediction_score_column_name is not None: if model_type in NUMERIC_MODEL_TYPES: # allow numeric prediction to be sent in as either prediction_label (legacy) or # prediction_score. s.arrow_schema.prediction_label_column_name = ( schema.prediction_score_column_name ) else: s.arrow_schema.prediction_score_column_name = ( schema.prediction_score_column_name ) if schema.feature_column_names is not None: s.arrow_schema.feature_column_names.extend(schema.feature_column_names) if schema.embedding_feature_column_names is not None: for ( emb_name, emb_col_names, ) in schema.embedding_feature_column_names.items(): # emb_name is how it will show in the UI s.arrow_schema.embedding_feature_column_names_map[ emb_name ].vector_column_name = emb_col_names.vector_column_name if emb_col_names.data_column_name: s.arrow_schema.embedding_feature_column_names_map[ emb_name ].data_column_name = emb_col_names.data_column_name if emb_col_names.link_to_data_column_name: s.arrow_schema.embedding_feature_column_names_map[ emb_name ].link_to_data_column_name = ( emb_col_names.link_to_data_column_name ) if schema.prompt_column_names is not None: if isinstance(schema.prompt_column_names, str): s.arrow_schema.embedding_feature_column_names_map[ "prompt" ].data_column_name = schema.prompt_column_names elif isinstance(schema.prompt_column_names, EmbeddingColumnNames): col_names = schema.prompt_column_names s.arrow_schema.embedding_feature_column_names_map[ "prompt" ].vector_column_name = col_names.vector_column_name if col_names.data_column_name: s.arrow_schema.embedding_feature_column_names_map[ "prompt" ].data_column_name = col_names.data_column_name if schema.response_column_names is not None: if isinstance(schema.response_column_names, str): s.arrow_schema.embedding_feature_column_names_map[ "response" ].data_column_name = schema.response_column_names elif isinstance(schema.response_column_names, EmbeddingColumnNames): col_names = schema.response_column_names s.arrow_schema.embedding_feature_column_names_map[ "response" ].vector_column_name = col_names.vector_column_name if col_names.data_column_name: s.arrow_schema.embedding_feature_column_names_map[ "response" ].data_column_name = col_names.data_column_name if schema.tag_column_names is not None: s.arrow_schema.tag_column_names.extend(schema.tag_column_names) if ( model_type == ModelTypes.RANKING and schema.relevance_labels_column_name is not None ): s.arrow_schema.actual_label_column_name = ( schema.relevance_labels_column_name ) elif ( model_type == ModelTypes.RANKING and schema.attributions_column_name is not None ): s.arrow_schema.actual_label_column_name = ( schema.attributions_column_name ) elif schema.actual_label_column_name is not None: s.arrow_schema.actual_label_column_name = ( schema.actual_label_column_name ) if ( model_type == ModelTypes.RANKING and schema.relevance_score_column_name is not None ): s.arrow_schema.actual_score_column_name = ( schema.relevance_score_column_name ) elif schema.actual_score_column_name is not None: if model_type in NUMERIC_MODEL_TYPES: # allow numeric prediction to be sent in as either prediction_label (legacy) or # prediction_score. s.arrow_schema.actual_label_column_name = ( schema.actual_score_column_name ) else: s.arrow_schema.actual_score_column_name = ( schema.actual_score_column_name ) if schema.shap_values_column_names is not None: s.arrow_schema.shap_values_column_names.update( schema.shap_values_column_names ) if schema.prediction_group_id_column_name is not None: s.arrow_schema.prediction_group_id_column_name = ( schema.prediction_group_id_column_name ) if schema.rank_column_name is not None: s.arrow_schema.rank_column_name = schema.rank_column_name if model_type == ModelTypes.OBJECT_DETECTION: if schema.object_detection_actual_column_names is not None: obj_det_actual = schema.object_detection_actual_column_names actual_labels = ( s.arrow_schema.actual_object_detection_label_column_names ) actual_labels.bboxes_coordinates_column_name = ( obj_det_actual.bounding_boxes_coordinates_column_name ) actual_labels.bboxes_categories_column_name = ( obj_det_actual.categories_column_name ) if obj_det_actual.scores_column_name is not None: actual_labels.bboxes_scores_column_name = ( obj_det_actual.scores_column_name ) if schema.semantic_segmentation_actual_column_names is not None: sem_seg_actual = schema.semantic_segmentation_actual_column_names sem_seg_labels = ( s.arrow_schema.actual_semantic_segmentation_label_column_names ) sem_seg_labels.polygons_coordinates_column_name = ( sem_seg_actual.polygon_coordinates_column_name ) sem_seg_labels.polygons_categories_column_name = ( sem_seg_actual.categories_column_name ) if schema.instance_segmentation_actual_column_names is not None: inst_seg_actual = schema.instance_segmentation_actual_column_names inst_seg_labels = ( s.arrow_schema.actual_instance_segmentation_label_column_names ) inst_seg_labels.polygons_coordinates_column_name = ( inst_seg_actual.polygon_coordinates_column_name ) inst_seg_labels.polygons_categories_column_name = ( inst_seg_actual.categories_column_name ) if ( inst_seg_actual.bounding_boxes_coordinates_column_name is not None ): inst_seg_labels.bboxes_coordinates_column_name = ( inst_seg_actual.bounding_boxes_coordinates_column_name ) if model_type == ModelTypes.GENERATIVE_LLM: if schema.prompt_template_column_names is not None: prompt_template_names = schema.prompt_template_column_names arrow_prompt_names = s.arrow_schema.prompt_template_column_names arrow_prompt_names.template_column_name = ( prompt_template_names.template_column_name ) arrow_prompt_names.template_version_column_name = ( prompt_template_names.template_version_column_name ) if schema.llm_config_column_names is not None: s.arrow_schema.llm_config_column_names.model_column_name = ( schema.llm_config_column_names.model_column_name ) s.arrow_schema.llm_config_column_names.params_map_column_name = ( schema.llm_config_column_names.params_column_name ) if schema.retrieved_document_ids_column_name is not None: s.arrow_schema.retrieved_document_ids_column_name = ( schema.retrieved_document_ids_column_name ) if model_type == ModelTypes.MULTI_CLASS: if schema.prediction_score_column_name is not None: s.arrow_schema.prediction_score_column_name = ( schema.prediction_score_column_name ) if schema.multi_class_threshold_scores_column_name is not None: s.arrow_schema.multi_class_threshold_scores_column_name = ( schema.multi_class_threshold_scores_column_name ) if schema.actual_score_column_name is not None: s.arrow_schema.actual_score_column_name = ( schema.actual_score_column_name ) return s def _get_pb_schema_corpus( schema: CorpusSchema, model_id: str, ) -> pb2.Schema: """Construct a protocol buffer Schema from CorpusSchema for document corpus logging.""" s = pb2.Schema() s.constants.model_id = model_id s.constants.environment = pb2.Schema.Environment.CORPUS s.constants.model_type = pb2.Schema.ModelType.GENERATIVE_LLM if schema.document_id_column_name is not None: s.arrow_schema.document_column_names.id_column_name = ( schema.document_id_column_name ) if schema.document_version_column_name is not None: s.arrow_schema.document_column_names.version_column_name = ( schema.document_version_column_name ) if schema.document_text_embedding_column_names is not None: doc_text_emb_cols = schema.document_text_embedding_column_names doc_text_col = s.arrow_schema.document_column_names.text_column_name doc_text_col.vector_column_name = doc_text_emb_cols.vector_column_name doc_text_col.data_column_name = doc_text_emb_cols.data_column_name if doc_text_emb_cols.link_to_data_column_name is not None: doc_text_col.link_to_data_column_name = ( doc_text_emb_cols.link_to_data_column_name ) return s