diff --git a/README.md b/README.md index 809f94d..27d8e87 100644 --- a/README.md +++ b/README.md @@ -48,9 +48,12 @@ suggest = "{{command[0]}} fix {{command[2:]}}" pattern = [ "pattern 1", ] -# this will add a `sudo` before the command if the `sudo` is found by `which` +# this will add a `sudo` before the command if: +# - the `sudo` is found by `which` +# - the command last command does not contain `sudo` +# - the last command typed by the user does not contain `sudo` suggest = ''' -#[executable(sudo)] +#[executable(sudo), !cmd_contains(sudo)] sudo {{command}}''' ``` @@ -60,6 +63,12 @@ The placeholder is evaluated as following: - `{{command[1]}}`: The first argument of the command (the command itself has index of 0) - `{{command[2:5]}}`: The second to fifth arguments. If any of the side is not specified, them it defaults to the start (if it is left) or the end (if it is right). +The suggestion can have additional conditions to check. To specify the conditions, add a `#[...]` at the first line (just like derive macros in Rust). Available conditions: + +- `executable`: Check if the argument can be found by `which` +- `cmd_contains`: Check if the last user input contains the argument +- `err_contains`: Check if the error of the command contains the argument + ## Current Progress We need more rule files! diff --git a/rule_parser/src/lib.rs b/rule_parser/src/lib.rs index 761424a..7beb7e2 100644 --- a/rule_parser/src/lib.rs +++ b/rule_parser/src/lib.rs @@ -14,11 +14,11 @@ pub fn parse_rules(input: TokenStream) -> TokenStream { #[derive(serde::Deserialize)] struct Rule { command: String, - match_output: Vec, + match_err: Vec, } #[derive(serde::Deserialize)] -struct MatchOutput { +struct MatchError { pattern: Vec, suggest: Vec, } @@ -43,13 +43,13 @@ fn gen_string_hashmap(rules: Vec) -> String { for rule in rules { let command = rule.command.to_owned(); string_hashmap.push_str(&format!("(\"{}\", vec![", command)); - for match_output in rule.match_output { - let pattern = match_output + for match_err in rule.match_err { + let pattern = match_err .pattern .iter() .map(|x| x.to_lowercase()) .collect::>(); - let suggest = match_output.suggest; + let suggest = match_err.suggest; string_hashmap.push_str(&format!( "(vec![\"{}\"], vec![\"{}\"]),", pattern.join("\", \""), diff --git a/rules/sudo-doas.toml b/rules/sudo-doas.toml index e73d928..699fce6 100644 --- a/rules/sudo-doas.toml +++ b/rules/sudo-doas.toml @@ -1,6 +1,6 @@ command = "privilege" -[[match_output]] +[[match_err]] pattern = [ "as root", "authentication is required", diff --git a/src/corrections.rs b/src/corrections.rs index e3cd642..07ebcc9 100644 --- a/src/corrections.rs +++ b/src/corrections.rs @@ -6,25 +6,25 @@ use crate::shell::{command_output, PRIVILEGE_LIST}; use crate::style::highlight_difference; pub fn correct_command(shell: &str, last_command: &str) -> Option { - let command_output = command_output(shell, last_command); + let err = command_output(shell, last_command); let split_command = last_command.split_whitespace().collect::>(); - let command = match PRIVILEGE_LIST.contains(&split_command[0]) { + let executable = match PRIVILEGE_LIST.contains(&split_command[0]) { true => split_command.get(1).expect("No command found."), false => split_command.first().expect("No command found."), }; - if !PRIVILEGE_LIST.contains(command) { - let suggest = match_pattern("privilege", &command_output); + if !PRIVILEGE_LIST.contains(executable) { + let suggest = match_pattern("privilege", last_command, &err); if let Some(suggest) = suggest { let suggest = eval_suggest(&suggest, last_command); return Some(suggest); } } - let suggest = match_pattern(command, &command_output); + let suggest = match_pattern(executable, last_command, &err); if let Some(suggest) = suggest { let suggest = eval_suggest(&suggest, last_command); - if PRIVILEGE_LIST.contains(command) { + if PRIVILEGE_LIST.contains(executable) { return Some(format!("{} {}", split_command[0], suggest)); } return Some(suggest); @@ -32,15 +32,15 @@ pub fn correct_command(shell: &str, last_command: &str) -> Option { None } -fn match_pattern(command: &str, error_msg: &str) -> Option { +fn match_pattern(executable: &str, command: &str, error_msg: &str) -> Option { let rules = parse_rules!("rules"); - if rules.contains_key(command) { - let suggest = rules.get(command).unwrap(); + if rules.contains_key(executable) { + let suggest = rules.get(executable).unwrap(); for (pattern, suggest) in suggest { for pattern in pattern { if error_msg.contains(pattern) { for suggest in suggest { - if let Some(suggest) = check_suggest(suggest) { + if let Some(suggest) = check_suggest(suggest, command, error_msg) { return Some(suggest); } } @@ -53,27 +53,38 @@ fn match_pattern(command: &str, error_msg: &str) -> Option { } } -fn check_suggest(suggest: &str) -> Option { +fn check_suggest(suggest: &str, command: &str, error_msg: &str) -> Option { if !suggest.starts_with('#') { return Some(suggest.to_owned()); } let lines = suggest.lines().collect::>(); - let conditions = lines.first().unwrap(); - let conditions = conditions.trim_matches(|c| c == '#' || c == '[' || c == ']'); - let conditions = conditions.split(',').collect::>(); - for condition in conditions { - let condition = condition.trim(); - let (condition, arg) = condition.split_once('(').unwrap(); - let arg = arg.trim_matches(|c| c == '(' || c == ')'); + let conditions = lines.first().unwrap().trim().replacen("#", "", 1); + let conditions = conditions + .trim_start_matches('[') + .trim_end_matches(']'); + let conditions = conditions + .split(",") + .collect::>() ; - if eval_condition(condition, arg) == false { + for condition in conditions { + let (mut condition, arg) = condition.split_once('(').unwrap(); + condition = condition.trim(); + let arg = arg.trim_matches(|c| c == '(' || c == ')'); + let reverse = match condition.starts_with('!') { + true => { + condition = condition.trim_start_matches('!'); + true + }, + false => false, + }; + if eval_condition(condition, arg, command, error_msg) == reverse { return None; } } Some(lines[1..].join("\n")) } -fn eval_condition(condition: &str, arg: &str) -> bool { +fn eval_condition(condition: &str, arg: &str, command: &str, error_msg: &str) -> bool { match condition { "executable" => { let output = std::process::Command::new("which") @@ -81,8 +92,14 @@ fn eval_condition(condition: &str, arg: &str) -> bool { .output() .expect("failed to execute process"); output.status.success() - } - _ => false, + }, + "err_contains" => { + error_msg.contains(arg) + }, + "cmd_contains" => { + command.contains(arg) + }, + _ => unreachable!("Unknown condition when evaluation condition: {}", condition) } }