"""Tabular data embedding generators for structured feature embeddings."""
import logging
from functools import partial
import pandas as pd
from arize.embeddings.base_generators import NLPEmbeddingGenerator
from arize.embeddings.constants import (
DEFAULT_TABULAR_MODEL,
IMPORT_ERROR_MESSAGE,
)
from arize.embeddings.usecases import UseCases
from arize.utils.types import is_list_of
try:
from datasets import Dataset
except ImportError:
raise ImportError(IMPORT_ERROR_MESSAGE) from None
logger = logging.getLogger(__name__)
TABULAR_PRETRAINED_MODELS = [
"bert-base-uncased",
"distilbert-base-uncased",
"xlm-roberta-base",
]
[docs]
class EmbeddingGeneratorForTabularFeatures(NLPEmbeddingGenerator):
"""Embedding generator for tabular feature data using prompt-based LLM encoding."""
def __repr__(self) -> str:
"""Return a string representation of the tabular embedding generator."""
return (
f"{self.__class__.__name__}(\n"
f" use_case={self.use_case},\n"
f" model_name={self.model_name},\n"
f" tokenizer_max_length={self.tokenizer_max_length},\n"
f" tokenizer={self.tokenizer.__class__},\n"
f" model={self.model.__class__},\n"
f")"
)
def __init__(
self,
model_name: str = DEFAULT_TABULAR_MODEL,
**kwargs: object,
) -> None:
"""Initialize the tabular features embedding generator.
Args:
model_name: Name of the pre-trained NLP model for tabular data.
**kwargs: Additional arguments for model initialization.
Raises:
ValueError: If model_name is not in supported models list.
"""
if model_name not in TABULAR_PRETRAINED_MODELS:
raise ValueError(
"model_name not supported. Check supported models with "
"`EmbeddingGeneratorForTabularFeatures.list_pretrained_models()`"
)
super().__init__(
use_case=UseCases.STRUCTURED.TABULAR_EMBEDDINGS,
model_name=model_name,
**kwargs, # type: ignore[arg-type]
)
[docs]
def generate_embeddings( # type: ignore[override]
self,
df: pd.DataFrame,
selected_columns: list[str],
col_name_map: dict[str, str] | None = None,
return_prompt_col: bool = False,
) -> pd.Series | tuple[pd.Series, pd.Series]:
"""Obtain embedding vectors from your tabular data.
Prompts are generated from your `selected_columns` and passed to a pre-trained
large language model for embedding vector computation.
Args:
df: Pandas DataFrame containing the tabular data. Not all columns will be
considered, see `selected_columns`.
selected_columns: Columns to be considered to construct the prompt to be passed to
the LLM.
col_name_map: Mapping between selected column names and a more verbose description of
the name. This helps the LLM understand the features better.
return_prompt_col: If set to True, an extra pandas Series will be returned
containing the constructed prompts. Defaults to False.
Returns:
A pandas Series containing the embedding vectors and, if `return_prompt_col` is
set to True, a pandas Series containing the prompts created from tabular features.
"""
if col_name_map is None:
col_name_map = {}
if not isinstance(df, pd.DataFrame):
raise TypeError("df must be a pandas DataFrame")
self.check_invalid_index(field=df)
if not is_list_of(selected_columns, str):
raise TypeError("columns must be a list of column names (strings)")
missing_cols = set(selected_columns).difference(df.columns)
if missing_cols:
raise ValueError(
"selected_columns list must only contain columns of the dataframe. "
f"The following columns are not found {missing_cols}"
)
if not isinstance(col_name_map, dict):
raise TypeError(
"col_name_map must be a dictionary mapping column names to new column "
"names"
)
for k, v in col_name_map.items():
if not isinstance(k, str) or not isinstance(v, str):
raise TypeError(
"col_name_map dictionary keys and values should be strings"
)
missing_cols = set(col_name_map.keys()).difference(df.columns)
if missing_cols:
raise ValueError(
"col_name_map must only contain keys which are columns of the dataframe. "
f"The following columns are not found {missing_cols}"
)
prompts: pd.Series = df.rename(columns=col_name_map).apply(
partial(
self.__prompt_fn,
columns=[
col_name_map.get(col, col) for col in selected_columns
],
),
axis=1,
)
ds = Dataset.from_dict({"prompt": prompts})
ds.set_transform(partial(self.tokenize, text_feat_name="prompt"))
logger.info("Generating embedding vectors")
ds = ds.map(
lambda batch: self._get_embedding_vector(
batch, self.__get_method_for_embedding_calculation()
),
batched=True,
batch_size=self.batch_size,
)
result_df: pd.DataFrame = ds.to_pandas()
if return_prompt_col:
return result_df["embedding_vector"], prompts
return result_df["embedding_vector"]
@staticmethod
def __prompt_fn(row: pd.DataFrame, columns: list[str]) -> str:
return " ".join(
f"The {col.replace('_', ' ')} is {str(row[col]).strip()}."
for col in columns
)
def __get_method_for_embedding_calculation(self) -> str:
try:
return {
"bert-base-uncased": "avg_token",
"distilbert-base-uncased": "avg_token",
"xlm-roberta-base": "cls_token",
}[self.model_name]
except Exception as exc:
raise ValueError(
f"Unsupported model_name {self.model_name}"
) from exc
[docs]
@staticmethod
def list_pretrained_models() -> pd.DataFrame:
"""Return a :class:`pandas.DataFrame` of available pretrained tabular models."""
return pd.DataFrame({"Model Name": sorted(TABULAR_PRETRAINED_MODELS)})