Production ML Model Serving Patterns¶
This guide covers production-ready patterns for serving ML models with datason, including error handling, monitoring, performance optimization, and security considerations.
🎯 Overview¶
Production ML serving requires more than basic serialization. This guide demonstrates:
- Error Handling: Graceful degradation and recovery
- Performance Monitoring: Metrics, health checks, and observability
- Security: Input validation and resource limits
- Scalability: Batch processing and memory management
- Observability: Logging, tracing, and debugging
🏗️ Production Architecture¶
Model Wrapper Pattern¶
import datason
from datason.config import get_api_config, SerializationConfig
import time
import uuid
from datetime import datetime
import logging
class ProductionModelWrapper:
"""Production-ready model wrapper with comprehensive monitoring."""
def __init__(self, model_id: str, model_version: str):
self.model_id = model_id
self.model_version = model_version
# Monitoring metrics
self.request_count = 0
self.error_count = 0
self.total_latency = 0.0
# Configuration
self.api_config = get_api_config()
self.performance_config = SerializationConfig(
uuid_format="string",
parse_uuids=False,
date_format="unix",
preserve_decimals=False,
max_depth=20,
max_size=50_000_000 # 50MB limit
)
# Setup logging
self.logger = logging.getLogger(f"model.{model_id}")
def predict(self, features: Any) -> Dict[str, Any]:
"""Make prediction with full monitoring and error handling."""
start_time = time.perf_counter()
request_id = str(uuid.uuid4())
try:
self.request_count += 1
self.logger.info(f"Processing request {request_id}")
# Input validation
self._validate_input(features)
# Process with datason
processed_features = datason.auto_deserialize(
features,
config=self.api_config
)
# Model inference
prediction = self._run_inference(processed_features)
# Prepare response
response = {
"request_id": request_id,
"model_id": self.model_id,
"model_version": self.model_version,
"prediction": prediction,
"timestamp": datetime.now(),
"processing_time_ms": (time.perf_counter() - start_time) * 1000
}
# Serialize response
return datason.serialize(response, config=self.performance_config)
except Exception as e:
self.error_count += 1
self.logger.error(f"Request {request_id} failed: {e}")
return self._create_error_response(request_id, str(e))
finally:
self.total_latency += time.perf_counter() - start_time
def _validate_input(self, features: Any) -> None:
"""Validate input with security checks."""
if not features:
raise ValueError("Features cannot be empty")
# Size validation
try:
estimated_size = len(str(features))
if estimated_size > self.performance_config.max_size:
raise ValueError(f"Input too large: {estimated_size} bytes")
except Exception:
pass
# Type validation
if not isinstance(features, (dict, list)):
raise ValueError("Features must be dict or list")
def _run_inference(self, features: Any) -> Any:
"""Run actual model inference (implement with your model)."""
# Replace with your actual model
return {"class": "positive", "confidence": 0.85}
def _create_error_response(self, request_id: str, error: str) -> Dict[str, Any]:
"""Create standardized error response."""
error_response = {
"request_id": request_id,
"error": error,
"model_id": self.model_id,
"timestamp": datetime.now(),
"status": "error"
}
return datason.serialize(error_response, config=self.api_config)
def get_health_metrics(self) -> Dict[str, Any]:
"""Get comprehensive health metrics."""
avg_latency = self.total_latency / max(self.request_count, 1)
error_rate = self.error_count / max(self.request_count, 1)
metrics = {
"model_id": self.model_id,
"model_version": self.model_version,
"status": "healthy" if error_rate < 0.1 else "degraded",
"metrics": {
"total_requests": self.request_count,
"error_count": self.error_count,
"error_rate": error_rate,
"avg_latency_ms": avg_latency * 1000,
"last_check": datetime.now()
}
}
return datason.serialize(metrics, config=self.api_config)
🚀 Framework Integration Patterns¶
FastAPI Production Service¶
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
app = FastAPI(
title="Production ML API",
description="Production ML serving with datason",
version="1.0.0"
)
# Add middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Model registry
models = {
"classifier_v1": ProductionModelWrapper("classifier", "1.0.0"),
"regressor_v1": ProductionModelWrapper("regressor", "1.0.0"),
}
class PredictionRequest(BaseModel):
features: Dict[str, Any]
model_id: str = "classifier_v1"
options: Optional[Dict[str, Any]] = None
@app.get("/health")
async def health_check():
"""Comprehensive health check."""
health_status = {
"status": "healthy",
"timestamp": datetime.now(),
"models": {
model_id: model.get_health_metrics()
for model_id, model in models.items()
}
}
return datason.serialize(health_status, config=get_api_config())
@app.post("/predict")
async def predict(request: PredictionRequest):
"""Single prediction with full error handling."""
if request.model_id not in models:
raise HTTPException(
status_code=404,
detail=f"Model {request.model_id} not found"
)
model = models[request.model_id]
return model.predict(request.features)
@app.post("/predict/batch")
async def predict_batch(
batch_features: List[Dict[str, Any]],
model_id: str = "classifier_v1"
):
"""Batch prediction with size limits."""
MAX_BATCH_SIZE = 100
if len(batch_features) > MAX_BATCH_SIZE:
raise HTTPException(
status_code=400,
detail=f"Batch size {len(batch_features)} exceeds limit {MAX_BATCH_SIZE}"
)
if model_id not in models:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
model = models[model_id]
results = []
for features in batch_features:
result = model.predict(features)
results.append(result)
batch_response = {
"batch_id": str(uuid.uuid4()),
"model_id": model_id,
"results": results,
"batch_size": len(results),
"timestamp": datetime.now()
}
return datason.serialize(batch_response, config=get_api_config())
@app.get("/metrics")
async def get_metrics():
"""Detailed metrics endpoint."""
metrics = {
"service_metrics": {
"total_models": len(models),
"timestamp": datetime.now()
},
"model_metrics": {
model_id: model.get_health_metrics()
for model_id, model in models.items()
}
}
return datason.serialize(metrics, config=get_api_config())
# Run with: uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4
BentoML Production Service¶
import bentoml
from bentoml.io import JSON
from bentoml import metrics
# Create service with monitoring
svc = bentoml.Service("production_ml_service")
# Metrics
request_counter = metrics.Counter(
name="prediction_requests_total",
documentation="Total prediction requests"
)
error_counter = metrics.Counter(
name="prediction_errors_total",
documentation="Total prediction errors"
)
latency_histogram = metrics.Histogram(
name="prediction_latency_seconds",
documentation="Prediction latency"
)
# Model wrapper
model = ProductionModelWrapper("bentoml_model", "1.0.0")
@svc.api(input=JSON(), output=JSON())
@metrics.track(request_counter, error_counter, latency_histogram)
def predict(input_data: dict) -> dict:
"""Production prediction with metrics."""
try:
if "features" not in input_data:
raise ValueError("Missing 'features' in input")
result = model.predict(input_data["features"])
return result
except Exception as e:
error_counter.inc()
return model._create_error_response(str(uuid.uuid4()), str(e))
@svc.api(input=JSON(), output=JSON())
def health() -> dict:
"""Health check endpoint."""
return model.get_health_metrics()
@svc.api(input=JSON(), output=JSON())
def metrics_endpoint() -> dict:
"""Custom metrics endpoint."""
return {
"custom_metrics": model.get_health_metrics(),
"timestamp": datetime.now()
}
# Deploy with: bentoml serve production_service:svc --production
Ray Serve Production Deployment¶
from ray import serve
import ray
@serve.deployment(
num_replicas=3,
max_concurrent_queries=100,
ray_actor_options={
"num_cpus": 1,
"memory": 2000 * 1024 * 1024 # 2GB
}
)
class ProductionMLDeployment:
def __init__(self):
self.model = ProductionModelWrapper("ray_serve_model", "1.0.0")
self.logger = logging.getLogger("ray_serve_deployment")
self.logger.info("Ray Serve deployment initialized")
async def __call__(self, request):
"""Handle requests with comprehensive error handling."""
try:
# Parse request
if hasattr(request, 'json'):
payload = await request.json()
else:
payload = request
# Validate
if "features" not in payload:
raise ValueError("Missing 'features' in request")
# Predict
result = self.model.predict(payload["features"])
return result
except Exception as e:
self.logger.error(f"Deployment error: {e}")
return self.model._create_error_response(str(uuid.uuid4()), str(e))
def health_check(self):
"""Health check for Ray Serve."""
return self.model.get_health_metrics()
# Deploy with autoscaling
deployment = ProductionMLDeployment.options(
autoscaling_config={
"min_replicas": 1,
"max_replicas": 10,
"target_num_ongoing_requests_per_replica": 5
}
)
app = deployment.bind()
# Deploy with: serve run production_deployment:app
📊 Monitoring and Observability¶
Comprehensive Logging¶
import logging
import structlog
from datetime import datetime
# Configure structured logging
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer()
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
class MonitoredModelWrapper(ProductionModelWrapper):
"""Model wrapper with enhanced monitoring."""
def __init__(self, model_id: str, model_version: str):
super().__init__(model_id, model_version)
self.logger = structlog.get_logger("model").bind(
model_id=model_id,
model_version=model_version
)
def predict(self, features: Any) -> Dict[str, Any]:
"""Prediction with structured logging."""
request_id = str(uuid.uuid4())
start_time = time.perf_counter()
self.logger.info(
"prediction_started",
request_id=request_id,
input_size=len(str(features))
)
try:
result = super().predict(features)
self.logger.info(
"prediction_completed",
request_id=request_id,
duration_ms=(time.perf_counter() - start_time) * 1000,
success=True
)
return result
except Exception as e:
self.logger.error(
"prediction_failed",
request_id=request_id,
error=str(e),
duration_ms=(time.perf_counter() - start_time) * 1000,
success=False
)
raise
Metrics Collection¶
from prometheus_client import Counter, Histogram, Gauge, start_http_server
import threading
# Prometheus metrics
PREDICTION_REQUESTS = Counter(
'ml_prediction_requests_total',
'Total prediction requests',
['model_id', 'model_version', 'status']
)
PREDICTION_LATENCY = Histogram(
'ml_prediction_latency_seconds',
'Prediction latency',
['model_id', 'model_version']
)
ACTIVE_MODELS = Gauge(
'ml_active_models',
'Number of active models'
)
class MetricsModelWrapper(ProductionModelWrapper):
"""Model wrapper with Prometheus metrics."""
def predict(self, features: Any) -> Dict[str, Any]:
"""Prediction with metrics collection."""
start_time = time.perf_counter()
try:
result = super().predict(features)
# Record success metrics
PREDICTION_REQUESTS.labels(
model_id=self.model_id,
model_version=self.model_version,
status='success'
).inc()
PREDICTION_LATENCY.labels(
model_id=self.model_id,
model_version=self.model_version
).observe(time.perf_counter() - start_time)
return result
except Exception as e:
# Record error metrics
PREDICTION_REQUESTS.labels(
model_id=self.model_id,
model_version=self.model_version,
status='error'
).inc()
raise
# Start Prometheus metrics server
def start_metrics_server(port: int = 8080):
"""Start Prometheus metrics server."""
start_http_server(port)
print(f"Metrics server started on port {port}")
# Start in background thread
metrics_thread = threading.Thread(target=start_metrics_server)
metrics_thread.daemon = True
metrics_thread.start()
🔒 Security and Validation¶
Input Validation¶
from pydantic import BaseModel, validator
from typing import Any, Dict, List, Union
import jsonschema
class SecureModelWrapper(ProductionModelWrapper):
"""Model wrapper with enhanced security."""
def __init__(self, model_id: str, model_version: str):
super().__init__(model_id, model_version)
# Define input schema
self.input_schema = {
"type": "object",
"properties": {
"features": {
"type": "object",
"additionalProperties": {"type": "number"}
}
},
"required": ["features"],
"additionalProperties": False
}
# Security limits
self.max_input_size = 1024 * 1024 # 1MB
self.max_features = 1000
self.allowed_feature_types = (int, float, str, bool)
def _validate_input(self, features: Any) -> None:
"""Enhanced security validation."""
super()._validate_input(features)
# Schema validation
try:
jsonschema.validate({"features": features}, self.input_schema)
except jsonschema.ValidationError as e:
raise ValueError(f"Schema validation failed: {e.message}")
# Feature count validation
if isinstance(features, dict) and len(features) > self.max_features:
raise ValueError(f"Too many features: {len(features)} > {self.max_features}")
# Feature type validation
if isinstance(features, dict):
for key, value in features.items():
if not isinstance(value, self.allowed_feature_types):
raise ValueError(f"Invalid feature type for {key}: {type(value)}")
# Size validation
import sys
if sys.getsizeof(features) > self.max_input_size:
raise ValueError("Input size exceeds limit")
Rate Limiting¶
from collections import defaultdict
import time
from threading import Lock
class RateLimitedModelWrapper(ProductionModelWrapper):
"""Model wrapper with rate limiting."""
def __init__(self, model_id: str, model_version: str):
super().__init__(model_id, model_version)
# Rate limiting
self.requests_per_minute = 100
self.requests_per_hour = 1000
self.request_times = defaultdict(list)
self.lock = Lock()
def _check_rate_limit(self, client_id: str = "default") -> None:
"""Check rate limits for client."""
current_time = time.time()
with self.lock:
# Clean old requests
self.request_times[client_id] = [
req_time for req_time in self.request_times[client_id]
if current_time - req_time < 3600 # Keep last hour
]
# Check limits
recent_requests = [
req_time for req_time in self.request_times[client_id]
if current_time - req_time < 60 # Last minute
]
if len(recent_requests) >= self.requests_per_minute:
raise ValueError("Rate limit exceeded: too many requests per minute")
if len(self.request_times[client_id]) >= self.requests_per_hour:
raise ValueError("Rate limit exceeded: too many requests per hour")
# Record this request
self.request_times[client_id].append(current_time)
def predict(self, features: Any, client_id: str = "default") -> Dict[str, Any]:
"""Prediction with rate limiting."""
self._check_rate_limit(client_id)
return super().predict(features)
🚀 Performance Optimization¶
Async Processing¶
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import List
class AsyncModelWrapper(ProductionModelWrapper):
"""Async model wrapper for high throughput."""
def __init__(self, model_id: str, model_version: str):
super().__init__(model_id, model_version)
self.executor = ThreadPoolExecutor(max_workers=4)
async def predict_async(self, features: Any) -> Dict[str, Any]:
"""Async prediction."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self.predict,
features
)
async def predict_batch_async(self, batch_features: List[Any]) -> List[Dict[str, Any]]:
"""Async batch prediction."""
tasks = [
self.predict_async(features)
for features in batch_features
]
return await asyncio.gather(*tasks)
# Usage
async def main():
model = AsyncModelWrapper("async_model", "1.0.0")
# Single prediction
result = await model.predict_async({"feature1": 1.0})
# Batch prediction
batch_features = [{"feature1": i} for i in range(10)]
results = await model.predict_batch_async(batch_features)
print(f"Processed {len(results)} predictions")
# Run with: asyncio.run(main())
Caching¶
from functools import lru_cache
import hashlib
import json
class CachedModelWrapper(ProductionModelWrapper):
"""Model wrapper with prediction caching."""
def __init__(self, model_id: str, model_version: str):
super().__init__(model_id, model_version)
self.cache_hits = 0
self.cache_misses = 0
def _get_cache_key(self, features: Any) -> str:
"""Generate cache key for features."""
features_str = json.dumps(features, sort_keys=True)
return hashlib.md5(features_str.encode()).hexdigest()
@lru_cache(maxsize=1000)
def _cached_inference(self, cache_key: str, features_json: str) -> str:
"""Cached inference with string serialization."""
features = json.loads(features_json)
result = self._run_inference(features)
return json.dumps(result)
def predict(self, features: Any) -> Dict[str, Any]:
"""Prediction with caching."""
cache_key = self._get_cache_key(features)
features_json = json.dumps(features, sort_keys=True)
try:
# Try cache first
cached_result = self._cached_inference(cache_key, features_json)
prediction = json.loads(cached_result)
self.cache_hits += 1
except Exception:
# Cache miss - run normal prediction
prediction = self._run_inference(features)
self.cache_misses += 1
# Build full response
response = {
"request_id": str(uuid.uuid4()),
"model_id": self.model_id,
"prediction": prediction,
"cached": self.cache_hits > self.cache_misses,
"timestamp": datetime.now()
}
return datason.serialize(response, config=self.performance_config)
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache performance statistics."""
total_requests = self.cache_hits + self.cache_misses
hit_rate = self.cache_hits / max(total_requests, 1)
return {
"cache_hits": self.cache_hits,
"cache_misses": self.cache_misses,
"hit_rate": hit_rate,
"total_requests": total_requests
}
🔧 Configuration Management¶
Environment-Based Configuration¶
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class ProductionConfig:
"""Production configuration with environment variables."""
# Model settings
model_id: str = os.getenv("MODEL_ID", "default_model")
model_version: str = os.getenv("MODEL_VERSION", "1.0.0")
# Performance settings
max_batch_size: int = int(os.getenv("MAX_BATCH_SIZE", "32"))
timeout_seconds: int = int(os.getenv("TIMEOUT_SECONDS", "30"))
max_memory_mb: int = int(os.getenv("MAX_MEMORY_MB", "1000"))
# Rate limiting
requests_per_minute: int = int(os.getenv("REQUESTS_PER_MINUTE", "100"))
requests_per_hour: int = int(os.getenv("REQUESTS_PER_HOUR", "1000"))
# Caching
enable_caching: bool = os.getenv("ENABLE_CACHING", "true").lower() == "true"
cache_size: int = int(os.getenv("CACHE_SIZE", "1000"))
# Monitoring
enable_metrics: bool = os.getenv("ENABLE_METRICS", "true").lower() == "true"
metrics_port: int = int(os.getenv("METRICS_PORT", "8080"))
# Datason settings
uuid_format: str = os.getenv("UUID_FORMAT", "string")
parse_uuids: bool = os.getenv("PARSE_UUIDS", "false").lower() == "true"
def create_production_model(config: ProductionConfig) -> ProductionModelWrapper:
"""Create production model with configuration."""
# Choose wrapper based on config
if config.enable_caching:
model = CachedModelWrapper(config.model_id, config.model_version)
else:
model = ProductionModelWrapper(config.model_id, config.model_version)
# Apply configuration
model.max_batch_size = config.max_batch_size
model.timeout_seconds = config.timeout_seconds
return model
# Usage
config = ProductionConfig()
model = create_production_model(config)
📈 Deployment Strategies¶
Blue-Green Deployment¶
class ModelRegistry:
"""Model registry for blue-green deployments."""
def __init__(self):
self.models = {}
self.active_model = None
self.staging_model = None
def register_model(self, model_id: str, model: ProductionModelWrapper, stage: str = "staging"):
"""Register a new model version."""
self.models[model_id] = model
if stage == "staging":
self.staging_model = model_id
elif stage == "active":
self.active_model = model_id
def promote_to_production(self, model_id: str):
"""Promote staging model to production."""
if model_id not in self.models:
raise ValueError(f"Model {model_id} not found")
# Swap models
old_active = self.active_model
self.active_model = model_id
self.staging_model = old_active
print(f"Promoted {model_id} to production")
def get_active_model(self) -> ProductionModelWrapper:
"""Get currently active model."""
if not self.active_model:
raise ValueError("No active model")
return self.models[self.active_model]
def health_check_all(self) -> Dict[str, Any]:
"""Health check all models."""
return {
model_id: model.get_health_metrics()
for model_id, model in self.models.items()
}
# Usage
registry = ModelRegistry()
# Register models
registry.register_model("model_v1", ProductionModelWrapper("model", "1.0"), "active")
registry.register_model("model_v2", ProductionModelWrapper("model", "2.0"), "staging")
# Test staging model
staging_model = registry.models["model_v2"]
test_result = staging_model.predict({"test": "data"})
# Promote if tests pass
if test_result.get("status") != "error":
registry.promote_to_production("model_v2")
Canary Deployment¶
import random
class CanaryModelWrapper:
"""Canary deployment wrapper."""
def __init__(self, stable_model: ProductionModelWrapper, canary_model: ProductionModelWrapper, canary_percentage: float = 0.1):
self.stable_model = stable_model
self.canary_model = canary_model
self.canary_percentage = canary_percentage
self.canary_requests = 0
self.stable_requests = 0
def predict(self, features: Any) -> Dict[str, Any]:
"""Route requests between stable and canary models."""
# Decide which model to use
if random.random() < self.canary_percentage:
# Use canary model
self.canary_requests += 1
result = self.canary_model.predict(features)
result["model_type"] = "canary"
else:
# Use stable model
self.stable_requests += 1
result = self.stable_model.predict(features)
result["model_type"] = "stable"
return result
def get_canary_stats(self) -> Dict[str, Any]:
"""Get canary deployment statistics."""
total_requests = self.canary_requests + self.stable_requests
return {
"canary_requests": self.canary_requests,
"stable_requests": self.stable_requests,
"canary_percentage_actual": self.canary_requests / max(total_requests, 1),
"canary_percentage_target": self.canary_percentage,
"total_requests": total_requests
}
# Usage
stable_model = ProductionModelWrapper("stable", "1.0")
canary_model = ProductionModelWrapper("canary", "2.0")
canary_wrapper = CanaryModelWrapper(stable_model, canary_model, 0.1)
# Route traffic
for i in range(100):
result = canary_wrapper.predict({"test": i})
print(f"Request {i}: {result.get('model_type')}")
print(canary_wrapper.get_canary_stats())
🎯 Best Practices Summary¶
✅ Do's¶
- Use structured logging for debugging and monitoring
- Implement comprehensive health checks with meaningful metrics
- Validate inputs thoroughly with schema validation
- Set resource limits to prevent abuse and ensure stability
- Use async processing for high-throughput scenarios
- Implement caching for repeated predictions
- Monitor performance with metrics and alerting
- Use blue-green or canary deployments for safe updates
❌ Don'ts¶
- Don't skip input validation - always validate and sanitize inputs
- Don't ignore error handling - implement graceful degradation
- Don't hardcode configurations - use environment variables
- Don't skip monitoring - observability is critical in production
- Don't deploy without testing - always test in staging first
- Don't ignore security - implement rate limiting and access controls
🔧 Configuration Checklist¶
- Input validation and sanitization
- Error handling and graceful degradation
- Performance monitoring and metrics
- Health checks and observability
- Rate limiting and security controls
- Caching for performance optimization
- Async processing for scalability
- Environment-based configuration
- Deployment strategies (blue-green/canary)
- Comprehensive logging and debugging
This production guide ensures your ML models are served reliably, securely, and efficiently with datason's serialization capabilities.