"""Client implementation for managing prompts in the Arize platform."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from arize.pre_releases import ReleaseStage, prerelease_endpoint
from arize.prompts.types import PromptVersion, PromptWithVersion
from arize.utils.resolve import (
_find_prompt_id,
_find_space_id,
_resolve_resource,
)
if TYPE_CHECKING:
import builtins
from arize._generated.api_client.api_client import ApiClient
from arize.config import SDKConfiguration
from arize.prompts.types import (
InputVariableFormat,
InvocationParams,
LLMMessage,
LlmProvider,
Prompt,
PromptListResponse,
PromptVersionLabelsResponse,
PromptVersionListResponse,
ProviderParams,
)
logger = logging.getLogger(__name__)
[docs]
class PromptsClient:
"""Client for managing prompts in the Arize platform.
This class is primarily intended for internal use within the SDK. Users are
highly encouraged to access resource-specific functionality via
:class:`arize.ArizeClient`.
The prompts client is a thin wrapper around the generated REST API client,
using the shared generated API client owned by
:class:`arize.config.SDKConfiguration`.
"""
def __init__(
self, *, sdk_config: SDKConfiguration, generated_client: ApiClient
) -> None:
"""
Args:
sdk_config: Resolved SDK configuration.
generated_client: Shared generated API client instance.
""" # noqa: D205, D212
self._sdk_config = sdk_config
# Import at runtime so it's still lazy and extras-gated by the parent
from arize._generated import api_client as gen
# Use the provided client directly
self._api = gen.PromptsApi(generated_client)
self._spaces_api = gen.SpacesApi(generated_client)
[docs]
@prerelease_endpoint(key="prompts.list", stage=ReleaseStage.ALPHA)
def list(
self,
*,
name: str | None = None,
space: str | None = None,
limit: int = 100,
cursor: str | None = None,
) -> PromptListResponse:
"""List prompts in a space.
Args:
name: Optional case-insensitive substring filter on the prompt name.
space: Optional space filter. If the value is a base64-encoded resource ID it is
treated as a space ID; otherwise it is used as a case-insensitive
substring filter on the space name.
limit: Maximum number of prompts to return. The server enforces an
upper bound of 100.
cursor: Opaque pagination cursor returned from a previous response.
Returns:
A response object with the prompts and pagination information.
Raises:
ApiException: If the REST API
returns an error response (e.g. 401/403/429).
"""
resolved_space = _resolve_resource(space)
return self._api.prompts_list(
space_id=resolved_space.id,
space_name=resolved_space.name,
name=name,
limit=limit,
cursor=cursor,
)
[docs]
@prerelease_endpoint(key="prompts.create", stage=ReleaseStage.ALPHA)
def create(
self,
*,
space: str,
name: str,
commit_message: str,
input_variable_format: InputVariableFormat,
provider: LlmProvider,
messages: builtins.list[LLMMessage],
description: str | None = None,
model: str | None = None,
invocation_params: InvocationParams | None = None,
provider_params: ProviderParams | None = None,
) -> PromptWithVersion:
"""Create a prompt with an initial version.
Args:
space: Space ID or name to create the prompt in. If a name is
provided it will be resolved to a space ID automatically.
name: Prompt name (must be unique within the space).
commit_message: Commit message describing the initial version.
input_variable_format: Variable interpolation format for the prompt
template (e.g. ``InputVariableFormat.F_STRING``).
provider: LLM provider for the prompt.
messages: Messages that make up the prompt template (at least one required).
description: Optional description of the prompt.
model: Optional model name. If omitted, no default model is set on the
version.
invocation_params: Optional invocation parameters (e.g. temperature,
max_tokens).
provider_params: Optional provider-specific parameters.
Returns:
The created prompt with its initial version.
Raises:
ApiException: If the REST API
returns an error response (e.g. 400/401/403/409/429).
"""
space_id = _find_space_id(self._spaces_api, space)
from arize._generated import api_client as gen
version = gen.PromptVersionCreateRequest(
commit_message=commit_message,
input_variable_format=input_variable_format,
provider=provider,
model=model,
messages=messages,
invocation_params=invocation_params,
provider_params=provider_params,
)
body = gen.PromptsCreateRequest(
space_id=space_id,
name=name,
description=description,
version=version,
)
result = self._api.prompts_create(prompts_create_request=body)
return PromptWithVersion.model_validate(result, from_attributes=True)
[docs]
@prerelease_endpoint(key="prompts.get", stage=ReleaseStage.ALPHA)
def get(
self,
*,
prompt: str,
space: str | None = None,
version_id: str | None = None,
label: str | None = None,
) -> PromptWithVersion:
"""Get a prompt by ID or name.
Optionally resolves a specific version by ``version_id`` or a ``label``.
If neither is supplied, the latest version is returned.
Args:
prompt: Prompt ID or name. If a name is provided, ``space`` must
also be supplied so the name can be resolved.
space: Optional space ID or name. Required when *prompt* is a name.
version_id: Optional specific version ID to retrieve.
label: Optional label name to resolve to a version (e.g. ``"production"``).
Returns:
The prompt object with its resolved version.
Raises:
ApiException: If the REST API
returns an error response (e.g. 401/403/404/429).
"""
prompt_id = _find_prompt_id(
api=self._api,
prompt=prompt,
space=space,
)
result = self._api.prompts_get(
prompt_id=prompt_id,
version_id=version_id,
label=label,
)
return PromptWithVersion.model_validate(result, from_attributes=True)
[docs]
@prerelease_endpoint(key="prompts.update", stage=ReleaseStage.ALPHA)
def update(
self,
*,
prompt: str,
space: str | None = None,
description: str,
) -> Prompt:
"""Update a prompt's metadata.
Args:
prompt: Prompt ID or name. If a name is provided, ``space`` must
also be supplied so the name can be resolved.
space: Optional space ID or name. Required when *prompt* is a name.
description: Updated description for the prompt.
Returns:
The updated prompt object.
Raises:
ValueError: If no fields to update are provided.
ApiException: If the REST API
returns an error response (e.g. 401/403/404/429).
"""
prompt_id = _find_prompt_id(
api=self._api,
prompt=prompt,
space=space,
)
from arize._generated import api_client as gen
body = gen.PromptsUpdateRequest(description=description)
return self._api.prompts_update(
prompt_id=prompt_id, prompts_update_request=body
)
[docs]
@prerelease_endpoint(key="prompts.delete", stage=ReleaseStage.ALPHA)
def delete(self, *, prompt: str, space: str | None = None) -> None:
"""Delete a prompt by ID or name.
This operation is irreversible and removes all associated versions.
Args:
prompt: Prompt ID or name. If a name is provided, ``space`` must
also be supplied so the name can be resolved.
space: Optional space ID or name. Required when *prompt* is a name.
Returns:
None on success (204 No Content).
Raises:
ApiException: If the REST API
returns an error response (e.g. 401/403/404/429).
"""
prompt_id = _find_prompt_id(
api=self._api,
prompt=prompt,
space=space,
)
return self._api.prompts_delete(prompt_id=prompt_id)
[docs]
@prerelease_endpoint(key="prompts.list_versions", stage=ReleaseStage.ALPHA)
def list_versions(
self,
*,
prompt: str,
space: str | None = None,
limit: int = 100,
cursor: str | None = None,
) -> PromptVersionListResponse:
"""List versions for a prompt.
Args:
prompt: Prompt ID or name. If a name is provided, ``space`` must
also be supplied so the name can be resolved.
space: Optional space ID or name. Required when *prompt* is a name.
limit: Maximum number of versions to return. The server enforces an
upper bound of 100.
cursor: Opaque pagination cursor returned from a previous response.
Returns:
A response object with the prompt versions and pagination information.
Raises:
ApiException: If the REST API
returns an error response (e.g. 401/403/404/429).
"""
prompt_id = _find_prompt_id(
api=self._api,
prompt=prompt,
space=space,
)
return self._api.prompt_versions_list(
prompt_id=prompt_id,
limit=limit,
cursor=cursor,
)
[docs]
@prerelease_endpoint(key="prompts.create_version", stage=ReleaseStage.ALPHA)
def create_version(
self,
*,
prompt: str,
space: str | None = None,
commit_message: str,
input_variable_format: InputVariableFormat,
provider: LlmProvider,
messages: builtins.list[LLMMessage],
model: str | None = None,
invocation_params: InvocationParams | None = None,
provider_params: ProviderParams | None = None,
) -> PromptVersion:
"""Create a new version for an existing prompt.
Args:
prompt: Prompt ID or name. If a name is provided, ``space`` must
also be supplied so the name can be resolved.
space: Optional space ID or name. Required when *prompt* is a name.
commit_message: Commit message describing this version.
input_variable_format: Variable interpolation format for the prompt
template (e.g. ``InputVariableFormat.F_STRING``).
provider: LLM provider for this version.
messages: Messages that make up the prompt template (at least one required).
model: Optional model name. If omitted, no default model is set on this
version.
invocation_params: Optional invocation parameters (e.g. temperature,
max_tokens).
provider_params: Optional provider-specific parameters.
Returns:
The created prompt version.
Raises:
ApiException: If the REST API
returns an error response (e.g. 400/401/403/404/429).
"""
prompt_id = _find_prompt_id(
api=self._api,
prompt=prompt,
space=space,
)
from arize._generated import api_client as gen
body = gen.PromptVersionsCreateRequest(
commit_message=commit_message,
input_variable_format=input_variable_format,
provider=provider,
model=model,
messages=messages,
invocation_params=invocation_params,
provider_params=provider_params,
)
result = self._api.prompt_versions_create(
prompt_id=prompt_id, prompt_versions_create_request=body
)
return PromptVersion.model_validate(result, from_attributes=True)
[docs]
@prerelease_endpoint(key="prompts.get_label", stage=ReleaseStage.ALPHA)
def get_label(
self, *, prompt: str, space: str | None = None, label_name: str
) -> PromptVersion:
"""Resolve a label to a prompt version.
Args:
prompt: Prompt ID or name. If a name is provided, ``space`` must
also be supplied so the name can be resolved.
space: Optional space ID or name. Required when *prompt* is a name.
label_name: Label name to resolve (e.g. ``"production"``, ``"staging"``).
Returns:
The prompt version the label currently points to.
Raises:
ApiException: If the REST API
returns an error response (e.g. 401/403/404/429).
"""
prompt_id = _find_prompt_id(
api=self._api,
prompt=prompt,
space=space,
)
result = self._api.prompt_labels_get(
prompt_id=prompt_id, label_name=label_name
)
return PromptVersion.model_validate(result, from_attributes=True)
[docs]
@prerelease_endpoint(key="prompts.set_labels", stage=ReleaseStage.ALPHA)
def set_labels(
self,
*,
version_id: str,
labels: builtins.list[str],
) -> PromptVersionLabelsResponse:
"""Set labels on a prompt version.
Replaces all existing labels on the version with the provided list.
Args:
version_id: Version ID to set labels on.
labels: List of label names to assign (replaces all existing labels).
Returns:
The response with the updated labels.
Raises:
ApiException: If the REST API
returns an error response (e.g. 400/401/403/404/429).
"""
from arize._generated import api_client as gen
body = gen.PromptVersionLabelsSetRequest(labels=labels)
return self._api.prompt_version_labels_set(
version_id=version_id, prompt_version_labels_set_request=body
)
[docs]
@prerelease_endpoint(key="prompts.delete_label", stage=ReleaseStage.ALPHA)
def delete_label(self, *, version_id: str, label_name: str) -> None:
"""Remove a label from a prompt version.
Args:
version_id: Version ID to remove the label from.
label_name: Label name to remove (e.g. ``"production"``, ``"staging"``).
Returns:
None on success (204 No Content).
Raises:
ApiException: If the REST API
returns an error response (e.g. 401/403/404/429).
"""
return self._api.prompt_version_labels_delete(
version_id=version_id, label_name=label_name
)