Skip to content

Commit

Permalink
refactor: enhance configuration management with default values and im…
Browse files Browse the repository at this point in the history
…proved validation
  • Loading branch information
John0n1 committed Feb 2, 2025
1 parent 5e80893 commit f494767
Showing 1 changed file with 135 additions and 59 deletions.
194 changes: 135 additions & 59 deletions python/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@

logger = logger.getLogger(__name__)

# Constants
DEFAULT_MAX_GAS_PRICE = 100_000_000_000 # 100 Gwei in wei
DEFAULT_GAS_LIMIT = 1_000_000
DEFAULT_MAX_SLIPPAGE = 0.01
DEFAULT_MIN_PROFIT = 0.001
DEFAULT_MIN_BALANCE = 0.000001

# Standard Ethereum addresses
MAINNET_ADDRESSES = {
'WETH': '0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2',
'USDC': '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48',
'USDT': '0xdAC17F958D2ee523a2206206994597C13D831ec7'
}

class Configuration:
"""
Loads configuration from environment variables and JSON files.
Expand All @@ -28,12 +42,24 @@ def __init__(self, env_path: Optional[str] = None) -> None:
"""
self.env_path = env_path if env_path else ".env"
self._load_env()
self._initialize_defaults()
self.signatures: Dict[str, Dict[str, str]] = {}
self.method_selectors: Dict[str, Dict[str, str]] = {}
self.BASE_PATH = Path(__file__).parent.parent

def _initialize_defaults(self) -> None:
"""Initialize configuration with default values."""
# General settings
self.MAX_GAS_PRICE: int = self._get_env_int("MAX_GAS_PRICE", 100_000_000_000) # 100 Gwei in wei
self.GAS_LIMIT: int = self._get_env_int("GAS_LIMIT", 1_000_000)
self.MAX_SLIPPAGE: float = self._get_env_float("MAX_SLIPPAGE", 0.01)
self.MIN_PROFIT: float = self._get_env_float("MIN_PROFIT", 0.001)
self.MIN_BALANCE: float = self._get_env_float("MIN_BALANCE", 0.000001)
self.MAX_GAS_PRICE = self._get_env_int("MAX_GAS_PRICE", DEFAULT_MAX_GAS_PRICE)
self.GAS_LIMIT = self._get_env_int("GAS_LIMIT", DEFAULT_GAS_LIMIT)
self.MAX_SLIPPAGE = self._get_env_float("MAX_SLIPPAGE", DEFAULT_MAX_SLIPPAGE)
self.MIN_PROFIT = self._get_env_float("MIN_PROFIT", DEFAULT_MIN_PROFIT)
self.MIN_BALANCE = self._get_env_float("MIN_BALANCE", DEFAULT_MIN_BALANCE)

# Standard addresses
self.WETH_ADDRESS = MAINNET_ADDRESSES['WETH']
self.USDC_ADDRESS = MAINNET_ADDRESSES['USDC']
self.USDT_ADDRESS = MAINNET_ADDRESSES['USDT']

# API Keys and Endpoints
self.ETHERSCAN_API_KEY: str = self._get_env_str("ETHERSCAN_API_KEY")
Expand All @@ -51,7 +77,6 @@ def __init__(self, env_path: Optional[str] = None) -> None:
self.WALLET_KEY: str = self._get_env_str("WALLET_KEY")

# Paths
self.BASE_PATH: Path = Path(__file__).parent.parent
self.ERC20_ABI: Path = self._resolve_path("ERC20_ABI")
self.AAVE_FLASHLOAN_ABI: Path = self._resolve_path("AAVE_FLASHLOAN_ABI")
self.AAVE_POOL_ABI: Path = self._resolve_path("AAVE_POOL_ABI")
Expand Down Expand Up @@ -94,32 +119,32 @@ def __init__(self, env_path: Optional[str] = None) -> None:
# Create directories if they don't exist
os.makedirs(self.LINEAR_REGRESSION_PATH, exist_ok=True)

