1import hashlib
2import struct
3import threading
4from dataclasses import dataclass, field
5from typing import Optional
6
7
8@dataclass(frozen=True)
9class Condition:
10 attribute: str
11 operator: str
12 values: list[str]
13
14
15@dataclass(frozen=True)
16class Rule:
17 conditions: list[Condition]
18 variant: str = ""
19 priority: int = 0
20
21
22@dataclass(frozen=True)
23class FlagConfig:
24 key: str
25 enabled: bool
26 percentage: float = 0.0
27 rules: list[Rule] = field(default_factory=list)
28 variants: dict[str, float] = field(default_factory=dict)
29 default_variant: str = ""
30
31
32@dataclass
33class EvalContext:
34 user_id: str
35 email: str = ""
36 plan: str = ""
37 country: str = ""
38 properties: dict[str, str] = field(default_factory=dict)
39
40
41@dataclass
42class EvalResult:
43 enabled: bool
44 variant: str = ""
45 reason: str = ""
46
47
48class FlagEvaluator:
49 def __init__(self):
50 self._flags: dict[str, FlagConfig] = {}
51 self._lock = threading.Lock()
52
53 def update(self, configs: list[FlagConfig]) -> None:
54 new_flags = {c.key: c for c in configs}
55 with self._lock:
56 self._flags = new_flags
57
58 def evaluate(self, flag_key: str, context: EvalContext) -> EvalResult:
59 flags = self._flags
60 flag = flags.get(flag_key)
61
62 if flag is None:
63 return EvalResult(enabled=False, reason="not_found")
64 if not flag.enabled:
65 return EvalResult(enabled=False, reason="disabled")
66
67
68 for rule in sorted(flag.rules, key=lambda r: r.priority, reverse=True):
69 if self._matches_all(rule.conditions, context):
70 return EvalResult(enabled=True, variant=rule.variant, reason="rule_match")
71
72
73 if 0 < flag.percentage < 100:
74 bucket = self._hash_bucket(flag_key, context.user_id)
75 if bucket < flag.percentage:
76 return EvalResult(enabled=True, variant=flag.default_variant, reason="percentage")
77 return EvalResult(enabled=False, reason="percentage_excluded")
78
79 return EvalResult(enabled=True, variant=flag.default_variant, reason="default")
80
81 def is_enabled(self, flag_key: str, context: EvalContext) -> bool:
82 return self.evaluate(flag_key, context).enabled
83
84 @staticmethod
85 def _hash_bucket(flag_key: str, user_id: str) -> float:
86 h = hashlib.sha256(f"{flag_key}:{user_id}".encode()).digest()
87 value = struct.unpack(">I", h[:4])[0]
88 return (value / 0xFFFFFFFF) * 100
89
90 @staticmethod
91 def _matches_all(conditions: list[Condition], ctx: EvalContext) -> bool:
92 for cond in conditions:
93 value = _get_attribute(cond.attribute, ctx)
94 if not _match_condition(cond, value):
95 return False
96 return True
97
98
99def _get_attribute(attr: str, ctx: EvalContext) -> str:
100 mapping = {
101 "user_id": ctx.user_id,
102 "email": ctx.email,
103 "plan": ctx.plan,
104 "country": ctx.country,
105 }
106 return mapping.get(attr, ctx.properties.get(attr, ""))
107
108
109def _match_condition(cond: Condition, value: str) -> bool:
110 match cond.operator:
111 case "eq":
112 return len(cond.values) > 0 and value == cond.values[0]
113 case "neq":
114 return len(cond.values) > 0 and value != cond.values[0]
115 case "in":
116 return value in cond.values
117 case "contains":
118 return len(cond.values) > 0 and cond.values[0] in value
119 case "starts_with":
120 return len(cond.values) > 0 and value.startswith(cond.values[0])
121 case _:
122 return False
123