feat: stream output

This commit is contained in:
iff 2025-04-08 03:39:32 +02:00
parent 86241547e8
commit 3215fe45f6
7 changed files with 1545 additions and 162 deletions

View file

@ -1,8 +1,12 @@
use askama::Template;
use std::collections::HashMap;
use sys_locale::get_locale;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::buffer;
struct Conf {
key: String,
@ -10,38 +14,57 @@ struct Conf {
model: String,
}
#[derive(Serialize, Deserialize)]
#[derive(Serialize)]
struct Input {
role: String,
content: String,
}
#[derive(Serialize, Deserialize)]
#[derive(Serialize)]
struct Messages {
messages: Vec<Input>,
model: String,
stream: bool,
}
#[derive(Serialize, Deserialize)]
pub struct AISuggest {
pub commands: Vec<String>,
pub note: String,
#[derive(Debug, Deserialize)]
pub struct ChatCompletion {
// id: String,
// object: String,
// created: usize,
choices: Vec<Choice>,
}
pub struct AIResponse {
pub suggestion: AISuggest,
pub think: Option<String>,
#[derive(Debug, Deserialize)]
pub struct Choice {
delta: Delta,
// index: usize,
// finish_reason: Option<String>,
}
pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AIResponse> {
#[derive(Debug, Deserialize)]
pub struct Delta {
content: Option<String>,
}
#[derive(Template)]
#[template(path = "prompt.txt")]
struct AiPrompt<'a> {
last_command: &'a str,
error_msg: &'a str,
additional_prompt: &'a str,
set_locale: &'a str,
}
pub async fn ai_suggestion(last_command: &str, error_msg: &str) {
if std::env::var("_PR_AI_DISABLE").is_ok() {
return None;
return;
}
let conf = match Conf::new() {
Some(conf) => conf,
None => {
return None;
return;
}
};
@ -77,138 +100,105 @@ pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AIResponse>
"".to_string()
};
let ai_prompt = format!(
r#"
{addtional_prompt}
`{last_command}` returns the following error message: `{error_msg}`. Provide possible commands to fix it. Answer in the following exact JSON template without any extra text:
```
{{"commands":["command 1","command 2"],"note":"why they may fix the error{set_locale}"}}
```
"#
);
let ai_prompt = AiPrompt {
last_command,
error_msg,
additional_prompt: &addtional_prompt,
set_locale: &set_locale,
}
.render()
.unwrap()
.trim()
.to_string();
#[cfg(debug_assertions)]
eprintln!("AI module: AI prompt: {}", ai_prompt);
let res;
let messages = Messages {
// let res;
let body = Messages {
messages: vec![Input {
role: "user".to_string(),
content: ai_prompt.trim().to_string(),
}],
model: conf.model,
stream: true,
};
#[cfg(feature = "libcurl")]
{
use curl::easy::Easy as Curl;
use curl::easy::List;
use std::io::Read;
let client = reqwest::Client::new();
let res = client
.post(&conf.url)
.body(serde_json::to_string(&body).unwrap())
.header("Content-Type", "application/json")
.bearer_auth(&conf.key)
.send()
.await;
let str_json = serde_json::to_string(&messages).unwrap();
let mut data = str_json.as_bytes();
let mut stream = res.unwrap().bytes_stream();
let mut json_buffer = vec![];
let mut buffer = buffer::Buffer::new();
while let Some(item) = stream.next().await {
let item = item.unwrap();
let str = std::str::from_utf8(&item).unwrap();
let mut dst = Vec::new();
let mut handle = Curl::new();
handle.url(&conf.url).unwrap();
handle.post(true).unwrap();
handle.post_field_size(data.len() as u64).unwrap();
let mut headers = List::new();
headers
.append(&format!("Authorization: Bearer {}", conf.key))
.unwrap();
headers.append("Content-Type: application/json").unwrap();
handle.http_headers(headers).unwrap();
{
let mut transfer = handle.transfer();
transfer
.read_function(|buf| Ok(data.read(buf).unwrap_or(0)))
.unwrap();
transfer
.write_function(|buf| {
dst.extend_from_slice(buf);
Ok(buf.len())
})
.unwrap();
transfer.perform().expect("Failed to perform request");
if json_buffer.is_empty() {
json_buffer.push(str.to_string());
continue;
}
res = String::from_utf8(dst).unwrap();
}
#[cfg(not(feature = "libcurl"))]
{
let proc = std::process::Command::new("curl")
.arg("-X")
.arg("POST")
.arg("-H")
.arg(format!("Authorization: Bearer {}", conf.key))
.arg("-H")
.arg("Content-Type: application/json")
.arg("-d")
.arg(serde_json::to_string(&messages).unwrap())
.arg(conf.url)
.output();
if !str.contains("\n\ndata: {") {
json_buffer.push(str.to_string());
continue;
}
let data_loc = str.find("\n\ndata: {").unwrap();
let split = str.split_at(data_loc);
json_buffer.push(split.0.to_string());
let working_str = json_buffer.join("");
json_buffer.clear();
json_buffer.push(split.1.to_string());
let out = match proc {
Ok(proc) => proc.stdout,
Err(_) => {
return None;
for part in working_str.split("\n\n") {
if let Some(data) = part.strip_prefix("data: ") {
if data == "[DONE]" {
break;
}
let json = serde_json::from_str::<ChatCompletion>(data).unwrap_or_else(|_| {
panic!("AI module: Failed to parse JSON content: {}", data)
});
let choice = json.choices.first().expect("AI module: No choices found");
if let Some(content) = &choice.delta.content {
buffer.proc(content);
}
}
};
res = String::from_utf8(out).unwrap();
}
}
let json: Value = {
let json = serde_json::from_str(&res);
if let Ok(json) = json {
json
} else {
eprintln!("AI module: Failed to parse JSON response: {}", res);
return None;
if !json_buffer.is_empty() {
let working_str = json_buffer.join("");
for part in working_str.split("\n\n") {
if let Some(data) = part.strip_prefix("data: ") {
if data == "[DONE]" {
break;
}
let json = serde_json::from_str::<ChatCompletion>(data).unwrap_or_else(|_| {
panic!("AI module: Failed to parse JSON content: {}", data)
});
let choice = json.choices.first().expect("AI module: No choices found");
if let Some(content) = &choice.delta.content {
buffer.proc(content);
}
}
}
};
let content = &json["choices"][0]["message"]["content"];
let mut str = content
.as_str()
.expect("AI module: Failed to get content from response")
json_buffer.clear();
}
let suggestions = buffer
.print_return_remain()
.trim()
.to_string();
.trim_end_matches("```")
.trim()
.trim_start_matches("<suggestions>")
.trim_end_matches("</suggestions>")
.replace("<br>", "<_PR_BR>");
let think = if str.starts_with("<think>") {
let start_len = "<think>".len();
let end_len = "</think>".len();
let end = str.find("</think>").unwrap() + end_len;
let think = str[start_len..end - end_len].to_string();
str = str[end..].to_string();
Some(think)
} else {
None
};
let suggestion: AISuggest = {
let str = {
str.trim()
.trim_start_matches("```")
.trim_start_matches("json")
.trim_end_matches("```")
};
let json = serde_json::from_str(str);
if json.is_err() {
eprintln!("AI module: Failed to parse JSON content: {}", str);
return None;
}
json.unwrap()
};
let response = AIResponse { suggestion, think };
Some(response)
println!("{}", suggestions);
}
impl Conf {