# WETH and USDC addresses
self.WETH_ADDRESS: str = "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2" # Mainnet WETH
self.USDC_ADDRESS: str = "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48" # Mainnet USDC
self.USDT_ADDRESS: str = "0xdAC17F958D2ee523a2206206994597C13D831ec7" # Mainnet USDT

def _load_env(self) -> None:
"""Load environment variables fromenv file."""
dotenv.load_dotenv(dotenv_path=self.env_path)
logger.debug(f"Environment variables loaded from: {self.env_path}")

def _validate_ethereum_address(self, address: str, var_name: str) -> str:
"""Validate Ethereum address format."""
"""
Validate Ethereum address format with improved checks.
Returns: Normalized address string
Raises: ValueError if invalid
"""
if not isinstance(address, str):
raise ValueError(f"Invalid {var_name}: must be a string")

clean_address = address.lower().replace('0x', '')
raise ValueError(f"{var_name} must be a string, got {type(address)}")

if len(clean_address) != 40:
raise ValueError(f"Invalid {var_name} length: {address}")
clean_address = address.lower().strip()
if not clean_address.startswith('0x'):
clean_address = '0x' + clean_address

if len(clean_address) != 42:
raise ValueError(f"Invalid {var_name} length: {len(clean_address)}")

try:
# Check if address contains only valid hex characters
int(clean_address, 16)
return f"0x{clean_address}"
int(clean_address[2:], 16)
return clean_address
except ValueError:
raise ValueError(f"Invalid hex format for {var_name}: {address}")
raise ValueError(f"Invalid hex format for {var_name}")

def _get_env_str(self, var_name: str, default: Optional[str] = None) -> str:
"""Get an environment variable as string, raising error if missing."""
Expand Down Expand Up @@ -179,26 +204,25 @@ def _resolve_path(self, path_env_var: str) -> Path:
logger.debug(f"Resolved path: {full_path}")
return full_path

async def _load_json(self, file_path: Path, description: str) -> Any:
"""Load JSON data from a file with proper async."""
async def _load_json_safe(self, file_path: Path, description: str) -> Any:
"""Load JSON with better error handling and validation."""
try:
async with aiofiles.open(file_path, 'r') as f:
data = json.loads(await f.read())
logger.debug(f"Successfully loaded {description} from {file_path}")
return data
except FileNotFoundError as e:
logger.error(f"File not found for {description}: {file_path} - {e}")
raise
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for {description} in file {file_path}: {e}")
raise
async with aiofiles.open(file_path, 'r') as f:
content = await f.read()
try:
data = json.loads(content)
logger.debug(f"Loaded {description} from {file_path}")
return data
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in {description}: {e}")
except FileNotFoundError:
raise FileNotFoundError(f"Missing {description} file: {file_path}")
except Exception as e:
logger.error(f"Unexpected error loading {description} from {file_path}: {e}")
raise
raise RuntimeError(f"Error loading {description}: {e}")

async def get_token_addresses(self) -> List[str]:
"""Get the list of monitored token addresses from the config file."""
data = await self._load_json(self.TOKEN_ADDRESSES, "monitored tokens")
data = await self._load_json_safe(self.TOKEN_ADDRESSES, "monitored tokens")
if not isinstance(data, dict):
logger.error("Invalid format for token addresses: must be a dictionary")
raise ValueError("Invalid format for token addresses: must be a dictionary")
Expand All @@ -210,15 +234,15 @@ async def get_token_addresses(self) -> List[str]:

async def get_token_symbols(self) -> Dict[str, str]:
"""Get the mapping of token addresses to symbols from the config file."""
data = await self._load_json(self.TOKEN_SYMBOLS, "token symbols")
data = await self._load_json_safe(self.TOKEN_SYMBOLS, "token symbols")
if not isinstance(data, dict):
logger.error("Invalid format for token symbols: must be a dict")
raise ValueError("Invalid format for token symbols: must be a dict")
return data

