pay-respects/src/requests.rs

191 lines
4.1 KiB
Rust
Raw Normal View History

2024-11-19 14:59:15 +01:00
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
2024-11-19 17:02:46 +01:00
use serde_json::Value;
2024-11-19 14:59:15 +01:00
#[derive(Serialize, Deserialize)]
struct Input {
role: String,
content: String,
}
#[derive(Serialize, Deserialize)]
struct Messages {
messages: Vec<Input>,
model: String,
}
#[derive(Serialize, Deserialize)]
pub struct AISuggest {
pub command: String,
pub note: String,
}
pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AISuggest> {
2024-11-20 21:16:14 +01:00
if std::env::var("_PR_AI_DISABLE").is_ok() {
return None;
}
2024-11-19 21:16:34 +01:00
let error_msg = if error_msg.len() > 300 {
&error_msg[..300]
} else {
error_msg
};
2024-11-19 14:59:15 +01:00
let mut map = HashMap::new();
map.insert("last_command", last_command);
map.insert("error_msg", error_msg);
let api_key = match std::env::var("_PR_AI_API_KEY") {
Ok(key) => Some(key),
2024-11-19 15:57:00 +01:00
Err(_) => {
let env_key = option_env!("_DEF_PR_AI_API_KEY").map(|key| key.to_string());
if env_key.is_none() {
2024-12-07 15:35:02 +01:00
Some("Y29uZ3JhdHVsYXRpb25zLCB5b3UgZm91bmQgdGhlIHNlY3JldCE=".to_string())
2024-11-19 22:27:37 +01:00
} else if env_key.as_ref().unwrap().is_empty() {
None
2024-11-19 15:57:00 +01:00
} else {
env_key
}
}
2024-11-19 14:59:15 +01:00
};
let api_key = match api_key {
Some(key) => {
if key.is_empty() {
return None;
}
key
}
2024-11-19 14:59:15 +01:00
None => {
return None;
}
};
let request_url = match std::env::var("_PR_AI_URL") {
Ok(url) => url,
2024-12-07 15:35:02 +01:00
Err(_) => "https://iff.envs.net/completions.py".to_string(),
2024-11-19 14:59:15 +01:00
};
let model = match std::env::var("_PR_AI_MODEL") {
Ok(model) => model,
2024-11-19 16:19:22 +01:00
Err(_) => "llama3-8b-8192".to_string(),
2024-11-19 14:59:15 +01:00
};
2024-11-20 21:16:14 +01:00
let user_locale = std::env::var("_PR_AI_LOCALE").unwrap_or("en-US".to_string());
let set_locale = if !user_locale.starts_with("en") {
format!(". Use language for locale {}", user_locale)
2024-11-19 17:16:38 +01:00
} else {
"".to_string()
};
2024-11-19 14:59:15 +01:00
2024-11-19 16:19:22 +01:00
let ai_prompt = format!(
r#"
2024-11-20 21:16:14 +01:00
The command `{last_command}` returns the following error message: `{error_msg}`. Provide a command to fix it. Answer in the following JSON format without any extra text:
2024-11-19 14:59:15 +01:00
```
2024-11-20 21:16:14 +01:00
{{"command":"suggestion","note":"why it may fix the error{set_locale}"}}
2024-11-19 14:59:15 +01:00
```
2024-11-19 16:19:22 +01:00
"#
);
2024-11-19 14:59:15 +01:00
2024-11-22 17:28:49 +01:00
let res;
2024-11-19 14:59:15 +01:00
let messages = Messages {
messages: vec![Input {
role: "user".to_string(),
content: ai_prompt.to_string(),
}],
model,
};
2024-11-19 16:19:22 +01:00
2024-11-22 17:28:49 +01:00
#[cfg(feature = "libcurl")]
{
use curl::easy::Easy as Curl;
use curl::easy::List;
use std::io::Read;
2024-11-22 10:49:24 +01:00
2024-11-22 17:28:49 +01:00
let str_json = serde_json::to_string(&messages).unwrap();
let mut data = str_json.as_bytes();
2024-11-22 10:49:24 +01:00
2024-11-22 17:28:49 +01:00
let mut dst = Vec::new();
let mut handle = Curl::new();
2024-11-22 10:49:24 +01:00
2024-11-22 17:28:49 +01:00
handle.url(&request_url).unwrap();
handle.post(true).unwrap();
handle.post_field_size(data.len() as u64).unwrap();
2024-11-22 10:49:24 +01:00
2024-11-22 17:28:49 +01:00
let mut headers = List::new();
headers
.append(&format!("Authorization: Bearer {}", api_key))
2024-11-22 10:49:24 +01:00
.unwrap();
2024-11-22 17:28:49 +01:00
headers.append("Content-Type: application/json").unwrap();
handle.http_headers(headers).unwrap();
2024-11-22 10:49:24 +01:00
2024-11-22 17:28:49 +01:00
{
let mut transfer = handle.transfer();
2024-11-22 10:49:24 +01:00
2024-11-22 17:28:49 +01:00
transfer
.read_function(|buf| Ok(data.read(buf).unwrap_or(0)))
.unwrap();
transfer
.write_function(|buf| {
dst.extend_from_slice(buf);
Ok(buf.len())
})
.unwrap();
2024-11-22 10:49:24 +01:00
2024-11-22 17:28:49 +01:00
transfer.perform().expect("Failed to perform request");
}
2024-11-22 10:49:24 +01:00
2024-11-22 17:28:49 +01:00
res = String::from_utf8(dst).unwrap();
}
#[cfg(not(feature = "libcurl"))]
{
let proc = std::process::Command::new("curl")
.arg("-X")
.arg("POST")
.arg("-H")
.arg(format!("Authorization: Bearer {}", api_key))
.arg("-H")
.arg("Content-Type: application/json")
.arg("-d")
.arg(serde_json::to_string(&messages).unwrap())
.arg(request_url)
.output();
let out = match proc {
Ok(proc) => proc.stdout,
Err(_) => {
return None;
}
};
res = String::from_utf8(out).unwrap();
}
2024-12-07 15:35:02 +01:00
let json: Value = {
let json = serde_json::from_str(&res);
if json.is_err() {
eprintln!("Failed to parse JSON response: {}", res);
return None;
} else {
json.unwrap()
}
};
2024-11-19 22:27:37 +01:00
2024-11-19 14:59:15 +01:00
let content = &json["choices"][0]["message"]["content"];
2024-11-19 22:27:37 +01:00
let suggestion: AISuggest = {
2024-11-20 09:15:00 +01:00
let str = {
let str = content.as_str();
2024-11-20 10:19:08 +01:00
str?;
2024-11-22 10:49:24 +01:00
str.expect("Failed to get content from response")
2024-11-20 10:19:08 +01:00
.trim_start_matches("```")
.trim_end_matches("```")
};
2024-11-19 22:27:37 +01:00
let json = serde_json::from_str(str);
if json.is_err() {
return None;
}
json.unwrap()
};
2024-11-19 14:59:15 +01:00
Some(suggestion)
}