mirror of
https://github.com/TECHNOFAB11/pay-respects.git
synced 2025-12-12 14:30:10 +01:00
feat: stream output
This commit is contained in:
parent
86241547e8
commit
3215fe45f6
7 changed files with 1545 additions and 162 deletions
138
module-request-ai/src/buffer.rs
Normal file
138
module-request-ai/src/buffer.rs
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
use std::io::Write;
|
||||
use textwrap::fill as textwrap_fill;
|
||||
|
||||
fn termwidth() -> usize {
|
||||
use terminal_size::{terminal_size, Height, Width};
|
||||
let size = terminal_size();
|
||||
if let Some((Width(w), Height(_))) = size {
|
||||
std::cmp::min(w as usize, 80)
|
||||
} else {
|
||||
80
|
||||
}
|
||||
}
|
||||
|
||||
fn fill(str: &str) -> Option<String> {
|
||||
let width = termwidth();
|
||||
let filled = textwrap_fill(str, width);
|
||||
if filled.contains('\n') {
|
||||
Some(filled)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
use colored::Colorize;
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum State {
|
||||
Write,
|
||||
Store,
|
||||
}
|
||||
|
||||
pub struct Buffer {
|
||||
pub buf: Vec<String>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl Buffer {
|
||||
pub fn new() -> Self {
|
||||
Buffer {
|
||||
buf: vec![],
|
||||
state: State::Write,
|
||||
}
|
||||
}
|
||||
pub fn proc(&mut self, data: &str) {
|
||||
if self.state == State::Write {
|
||||
if !data.contains("\n") {
|
||||
self.buf.push(data.to_string());
|
||||
let buffered = self.buf.join("").trim().to_string();
|
||||
let filled = fill(&buffered);
|
||||
if let Some(filled) = filled {
|
||||
self.buf.clear();
|
||||
let formatted = format!("\r{}", filled);
|
||||
eprint!("{}", formatted);
|
||||
self.buf
|
||||
.push(formatted.split_once("\n").unwrap().1.to_string());
|
||||
std::io::stdout().flush().unwrap();
|
||||
return;
|
||||
}
|
||||
eprint!("{}", data);
|
||||
std::io::stdout().flush().unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
let mut data = data.to_string();
|
||||
while data.contains("\n") {
|
||||
let lines = data.split_once("\n").unwrap();
|
||||
let first = lines.0;
|
||||
let last = lines.1;
|
||||
self.buf.push(first.to_string());
|
||||
let buffered = self.buf.join("").trim().to_string();
|
||||
self.buf.clear();
|
||||
if buffered.ends_with("<note>") {
|
||||
let warn = format!("\r{}:", t!("ai-suggestion"))
|
||||
.bold()
|
||||
.blue()
|
||||
.to_string();
|
||||
let first = buffered.replace("<note>", &warn);
|
||||
eprintln!("{}", first);
|
||||
std::io::stdout().flush().unwrap();
|
||||
} else if buffered.ends_with("</note>") {
|
||||
let tag = "</note>";
|
||||
let whitespace = " ".repeat(tag.len());
|
||||
let formatted = format!("\r{}", whitespace);
|
||||
let first = buffered.replace("</note>", &formatted);
|
||||
eprintln!("{}", first);
|
||||
self.state = State::Store;
|
||||
std::io::stdout().flush().unwrap();
|
||||
} else if buffered.ends_with("<think>") {
|
||||
let tag = "<think>";
|
||||
let warn = format!("\r{}:", t!("ai-thinking"))
|
||||
.bold()
|
||||
.blue()
|
||||
.to_string();
|
||||
let first = buffered.replace(tag, &warn);
|
||||
eprintln!("{}", first);
|
||||
std::io::stdout().flush().unwrap();
|
||||
} else if buffered.ends_with("</think>") {
|
||||
let tag = "</think>";
|
||||
let whitespace = " ".repeat(tag.len());
|
||||
let formatted = format!("\r{}", whitespace);
|
||||
let first = buffered.replace(tag, &formatted);
|
||||
eprintln!("{}", first);
|
||||
std::io::stdout().flush().unwrap();
|
||||
} else if buffered.ends_with("```") {
|
||||
let tag = "```";
|
||||
let whitespace = " ".repeat(tag.len());
|
||||
let formatted = format!("\r{}", whitespace);
|
||||
let first = buffered.replace(tag, &formatted);
|
||||
eprintln!("{}", first);
|
||||
std::io::stdout().flush().unwrap();
|
||||
} else {
|
||||
eprintln!("{}", first);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
data = last.to_string();
|
||||
}
|
||||
eprint!("{}", data);
|
||||
return;
|
||||
}
|
||||
self.buf.push(data.to_string());
|
||||
}
|
||||
|
||||
pub fn print_return_remain(&mut self) -> String {
|
||||
let buffered = self.buf.join("").trim().to_string();
|
||||
self.buf.clear();
|
||||
if self.state == State::Store {
|
||||
return buffered;
|
||||
}
|
||||
|
||||
let split = buffered.split_once("<suggestions>");
|
||||
if let Some((first, last)) = split {
|
||||
eprint!("{}", first);
|
||||
std::io::stdout().flush().unwrap();
|
||||
return last.to_string();
|
||||
}
|
||||
"".to_string()
|
||||
}
|
||||
}
|
||||
|
|
@ -15,16 +15,16 @@
|
|||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
use crate::requests::ai_suggestion;
|
||||
use colored::Colorize;
|
||||
use sys_locale::get_locale;
|
||||
use textwrap::fill;
|
||||
mod buffer;
|
||||
mod requests;
|
||||
|
||||
#[macro_use]
|
||||
extern crate rust_i18n;
|
||||
i18n!("i18n", fallback = "en", minify_key = true);
|
||||
|
||||
fn main() -> Result<(), std::io::Error> {
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), std::io::Error> {
|
||||
let locale = {
|
||||
let sys_locale = get_locale().unwrap_or("en-US".to_string());
|
||||
if sys_locale.len() < 2 {
|
||||
|
|
@ -57,31 +57,8 @@ fn main() -> Result<(), std::io::Error> {
|
|||
if command.split_whitespace().count() == 1 {
|
||||
return Ok(());
|
||||
}
|
||||
let suggest = ai_suggestion(&command, &error);
|
||||
if let Some(suggest) = suggest {
|
||||
if let Some(thinking) = suggest.think {
|
||||
let note = format!("{}:", t!("ai-thinking")).bold().blue();
|
||||
let thinking = fill(&thinking, termwidth());
|
||||
eprintln!("{}{}", note, thinking);
|
||||
}
|
||||
let warn = format!("{}:", t!("ai-suggestion")).bold().blue();
|
||||
let note = fill(&suggest.suggestion.note, termwidth());
|
||||
|
||||
eprintln!("{}\n{}\n", warn, note);
|
||||
let suggestions = suggest.suggestion.commands;
|
||||
for suggestion in suggestions {
|
||||
print!("{}<_PR_BR>", suggestion);
|
||||
}
|
||||
}
|
||||
ai_suggestion(&command, &error).await;
|
||||
// if let Some(suggest) = suggest {
|
||||
// }
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn termwidth() -> usize {
|
||||
use terminal_size::{terminal_size, Height, Width};
|
||||
let size = terminal_size();
|
||||
if let Some((Width(w), Height(_))) = size {
|
||||
std::cmp::min(w as usize, 80)
|
||||
} else {
|
||||
80
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
use askama::Template;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use sys_locale::get_locale;
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::buffer;
|
||||
|
||||
struct Conf {
|
||||
key: String,
|
||||
|
|
@ -10,38 +14,57 @@ struct Conf {
|
|||
model: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize)]
|
||||
struct Input {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize)]
|
||||
struct Messages {
|
||||
messages: Vec<Input>,
|
||||
model: String,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AISuggest {
|
||||
pub commands: Vec<String>,
|
||||
pub note: String,
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatCompletion {
|
||||
// id: String,
|
||||
// object: String,
|
||||
// created: usize,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
pub struct AIResponse {
|
||||
pub suggestion: AISuggest,
|
||||
pub think: Option<String>,
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Choice {
|
||||
delta: Delta,
|
||||
// index: usize,
|
||||
// finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AIResponse> {
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Delta {
|
||||
content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "prompt.txt")]
|
||||
struct AiPrompt<'a> {
|
||||
last_command: &'a str,
|
||||
error_msg: &'a str,
|
||||
additional_prompt: &'a str,
|
||||
set_locale: &'a str,
|
||||
}
|
||||
|
||||
pub async fn ai_suggestion(last_command: &str, error_msg: &str) {
|
||||
if std::env::var("_PR_AI_DISABLE").is_ok() {
|
||||
return None;
|
||||
return;
|
||||
}
|
||||
|
||||
let conf = match Conf::new() {
|
||||
Some(conf) => conf,
|
||||
None => {
|
||||
return None;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -77,138 +100,105 @@ pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AIResponse>
|
|||
"".to_string()
|
||||
};
|
||||
|
||||
let ai_prompt = format!(
|
||||
r#"
|
||||
{addtional_prompt}
|
||||
`{last_command}` returns the following error message: `{error_msg}`. Provide possible commands to fix it. Answer in the following exact JSON template without any extra text:
|
||||
```
|
||||
{{"commands":["command 1","command 2"],"note":"why they may fix the error{set_locale}"}}
|
||||
```
|
||||
"#
|
||||
);
|
||||
let ai_prompt = AiPrompt {
|
||||
last_command,
|
||||
error_msg,
|
||||
additional_prompt: &addtional_prompt,
|
||||
set_locale: &set_locale,
|
||||
}
|
||||
.render()
|
||||
.unwrap()
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
eprintln!("AI module: AI prompt: {}", ai_prompt);
|
||||
|
||||
let res;
|
||||
let messages = Messages {
|
||||
// let res;
|
||||
let body = Messages {
|
||||
messages: vec![Input {
|
||||
role: "user".to_string(),
|
||||
content: ai_prompt.trim().to_string(),
|
||||
}],
|
||||
model: conf.model,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
#[cfg(feature = "libcurl")]
|
||||
{
|
||||
use curl::easy::Easy as Curl;
|
||||
use curl::easy::List;
|
||||
use std::io::Read;
|
||||
let client = reqwest::Client::new();
|
||||
let res = client
|
||||
.post(&conf.url)
|
||||
.body(serde_json::to_string(&body).unwrap())
|
||||
.header("Content-Type", "application/json")
|
||||
.bearer_auth(&conf.key)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let str_json = serde_json::to_string(&messages).unwrap();
|
||||
let mut data = str_json.as_bytes();
|
||||
let mut stream = res.unwrap().bytes_stream();
|
||||
let mut json_buffer = vec![];
|
||||
let mut buffer = buffer::Buffer::new();
|
||||
while let Some(item) = stream.next().await {
|
||||
let item = item.unwrap();
|
||||
let str = std::str::from_utf8(&item).unwrap();
|
||||
|
||||
let mut dst = Vec::new();
|
||||
let mut handle = Curl::new();
|
||||
|
||||
handle.url(&conf.url).unwrap();
|
||||
handle.post(true).unwrap();
|
||||
handle.post_field_size(data.len() as u64).unwrap();
|
||||
|
||||
let mut headers = List::new();
|
||||
headers
|
||||
.append(&format!("Authorization: Bearer {}", conf.key))
|
||||
.unwrap();
|
||||
headers.append("Content-Type: application/json").unwrap();
|
||||
handle.http_headers(headers).unwrap();
|
||||
|
||||
{
|
||||
let mut transfer = handle.transfer();
|
||||
|
||||
transfer
|
||||
.read_function(|buf| Ok(data.read(buf).unwrap_or(0)))
|
||||
.unwrap();
|
||||
|
||||
transfer
|
||||
.write_function(|buf| {
|
||||
dst.extend_from_slice(buf);
|
||||
Ok(buf.len())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
transfer.perform().expect("Failed to perform request");
|
||||
if json_buffer.is_empty() {
|
||||
json_buffer.push(str.to_string());
|
||||
continue;
|
||||
}
|
||||
|
||||
res = String::from_utf8(dst).unwrap();
|
||||
}
|
||||
#[cfg(not(feature = "libcurl"))]
|
||||
{
|
||||
let proc = std::process::Command::new("curl")
|
||||
.arg("-X")
|
||||
.arg("POST")
|
||||
.arg("-H")
|
||||
.arg(format!("Authorization: Bearer {}", conf.key))
|
||||
.arg("-H")
|
||||
.arg("Content-Type: application/json")
|
||||
.arg("-d")
|
||||
.arg(serde_json::to_string(&messages).unwrap())
|
||||
.arg(conf.url)
|
||||
.output();
|
||||
if !str.contains("\n\ndata: {") {
|
||||
json_buffer.push(str.to_string());
|
||||
continue;
|
||||
}
|
||||
let data_loc = str.find("\n\ndata: {").unwrap();
|
||||
let split = str.split_at(data_loc);
|
||||
json_buffer.push(split.0.to_string());
|
||||
let working_str = json_buffer.join("");
|
||||
json_buffer.clear();
|
||||
json_buffer.push(split.1.to_string());
|
||||
|
||||
let out = match proc {
|
||||
Ok(proc) => proc.stdout,
|
||||
Err(_) => {
|
||||
return None;
|
||||
for part in working_str.split("\n\n") {
|
||||
if let Some(data) = part.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
let json = serde_json::from_str::<ChatCompletion>(data).unwrap_or_else(|_| {
|
||||
panic!("AI module: Failed to parse JSON content: {}", data)
|
||||
});
|
||||
let choice = json.choices.first().expect("AI module: No choices found");
|
||||
if let Some(content) = &choice.delta.content {
|
||||
buffer.proc(content);
|
||||
}
|
||||
}
|
||||
};
|
||||
res = String::from_utf8(out).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let json: Value = {
|
||||
let json = serde_json::from_str(&res);
|
||||
if let Ok(json) = json {
|
||||
json
|
||||
} else {
|
||||
eprintln!("AI module: Failed to parse JSON response: {}", res);
|
||||
return None;
|
||||
if !json_buffer.is_empty() {
|
||||
let working_str = json_buffer.join("");
|
||||
for part in working_str.split("\n\n") {
|
||||
if let Some(data) = part.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
let json = serde_json::from_str::<ChatCompletion>(data).unwrap_or_else(|_| {
|
||||
panic!("AI module: Failed to parse JSON content: {}", data)
|
||||
});
|
||||
let choice = json.choices.first().expect("AI module: No choices found");
|
||||
if let Some(content) = &choice.delta.content {
|
||||
buffer.proc(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let content = &json["choices"][0]["message"]["content"];
|
||||
let mut str = content
|
||||
.as_str()
|
||||
.expect("AI module: Failed to get content from response")
|
||||
json_buffer.clear();
|
||||
}
|
||||
let suggestions = buffer
|
||||
.print_return_remain()
|
||||
.trim()
|
||||
.to_string();
|
||||
.trim_end_matches("```")
|
||||
.trim()
|
||||
.trim_start_matches("<suggestions>")
|
||||
.trim_end_matches("</suggestions>")
|
||||
.replace("<br>", "<_PR_BR>");
|
||||
|
||||
let think = if str.starts_with("<think>") {
|
||||
let start_len = "<think>".len();
|
||||
let end_len = "</think>".len();
|
||||
let end = str.find("</think>").unwrap() + end_len;
|
||||
let think = str[start_len..end - end_len].to_string();
|
||||
str = str[end..].to_string();
|
||||
Some(think)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let suggestion: AISuggest = {
|
||||
let str = {
|
||||
str.trim()
|
||||
.trim_start_matches("```")
|
||||
.trim_start_matches("json")
|
||||
.trim_end_matches("```")
|
||||
};
|
||||
let json = serde_json::from_str(str);
|
||||
if json.is_err() {
|
||||
eprintln!("AI module: Failed to parse JSON content: {}", str);
|
||||
return None;
|
||||
}
|
||||
json.unwrap()
|
||||
};
|
||||
|
||||
let response = AIResponse { suggestion, think };
|
||||
Some(response)
|
||||
println!("{}", suggestions);
|
||||
}
|
||||
|
||||
impl Conf {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue