pay-respects/module-request-ai/src/requests.rs

256 lines
5.3 KiB
Rust
Raw Normal View History

2025-04-08 03:39:32 +02:00
use askama::Template;
2024-11-19 14:59:15 +01:00
use std::collections::HashMap;
2025-04-08 03:39:32 +02:00
use futures_util::StreamExt;
2024-11-19 14:59:15 +01:00
use serde::{Deserialize, Serialize};
2025-04-08 03:39:32 +02:00
use crate::buffer;
2024-11-19 14:59:15 +01:00
2024-12-12 16:02:01 +01:00
struct Conf {
key: String,
url: String,
model: String,
}
2025-04-08 03:39:32 +02:00
#[derive(Serialize)]
2024-11-19 14:59:15 +01:00
struct Input {
role: String,
content: String,
}
2025-04-08 03:39:32 +02:00
#[derive(Serialize)]
2024-11-19 14:59:15 +01:00
struct Messages {
messages: Vec<Input>,
model: String,
2025-04-08 03:39:32 +02:00
stream: bool,
}
#[derive(Debug, Deserialize)]
pub struct ChatCompletion {
// id: String,
// object: String,
// created: usize,
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
pub struct Choice {
delta: Delta,
// index: usize,
// finish_reason: Option<String>,
2024-11-19 14:59:15 +01:00
}
2025-04-08 03:39:32 +02:00
#[derive(Debug, Deserialize)]
pub struct Delta {
content: Option<String>,
2024-11-19 14:59:15 +01:00
}
2025-04-08 03:39:32 +02:00
#[derive(Template)]
#[template(path = "prompt.txt")]
struct AiPrompt<'a> {
last_command: &'a str,
error_msg: &'a str,
additional_prompt: &'a str,
set_locale: &'a str,
2025-04-06 17:53:27 +02:00
}
2025-06-02 14:17:28 +02:00
pub async fn ai_suggestion(last_command: &str, error_msg: &str, locale: &str) {
2024-12-12 16:02:01 +01:00
let conf = match Conf::new() {
Some(conf) => conf,
None => {
2025-04-08 03:39:32 +02:00
return;
2024-12-12 16:02:01 +01:00
}
};
2024-11-19 21:16:34 +01:00
let error_msg = if error_msg.len() > 300 {
&error_msg[..300]
} else {
error_msg
};
2024-11-19 14:59:15 +01:00
let mut map = HashMap::new();
map.insert("last_command", last_command);
map.insert("error_msg", error_msg);
2024-12-08 16:39:29 +01:00
let user_locale = {
let locale = std::env::var("_PR_AI_LOCALE")
2025-06-02 14:17:28 +02:00
.unwrap_or_else(|_| locale.to_string());
2024-12-08 16:39:29 +01:00
if locale.len() < 2 {
"en-US".to_string()
} else {
locale
}
};
2024-11-20 21:16:14 +01:00
let set_locale = if !user_locale.starts_with("en") {
format!(". Use language for locale {}", user_locale)
2024-11-19 17:16:38 +01:00
} else {
"".to_string()
};
2024-11-19 14:59:15 +01:00
2025-04-06 18:15:55 +02:00
let addtional_prompt = if std::env::var("_PR_AI_ADDITIONAL_PROMPT").is_ok() {
std::env::var("_PR_AI_ADDITIONAL_PROMPT").unwrap()
} else {
"".to_string()
};
2025-04-08 03:39:32 +02:00
let ai_prompt = AiPrompt {
last_command,
error_msg,
additional_prompt: &addtional_prompt,
set_locale: &set_locale,
}
.render()
.unwrap()
.trim()
.to_string();
2024-11-19 14:59:15 +01:00
2025-04-06 18:15:55 +02:00
#[cfg(debug_assertions)]
eprintln!("AI module: AI prompt: {}", ai_prompt);
2025-04-08 03:39:32 +02:00
// let res;
let body = Messages {
2024-11-19 14:59:15 +01:00
messages: vec![Input {
role: "user".to_string(),
2025-04-06 17:19:46 +02:00
content: ai_prompt.trim().to_string(),
2024-11-19 14:59:15 +01:00
}],
2024-12-12 16:02:01 +01:00
model: conf.model,
2025-04-08 03:39:32 +02:00
stream: true,
2024-11-19 14:59:15 +01:00
};
2024-11-19 16:19:22 +01:00
2025-04-08 03:39:32 +02:00
let client = reqwest::Client::new();
let res = client
.post(&conf.url)
.body(serde_json::to_string(&body).unwrap())
.header("Content-Type", "application/json")
.bearer_auth(&conf.key)
.send()
.await
.unwrap();
if res.status() != 200 {
eprintln!("AI module: Status code: {}", res.status());
eprintln!(
"AI module: Error message:\n {}",
res.text().await.unwrap().replace("\n", "\n ")
);
return;
}
2025-04-08 03:39:32 +02:00
let mut stream = res.bytes_stream();
let mut json_buffer = String::new();
2025-04-08 03:39:32 +02:00
let mut buffer = buffer::Buffer::new();
2025-04-08 03:39:32 +02:00
while let Some(item) = stream.next().await {
let item = item.unwrap();
let str = std::str::from_utf8(&item).unwrap();
if json_buffer.is_empty() {
json_buffer.push_str(str);
2025-04-08 03:39:32 +02:00
continue;
2024-11-22 17:28:49 +01:00
}
2024-11-22 10:49:24 +01:00
2025-04-08 03:39:32 +02:00
if !str.contains("\n\ndata: {") {
json_buffer.push_str(str);
2025-04-08 03:39:32 +02:00
continue;
}
let data_loc = str.find("\n\ndata: {").unwrap();
let split = str.split_at(data_loc);
json_buffer.push_str(split.0);
let working_str = json_buffer.clone();
2025-04-08 03:39:32 +02:00
json_buffer.clear();
json_buffer.push_str(split.1);
2025-04-08 03:39:32 +02:00
for part in working_str.split("\n\n") {
if let Some(data) = part.strip_prefix("data: ") {
if data == "[DONE]" {
break;
}
let json = serde_json::from_str::<ChatCompletion>(data).unwrap_or_else(|_| {
panic!("AI module: Failed to parse JSON content: {}", data)
});
let choice = json.choices.first().expect("AI module: No choices found");
if let Some(content) = &choice.delta.content {
buffer.proc(content);
}
2024-11-22 17:28:49 +01:00
}
2025-04-08 03:39:32 +02:00
}
2024-11-22 17:28:49 +01:00
}
2025-04-08 03:39:32 +02:00
if !json_buffer.is_empty() {
let working_str = json_buffer.clone();
2025-04-08 03:39:32 +02:00
for part in working_str.split("\n\n") {
if let Some(data) = part.strip_prefix("data: ") {
if data == "[DONE]" {
break;
}
let json = serde_json::from_str::<ChatCompletion>(data).unwrap_or_else(|_| {
panic!("AI module: Failed to parse JSON content: {}", data)
});
let choice = json.choices.first().expect("AI module: No choices found");
if let Some(content) = &choice.delta.content {
buffer.proc(content);
}
}
2024-12-07 15:35:02 +01:00
}
2025-04-08 03:39:32 +02:00
json_buffer.clear();
}
let suggestions = buffer
.print_return_remain()
2025-04-06 17:53:27 +02:00
.trim()
2025-04-08 03:39:32 +02:00
.trim_end_matches("```")
.trim()
2025-04-09 15:51:56 +02:00
.trim_start_matches("<suggest>")
.trim_end_matches("</suggest>")
2025-04-08 03:39:32 +02:00
.replace("<br>", "<_PR_BR>");
2025-04-06 17:53:27 +02:00
2025-04-08 03:39:32 +02:00
println!("{}", suggestions);
2024-11-19 14:59:15 +01:00
}
2024-12-12 16:02:01 +01:00
impl Conf {
pub fn new() -> Option<Self> {
let key = match std::env::var("_PR_AI_API_KEY") {
Ok(key) => key,
Err(_) => {
if let Some(key) = option_env!("_DEF_PR_AI_API_KEY") {
key.to_string()
} else {
"Y29uZ3JhdHVsYXRpb25zLCB5b3UgZm91bmQgdGhlIHNlY3JldCE=".to_string()
}
}
};
if key.is_empty() {
return None;
}
let url = match std::env::var("_PR_AI_URL") {
Ok(url) => url,
Err(_) => {
if let Some(url) = option_env!("_DEF_PR_AI_URL") {
url.to_string()
} else {
2025-04-08 20:09:59 +02:00
"https://iff.envs.net/stream-completions.py".to_string()
2024-12-12 16:02:01 +01:00
}
}
};
if url.is_empty() {
return None;
}
let model = match std::env::var("_PR_AI_MODEL") {
Ok(model) => model,
Err(_) => {
if let Some(model) = option_env!("_DEF_PR_AI_MODEL") {
model.to_string()
} else {
2025-04-08 20:09:59 +02:00
"{{ _PR_AI_MODEL }}".to_string()
2024-12-12 16:02:01 +01:00
}
}
};
if model.is_empty() {
return None;
}
Some(Conf { key, url, model })
}
}