Source code for arize.embeddings.auto_generator

"""Automatic embedding generation factory for various ML use cases."""

from typing import TypeAlias

import pandas as pd

from arize.embeddings import constants
from arize.embeddings.base_generators import BaseEmbeddingGenerator
from arize.embeddings.constants import (
    CV_PRETRAINED_MODELS,
    DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
    DEFAULT_CV_OBJECT_DETECTION_MODEL,
    DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
    DEFAULT_NLP_SUMMARIZATION_MODEL,
    DEFAULT_TABULAR_MODEL,
    NLP_PRETRAINED_MODELS,
)
from arize.embeddings.cv_generators import (
    EmbeddingGeneratorForCVImageClassification,
    EmbeddingGeneratorForCVObjectDetection,
)
from arize.embeddings.nlp_generators import (
    EmbeddingGeneratorForNLPSequenceClassification,
    EmbeddingGeneratorForNLPSummarization,
)
from arize.embeddings.tabular_generators import (
    EmbeddingGeneratorForTabularFeatures,
)
from arize.embeddings.usecases import (
    CVUseCases,
    NLPUseCases,
    TabularUseCases,
    UseCases,
)

UseCaseLike: TypeAlias = str | NLPUseCases | CVUseCases | TabularUseCases


[docs] class EmbeddingGenerator: """Factory class for creating embedding generators based on use case.""" def __init__(self, **kwargs: str) -> None: """Raise error directing users to use from_use_case factory method. Raises: OSError: Always raised to prevent direct instantiation. """ raise OSError( f"{self.__class__.__name__} is designed to be instantiated using the " f"`{self.__class__.__name__}.from_use_case(use_case, **kwargs)` method." )
[docs] @staticmethod def from_use_case( use_case: UseCaseLike, **kwargs: object ) -> BaseEmbeddingGenerator: """Create an embedding generator for the specified use case.""" if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION: return EmbeddingGeneratorForNLPSequenceClassification(**kwargs) # type: ignore[arg-type] if use_case == UseCases.NLP.SUMMARIZATION: return EmbeddingGeneratorForNLPSummarization(**kwargs) # type: ignore[arg-type] if use_case == UseCases.CV.IMAGE_CLASSIFICATION: return EmbeddingGeneratorForCVImageClassification(**kwargs) # type: ignore[arg-type] if use_case == UseCases.CV.OBJECT_DETECTION: return EmbeddingGeneratorForCVObjectDetection(**kwargs) # type: ignore[arg-type] if use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS: return EmbeddingGeneratorForTabularFeatures(**kwargs) # type: ignore[arg-type] raise ValueError(f"Invalid use case {use_case}")
[docs] @classmethod def list_default_models(cls) -> pd.DataFrame: """Return a :class:`pandas.DataFrame` of default models for each use case.""" df = pd.DataFrame( { "Area": ["NLP", "NLP", "CV", "CV", "STRUCTURED"], "Usecase": [ UseCases.NLP.SEQUENCE_CLASSIFICATION.name, UseCases.NLP.SUMMARIZATION.name, UseCases.CV.IMAGE_CLASSIFICATION.name, UseCases.CV.OBJECT_DETECTION.name, UseCases.STRUCTURED.TABULAR_EMBEDDINGS.name, ], "Model Name": [ DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL, DEFAULT_NLP_SUMMARIZATION_MODEL, DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, DEFAULT_CV_OBJECT_DETECTION_MODEL, DEFAULT_TABULAR_MODEL, ], } ) df.sort_values(by=list(df.columns), ascending=True, inplace=True) return df.reset_index(drop=True)
[docs] @classmethod def list_pretrained_models(cls) -> pd.DataFrame: """Return a :class:`pandas.DataFrame` of all available pretrained models.""" data = { "Task": ["NLP" for _ in NLP_PRETRAINED_MODELS] + ["CV" for _ in CV_PRETRAINED_MODELS], "Architecture": [ cls.__parse_model_arch(model) for model in NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS ], "Model Name": NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS, } df = pd.DataFrame(data) df.sort_values(by=list(df.columns), ascending=True, inplace=True) return df.reset_index(drop=True)
@staticmethod def __parse_model_arch(model_name: str) -> str: if constants.GPT.lower() in model_name.lower(): return constants.GPT if constants.BERT.lower() in model_name.lower(): return constants.BERT if constants.VIT.lower() in model_name.lower(): return constants.VIT raise ValueError("Invalid model_name, unknown architecture.")