2025-04-08 03:39:32 +02:00
|
|
|
use askama::Template;
|
2024-11-19 14:59:15 +01:00
|
|
|
use std::collections::HashMap;
|
2025-04-08 03:39:32 +02:00
|
|
|
|
2024-12-08 16:39:29 +01:00
|
|
|
use sys_locale::get_locale;
|
2024-11-19 14:59:15 +01:00
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
use futures_util::StreamExt;
|
2024-11-19 14:59:15 +01:00
|
|
|
use serde::{Deserialize, Serialize};
|
2025-04-08 03:39:32 +02:00
|
|
|
|
|
|
|
|
use crate::buffer;
|
2024-11-19 14:59:15 +01:00
|
|
|
|
2024-12-12 16:02:01 +01:00
|
|
|
struct Conf {
|
|
|
|
|
key: String,
|
|
|
|
|
url: String,
|
|
|
|
|
model: String,
|
|
|
|
|
}
|
|
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
#[derive(Serialize)]
|
2024-11-19 14:59:15 +01:00
|
|
|
struct Input {
|
|
|
|
|
role: String,
|
|
|
|
|
content: String,
|
|
|
|
|
}
|
|
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
#[derive(Serialize)]
|
2024-11-19 14:59:15 +01:00
|
|
|
struct Messages {
|
|
|
|
|
messages: Vec<Input>,
|
|
|
|
|
model: String,
|
2025-04-08 03:39:32 +02:00
|
|
|
stream: bool,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
|
pub struct ChatCompletion {
|
|
|
|
|
// id: String,
|
|
|
|
|
// object: String,
|
|
|
|
|
// created: usize,
|
|
|
|
|
choices: Vec<Choice>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
|
pub struct Choice {
|
|
|
|
|
delta: Delta,
|
|
|
|
|
// index: usize,
|
|
|
|
|
// finish_reason: Option<String>,
|
2024-11-19 14:59:15 +01:00
|
|
|
}
|
|
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
|
pub struct Delta {
|
|
|
|
|
content: Option<String>,
|
2024-11-19 14:59:15 +01:00
|
|
|
}
|
|
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
#[derive(Template)]
|
|
|
|
|
#[template(path = "prompt.txt")]
|
|
|
|
|
struct AiPrompt<'a> {
|
|
|
|
|
last_command: &'a str,
|
|
|
|
|
error_msg: &'a str,
|
|
|
|
|
additional_prompt: &'a str,
|
|
|
|
|
set_locale: &'a str,
|
2025-04-06 17:53:27 +02:00
|
|
|
}
|
|
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
pub async fn ai_suggestion(last_command: &str, error_msg: &str) {
|
2024-12-12 16:02:01 +01:00
|
|
|
let conf = match Conf::new() {
|
|
|
|
|
Some(conf) => conf,
|
|
|
|
|
None => {
|
2025-04-08 03:39:32 +02:00
|
|
|
return;
|
2024-12-12 16:02:01 +01:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2024-11-19 21:16:34 +01:00
|
|
|
let error_msg = if error_msg.len() > 300 {
|
|
|
|
|
&error_msg[..300]
|
|
|
|
|
} else {
|
|
|
|
|
error_msg
|
|
|
|
|
};
|
|
|
|
|
|
2024-11-19 14:59:15 +01:00
|
|
|
let mut map = HashMap::new();
|
|
|
|
|
map.insert("last_command", last_command);
|
|
|
|
|
map.insert("error_msg", error_msg);
|
|
|
|
|
|
2024-12-08 16:39:29 +01:00
|
|
|
let user_locale = {
|
|
|
|
|
let locale = std::env::var("_PR_AI_LOCALE")
|
2024-12-12 16:17:04 +01:00
|
|
|
.unwrap_or_else(|_| get_locale().unwrap_or("en-us".to_string()));
|
2024-12-08 16:39:29 +01:00
|
|
|
if locale.len() < 2 {
|
|
|
|
|
"en-US".to_string()
|
|
|
|
|
} else {
|
|
|
|
|
locale
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2024-11-20 21:16:14 +01:00
|
|
|
let set_locale = if !user_locale.starts_with("en") {
|
|
|
|
|
format!(". Use language for locale {}", user_locale)
|
2024-11-19 17:16:38 +01:00
|
|
|
} else {
|
|
|
|
|
"".to_string()
|
|
|
|
|
};
|
2024-11-19 14:59:15 +01:00
|
|
|
|
2025-04-06 18:15:55 +02:00
|
|
|
let addtional_prompt = if std::env::var("_PR_AI_ADDITIONAL_PROMPT").is_ok() {
|
|
|
|
|
std::env::var("_PR_AI_ADDITIONAL_PROMPT").unwrap()
|
|
|
|
|
} else {
|
|
|
|
|
"".to_string()
|
|
|
|
|
};
|
|
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
let ai_prompt = AiPrompt {
|
|
|
|
|
last_command,
|
|
|
|
|
error_msg,
|
|
|
|
|
additional_prompt: &addtional_prompt,
|
|
|
|
|
set_locale: &set_locale,
|
|
|
|
|
}
|
|
|
|
|
.render()
|
|
|
|
|
.unwrap()
|
|
|
|
|
.trim()
|
|
|
|
|
.to_string();
|
2024-11-19 14:59:15 +01:00
|
|
|
|
2025-04-06 18:15:55 +02:00
|
|
|
#[cfg(debug_assertions)]
|
|
|
|
|
eprintln!("AI module: AI prompt: {}", ai_prompt);
|
|
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
// let res;
|
|
|
|
|
let body = Messages {
|
2024-11-19 14:59:15 +01:00
|
|
|
messages: vec![Input {
|
|
|
|
|
role: "user".to_string(),
|
2025-04-06 17:19:46 +02:00
|
|
|
content: ai_prompt.trim().to_string(),
|
2024-11-19 14:59:15 +01:00
|
|
|
}],
|
2024-12-12 16:02:01 +01:00
|
|
|
model: conf.model,
|
2025-04-08 03:39:32 +02:00
|
|
|
stream: true,
|
2024-11-19 14:59:15 +01:00
|
|
|
};
|
2024-11-19 16:19:22 +01:00
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
let client = reqwest::Client::new();
|
|
|
|
|
let res = client
|
|
|
|
|
.post(&conf.url)
|
|
|
|
|
.body(serde_json::to_string(&body).unwrap())
|
|
|
|
|
.header("Content-Type", "application/json")
|
|
|
|
|
.bearer_auth(&conf.key)
|
|
|
|
|
.send()
|
|
|
|
|
.await;
|
|
|
|
|
|
|
|
|
|
let mut stream = res.unwrap().bytes_stream();
|
2025-04-10 01:31:41 +02:00
|
|
|
let mut json_buffer = String::new();
|
2025-04-08 03:39:32 +02:00
|
|
|
let mut buffer = buffer::Buffer::new();
|
|
|
|
|
while let Some(item) = stream.next().await {
|
|
|
|
|
let item = item.unwrap();
|
|
|
|
|
let str = std::str::from_utf8(&item).unwrap();
|
|
|
|
|
|
|
|
|
|
if json_buffer.is_empty() {
|
2025-04-10 01:31:41 +02:00
|
|
|
json_buffer.push_str(str);
|
2025-04-08 03:39:32 +02:00
|
|
|
continue;
|
2024-11-22 17:28:49 +01:00
|
|
|
}
|
2024-11-22 10:49:24 +01:00
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
if !str.contains("\n\ndata: {") {
|
2025-04-10 01:31:41 +02:00
|
|
|
json_buffer.push_str(str);
|
2025-04-08 03:39:32 +02:00
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
let data_loc = str.find("\n\ndata: {").unwrap();
|
|
|
|
|
let split = str.split_at(data_loc);
|
2025-04-10 01:31:41 +02:00
|
|
|
json_buffer.push_str(split.0);
|
|
|
|
|
let working_str = json_buffer.clone();
|
2025-04-08 03:39:32 +02:00
|
|
|
json_buffer.clear();
|
2025-04-10 01:31:41 +02:00
|
|
|
json_buffer.push_str(split.1);
|
2025-04-08 03:39:32 +02:00
|
|
|
|
|
|
|
|
for part in working_str.split("\n\n") {
|
|
|
|
|
if let Some(data) = part.strip_prefix("data: ") {
|
|
|
|
|
if data == "[DONE]" {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
let json = serde_json::from_str::<ChatCompletion>(data).unwrap_or_else(|_| {
|
|
|
|
|
panic!("AI module: Failed to parse JSON content: {}", data)
|
|
|
|
|
});
|
|
|
|
|
let choice = json.choices.first().expect("AI module: No choices found");
|
|
|
|
|
if let Some(content) = &choice.delta.content {
|
|
|
|
|
buffer.proc(content);
|
|
|
|
|
}
|
2024-11-22 17:28:49 +01:00
|
|
|
}
|
2025-04-08 03:39:32 +02:00
|
|
|
}
|
2024-11-22 17:28:49 +01:00
|
|
|
}
|
2025-04-08 03:39:32 +02:00
|
|
|
if !json_buffer.is_empty() {
|
2025-04-10 01:31:41 +02:00
|
|
|
let working_str = json_buffer.clone();
|
2025-04-08 03:39:32 +02:00
|
|
|
for part in working_str.split("\n\n") {
|
|
|
|
|
if let Some(data) = part.strip_prefix("data: ") {
|
|
|
|
|
if data == "[DONE]" {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
let json = serde_json::from_str::<ChatCompletion>(data).unwrap_or_else(|_| {
|
|
|
|
|
panic!("AI module: Failed to parse JSON content: {}", data)
|
|
|
|
|
});
|
|
|
|
|
let choice = json.choices.first().expect("AI module: No choices found");
|
|
|
|
|
if let Some(content) = &choice.delta.content {
|
|
|
|
|
buffer.proc(content);
|
|
|
|
|
}
|
|
|
|
|
}
|
2024-12-07 15:35:02 +01:00
|
|
|
}
|
2025-04-08 03:39:32 +02:00
|
|
|
json_buffer.clear();
|
|
|
|
|
}
|
|
|
|
|
let suggestions = buffer
|
|
|
|
|
.print_return_remain()
|
2025-04-06 17:53:27 +02:00
|
|
|
.trim()
|
2025-04-08 03:39:32 +02:00
|
|
|
.trim_end_matches("```")
|
|
|
|
|
.trim()
|
2025-04-09 15:51:56 +02:00
|
|
|
.trim_start_matches("<suggest>")
|
|
|
|
|
.trim_end_matches("</suggest>")
|
2025-04-08 03:39:32 +02:00
|
|
|
.replace("<br>", "<_PR_BR>");
|
2025-04-06 17:53:27 +02:00
|
|
|
|
2025-04-08 03:39:32 +02:00
|
|
|
println!("{}", suggestions);
|
2024-11-19 14:59:15 +01:00
|
|
|
}
|
2024-12-12 16:02:01 +01:00
|
|
|
|
|
|
|
|
impl Conf {
|
|
|
|
|
pub fn new() -> Option<Self> {
|
|
|
|
|
let key = match std::env::var("_PR_AI_API_KEY") {
|
|
|
|
|
Ok(key) => key,
|
|
|
|
|
Err(_) => {
|
|
|
|
|
if let Some(key) = option_env!("_DEF_PR_AI_API_KEY") {
|
|
|
|
|
key.to_string()
|
|
|
|
|
} else {
|
|
|
|
|
"Y29uZ3JhdHVsYXRpb25zLCB5b3UgZm91bmQgdGhlIHNlY3JldCE=".to_string()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
if key.is_empty() {
|
|
|
|
|
return None;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let url = match std::env::var("_PR_AI_URL") {
|
|
|
|
|
Ok(url) => url,
|
|
|
|
|
Err(_) => {
|
|
|
|
|
if let Some(url) = option_env!("_DEF_PR_AI_URL") {
|
|
|
|
|
url.to_string()
|
|
|
|
|
} else {
|
2025-04-08 20:09:59 +02:00
|
|
|
"https://iff.envs.net/stream-completions.py".to_string()
|
2024-12-12 16:02:01 +01:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
if url.is_empty() {
|
|
|
|
|
return None;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let model = match std::env::var("_PR_AI_MODEL") {
|
|
|
|
|
Ok(model) => model,
|
|
|
|
|
Err(_) => {
|
|
|
|
|
if let Some(model) = option_env!("_DEF_PR_AI_MODEL") {
|
|
|
|
|
model.to_string()
|
|
|
|
|
} else {
|
2025-04-08 20:09:59 +02:00
|
|
|
"{{ _PR_AI_MODEL }}".to_string()
|
2024-12-12 16:02:01 +01:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
if model.is_empty() {
|
|
|
|
|
return None;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Some(Conf { key, url, model })
|
|
|
|
|
}
|
|
|
|
|
}
|