feat: reasoning AI models

This commit is contained in:
iff 2025-04-06 17:53:27 +02:00
parent 874b924496
commit 8233ab723d
4 changed files with 48 additions and 8 deletions

View file

@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Support reasoning AI models (can take more than 20 seconds)
- Allow adding additional prompts for role-playing with perversion or whatever
- `exe_contains` condition to check if the command contains the argument - `exe_contains` condition to check if the command contains the argument
### Fixed ### Fixed

View file

@ -12,3 +12,14 @@ ja = "AIからの提案"
ko = "AI 제안" ko = "AI 제안"
zh = "AI 建议" zh = "AI 建议"
[ai-thinking]
en = "AI is thinking..."
es = "La IA está pensando..."
de = "KI denkt nach..."
fr = "L'IA réfléchit..."
it = "L'IA sta pensando..."
pt = "A IA está pensando..."
ru = "ИИ думает..."
ja = "AIが考えています..."
ko = "AI가 생각 중입니다..."
zh = "AI正在思考..."

View file

@ -48,11 +48,16 @@ fn main() -> Result<(), std::io::Error> {
} }
let suggest = ai_suggestion(&command, &error); let suggest = ai_suggestion(&command, &error);
if let Some(suggest) = suggest { if let Some(suggest) = suggest {
if let Some(thinking) = suggest.think {
let note = format!("{}:", t!("ai-thinking")).bold().blue();
let thinking = fill(&thinking, termwidth());
eprintln!("{}{}", note, thinking);
}
let warn = format!("{}:", t!("ai-suggestion")).bold().blue(); let warn = format!("{}:", t!("ai-suggestion")).bold().blue();
let note = fill(&suggest.note, termwidth()); let note = fill(&suggest.suggestion.note, termwidth());
eprintln!("{}\n{}\n", warn, note); eprintln!("{}\n{}\n", warn, note);
let suggestions = suggest.commands; let suggestions = suggest.suggestion.commands;
for suggestion in suggestions { for suggestion in suggestions {
print!("{}<_PR_BR>", suggestion); print!("{}<_PR_BR>", suggestion);
} }

View file

@ -28,7 +28,12 @@ pub struct AISuggest {
pub note: String, pub note: String,
} }
pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AISuggest> { pub struct AIResponse {
pub suggestion: AISuggest,
pub think: Option<String>,
}
pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AIResponse> {
if std::env::var("_PR_AI_DISABLE").is_ok() { if std::env::var("_PR_AI_DISABLE").is_ok() {
return None; return None;
} }
@ -148,6 +153,7 @@ pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AISuggest> {
}; };
res = String::from_utf8(out).unwrap(); res = String::from_utf8(out).unwrap();
} }
let json: Value = { let json: Value = {
let json = serde_json::from_str(&res); let json = serde_json::from_str(&res);
if let Ok(json) = json { if let Ok(json) = json {
@ -159,12 +165,26 @@ pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AISuggest> {
}; };
let content = &json["choices"][0]["message"]["content"]; let content = &json["choices"][0]["message"]["content"];
let mut str = content
.as_str()
.expect("AI module: Failed to get content from response")
.trim()
.to_string();
let think = if str.starts_with("<think>") {
let start_len = "<think>".len();
let end_len = "</think>".len();
let end = str.find("</think>").unwrap() + end_len;
let think = str[start_len..end - end_len].to_string();
str = str[end..].to_string();
Some(think)
} else {
None
};
let suggestion: AISuggest = { let suggestion: AISuggest = {
let str = { let str = {
let str = content.as_str(); str.trim()
str?;
str.expect("AI module: Failed to get content from response")
.trim_start_matches("```") .trim_start_matches("```")
.trim_start_matches("json") .trim_start_matches("json")
.trim_end_matches("```") .trim_end_matches("```")
@ -176,7 +196,9 @@ pub fn ai_suggestion(last_command: &str, error_msg: &str) -> Option<AISuggest> {
} }
json.unwrap() json.unwrap()
}; };
Some(suggestion)
let response = AIResponse { suggestion, think };
Some(response)
} }
impl Conf { impl Conf {
@ -215,7 +237,7 @@ impl Conf {
if let Some(model) = option_env!("_DEF_PR_AI_MODEL") { if let Some(model) = option_env!("_DEF_PR_AI_MODEL") {
model.to_string() model.to_string()
} else { } else {
"llama3-70b-8192".to_string() "qwen-2.5-32b".to_string()
} }
} }
}; };