1
2package com.example.saga;
3
4import com.example.saga.domain.SagaState;
5import com.fasterxml.jackson.databind.ObjectMapper;
6import org.slf4j.Logger;
7import org.slf4j.LoggerFactory;
8import org.springframework.transaction.annotation.Transactional;
9
10import java.time.Duration;
11import java.util.*;
12
13public class SagaOrchestrator<T> {
14
15 private static final Logger log = LoggerFactory.getLogger(SagaOrchestrator.class);
16
17 private final List<SagaStep<T>> steps;
18 private final SagaStateRepository repository;
19 private final ObjectMapper objectMapper;
20 private final Class<T> contextType;
21
22 public SagaOrchestrator(
23 List<SagaStep<T>> steps,
24 SagaStateRepository repository,
25 ObjectMapper objectMapper,
26 Class<T> contextType) {
27 this.steps = List.copyOf(steps);
28 this.repository = repository;
29 this.objectMapper = objectMapper;
30 this.contextType = contextType;
31 }
32
33 public String execute(String sagaType, T context) {
34 try {
35 String contextJson = objectMapper.writeValueAsString(context);
36 String sagaId = UUID.randomUUID().toString();
37 SagaState state = SagaState.create(sagaId, sagaType, contextJson);
38 repository.save(state);
39
40 log.info("Saga started: id={}, type={}, steps={}", sagaId, sagaType, steps.size());
41
42 executeSteps(state, context);
43 return sagaId;
44
45 } catch (Exception e) {
46 throw new SagaExecutionException("Failed to execute saga", e);
47 }
48 }
49
50 private void executeSteps(SagaState state, T context) {
51 for (int i = state.getCurrentStep(); i < steps.size(); i++) {
52 SagaStep<T> step = steps.get(i);
53
54 log.info("Executing step: saga={}, step={}", state.getId(), step.getName());
55
56 try {
57 executeWithRetry(step, context);
58 state.markStepCompleted(step.getName());
59 persistContext(state, context);
60 repository.save(state);
61
62 } catch (Exception e) {
63 log.error("Step failed: saga={}, step={}, error={}",
64 state.getId(), step.getName(), e.getMessage());
65
66 state.markCompensating(e.getMessage());
67 repository.save(state);
68
69 compensate(state, context);
70 return;
71 }
72 }
73
74 state.markCompleted();
75 repository.save(state);
76 log.info("Saga completed: id={}", state.getId());
77 }
78
79 private void executeWithRetry(SagaStep<T> step, T context) throws SagaStepException {
80 RetryPolicy policy = step.getRetryPolicy();
81 SagaStepException lastException = null;
82
83 for (int attempt = 0; attempt < policy.maxAttempts(); attempt++) {
84 try {
85 if (attempt > 0) {
86 long backoffMs = (long) (policy.initialBackoff().toMillis()
87 * Math.pow(policy.multiplier(), attempt - 1));
88 backoffMs = Math.min(backoffMs, policy.maxBackoff().toMillis());
89 Thread.sleep(backoffMs);
90 }
91
92 step.execute(context);
93 return;
94
95 } catch (SagaStepException e) {
96 lastException = e;
97 if (!e.isRetryable()) throw e;
98
99 log.warn("Step attempt failed: step={}, attempt={}/{}, error={}",
100 step.getName(), attempt + 1, policy.maxAttempts(), e.getMessage());
101
102 } catch (InterruptedException e) {
103 Thread.currentThread().interrupt();
104 throw new SagaStepException("Interrupted during retry backoff", e, false);
105 }
106 }
107
108 throw lastException;
109 }
110
111 private void compensate(SagaState state, T context) {
112 List<String> toCompensate = new ArrayList<>(state.getCompletedSteps());
113 Collections.reverse(toCompensate);
114 List<String> compensationErrors = new ArrayList<>();
115
116 for (String stepName : toCompensate) {
117 SagaStep<T> step = findStep(stepName);
118 if (step == null) continue;
119
120 log.info("Compensating step: saga={}, step={}", state.getId(), stepName);
121
122 try {
123 step.compensate(context);
124 } catch (SagaStepException e) {
125 log.error("Compensation failed: saga={}, step={}, error={}",
126 state.getId(), stepName, e.getMessage());
127 compensationErrors.add(stepName + ": " + e.getMessage());
128 }
129 }
130
131 if (compensationErrors.isEmpty()) {
132 state.markFailed(state.getError());
133 } else {
134 state.markFailed(state.getError() + "; compensation errors: " +
135 String.join(", ", compensationErrors));
136 }
137 repository.save(state);
138 }
139
140 private SagaStep<T> findStep(String name) {
141 return steps.stream()
142 .filter(s -> s.getName().equals(name))
143 .findFirst()
144 .orElse(null);
145 }
146
147 private void persistContext(SagaState state, T context) {
148 try {
149 state.setContext(objectMapper.writeValueAsString(context));
150 } catch (Exception e) {
151 log.warn("Failed to persist saga context: {}", e.getMessage());
152 }
153 }
154}
155