async def get_erc20_signatures(self) -> Dict[str, str]:
"""Load ERC20 function signatures from JSON."""
data = await self._load_json(self.ERC20_SIGNATURES, "ERC20 function signatures")
data = await self._load_json_safe(self.ERC20_SIGNATURES, "ERC20 function signatures")
if not isinstance(data, dict):
logger.error("Invalid format for ERC20 signatures: must be a dict")
raise ValueError("Invalid format for ERC20 signatures: must be a dict")
Expand Down Expand Up @@ -255,33 +279,85 @@ async def load(self) -> None:
"""Load and validate all configuration data."""
try:
# Create required directories if they don't exist
os.makedirs(self.LINEAR_REGRESSION_PATH, exist_ok=True)
os.makedirs(self.BASE_PATH / "abi", exist_ok=True)
os.makedirs(self.BASE_PATH / "utils", exist_ok=True)

# Load critical ABIs
self.AAVE_FLASHLOAN_ABI = await self.load_abi_from_path(self._resolve_path("AAVE_FLASHLOAN_ABI"))
self.AAVE_POOL_ABI = await self.load_abi_from_path(self._resolve_path("AAVE_POOL_ABI"))
self._create_required_directories()

# Validate API keys are set
required_keys = [
'ETHERSCAN_API_KEY',
'INFURA_API_KEY',
'COINGECKO_API_KEY',
'COINMARKETCAP_API_KEY',
'CRYPTOCOMPARE_API_KEY'
]
# Load and validate critical ABIs
await self._load_critical_abis()

# Validate API keys
self._validate_api_keys()

# Validate addresses
self._validate_addresses()

for key in required_keys:
if not getattr(self, key):
logger.warning(f"Missing API key: {key}")

logger.info("Configuration loaded successfully")

except Exception as e:
logger.error(f"Configuration load failed: {e}")
logger.error(f"Configuration load failed: {str(e)}")
raise

def _create_required_directories(self) -> None:
"""Create necessary directories if they don't exist."""
required_dirs = [
self.LINEAR_REGRESSION_PATH,
self.BASE_PATH / "abi",
self.BASE_PATH / "utils"
]

for directory in required_dirs:
os.makedirs(directory, exist_ok=True)
logger.debug(f"Ensured directory exists: {directory}")

async def _load_critical_abis(self) -> None:
"""Load and validate critical ABIs."""
try:
self.AAVE_FLASHLOAN_ABI = await self.load_abi_from_path(
self._resolve_path("AAVE_FLASHLOAN_ABI")
)
self.AAVE_POOL_ABI = await self.load_abi_from_path(
self._resolve_path("AAVE_POOL_ABI")
)
except Exception as e:
raise RuntimeError(f"Failed to load critical ABIs: {e}")

def _validate_api_keys(self) -> None:
"""Validate required API keys are set."""
required_keys = [
'ETHERSCAN_API_KEY',
'INFURA_API_KEY',
'COINGECKO_API_KEY',
'COINMARKETCAP_API_KEY',
'CRYPTOCOMPARE_API_KEY'
]

missing_keys = [
key for key in required_keys
if not getattr(self, key, None)
]

if missing_keys:
logger.warning(f"Missing API keys: {', '.join(missing_keys)}")

def _validate_addresses(self) -> None:
"""Validate all Ethereum addresses in configuration."""
address_fields = [
'WALLET_ADDRESS',
'UNISWAP_ADDRESS',
'SUSHISWAP_ADDRESS',
'AAVE_POOL_ADDRESS',
'AAVE_FLASHLOAN_ADDRESS'
]

for field in address_fields:
value = getattr(self, field, None)
if value:
try:
setattr(self, field,
self._validate_ethereum_address(value, field))
except ValueError as e:
logger.error(f"Invalid address for {field}: {e}")
raise

def _validate_abi(self, abi: List[Dict], abi_type: str) -> bool:
"""Validate the structure and required methods of an ABI."""
if not isinstance(abi, list):
Expand Down

0 comments on commit f494767

Please sign in to comment.