diff --git a/Cargo.toml b/Cargo.toml index d6285fc..d546ea4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,21 +4,21 @@ version = "0.1.0" edition = "2024" [dependencies] -anyhow = "1.0.98" +anyhow = "1.0" backon = "1" -clap = { version = "4", features = ["derive"] } +clap = { version = "4.5", features = ["derive"] } futures = "0.3" -regex = "1.11.1" -reqwest = { version = "0.12.22", features = ["json"] } -serde_json = "1.0.140" +regex = "1.11" +reqwest = { version = "0.12", features = ["json"] } +serde_json = "1.0" tokio = { version = "1.46", features = [ "macros", "rt-multi-thread", "time", "fs", ] } -url = { version = "2.5.4", features = ["serde"] } +url = { version = "2.5", features = ["serde"] } async-stream = "0.3" -figment = { version = "0.10.19", features = ["env", "toml"] } -serde = { version = "1.0.219", features = ["derive"] } +figment = { version = "0.10", features = ["env", "toml"] } +serde = { version = "1.0", features = ["derive"] } toml = "0.9" diff --git a/src/config/basic_client.rs b/src/config/basic_client.rs index 46a15d4..b3dcf29 100644 --- a/src/config/basic_client.rs +++ b/src/config/basic_client.rs @@ -5,9 +5,9 @@ use reqwest::Client; use crate::config::KeyCheckerConfig; pub fn client_builder(config: &KeyCheckerConfig) -> Result { - let mut builder = Client::builder().timeout(Duration::from_secs(config.timeout_sec())); + let mut builder = Client::builder().timeout(Duration::from_secs(config.timeout_sec)); - if let Some(ref proxy_url) = config.proxy() { + if let Some(ref proxy_url) = config.proxy { builder = builder.proxy(reqwest::Proxy::all(proxy_url.clone())?); } diff --git a/src/config/basic_config.rs b/src/config/basic_config.rs index 50efcbc..74c56b8 100644 --- a/src/config/basic_config.rs +++ b/src/config/basic_config.rs @@ -1,4 +1,4 @@ -use anyhow::{Ok, Result}; +use anyhow::Result; use clap::Parser; use figment::{ Figment, @@ -10,55 +10,72 @@ use std::path::PathBuf; use std::sync::LazyLock; use url::Url; -#[derive(Debug, Serialize, Deserialize, Parser)] +/// Cli arguments +#[derive(Parser, Debug, Serialize, Deserialize)] +struct Cli { + #[arg(short = 'i', long)] + #[serde(skip_serializing_if = "Option::is_none")] + input_path: Option, + + #[arg(short = 'o', long)] + #[serde(skip_serializing_if = "Option::is_none")] + output_path: Option, + + #[arg(short = 'b', long)] + #[serde(skip_serializing_if = "Option::is_none")] + backup_path: Option, + + #[arg(short = 'u', long)] + #[serde(skip_serializing_if = "Option::is_none")] + api_host: Option, + + #[arg(short = 't', long)] + #[serde(skip_serializing_if = "Option::is_none")] + timeout_sec: Option, + + #[arg(short = 'c', long)] + #[serde(skip_serializing_if = "Option::is_none")] + concurrency: Option, + + #[arg(short = 'x', long)] + #[serde(skip_serializing_if = "Option::is_none")] + proxy: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct KeyCheckerConfig { // Input file path containing API keys to check. #[serde(default)] - #[arg(short, long)] - input_path: Option, + pub input_path: PathBuf, // Output file path for valid API keys. #[serde(default)] - #[arg(short, long)] - output_path: Option, + pub output_path: PathBuf, // Backup file path for all API keys. #[serde(default)] - #[arg(short, long)] - backup_path: Option, + pub backup_path: PathBuf, // API host URL for key validation. - #[serde(default)] - #[arg(short, long)] - api_host: Option, + #[serde(default = "default_api_host")] + pub api_host: Url, // Request timeout in seconds. #[serde(default)] - #[arg(short, long)] - timeout_sec: Option, + pub timeout_sec: u64, // Maximum number of concurrent requests. #[serde(default)] - #[arg(short, long)] - concurrency: Option, + pub concurrency: usize, // Optional proxy URL for HTTP requests (e.g., --proxy http://user:pass@host:port). #[serde(default)] - #[arg(short, long)] - proxy: Option, + pub proxy: Option, } impl Default for KeyCheckerConfig { fn default() -> Self { - Self { - input_path: Some(default_input_path()), - output_path: Some(default_output_path()), - backup_path: Some(default_backup_path()), - api_host: Some(default_api_host()), - timeout_sec: Some(default_timeout()), - concurrency: Some(default_concurrency()), - proxy: None, - } + (*DEFAULT_CONFIG).clone() } } impl KeyCheckerConfig { @@ -73,81 +90,30 @@ impl KeyCheckerConfig { fs::write(CONFIG_PATH.as_path(), toml_content)?; } - // Load configuration from config.toml, environment variables, and defaults - let mut figment = Figment::new() + // Load configuration from config.toml, environment variables, and CLI arguments + let config: Self = Figment::new() .merge(Serialized::defaults(Self::default())) .merge(Toml::file(CONFIG_PATH.as_path())) - .merge(Env::prefixed("KEYCHECKER_")); + .merge(Env::prefixed("KEYCHECKER_")) + .merge(Serialized::defaults(Cli::parse())) + .extract()?; - // Only merge non-None command line arguments - let cli_args = Self::parse(); - if let Some(input_path) = cli_args.input_path { - figment = figment.merge(("input_path", input_path)); - } - if let Some(output_path) = cli_args.output_path { - figment = figment.merge(("output_path", output_path)); - } - if let Some(backup_path) = cli_args.backup_path { - figment = figment.merge(("backup_path", backup_path)); - } - if let Some(api_host) = cli_args.api_host { - figment = figment.merge(("api_host", api_host)); - } - if let Some(timeout_sec) = cli_args.timeout_sec { - figment = figment.merge(("timeout_sec", timeout_sec)); - } - if let Some(concurrency) = cli_args.concurrency { - figment = figment.merge(("concurrency", concurrency)); - } - if let Some(proxy) = cli_args.proxy { - figment = figment.merge(("proxy", proxy)); - } - - let config = figment.extract()?; - - println!("Final loaded config: {:?}", config); + dbg!(&config); Ok(config) } - pub fn input_path(&self) -> PathBuf { - self.input_path.clone().unwrap_or_else(default_input_path) - } - pub fn output_path(&self) -> PathBuf { - self.output_path.clone().unwrap_or_else(default_output_path) - } - pub fn backup_path(&self) -> PathBuf { - self.backup_path.clone().unwrap_or_else(default_backup_path) - } - pub fn api_host(&self) -> Url { - self.api_host.clone().unwrap_or_else(default_api_host) - } - pub fn timeout_sec(&self) -> u64 { - self.timeout_sec.unwrap_or_else(default_timeout) - } - pub fn concurrency(&self) -> usize { - self.concurrency.unwrap_or_else(default_concurrency) - } - pub fn proxy(&self) -> Option { - self.proxy.clone() - } } -fn default_input_path() -> PathBuf { - "keys.txt".into() -} - -fn default_output_path() -> PathBuf { - "output_keys.txt".into() -} -fn default_backup_path() -> PathBuf { - "backup_keys.txt".into() -} +// Single LazyLock for entire default configuration +static DEFAULT_CONFIG: LazyLock = LazyLock::new(|| KeyCheckerConfig { + input_path: "keys.txt".into(), + output_path: "output_keys.txt".into(), + backup_path: "backup_keys.txt".into(), + api_host: Url::parse("https://generativelanguage.googleapis.com/").unwrap(), + timeout_sec: 15, + concurrency: 50, + proxy: None, +}); fn default_api_host() -> Url { - Url::parse("https://generativelanguage.googleapis.com/").unwrap() -} -fn default_timeout() -> u64 { - 20 -} -fn default_concurrency() -> usize { - 30 + DEFAULT_CONFIG.api_host.clone() } diff --git a/src/key_validator.rs b/src/key_validator.rs index 41f56e7..8304531 100644 --- a/src/key_validator.rs +++ b/src/key_validator.rs @@ -61,7 +61,6 @@ pub async fn validate_key_with_retry( async fn keytest(client: Client, api_host: &Url, key: &GeminiKey) -> Result { const API_PATH: &str = "v1beta/models/gemini-2.0-flash-exp:generateContent"; let full_url = api_host.join(API_PATH)?; - let request_body = serde_json::json!({ "contents": [ { diff --git a/src/main.rs b/src/main.rs index a20d2d6..cb4b70d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ use gemini_keychecker::validation::ValidationService; #[tokio::main] async fn main() -> Result<()> { let config = KeyCheckerConfig::load_config().unwrap(); - let keys = load_keys(config.input_path().as_path())?; + let keys = load_keys(config.input_path.as_path())?; let client = client_builder(&config)?; let validation_service = ValidationService::new(config, client); diff --git a/src/validation.rs b/src/validation.rs index 07b7cdd..f3e7d7f 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -42,13 +42,13 @@ impl ValidationService { // Create stream to validate keys concurrently let valid_keys_stream = stream - .map(|key| validate_key_with_retry(self.client.to_owned(), self.config.api_host(), key)) - .buffer_unordered(self.config.concurrency()) + .map(|key| validate_key_with_retry(self.client.to_owned(), self.config.api_host.clone(), key)) + .buffer_unordered(self.config.concurrency) .filter_map(|r| async { r }); pin_mut!(valid_keys_stream); // Open output file for writing valid keys - let output_file = fs::File::create(&self.config.output_path()).await?; + let output_file = fs::File::create(&self.config.output_path).await?; let mut buffer_writer = tokio::io::BufWriter::new(output_file); // Process validated keys and write to output file