Source code for arize.config

"""SDK configuration and settings management for the Arize client."""

import logging
from dataclasses import dataclass, field, fields
from pathlib import Path

from arize._env import (
    _env_bool,
    _env_float,
    _env_http_scheme,
    _env_int,
    _env_str,
)
from arize._headers import (
    _builtin_flight_headers,
    _builtin_grpc_headers,
    _builtin_http_headers,
    _validate_default_headers,
)
from arize.constants.config import (
    DEFAULT_API_HOST,
    DEFAULT_API_SCHEME,
    DEFAULT_ARIZE_DIRECTORY,
    DEFAULT_ENABLE_CACHING,
    DEFAULT_FLIGHT_HOST,
    DEFAULT_FLIGHT_PORT,
    DEFAULT_FLIGHT_SCHEME,
    DEFAULT_MAX_HTTP_PAYLOAD_SIZE_MB,
    DEFAULT_MAX_PAST_YEARS,
    DEFAULT_OTLP_HOST,
    DEFAULT_OTLP_SCHEME,
    DEFAULT_PYARROW_MAX_CHUNKSIZE,
    DEFAULT_REQUEST_VERIFY,
    DEFAULT_STREAM_MAX_QUEUE_BOUND,
    DEFAULT_STREAM_MAX_WORKERS,
    ENV_API_HOST,
    ENV_API_KEY,
    ENV_API_SCHEME,
    ENV_ARIZE_DIRECTORY,
    ENV_BASE_DOMAIN,
    ENV_ENABLE_CACHING,
    ENV_FLIGHT_HOST,
    ENV_FLIGHT_PORT,
    ENV_FLIGHT_SCHEME,
    ENV_MAX_HTTP_PAYLOAD_SIZE_MB,
    ENV_MAX_PAST_YEARS,
    ENV_OTLP_HOST,
    ENV_OTLP_SCHEME,
    ENV_PYARROW_MAX_CHUNKSIZE,
    ENV_REGION,
    ENV_REQUEST_VERIFY,
    ENV_SINGLE_HOST,
    ENV_SINGLE_PORT,
    ENV_STREAM_MAX_QUEUE_BOUND,
    ENV_STREAM_MAX_WORKERS,
)
from arize.constants.pyarrow import MAX_CHUNKSIZE
from arize.exceptions.auth import MissingAPIKeyError
from arize.exceptions.config import (
    MultipleEndpointOverridesError,
)
from arize.regions import REGION_ENDPOINTS, Region

logger = logging.getLogger(__name__)

SENSITIVE_FIELD_MARKERS = ("key", "token", "secret")


def _is_sensitive_field(name: str) -> bool:
    """Check if a field name contains sensitive information markers.

    Args:
        name: The field name to check.

    Returns:
        bool: True if the field name contains 'key', 'token', or 'secret' (case-insensitive).
    """
    n = name.lower()
    return bool(any(k in n for k in SENSITIVE_FIELD_MARKERS))


def _mask_secret(secret: str, N: int = 4) -> str:
    """Mask a secret string by showing only the first N characters.

    Args:
        secret: The secret string to mask.
        N: Number of characters to show before masking. Defaults to 4.

    Returns:
        str: The masked string (first N chars + '***'), or empty string if input is empty.
    """
    if len(secret) == 0:
        return ""
    return f"{secret[:N]}***"


def _endpoint(scheme: str, base: str, path: str = "") -> str:
    """Construct a full endpoint URL from scheme, base, and optional path.

    Args:
        scheme: The URL scheme (e.g., "http", "https").
        base: The base URL or hostname.
        path: Optional path to append to the base URL. Defaults to empty string.

    Returns:
        str: The fully constructed endpoint URL.
    """
    endpoint = scheme + "://" + base.rstrip("/")
    if path:
        endpoint += "/" + path.lstrip("/")
    return endpoint


