1package saga
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "math"
9 "time"
10
11 "github.com/google/uuid"
12)
13
14type Store interface {
15 Save(ctx context.Context, state *SagaState) error
16 Get(ctx context.Context, id string) (*SagaState, error)
17 SaveWithOptimisticLock(ctx context.Context, state *SagaState) error
18}
19
20type Orchestrator[T any] struct {
21 steps []StepDefinition[T]
22 store Store
23 logger *slog.Logger
24}
25
26func NewOrchestrator[T any](store Store, logger *slog.Logger, steps ...StepDefinition[T]) *Orchestrator[T] {
27 return &Orchestrator[T]{
28 steps: steps,
29 store: store,
30 logger: logger,
31 }
32}
33
34func (o *Orchestrator[T]) Execute(ctx context.Context, sagaType string, sagaCtx *T) (string, error) {
35 ctxBytes, err := json.Marshal(sagaCtx)
36 if err != nil {
37 return "", fmt.Errorf("marshal saga context: %w", err)
38 }
39
40 state := &SagaState{
41 ID: uuid.New().String(),
42 SagaType: sagaType,
43 Status: FlowRunning,
44 CurrentStep: 0,
45 Context: ctxBytes,
46 CompletedSteps: make([]string, 0),
47 Version: 1,
48 CreatedAt: time.Now(),
49 UpdatedAt: time.Now(),
50 }
51
52 if err := o.store.Save(ctx, state); err != nil {
53 return "", fmt.Errorf("save initial state: %w", err)
54 }
55
56 o.logger.Info("saga started",
57 "saga_id", state.ID,
58 "saga_type", sagaType,
59 "steps", len(o.steps),
60 )
61
62 if err := o.executeSteps(ctx, state, sagaCtx); err != nil {
63 return state.ID, err
64 }
65
66 return state.ID, nil
67}
68
69func (o *Orchestrator[T]) executeSteps(ctx context.Context, state *SagaState, sagaCtx *T) error {
70 for i := state.CurrentStep; i < len(o.steps); i++ {
71 step := o.steps[i]
72
73 o.logger.Info("executing step",
74 "saga_id", state.ID,
75 "step", step.Name,
76 "index", i,
77 )
78
79 if err := o.executeStepWithRetry(ctx, step, sagaCtx); err != nil {
80 o.logger.Error("step failed",
81 "saga_id", state.ID,
82 "step", step.Name,
83 "error", err,
84 )
85
86 state.Status = FlowCompensating
87 state.Error = err.Error()
88 state.UpdatedAt = time.Now()
89 _ = o.store.SaveWithOptimisticLock(ctx, state)
90
91 compErr := o.compensate(ctx, state, sagaCtx)
92 if compErr != nil {
93 state.Status = FlowFailed
94 state.Error = fmt.Sprintf("execute: %s; compensate: %s", err, compErr)
95 } else {
96 state.Status = FlowFailed
97 }
98 state.UpdatedAt = time.Now()
99 _ = o.store.SaveWithOptimisticLock(ctx, state)
100
101 return fmt.Errorf("saga failed at step %s: %w", step.Name, err)
102 }
103
104 state.CurrentStep = i + 1
105 state.CompletedSteps = append(state.CompletedSteps, step.Name)
106 state.UpdatedAt = time.Now()
107
108
109 ctxBytes, _ := json.Marshal(sagaCtx)
110 state.Context = ctxBytes
111
112 if err := o.store.SaveWithOptimisticLock(ctx, state); err != nil {
113 return fmt.Errorf("persist state after step %s: %w", step.Name, err)
114 }
115 }
116
117 state.Status = FlowCompleted
118 state.UpdatedAt = time.Now()
119 _ = o.store.SaveWithOptimisticLock(ctx, state)
120
121 o.logger.Info("saga completed", "saga_id", state.ID)
122 return nil
123}
124
125func (o *Orchestrator[T]) executeStepWithRetry(ctx context.Context, step StepDefinition[T], sagaCtx *T) error {
126 policy := step.RetryPolicy
127 if policy == nil {
128 policy = &RetryPolicy{MaxAttempts: 1}
129 }
130
131 var lastErr error
132 for attempt := 0; attempt < policy.MaxAttempts; attempt++ {
133 if attempt > 0 {
134 backoff := time.Duration(float64(policy.InitialBackoff) * math.Pow(policy.Multiplier, float64(attempt-1)))
135 if backoff > policy.MaxBackoff {
136 backoff = policy.MaxBackoff
137 }
138 time.Sleep(backoff)
139 }
140
141 stepCtx := ctx
142 if step.Timeout > 0 {
143 var cancel context.CancelFunc
144 stepCtx, cancel = context.WithTimeout(ctx, step.Timeout)
145 defer cancel()
146 }
147
148 lastErr = step.Execute(stepCtx, sagaCtx)
149 if lastErr == nil {
150 return nil
151 }
152
153 o.logger.Warn("step attempt failed",
154 "step", step.Name,
155 "attempt", attempt+1,
156 "max_attempts", policy.MaxAttempts,
157 "error", lastErr,
158 )
159 }
160
161 return lastErr
162}
163
164func (o *Orchestrator[T]) compensate(ctx context.Context, state *SagaState, sagaCtx *T) error {
165 var compensationErrors []error
166
167
168 for i := len(state.CompletedSteps) - 1; i >= 0; i-- {
169 stepName := state.CompletedSteps[i]
170 step := o.findStep(stepName)
171 if step == nil {
172 continue
173 }
174
175 o.logger.Info("compensating step",
176 "saga_id", state.ID,
177 "step", stepName,
178 )
179
180 compCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
181 if err := step.Compensate(compCtx, sagaCtx); err != nil {
182 o.logger.Error("compensation failed",
183 "saga_id", state.ID,
184 "step", stepName,
185 "error", err,
186 )
187 compensationErrors = append(compensationErrors, fmt.Errorf("compensate %s: %w", stepName, err))
188 }
189 cancel()
190 }
191
192 if len(compensationErrors) > 0 {
193 return fmt.Errorf("compensation errors: %v", compensationErrors)
194 }
195 return nil
196}
197
198func (o *Orchestrator[T]) findStep(name string) *StepDefinition[T] {
199 for i := range o.steps {
200 if o.steps[i].Name == name {
201 return &o.steps[i]
202 }
203 }
204 return nil
205}
206