Source code for arize.embeddings.cv_generators

"""Computer vision embedding generators for image classification and object detection."""

from arize.embeddings.base_generators import CVEmbeddingGenerator
from arize.embeddings.constants import (
    DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
    DEFAULT_CV_OBJECT_DETECTION_MODEL,
)
from arize.embeddings.usecases import UseCases


[docs] class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator): """Embedding generator for computer vision image classification tasks.""" def __init__( self, model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, **kwargs: object, ) -> None: """Initialize the image classification embedding generator. Args: model_name: Name of the pre-trained vision model. **kwargs: Additional arguments for model initialization. """ super().__init__( use_case=UseCases.CV.IMAGE_CLASSIFICATION, model_name=model_name, **kwargs, # type: ignore[arg-type] )
[docs] class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator): """Embedding generator for computer vision object detection tasks.""" def __init__( self, model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL, **kwargs: object, ) -> None: """Initialize the object detection embedding generator. Args: model_name: Name of the pre-trained vision model. **kwargs: Additional arguments for model initialization. """ super().__init__( use_case=UseCases.CV.OBJECT_DETECTION, model_name=model_name, **kwargs, # type: ignore[arg-type] )