Skip to content

Commit

Permalink
Merge pull request #11 from bodo-run/token-count-shorthand
Browse files Browse the repository at this point in the history
Token count shorthand
  • Loading branch information
mohsen1 authored Jan 19, 2025
2 parents fbd5dd3 + da87ecc commit bc3e35e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ yek src/ | pbcopy
Cap the max size to 128K tokens and only process the `src` directory:

```bash
yek --max-size 128000 --tokens src/
yek --max-size 128K --tokens src/
```

Cap the max size to 100KB and only process the `src` directory, writing to a specific directory:
Expand Down
85 changes: 79 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,82 @@ use tracing::{info, Level};
use tracing_subscriber::fmt;
use yek::{find_config_file, load_config_file, serialize_repo};

fn parse_size_input(input: &str) -> std::result::Result<usize, String> {
Byte::from_str(input)
.map(|b| b.get_bytes() as usize)
.map_err(|e| e.to_string())
fn parse_size_input(input: &str, is_tokens: bool) -> std::result::Result<usize, String> {
if is_tokens {
// Handle token count with K suffix
let input = input.trim();
if input.to_uppercase().ends_with('K') {
let num = input[..input.len() - 1]
.parse::<usize>()
.map_err(|e| format!("Invalid token count: {}", e))?;
Ok(num * 1000)
} else {
input
.parse::<usize>()
.map_err(|e| format!("Invalid token count: {}", e))
}
} else {
Byte::from_str(input)
.map(|b| b.get_bytes() as usize)
.map_err(|e| e.to_string())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_parse_size_input_bytes() {
// Using byte_unit::Byte to calculate expected values
assert_eq!(
parse_size_input("10MB", false).unwrap(),
Byte::from_str("10MB").unwrap().get_bytes() as usize
);
assert_eq!(
parse_size_input("128KB", false).unwrap(),
Byte::from_str("128KB").unwrap().get_bytes() as usize
);
assert_eq!(
parse_size_input("1GB", false).unwrap(),
Byte::from_str("1GB").unwrap().get_bytes() as usize
);
assert!(parse_size_input("invalid", false).is_err());
}

#[test]
fn test_parse_size_input_tokens() {
// Test K suffix variations
assert_eq!(parse_size_input("100K", true).unwrap(), 100_000);
assert_eq!(parse_size_input("100k", true).unwrap(), 100_000);
assert_eq!(parse_size_input("0K", true).unwrap(), 0);
assert_eq!(parse_size_input("1K", true).unwrap(), 1_000);
assert_eq!(parse_size_input("1k", true).unwrap(), 1_000);

// Test without K suffix
assert_eq!(parse_size_input("100", true).unwrap(), 100);
assert_eq!(parse_size_input("1000", true).unwrap(), 1000);
assert_eq!(parse_size_input("0", true).unwrap(), 0);

// Test invalid inputs
assert!(parse_size_input("K", true).is_err());
assert!(parse_size_input("-1K", true).is_err());
assert!(parse_size_input("-100", true).is_err());
assert!(parse_size_input("100KB", true).is_err());
assert!(parse_size_input("invalid", true).is_err());
assert!(parse_size_input("", true).is_err());
assert!(parse_size_input(" ", true).is_err());
assert!(parse_size_input("100K100", true).is_err());
assert!(parse_size_input("100.5K", true).is_err());
}

#[test]
fn test_parse_size_input_whitespace() {
// Test whitespace handling
assert_eq!(parse_size_input(" 100K ", true).unwrap(), 100_000);
assert_eq!(parse_size_input("\t100k\n", true).unwrap(), 100_000);
assert_eq!(parse_size_input(" 100 ", true).unwrap(), 100);
}
}

fn main() -> Result<()> {
Expand All @@ -26,7 +98,7 @@ fn main() -> Result<()> {
.arg(
Arg::new("max-size")
.long("max-size")
.help("Maximum size per chunk (e.g. '10MB', '128KB', '1GB')")
.help("Maximum size per chunk (e.g. '10MB', '128KB', '1GB' or '100K' tokens when --tokens is used)")
.default_value("10MB"),
)
.arg(
Expand Down Expand Up @@ -68,7 +140,8 @@ fn main() -> Result<()> {

// Parse max size
let max_size_str = matches.get_one::<String>("max-size").unwrap();
let max_size = parse_size_input(max_size_str).map_err(|e| anyhow::anyhow!(e))?;
let max_size = parse_size_input(max_size_str, matches.get_flag("tokens"))
.map_err(|e| anyhow::anyhow!("{}", e))?;

// Get directories to process
let directories: Vec<PathBuf> = matches
Expand Down

0 comments on commit bc3e35e

Please sign in to comment.