Source code for arize.embeddings.nlp_generators

"""NLP embedding generators for text classification and summarization tasks."""

import logging
from functools import partial

import pandas as pd

from arize.embeddings.base_generators import NLPEmbeddingGenerator
from arize.embeddings.constants import (
    DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
    DEFAULT_NLP_SUMMARIZATION_MODEL,
    IMPORT_ERROR_MESSAGE,
)
from arize.embeddings.usecases import UseCases

try:
    from datasets import Dataset
except ImportError:
    raise ImportError(IMPORT_ERROR_MESSAGE) from None


logger = logging.getLogger(__name__)


[docs] class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator): """Embedding generator for NLP sequence classification tasks.""" def __init__( self, model_name: str = DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL, **kwargs: object, ) -> None: """Initialize the sequence classification embedding generator. Args: model_name: Name of the pre-trained NLP model. **kwargs: Additional arguments for model initialization. """ super().__init__( use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION, model_name=model_name, **kwargs, # type: ignore[arg-type] )
[docs] def generate_embeddings( # type: ignore[override] self, text_col: pd.Series, class_label_col: pd.Series | None = None, ) -> pd.Series: """Obtain embedding vectors from your text data using pre-trained large language models. Args: text_col: A pandas Series containing the different pieces of text. class_label_col: If this column is passed, the sentence "The classification label is <class_label>" will be appended to the text in the `text_col`. Returns: A pandas Series containing the embedding vectors. """ if not isinstance(text_col, pd.Series): raise TypeError("text_col must be a pandas Series") self.check_invalid_index(field=text_col) if class_label_col is not None: if not isinstance(class_label_col, pd.Series): raise TypeError("class_label_col must be a pandas Series") temp_df = pd.concat( {"text": text_col, "class_label": class_label_col}, axis=1 ) prepared_text_col = temp_df.apply( lambda row: f" The classification label is {row['class_label']}. {row['text']}", axis=1, ) ds = Dataset.from_dict({"text": prepared_text_col}) else: ds = Dataset.from_dict({"text": text_col}) ds.set_transform(partial(self.tokenize, text_feat_name="text")) logger.info("Generating embedding vectors") ds = ds.map( lambda batch: self._get_embedding_vector(batch, "cls_token"), batched=True, batch_size=self.batch_size, ) result_df: pd.DataFrame = ds.to_pandas() return result_df["embedding_vector"]
[docs] class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator): """Embedding generator for NLP text summarization tasks.""" def __init__( self, model_name: str = DEFAULT_NLP_SUMMARIZATION_MODEL, **kwargs: object, ) -> None: """Initialize the text summarization embedding generator. Args: model_name: Name of the pre-trained NLP model. **kwargs: Additional arguments for model initialization. """ super().__init__( use_case=UseCases.NLP.SUMMARIZATION, model_name=model_name, **kwargs, # type: ignore[arg-type] )
[docs] def generate_embeddings( # type: ignore[override] self, text_col: pd.Series, ) -> pd.Series: """Obtain embedding vectors from your text data using pre-trained large language models. Args: text_col: A pandas Series containing the different pieces of text. Returns: A pandas Series containing the embedding vectors. """ if not isinstance(text_col, pd.Series): raise TypeError("text_col must be a pandas Series") self.check_invalid_index(field=text_col) ds = Dataset.from_dict({"text": text_col}) ds.set_transform(partial(self.tokenize, text_feat_name="text")) logger.info("Generating embedding vectors") ds = ds.map( lambda batch: self._get_embedding_vector(batch, "cls_token"), batched=True, batch_size=self.batch_size, ) df: pd.DataFrame = ds.to_pandas() return df["embedding_vector"]