1use crate::error::{AptosError, AptosResult};
26use std::collections::HashSet;
27use std::future::Future;
28use std::sync::Arc;
29use std::time::Duration;
30use tokio::time::sleep;
31
32#[derive(Debug, Clone)]
34pub struct RetryConfig {
35 pub max_retries: u32,
37 pub initial_delay_ms: u64,
39 pub max_delay_ms: u64,
41 pub exponential_base: f64,
43 pub jitter: bool,
45 pub jitter_factor: f64,
47 pub retryable_status_codes: HashSet<u16>,
50}
51
52impl Default for RetryConfig {
53 fn default() -> Self {
54 Self {
55 max_retries: 3,
56 initial_delay_ms: 100,
57 max_delay_ms: 10_000,
58 exponential_base: 2.0,
59 jitter: true,
60 jitter_factor: 0.5,
61 retryable_status_codes: [
62 408, 429, 500, 502, 503, 504, ]
69 .into_iter()
70 .collect(),
71 }
72 }
73}
74
75impl RetryConfig {
76 pub fn builder() -> RetryConfigBuilder {
78 RetryConfigBuilder::default()
79 }
80
81 pub fn no_retry() -> Self {
83 Self {
84 max_retries: 0,
85 ..Default::default()
86 }
87 }
88
89 pub fn aggressive() -> Self {
91 Self {
92 max_retries: 5,
93 initial_delay_ms: 50,
94 max_delay_ms: 5_000,
95 exponential_base: 1.5,
96 jitter: true,
97 jitter_factor: 0.3,
98 ..Default::default()
99 }
100 }
101
102 pub fn conservative() -> Self {
104 Self {
105 max_retries: 3,
106 initial_delay_ms: 500,
107 max_delay_ms: 30_000,
108 exponential_base: 2.0,
109 jitter: true,
110 jitter_factor: 0.5,
111 ..Default::default()
112 }
113 }
114
115 #[allow(clippy::cast_possible_truncation)] pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
118 if attempt == 0 {
119 return Duration::from_millis(0);
120 }
121
122 #[allow(clippy::cast_precision_loss)] let base_delay = self.initial_delay_ms as f64
125 * self
126 .exponential_base
127 .powi(attempt.saturating_sub(1).cast_signed());
128
129 #[allow(clippy::cast_precision_loss)] let capped_delay = base_delay.min(self.max_delay_ms as f64);
132
133 let final_delay = if self.jitter {
135 let jitter_range = capped_delay * self.jitter_factor;
136 let jitter = rand::random::<f64>() * jitter_range * 2.0 - jitter_range;
137 (capped_delay + jitter).max(0.0)
138 } else {
139 capped_delay
140 };
141
142 #[allow(clippy::cast_sign_loss)] Duration::from_millis(final_delay as u64)
144 }
145
146 #[inline]
148 pub fn is_retryable_status(&self, status_code: u16) -> bool {
149 self.retryable_status_codes.contains(&status_code)
150 }
151
152 #[inline]
154 pub fn is_retryable_error(&self, error: &AptosError) -> bool {
155 match error {
156 AptosError::Http(_) | AptosError::RateLimited { .. } => true,
158 AptosError::Api { status_code, .. } => self.is_retryable_status(*status_code),
160 _ => false,
162 }
163 }
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct RetryConfigBuilder {
169 max_retries: Option<u32>,
170 initial_delay_ms: Option<u64>,
171 max_delay_ms: Option<u64>,
172 exponential_base: Option<f64>,
173 jitter: Option<bool>,
174 jitter_factor: Option<f64>,
175 retryable_status_codes: Option<HashSet<u16>>,
176}
177
178impl RetryConfigBuilder {
179 #[must_use]
181 pub fn max_retries(mut self, max_retries: u32) -> Self {
182 self.max_retries = Some(max_retries);
183 self
184 }
185
186 #[must_use]
188 pub fn initial_delay_ms(mut self, initial_delay_ms: u64) -> Self {
189 self.initial_delay_ms = Some(initial_delay_ms);
190 self
191 }
192
193 #[must_use]
195 pub fn max_delay_ms(mut self, max_delay_ms: u64) -> Self {
196 self.max_delay_ms = Some(max_delay_ms);
197 self
198 }
199
200 #[must_use]
202 pub fn exponential_base(mut self, base: f64) -> Self {
203 self.exponential_base = Some(base);
204 self
205 }
206
207 #[must_use]
209 pub fn jitter(mut self, jitter: bool) -> Self {
210 self.jitter = Some(jitter);
211 self
212 }
213
214 #[must_use]
216 pub fn jitter_factor(mut self, factor: f64) -> Self {
217 self.jitter_factor = Some(factor.clamp(0.0, 1.0));
218 self
219 }
220
221 #[must_use]
223 pub fn retryable_status_codes(mut self, codes: impl IntoIterator<Item = u16>) -> Self {
224 self.retryable_status_codes = Some(codes.into_iter().collect());
225 self
226 }
227
228 #[must_use]
230 pub fn add_retryable_status_code(mut self, code: u16) -> Self {
231 let mut codes = self.retryable_status_codes.unwrap_or_default();
232 codes.insert(code);
233 self.retryable_status_codes = Some(codes);
234 self
235 }
236
237 #[must_use]
239 pub fn build(self) -> RetryConfig {
240 let default = RetryConfig::default();
241 RetryConfig {
242 max_retries: self.max_retries.unwrap_or(default.max_retries),
243 initial_delay_ms: self.initial_delay_ms.unwrap_or(default.initial_delay_ms),
244 max_delay_ms: self.max_delay_ms.unwrap_or(default.max_delay_ms),
245 exponential_base: self.exponential_base.unwrap_or(default.exponential_base),
246 jitter: self.jitter.unwrap_or(default.jitter),
247 jitter_factor: self.jitter_factor.unwrap_or(default.jitter_factor),
248 retryable_status_codes: self
249 .retryable_status_codes
250 .unwrap_or(default.retryable_status_codes),
251 }
252 }
253}
254
255#[derive(Debug, Clone)]
257pub struct RetryExecutor {
258 config: Arc<RetryConfig>,
259}
260
261impl RetryExecutor {
262 pub fn new(config: RetryConfig) -> Self {
264 Self {
265 config: Arc::new(config),
266 }
267 }
268
269 pub fn from_shared(config: Arc<RetryConfig>) -> Self {
271 Self { config }
272 }
273
274 pub fn with_defaults() -> Self {
276 Self::new(RetryConfig::default())
277 }
278
279 pub async fn execute<F, Fut, T>(&self, operation: F) -> AptosResult<T>
290 where
291 F: Fn() -> Fut,
292 Fut: Future<Output = AptosResult<T>>,
293 {
294 let mut attempt = 0;
295
296 loop {
297 match operation().await {
298 Ok(result) => return Ok(result),
299 Err(error) => {
300 if attempt >= self.config.max_retries || !self.config.is_retryable_error(&error)
302 {
303 return Err(error);
304 }
305
306 attempt += 1;
307
308 let delay = if let AptosError::RateLimited {
311 retry_after_secs: Some(secs),
312 } = &error
313 {
314 let capped_secs = (*secs).min(300); Duration::from_secs(capped_secs)
317 } else {
318 self.config.delay_for_attempt(attempt)
319 };
320
321 if !delay.is_zero() {
322 sleep(delay).await;
323 }
324 }
325 }
326 }
327 }
328
329 pub async fn execute_with_predicate<F, Fut, T, P>(
336 &self,
337 operation: F,
338 should_retry: P,
339 ) -> AptosResult<T>
340 where
341 F: Fn() -> Fut,
342 Fut: Future<Output = AptosResult<T>>,
343 P: Fn(&AptosError) -> bool,
344 {
345 let mut attempt = 0;
346
347 loop {
348 match operation().await {
349 Ok(result) => return Ok(result),
350 Err(error) => {
351 if attempt >= self.config.max_retries || !should_retry(&error) {
352 return Err(error);
353 }
354
355 attempt += 1;
356
357 let delay = if let AptosError::RateLimited {
359 retry_after_secs: Some(secs),
360 } = &error
361 {
362 let capped_secs = (*secs).min(300);
363 Duration::from_secs(capped_secs)
364 } else {
365 self.config.delay_for_attempt(attempt)
366 };
367
368 if !delay.is_zero() {
369 sleep(delay).await;
370 }
371 }
372 }
373 }
374 }
375}
376
377pub trait RetryExt<T> {
379 fn with_retry(self, config: &RetryConfig) -> impl Future<Output = AptosResult<T>>;
381}
382
383pub async fn retry<F, Fut, T>(operation: F) -> AptosResult<T>
391where
392 F: Fn() -> Fut,
393 Fut: Future<Output = AptosResult<T>>,
394{
395 RetryExecutor::with_defaults().execute(operation).await
396}
397
398pub async fn retry_with_config<F, Fut, T>(config: &RetryConfig, operation: F) -> AptosResult<T>
406where
407 F: Fn() -> Fut,
408 Fut: Future<Output = AptosResult<T>>,
409{
410 RetryExecutor::new(config.clone()).execute(operation).await
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use std::sync::Arc;
417 use std::sync::atomic::{AtomicU32, Ordering};
418
419 #[test]
420 fn test_default_config() {
421 let config = RetryConfig::default();
422 assert_eq!(config.max_retries, 3);
423 assert_eq!(config.initial_delay_ms, 100);
424 assert!(config.jitter);
425 }
426
427 #[test]
428 fn test_no_retry_config() {
429 let config = RetryConfig::no_retry();
430 assert_eq!(config.max_retries, 0);
431 }
432
433 #[test]
434 fn test_builder() {
435 let config = RetryConfig::builder()
436 .max_retries(5)
437 .initial_delay_ms(200)
438 .max_delay_ms(5000)
439 .exponential_base(1.5)
440 .jitter(false)
441 .build();
442
443 assert_eq!(config.max_retries, 5);
444 assert_eq!(config.initial_delay_ms, 200);
445 assert_eq!(config.max_delay_ms, 5000);
446 assert!((config.exponential_base - 1.5).abs() < f64::EPSILON);
447 assert!(!config.jitter);
448 }
449
450 #[test]
451 fn test_delay_calculation_no_jitter() {
452 let config = RetryConfig::builder()
453 .initial_delay_ms(100)
454 .exponential_base(2.0)
455 .jitter(false)
456 .build();
457
458 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(0));
460
461 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(100));
463
464 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(200));
466
467 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(400));
469 }
470
471 #[test]
472 fn test_delay_capped_at_max() {
473 let config = RetryConfig::builder()
474 .initial_delay_ms(1000)
475 .max_delay_ms(2000)
476 .exponential_base(2.0)
477 .jitter(false)
478 .build();
479
480 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(2000));
482 }
483
484 #[test]
485 fn test_retryable_status_codes() {
486 let config = RetryConfig::default();
487
488 assert!(config.is_retryable_status(429)); assert!(config.is_retryable_status(503)); assert!(!config.is_retryable_status(400)); assert!(!config.is_retryable_status(404)); }
493
494 #[test]
495 fn test_retryable_errors() {
496 let config = RetryConfig::default();
497
498 let api_error = AptosError::Api {
500 status_code: 503,
501 message: "Service Unavailable".to_string(),
502 error_code: None,
503 vm_error_code: None,
504 };
505 assert!(config.is_retryable_error(&api_error));
506
507 let rate_limited = AptosError::RateLimited {
509 retry_after_secs: Some(30),
510 };
511 assert!(config.is_retryable_error(&rate_limited));
512
513 let api_error_400 = AptosError::Api {
515 status_code: 400,
516 message: "Bad Request".to_string(),
517 error_code: None,
518 vm_error_code: None,
519 };
520 assert!(!config.is_retryable_error(&api_error_400));
521
522 let not_found = AptosError::NotFound("resource".to_string());
524 assert!(!config.is_retryable_error(¬_found));
525 }
526
527 #[tokio::test]
528 async fn test_retry_succeeds_on_first_try() {
529 let executor = RetryExecutor::with_defaults();
530 let counter = Arc::new(AtomicU32::new(0));
531 let counter_clone = counter.clone();
532
533 let result = executor
534 .execute(|| {
535 let counter = counter_clone.clone();
536 async move {
537 counter.fetch_add(1, Ordering::SeqCst);
538 Ok::<_, AptosError>(42)
539 }
540 })
541 .await;
542
543 assert_eq!(result.unwrap(), 42);
544 assert_eq!(counter.load(Ordering::SeqCst), 1);
545 }
546
547 #[tokio::test]
548 async fn test_retry_succeeds_after_failures() {
549 let config = RetryConfig::builder()
550 .max_retries(3)
551 .initial_delay_ms(1) .jitter(false)
553 .build();
554 let executor = RetryExecutor::new(config);
555 let counter = Arc::new(AtomicU32::new(0));
556 let counter_clone = counter.clone();
557
558 let result = executor
559 .execute(|| {
560 let counter = counter_clone.clone();
561 async move {
562 let count = counter.fetch_add(1, Ordering::SeqCst);
563 if count < 2 {
564 Err(AptosError::Api {
565 status_code: 503,
566 message: "Service Unavailable".to_string(),
567 error_code: None,
568 vm_error_code: None,
569 })
570 } else {
571 Ok(42)
572 }
573 }
574 })
575 .await;
576
577 assert_eq!(result.unwrap(), 42);
578 assert_eq!(counter.load(Ordering::SeqCst), 3); }
580
581 #[tokio::test]
582 async fn test_retry_exhausted() {
583 let config = RetryConfig::builder()
584 .max_retries(2)
585 .initial_delay_ms(1)
586 .jitter(false)
587 .build();
588 let executor = RetryExecutor::new(config);
589 let counter = Arc::new(AtomicU32::new(0));
590 let counter_clone = counter.clone();
591
592 let result = executor
593 .execute(|| {
594 let counter = counter_clone.clone();
595 async move {
596 counter.fetch_add(1, Ordering::SeqCst);
597 Err::<i32, _>(AptosError::Api {
598 status_code: 503,
599 message: "Always fails".to_string(),
600 error_code: None,
601 vm_error_code: None,
602 })
603 }
604 })
605 .await;
606
607 assert!(result.is_err());
608 assert_eq!(counter.load(Ordering::SeqCst), 3); }
610
611 #[tokio::test]
612 async fn test_no_retry_on_non_retryable_error() {
613 let config = RetryConfig::builder()
614 .max_retries(3)
615 .initial_delay_ms(1)
616 .build();
617 let executor = RetryExecutor::new(config);
618 let counter = Arc::new(AtomicU32::new(0));
619 let counter_clone = counter.clone();
620
621 let result = executor
622 .execute(|| {
623 let counter = counter_clone.clone();
624 async move {
625 counter.fetch_add(1, Ordering::SeqCst);
626 Err::<i32, _>(AptosError::Api {
627 status_code: 400, message: "Bad Request".to_string(),
629 error_code: None,
630 vm_error_code: None,
631 })
632 }
633 })
634 .await;
635
636 assert!(result.is_err());
637 assert_eq!(counter.load(Ordering::SeqCst), 1); }
639
640 #[test]
641 fn test_aggressive_config() {
642 let config = RetryConfig::aggressive();
643 assert_eq!(config.max_retries, 5);
644 assert_eq!(config.initial_delay_ms, 50);
645 assert_eq!(config.max_delay_ms, 5_000);
646 assert!((config.exponential_base - 1.5).abs() < f64::EPSILON);
647 assert!(config.jitter);
648 }
649
650 #[test]
651 fn test_conservative_config() {
652 let config = RetryConfig::conservative();
653 assert_eq!(config.max_retries, 3);
654 assert_eq!(config.initial_delay_ms, 500);
655 assert_eq!(config.max_delay_ms, 30_000);
656 assert!((config.exponential_base - 2.0).abs() < f64::EPSILON);
657 assert!(config.jitter);
658 }
659
660 #[test]
661 fn test_builder_jitter_factor() {
662 let config = RetryConfig::builder().jitter_factor(0.25).build();
663
664 assert!((config.jitter_factor - 0.25).abs() < f64::EPSILON);
665 }
666
667 #[test]
668 fn test_builder_retryable_status_codes() {
669 let config = RetryConfig::builder()
670 .retryable_status_codes([500, 502])
671 .build();
672
673 assert!(config.is_retryable_status(500));
674 assert!(config.is_retryable_status(502));
675 assert!(!config.is_retryable_status(503)); }
677
678 #[test]
679 fn test_delay_with_jitter() {
680 let config = RetryConfig::builder()
681 .initial_delay_ms(100)
682 .jitter(true)
683 .jitter_factor(0.5)
684 .build();
685
686 let delay1 = config.delay_for_attempt(1);
688 assert!(delay1 >= Duration::from_millis(50));
690 assert!(delay1 <= Duration::from_millis(150));
691 }
692
693 #[test]
694 fn test_delay_zero_for_first_attempt() {
695 let config = RetryConfig::default();
696 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(0));
697 }
698
699 #[test]
700 fn test_retryable_error_transaction_error() {
701 let config = RetryConfig::default();
702
703 let txn_error = AptosError::Transaction("failed".to_string());
705 assert!(!config.is_retryable_error(&txn_error));
706 }
707
708 #[test]
709 fn test_retryable_error_invalid_address() {
710 let config = RetryConfig::default();
711
712 let addr_error = AptosError::InvalidAddress("bad".to_string());
714 assert!(!config.is_retryable_error(&addr_error));
715 }
716
717 #[tokio::test]
718 async fn test_retry_with_no_retry_config() {
719 let config = RetryConfig::no_retry();
720 let executor = RetryExecutor::new(config);
721 let counter = Arc::new(AtomicU32::new(0));
722 let counter_clone = counter.clone();
723
724 let result = executor
725 .execute(|| {
726 let counter = counter_clone.clone();
727 async move {
728 counter.fetch_add(1, Ordering::SeqCst);
729 Err::<i32, _>(AptosError::Api {
730 status_code: 503,
731 message: "Service Unavailable".to_string(),
732 error_code: None,
733 vm_error_code: None,
734 })
735 }
736 })
737 .await;
738
739 assert!(result.is_err());
740 assert_eq!(counter.load(Ordering::SeqCst), 1); }
742
743 #[test]
744 fn test_retry_config_clone() {
745 let config = RetryConfig::builder()
746 .max_retries(5)
747 .initial_delay_ms(200)
748 .build();
749
750 let cloned = config.clone();
751 assert_eq!(config.max_retries, cloned.max_retries);
752 assert_eq!(config.initial_delay_ms, cloned.initial_delay_ms);
753 }
754
755 #[test]
756 fn test_retry_config_debug() {
757 let config = RetryConfig::default();
758 let debug = format!("{config:?}");
759 assert!(debug.contains("RetryConfig"));
760 assert!(debug.contains("max_retries"));
761 }
762
763 #[test]
764 fn test_builder_add_retryable_status_code() {
765 let config = RetryConfig::builder()
766 .add_retryable_status_code(599)
767 .build();
768
769 assert!(config.is_retryable_status(599));
770 }
771
772 #[test]
773 fn test_builder_add_duplicate_status_code() {
774 let config = RetryConfig::builder()
775 .add_retryable_status_code(500)
776 .add_retryable_status_code(500) .build();
778
779 assert!(config.is_retryable_status(500));
781 assert_eq!(config.retryable_status_codes.len(), 1);
783 }
784
785 #[test]
786 fn test_builder_jitter_factor_clamped() {
787 let config = RetryConfig::builder()
788 .jitter_factor(2.0) .build();
790
791 assert!((config.jitter_factor - 1.0).abs() < f64::EPSILON);
792
793 let config2 = RetryConfig::builder()
794 .jitter_factor(-1.0) .build();
796
797 assert!(config2.jitter_factor.abs() < f64::EPSILON);
798 }
799
800 #[test]
801 fn test_retry_executor_new() {
802 let config = RetryConfig::default();
803 let executor = RetryExecutor::new(config.clone());
804
805 let debug = format!("{executor:?}");
806 assert!(debug.contains("RetryExecutor"));
807 }
808
809 #[tokio::test]
810 async fn test_retry_with_custom_predicate() {
811 let config = RetryConfig::builder()
812 .max_retries(3)
813 .initial_delay_ms(1)
814 .jitter(false)
815 .build();
816 let executor = RetryExecutor::new(config);
817 let counter = Arc::new(AtomicU32::new(0));
818 let counter_clone = counter.clone();
819
820 let result = executor
822 .execute_with_predicate(
823 || {
824 let counter = counter_clone.clone();
825 async move {
826 let count = counter.fetch_add(1, Ordering::SeqCst);
827 if count < 2 {
828 Err(AptosError::NotFound("test".to_string()))
829 } else {
830 Ok(42)
831 }
832 }
833 },
834 |_| true, )
836 .await;
837
838 assert_eq!(result.unwrap(), 42);
839 assert_eq!(counter.load(Ordering::SeqCst), 3);
840 }
841
842 #[tokio::test]
843 async fn test_retry_with_predicate_no_retry() {
844 let config = RetryConfig::builder()
845 .max_retries(3)
846 .initial_delay_ms(1)
847 .build();
848 let executor = RetryExecutor::new(config);
849 let counter = Arc::new(AtomicU32::new(0));
850 let counter_clone = counter.clone();
851
852 let result = executor
854 .execute_with_predicate(
855 || {
856 let counter = counter_clone.clone();
857 async move {
858 counter.fetch_add(1, Ordering::SeqCst);
859 Err::<i32, _>(AptosError::Api {
860 status_code: 503,
861 message: "Fail".to_string(),
862 error_code: None,
863 vm_error_code: None,
864 })
865 }
866 },
867 |_| false, )
869 .await;
870
871 assert!(result.is_err());
872 assert_eq!(counter.load(Ordering::SeqCst), 1); }
874
875 #[tokio::test]
876 async fn test_retry_convenience_function() {
877 let counter = Arc::new(AtomicU32::new(0));
878 let counter_clone = counter.clone();
879
880 let result = retry(|| {
881 let counter = counter_clone.clone();
882 async move {
883 counter.fetch_add(1, Ordering::SeqCst);
884 Ok::<_, AptosError>(42)
885 }
886 })
887 .await;
888
889 assert_eq!(result.unwrap(), 42);
890 assert_eq!(counter.load(Ordering::SeqCst), 1);
891 }
892
893 #[tokio::test]
894 async fn test_retry_with_config_convenience_function() {
895 let config = RetryConfig::builder()
896 .max_retries(1)
897 .initial_delay_ms(1)
898 .jitter(false)
899 .build();
900 let counter = Arc::new(AtomicU32::new(0));
901 let counter_clone = counter.clone();
902
903 let result = retry_with_config(&config, || {
904 let counter = counter_clone.clone();
905 async move {
906 let count = counter.fetch_add(1, Ordering::SeqCst);
907 if count < 1 {
908 Err(AptosError::Api {
910 status_code: 503,
911 message: "Service Unavailable".to_string(),
912 error_code: None,
913 vm_error_code: None,
914 })
915 } else {
916 Ok(42)
917 }
918 }
919 })
920 .await;
921
922 assert_eq!(result.unwrap(), 42);
923 assert_eq!(counter.load(Ordering::SeqCst), 2);
924 }
925
926 #[test]
927 fn test_retryable_rate_limited_error() {
928 let config = RetryConfig::default();
929
930 let rate_limited = AptosError::RateLimited {
932 retry_after_secs: Some(5),
933 };
934 assert!(config.is_retryable_error(&rate_limited));
935 }
936
937 #[test]
938 fn test_builder_default_debug() {
939 let builder = RetryConfigBuilder::default();
940 let debug = format!("{builder:?}");
941 assert!(debug.contains("RetryConfigBuilder"));
942 }
943}