1
2import asyncio
3import time
4from starlette.middleware.base import BaseHTTPMiddleware
5from starlette.requests import Request
6from starlette.responses import Response, JSONResponse
7from prometheus_client import Counter, Histogram, Gauge
8
9REQUEST_COUNT = Counter(
10 "http_requests_total", "Total requests",
11 ["method", "path", "status"],
12)
13REQUEST_LATENCY = Histogram(
14 "http_request_duration_seconds", "Request latency",
15 ["method", "path"],
16)
17ACTIVE_REQUESTS = Gauge(
18 "http_active_requests", "Active requests",
19)
20
21class RequestTracker:
22 def __init__(self):
23 self._count = 0
24 self._lock = asyncio.Lock()
25
26 async def increment(self):
27 async with self._lock:
28 self._count += 1
29 ACTIVE_REQUESTS.inc()
30
31 async def decrement(self):
32 async with self._lock:
33 self._count -= 1
34 ACTIVE_REQUESTS.dec()
35
36 @property
37 def active_count(self) -> int:
38 return self._count
39
40 async def wait_for_drain(self, timeout: float = 30.0) -> bool:
41 deadline = asyncio.get_event_loop().time() + timeout
42 while asyncio.get_event_loop().time() < deadline:
43 if self._count == 0:
44 return True
45 await asyncio.sleep(0.1)
46 return self._count == 0
47
48class RequestTrackerMiddleware(BaseHTTPMiddleware):
49 async def dispatch(self, request: Request, call_next):
50 if not hasattr(request.app.state, "request_tracker"):
51 request.app.state.request_tracker = RequestTracker()
52
53 tracker = request.app.state.request_tracker
54 await tracker.increment()
55
56 start = time.monotonic()
57 try:
58 response = await call_next(request)
59 return response
60 finally:
61 duration = time.monotonic() - start
62 await tracker.decrement()
63
64 path = request.url.path
65 REQUEST_COUNT.labels(
66 method=request.method,
67 path=path,
68 status=response.status_code if 'response' in dir() else 500,
69 ).inc()
70 REQUEST_LATENCY.labels(
71 method=request.method, path=path,
72 ).observe(duration)
73
74class ShutdownMiddleware(BaseHTTPMiddleware):
75 def __init__(self, app, health_checker):
76 super().__init__(app)
77 self.health_checker = health_checker
78
79 async def dispatch(self, request: Request, call_next):
80 if not self.health_checker.ready:
81 if not request.url.path.startswith("/health"):
82 return JSONResponse(
83 status_code=503,
84 content={"error": "Service shutting down"},
85 headers={
86 "Connection": "close",
87 "Retry-After": "5",
88 },
89 )
90 return await call_next(request)
91