feat: experimental AI support

This commit is contained in:
iff 2024-11-19 14:59:15 +01:00
parent 070343a5e2
commit 5d7624563d
6 changed files with 1312 additions and 0 deletions

View file

@ -30,6 +30,9 @@ mod replaces;
#[cfg(feature = "runtime-rules")]
mod runtime_rules;
#[cfg(feature = "request-ai")]
mod requests;
#[macro_use]
extern crate rust_i18n;
i18n!("i18n", fallback = "en", minify_key = true);
@ -40,6 +43,11 @@ fn main() {
let locale = get_locale().unwrap_or("en_US".to_string());
rust_i18n::set_locale(&locale[0..2]);
#[cfg(feature = "request-ai")]
{
std::env::set_var("_PR_LOCALE", &locale);
}
args::handle_args();
let shell = match std::env::var("_PR_SHELL") {

94
src/requests.rs Normal file
View file

@ -0,0 +1,94 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::{Result, Value};
use reqwest::blocking::Client;
#[derive(Serialize, Deserialize)]
struct Input {
role: String,
content: String,
}
#[derive(Serialize, Deserialize)]
struct Messages {
messages: Vec<Input>,
model: String,
}
#[derive(Serialize, Deserialize)]
pub struct AISuggest {
pub command: String,
pub note: String,
}
pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AISuggest> {
let mut map = HashMap::new();
map.insert("last_command", last_command);
map.insert("error_msg", error_msg);
let api_key = match std::env::var("_PR_AI_API_KEY") {
Ok(key) => Some(key),
Err(_) => option_env!("_DEV_PR_AI_API_KEY").map(|key| key.to_string())
};
let api_key = match api_key {
Some(key) => key,
None => {
return None;
}
};
let request_url = match std::env::var("_PR_AI_URL") {
Ok(url) => url,
Err(_) => "https://api.groq.com/openai/v1/chat/completions".to_string()
};
let model = match std::env::var("_PR_AI_MODEL") {
Ok(model) => model,
Err(_) => "llama3-8b-8192".to_string()
};
let user_locale = std::env::var("_PR_LOCALE").unwrap_or("en_US".to_string());
let ai_prompt = format!(r#"
You are a programmer trying to run a command in your shell. You run the command `{last_command}` and get the following error message: `{error_msg}`. What command should you run next to fix the error?
Answer in the following JSON format without any extra text:
```
{{"command":"your suggestion","note":"why you think this command will fix the error"}}
```
User locale is: {user_locale}, plese make sure to provide the note in the same language.
If you don't know the answer or can't provide a good suggestion, please reply the command field with `None` and provide a note explaining why you can't provide a suggestion
"#);
let messages = Messages {
messages: vec![Input {
role: "user".to_string(),
content: ai_prompt.to_string(),
}],
model,
};
let client = Client::new();
let res = client.post(&request_url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&messages)
.send();
let res = match res {
Ok(res) => res,
Err(_) => {
return None;
}
};
let res = &res.text().unwrap();
let json: Value = serde_json::from_str(res).unwrap();
let content = &json["choices"][0]["message"]["content"];
let suggestion: AISuggest = serde_json::from_str(content.as_str().unwrap()).unwrap();
Some(suggestion)
}

View file

@ -56,6 +56,21 @@ pub fn suggest_command(shell: &str, last_command: &str, error_msg: &str) -> Opti
}
}
#[cfg(feature = "request-ai")]{
use crate::requests::ai_suggestion;
let suggest = ai_suggestion(last_command, error_msg);
if let Some(suggest) = suggest {
eprintln!("{}: {}\n", t!("ai-suggestion").bold().blue(), suggest.note);
let command = suggest.command;
if command != "None" {
if PRIVILEGE_LIST.contains(&split_command[0].as_str()) {
return Some(format!("{} {}", split_command[0], command));
}
return Some(command);
}
}
}
None
}