[docs] @dataclass(frozen=True) class SDKConfiguration: """Configuration for the Arize SDK with endpoint and authentication settings. This class holds pure configuration data and does not manage client lifecycle. Client creation and caching is handled by :class:`arize.ArizeClient`. This class is used internally by ArizeClient to manage SDK configuration. It is not recommended to use this class directly; users should interact with ArizeClient instead. Each configuration parameter follows this resolution order: 1. Explicit value passed to ArizeClient constructor (highest priority) 2. Environment variable value 3. Built-in default value (lowest priority) Args: api_key: Arize API key for authentication. Required. Environment variable: ARIZE_API_KEY. Default: None (must be provided via argument or environment variable). api_host: API endpoint host. Environment variable: ARIZE_API_HOST. Default: "api.arize.com". api_scheme: API endpoint scheme (http/https). Environment variable: ARIZE_API_SCHEME. Default: "https". otlp_host: OTLP (OpenTelemetry Protocol) endpoint host. Environment variable: ARIZE_OTLP_HOST. Default: "otlp.arize.com". otlp_scheme: OTLP endpoint scheme (http/https). Environment variable: ARIZE_OTLP_SCHEME. Default: "https". flight_host: Apache Arrow Flight endpoint host. Environment variable: ARIZE_FLIGHT_HOST. Default: "flight.arize.com". flight_port: Apache Arrow Flight endpoint port (1-65535). Environment variable: ARIZE_FLIGHT_PORT. Default: 443. flight_scheme: Apache Arrow Flight endpoint scheme. Environment variable: ARIZE_FLIGHT_SCHEME. Default: "grpc+tls". pyarrow_max_chunksize: Maximum chunk size for PyArrow operations (1 to MAX_CHUNKSIZE). Environment variable: ARIZE_MAX_CHUNKSIZE. Default: 10_000. request_verify: Whether to verify SSL certificates for HTTP requests. Environment variable: ARIZE_REQUEST_VERIFY. Default: True. stream_max_workers: Maximum number of worker threads for streaming operations (minimum: 1). Environment variable: ARIZE_STREAM_MAX_WORKERS. Default: 8. stream_max_queue_bound: Maximum queue size for streaming operations (minimum: 1). Environment variable: ARIZE_STREAM_MAX_QUEUE_BOUND. Default: 5000. max_http_payload_size_mb: Maximum HTTP payload size in megabytes (minimum: 1). Environment variable: ARIZE_MAX_HTTP_PAYLOAD_SIZE_MB. Default: 100. arize_directory: Directory for Arize SDK files (cache, logs, etc.). Environment variable: ARIZE_DIRECTORY. Default: "~/.arize". enable_caching: Whether to enable local caching. Environment variable: ARIZE_ENABLE_CACHING. Default: True. region: Arize region (e.g., US_CENTRAL, EU_WEST). When specified, overrides individual host/port settings. Environment variable: ARIZE_REGION. Default: :class:`Region.UNSET`. single_host: Single host to use for all endpoints. When specified, overrides individual host settings. Environment variable: ARIZE_SINGLE_HOST. Default: "" (not set). single_port: Single port to use for all endpoints. When specified, overrides individual port settings (0-65535). Environment variable: ARIZE_SINGLE_PORT. Default: 0 (not set). base_domain: Base domain for generating all endpoint hosts. Intended for Private Connect setups. When specified, generates hosts as api.<base_domain>, otlp.<base_domain>, flight.<base_domain>. When specified, overrides individual host settings. Environment variable: ARIZE_BASE_DOMAIN. Default: "" (not set). max_past_years: Maximum number of years in the past allowed for prediction timestamps. For on-prem deployments with custom retention policies, this can be increased. When set to a value other than the default, a warning will be issued advising to contact Arize support. Environment variable: ARIZE_MAX_PAST_YEARS. Default: 5. default_headers: Custom headers added to every outbound request across all transports (HTTP REST, grpc-gateway, and Apache Arrow Flight). They appear in the ``headers``, ``headers_grpc``, and ``headers_flight`` properties. On the gRPC path each key is automatically prefixed with ``Grpc-Metadata-`` so grpc-gateway forwards it to the backend service. Keys may not collide (case-insensitively) with the SDK's built-in headers, start with ``Grpc-Metadata-``, or contain control characters; violations raise InvalidDefaultHeadersError. Unlike other fields, this has no environment variable and is programmatic-only (argument or default). Default: {} (empty). Note: The endpoint override options (region, single_host/single_port, base_domain) are mutually exclusive. Specifying more than one will raise MultipleEndpointOverridesError. Raises: MissingAPIKeyError: If api_key is not provided via argument or environment variable. MultipleEndpointOverridesError: If multiple endpoint override options are provided. InvalidDefaultHeadersError: If default_headers contains an invalid or reserved header. """ api_key: str = field( default_factory=lambda: _env_str(ENV_API_KEY, ""), ) api_host: str = field( default_factory=lambda: _env_str(ENV_API_HOST, DEFAULT_API_HOST) ) api_scheme: str = field( default_factory=lambda: _env_http_scheme( ENV_API_SCHEME, DEFAULT_API_SCHEME, ), ) otlp_host: str = field( default_factory=lambda: _env_str(ENV_OTLP_HOST, DEFAULT_OTLP_HOST) ) otlp_scheme: str = field( default_factory=lambda: _env_http_scheme( ENV_OTLP_SCHEME, DEFAULT_OTLP_SCHEME, ), ) flight_host: str = field( default_factory=lambda: _env_str(ENV_FLIGHT_HOST, DEFAULT_FLIGHT_HOST) ) flight_port: int = field( default_factory=lambda: _env_int( ENV_FLIGHT_PORT, DEFAULT_FLIGHT_PORT, min_val=1, max_val=65535 ) ) flight_scheme: str = field( default_factory=lambda: _env_str( ENV_FLIGHT_SCHEME, DEFAULT_FLIGHT_SCHEME, ), ) pyarrow_max_chunksize: int = field( default_factory=lambda: _env_int( ENV_PYARROW_MAX_CHUNKSIZE, DEFAULT_PYARROW_MAX_CHUNKSIZE, min_val=1, max_val=MAX_CHUNKSIZE, ) ) request_verify: bool = field( default_factory=lambda: _env_bool( ENV_REQUEST_VERIFY, DEFAULT_REQUEST_VERIFY ) ) stream_max_workers: int = field( default_factory=lambda: _env_int( ENV_STREAM_MAX_WORKERS, DEFAULT_STREAM_MAX_WORKERS, min_val=1, max_val=32, ) ) stream_max_queue_bound: int = field( default_factory=lambda: _env_int( ENV_STREAM_MAX_QUEUE_BOUND, DEFAULT_STREAM_MAX_QUEUE_BOUND, min_val=1, ) ) max_http_payload_size_mb: float = field( default_factory=lambda: _env_float( ENV_MAX_HTTP_PAYLOAD_SIZE_MB, DEFAULT_MAX_HTTP_PAYLOAD_SIZE_MB, min_val=1, ) ) arize_directory: str = field( default_factory=lambda: _env_str( ENV_ARIZE_DIRECTORY, DEFAULT_ARIZE_DIRECTORY ) ) enable_caching: bool = field( default_factory=lambda: _env_bool( ENV_ENABLE_CACHING, DEFAULT_ENABLE_CACHING ) ) region: Region = field( default_factory=lambda: Region(_env_str(ENV_REGION, "")) ) single_host: str = field( default_factory=lambda: _env_str(ENV_SINGLE_HOST, "") ) single_port: int = field( default_factory=lambda: _env_int( ENV_SINGLE_PORT, 0, min_val=0, max_val=65535 ) ) base_domain: str = field( default_factory=lambda: _env_str(ENV_BASE_DOMAIN, "") ) max_past_years: int = field( default_factory=lambda: _env_int( ENV_MAX_PAST_YEARS, DEFAULT_MAX_PAST_YEARS, min_val=1 ) ) default_headers: dict[str, str] = field(default_factory=dict) def __post_init__(self) -> None: """Validate and configure SDK endpoints after initialization. Endpoint override options are mutually exclusive. Only one of the following can be specified: 1. region - Overrides all via REGION_ENDPOINTS mapping 2. single_host/single_port - Overrides individual hosts/ports 3. base_domain - Generates hosts from base domain If none are specified, per-endpoint host/port settings are used. Raises: MissingAPIKeyError: If api_key is not provided. MultipleEndpointOverridesError: If multiple endpoint override options are provided. """ # Validate configuration if not self.api_key: raise MissingAPIKeyError() # Validate and freeze user-supplied default headers. Store a defensive # shallow copy so a caller mutating their dict afterward cannot alter # this (frozen) configuration. _validate_default_headers(self.default_headers) object.__setattr__(self, "default_headers", dict(self.default_headers)) # Check which override options are set has_base_domain = bool(self.base_domain) has_single_host = bool(self.single_host) has_single_port = self.single_port != 0 has_region = self.region is not Region.UNSET # Ensure only one override method is used (mutually exclusive) override_count = sum( [has_base_domain, has_single_host or has_single_port, has_region] ) if override_count > 1: # Determine which overrides were provided provided_overrides = [] if has_region: provided_overrides.append(f"region={self.region.value}") if has_single_host or has_single_port: if has_single_host: provided_overrides.append( f"single_host={self.single_host!r}" ) if has_single_port: provided_overrides.append(f"single_port={self.single_port}") if has_base_domain: provided_overrides.append(f"base_domain={self.base_domain!r}") error_message = ( f"Multiple endpoint override options provided: {', '.join(provided_overrides)}. " "Only one of the following can be specified: 'region', " "'single_host'/'single_port', or 'base_domain'." ) logger.error(error_message) raise MultipleEndpointOverridesError(error_message) if has_base_domain: logger.info( "Base domain %r provided; generating hosts from base domain.", self.base_domain, ) object.__setattr__(self, "api_host", f"api.{self.base_domain}") object.__setattr__(self, "otlp_host", f"otlp.{self.base_domain}") object.__setattr__( self, "flight_host", f"flight.{self.base_domain}" ) if has_single_host: logger.info( "Single host %r provided; overriding hosts configuration with single host.", self.single_host, ) object.__setattr__(self, "api_host", self.single_host) object.__setattr__(self, "otlp_host", self.single_host) object.__setattr__(self, "flight_host", self.single_host) if has_single_port: logger.info( "Single port %s provided; overriding ports configuration with single port.", self.single_port, ) object.__setattr__(self, "flight_port", self.single_port) if has_region: logger.info( "Region %s provided; overriding hosts & ports configuration with region defaults.", self.region.value, ) endpoints = REGION_ENDPOINTS[self.region] object.__setattr__(self, "api_host", endpoints.api_host) object.__setattr__(self, "otlp_host", endpoints.otlp_host) object.__setattr__(self, "flight_host", endpoints.flight_host) object.__setattr__(self, "flight_port", endpoints.flight_port) if self.max_past_years != DEFAULT_MAX_PAST_YEARS: logger.warning( f"max_past_years is set to {self.max_past_years} (default: {DEFAULT_MAX_PAST_YEARS}). " "This setting allows timestamps older than the default limit. " "Please contact Arize support to enable custom timestamp limits for your account." ) @property def cache_dir(self) -> str: """Return the path to the cache directory.""" return str(Path(self.arize_directory) / "cache") @property def api_url(self) -> str: """Return the base API URL.""" return _endpoint(self.api_scheme, self.api_host) @property def otlp_url(self) -> str: """Return the OTLP endpoint URL.""" return _endpoint(self.otlp_scheme, self.otlp_host, "/v1") @property def files_url(self) -> str: """Return the files upload endpoint URL.""" return _endpoint(self.api_scheme, self.api_host, "/v1/pandas_arrow") @property def records_url(self) -> str: """Return the records logging endpoint URL.""" return _endpoint(self.api_scheme, self.api_host, "/v1/log") @property def headers(self) -> dict[str, str]: """Return HTTP headers for API requests. User-supplied ``default_headers`` are included verbatim; the SDK's built-in headers are merged last and always win. """ return {**self.default_headers, **_builtin_http_headers(self.api_key)} @property def headers_grpc(self) -> dict[str, str]: """Return headers for gRPC (grpc-gateway) requests. User-supplied ``default_headers`` are each prefixed with ``Grpc-Metadata-`` so grpc-gateway forwards them to the backend service; the SDK's built-in headers are merged last and always win. """ prefixed = { f"Grpc-Metadata-{key}": value for key, value in self.default_headers.items() } return {**prefixed, **_builtin_grpc_headers(self.api_key)} @property def headers_flight(self) -> dict[str, str]: """Return headers for Apache Arrow Flight requests. User-supplied ``default_headers`` are included verbatim; the SDK's built-in headers are merged last and always win. The Flight client is responsible for byte-encoding these into the wire format. """ return {**self.default_headers, **_builtin_flight_headers(self.api_key)} def __repr__(self) -> str: """Return a detailed string representation with masked sensitive fields.""" # Dynamically build repr for all fields lines = [f"{self.__class__.__name__}("] for f in fields(self): if not f.repr: continue val = getattr(self, f.name) if f.name == "default_headers": # Mask values whose header name looks sensitive; show the rest. val = { k: (_mask_secret(v, 6) if _is_sensitive_field(k) else v) for k, v in val.items() } elif _is_sensitive_field(f.name): val = _mask_secret(val, 6) lines.append(f" {f.name}={val!r},") lines.append(")") return "\n".join(lines)