Coverage for yield_analysis_sdk\validators.py: 78%

59 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-02 19:22 +0800

1""" 

2Common validators and mixins for the yield analysis SDK. 

3""" 

4 

5import re 

6from typing import TYPE_CHECKING, Any, Union 

7 

8from pydantic import field_validator 

9 

10from .exceptions import ValidationError 

11 

12if TYPE_CHECKING: 

13 from .type import Chain 

14 

15 

16class ChainValidatorMixin: 

17 """Mixin class that provides chain validation functionality.""" 

18 

19 @field_validator("chain", mode="before") 

20 @classmethod 

21 def validate_chain(cls, v: Any) -> "Chain": 

22 """Validate chain and return OTHER if not found.""" 

23 from .type import Chain # Import here to avoid circular import 

24 

25 if isinstance(v, str): 

26 try: 

27 return Chain(v) 

28 except ValueError: 

29 return Chain.OTHER 

30 elif isinstance(v, Chain): 

31 return v 

32 else: 

33 return Chain.OTHER 

34 

35 

36class VaultAddressValidatorMixin: 

37 """Mixin class that provides vault address validation functionality.""" 

38 

39 @field_validator("vault_address", mode="before") 

40 @classmethod 

41 def validate_vault_address(cls, v: Any) -> str: 

42 """Validate vault address format and normalize it.""" 

43 if isinstance(v, str): 

44 return normalize_address(v) 

45 elif v is None: 

46 raise ValidationError("Vault address cannot be None") 

47 else: 

48 return str(v) 

49 

50 

51class UnderlyingTokenValidatorMixin: 

52 """Mixin class that provides token address validation functionality.""" 

53 

54 @field_validator("underlying_token", mode="before") 

55 @classmethod 

56 def validate_underlying_token(cls, v: Any) -> str: 

57 """Validate underlying token address format and normalize it.""" 

58 if isinstance(v, str): 

59 return normalize_address(v) 

60 elif v is None: 

61 raise ValidationError("Underlying token cannot be None") 

62 else: 

63 return str(v) 

64 

65 

66def validate_chain_value(value: Any) -> "Chain": 

67 """ 

68 Standalone function to validate chain values. 

69 

70 Args: 

71 value: The value to validate 

72 

73 Returns: 

74 Chain enum value, defaults to Chain.OTHER if invalid 

75 """ 

76 from .type import Chain # Import here to avoid circular import 

77 

78 if isinstance(value, str): 

79 try: 

80 return Chain(value) 

81 except ValueError: 

82 return Chain.OTHER 

83 elif isinstance(value, Chain): 

84 return value 

85 else: 

86 return Chain.OTHER 

87 

88 

89def normalize_address(address: str) -> str: 

90 """ 

91 Normalize address format. 

92 

93 Args: 

94 address: The address to normalize 

95 

96 Returns: 

97 Normalized address (lowercase, with 0x prefix) 

98 """ 

99 if not address: 

100 raise ValidationError("Address cannot be empty") 

101 

102 # Remove whitespace 

103 address = address.strip() 

104 

105 # Ensure it starts with 0x 

106 if not address.startswith("0x"): 

107 address = "0x" + address 

108 

109 # Convert to lowercase 

110 address = address.lower() 

111 

112 # Validate format (0x followed by 40 hex characters) 

113 if not re.match(r"^0x[a-f0-9]{40}$", address): 

114 raise ValidationError(f"Invalid address format: {address}") 

115 

116 return address 

117 

118 

119def validate_address_value(address: str) -> str: 

120 """ 

121 Standalone function to validate address values. 

122 

123 Args: 

124 address: The address to validate 

125 

126 Returns: 

127 Normalized address 

128 

129 Raises: 

130 ValidationError: If the address format is invalid 

131 """ 

132 return normalize_address(address)