From 609c7609d7580695dd2def106d3f8cf208084b00 Mon Sep 17 00:00:00 2001 From: iff Date: Fri, 22 Nov 2024 17:28:49 +0100 Subject: [PATCH] feat: optional libcurl linking --- Cargo.toml | 6 +++- src/args.rs | 8 +++-- src/requests.rs | 91 +++++++++++++++++++++++++++++++------------------ 3 files changed, 69 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7a9e961..3c96595 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,11 @@ pay-respects-parser = "0.2.6" [features] runtime-rules = ["dep:serde", "dep:toml"] -request-ai = ["dep:serde", "dep:serde_json", "dep:curl", "dep:textwrap"] +request-ai = ["dep:serde", "dep:serde_json", "dep:textwrap"] + +# linking to libcurl dynamically requires openssl when compiling and +# complicates cross compilation +libcurl = ["dep:curl"] [profile.release] strip = true diff --git a/src/args.rs b/src/args.rs index d21cd9b..ae7cafd 100644 --- a/src/args.rs +++ b/src/args.rs @@ -69,11 +69,15 @@ fn print_version() { println!("compile features:"); #[cfg(feature = "runtime-rules")] { - println!("\t- runtime-rules"); + println!(" - runtime-rules"); } #[cfg(feature = "request-ai")] { - println!("\t- request-ai"); + println!(" - request-ai"); + } + #[cfg(feature = "libcurl")] + { + println!(" - libcurl"); } std::process::exit(0); } diff --git a/src/requests.rs b/src/requests.rs index f84391b..b65a62a 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -1,12 +1,8 @@ use std::collections::HashMap; -use std::io::Read; use serde::{Deserialize, Serialize}; use serde_json::Value; -use curl::easy::Easy as Curl; -use curl::easy::List; - #[derive(Serialize, Deserialize)] struct Input { role: String, @@ -93,6 +89,7 @@ The command `{last_command}` returns the following error message: `{error_msg}`. "# ); + let res; let messages = Messages { messages: vec![Input { role: "user".to_string(), @@ -101,42 +98,70 @@ The command `{last_command}` returns the following error message: `{error_msg}`. model, }; - let str_json = serde_json::to_string(&messages).unwrap(); - let mut data = str_json.as_bytes(); - - let mut dst = Vec::new(); - let mut handle = Curl::new(); - - handle.url(&request_url).unwrap(); - handle.post(true).unwrap(); - handle.post_field_size(data.len() as u64).unwrap(); - - let mut headers = List::new(); - headers - .append(&format!("Authorization: Bearer {}", api_key)) - .unwrap(); - headers.append("Content-Type: application/json").unwrap(); - handle.http_headers(headers).unwrap(); - + #[cfg(feature = "libcurl")] { - let mut transfer = handle.transfer(); + use curl::easy::Easy as Curl; + use curl::easy::List; + use std::io::Read; - transfer - .read_function(|buf| Ok(data.read(buf).unwrap_or(0))) + let str_json = serde_json::to_string(&messages).unwrap(); + let mut data = str_json.as_bytes(); + + let mut dst = Vec::new(); + let mut handle = Curl::new(); + + handle.url(&request_url).unwrap(); + handle.post(true).unwrap(); + handle.post_field_size(data.len() as u64).unwrap(); + + let mut headers = List::new(); + headers + .append(&format!("Authorization: Bearer {}", api_key)) .unwrap(); + headers.append("Content-Type: application/json").unwrap(); + handle.http_headers(headers).unwrap(); - transfer - .write_function(|buf| { - dst.extend_from_slice(buf); - Ok(buf.len()) - }) - .unwrap(); + { + let mut transfer = handle.transfer(); - transfer.perform().expect("Failed to perform request"); + transfer + .read_function(|buf| Ok(data.read(buf).unwrap_or(0))) + .unwrap(); + + transfer + .write_function(|buf| { + dst.extend_from_slice(buf); + Ok(buf.len()) + }) + .unwrap(); + + transfer.perform().expect("Failed to perform request"); + } + + res = String::from_utf8(dst).unwrap(); } + #[cfg(not(feature = "libcurl"))] + { + let proc = std::process::Command::new("curl") + .arg("-X") + .arg("POST") + .arg("-H") + .arg(format!("Authorization: Bearer {}", api_key)) + .arg("-H") + .arg("Content-Type: application/json") + .arg("-d") + .arg(serde_json::to_string(&messages).unwrap()) + .arg(request_url) + .output(); - let res = String::from_utf8(dst).unwrap(); - + let out = match proc { + Ok(proc) => proc.stdout, + Err(_) => { + return None; + } + }; + res = String::from_utf8(out).unwrap(); + } let json: Value = serde_json::from_str(&res).unwrap(); let content = &json["choices"][0]["message"]["content"];