Skip to content

🧠 ML Integration

Specialized serializers for machine learning libraries and workflows.

🎯 Overview

ML integration provides optimized serialization for PyTorch, TensorFlow, scikit-learn, and other ML libraries.

📦 Functions

ML Detection & Serialization

datason.detect_and_serialize_ml_object(obj: Any) -> Optional[Dict[str, Any]]

Detect and serialize ML/AI objects automatically.

Parameters:

Name Type Description Default
obj Any

Object that might be from an ML/AI library

required

Returns:

Type Description
Optional[Dict[str, Any]]

Serialized object or None if not an ML/AI object

Source code in datason/ml_serializers.py
def detect_and_serialize_ml_object(obj: Any) -> Optional[Dict[str, Any]]:
    """Detect and serialize ML/AI objects automatically.

    Args:
        obj: Object that might be from an ML/AI library

    Returns:
        Serialized object or None if not an ML/AI object
    """

    # Helper function to safely check attributes
    def safe_hasattr(obj: Any, attr: str) -> bool:
        try:
            return hasattr(obj, attr)
        except Exception:
            return False

    # PyTorch tensors
    torch = _lazy_import_torch()
    if torch is not None and isinstance(obj, torch.Tensor):
        return serialize_pytorch_tensor(obj)

    # TensorFlow tensors
    tf = _lazy_import_tensorflow()
    if (
        tf is not None
        and safe_hasattr(obj, "numpy")
        and safe_hasattr(obj, "shape")
        and safe_hasattr(obj, "dtype")
        and "tensorflow" in str(type(obj))
    ):
        return serialize_tensorflow_tensor(obj)

    # JAX arrays
    jax, jnp = _lazy_import_jax()
    if jax is not None and safe_hasattr(obj, "shape") and safe_hasattr(obj, "dtype") and "jax" in str(type(obj)):
        return serialize_jax_array(obj)

    # Scikit-learn models
    sklearn, BaseEstimator = _lazy_import_sklearn()
    if sklearn is not None and isinstance(BaseEstimator, type):
        try:
            if isinstance(obj, BaseEstimator):
                return serialize_sklearn_model(obj)
        except (TypeError, AttributeError):
            # Handle case where BaseEstimator is a Mock or invalid type
            pass

    # Scipy sparse matrices
    scipy = _lazy_import_scipy()
    if scipy is not None and safe_hasattr(obj, "tocoo") and "scipy.sparse" in str(type(obj)):
        return serialize_scipy_sparse(obj)

    # PIL Images
    Image = _lazy_import_pil()
    if Image is not None and isinstance(obj, Image.Image):
        return serialize_pil_image(obj)

    # HuggingFace tokenizers
    transformers = _lazy_import_transformers()
    if transformers is not None and safe_hasattr(obj, "encode") and "transformers" in str(type(obj)):
        return serialize_huggingface_tokenizer(obj)

    # CatBoost models - use proper isinstance check like other frameworks
    catboost = _lazy_import_catboost()
    if catboost is not None:
        try:
            if isinstance(obj, (catboost.CatBoostClassifier, catboost.CatBoostRegressor)):
                return serialize_catboost_model(obj)
        except (TypeError, AttributeError):
            pass

    # Keras models - use proper isinstance check like other frameworks
    keras = _lazy_import_keras()
    if keras is not None:
        try:
            # Check for common Keras model types
            keras_model_types = []
            if hasattr(keras, "Model"):
                keras_model_types.append(keras.Model)
            if hasattr(keras, "Sequential"):
                keras_model_types.append(keras.Sequential)
            if hasattr(keras, "models"):
                if hasattr(keras.models, "Model"):
                    keras_model_types.append(keras.models.Model)
                if hasattr(keras.models, "Sequential"):
                    keras_model_types.append(keras.models.Sequential)

            if keras_model_types and isinstance(obj, tuple(keras_model_types)):
                return serialize_keras_model(obj)
        except (TypeError, AttributeError):
            pass

    # Optuna studies - use proper isinstance check like other frameworks
    optuna = _lazy_import_optuna()
    if optuna is not None:
        try:
            if hasattr(optuna, "Study") and isinstance(obj, optuna.Study):
                return serialize_optuna_study(obj)
        except (TypeError, AttributeError):
            pass

    # Plotly figures - use proper isinstance check like other frameworks
    plotly = _lazy_import_plotly()
    if plotly is not None:
        try:
            import plotly.graph_objects as go

            if isinstance(obj, go.Figure):
                return serialize_plotly_figure(obj)
        except (TypeError, AttributeError, ImportError):
            pass

    # Polars DataFrames - use proper isinstance check like other frameworks
    polars = _lazy_import_polars()
    if polars is not None:
        try:
            if hasattr(polars, "DataFrame") and isinstance(obj, polars.DataFrame):
                return serialize_polars_dataframe(obj)
        except (TypeError, AttributeError):
            pass

    return None