mirror of
https://gitlab.com/TECHNOFAB/go-copilot-proxy.git
synced 2025-12-11 22:10:06 +01:00
343 lines
9.3 KiB
Go
343 lines
9.3 KiB
Go
|
|
package server
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bufio"
|
||
|
|
"bytes"
|
||
|
|
"crypto/sha256"
|
||
|
|
"encoding/hex"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"net/url"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/go-chi/chi/v5"
|
||
|
|
"github.com/go-chi/chi/v5/middleware"
|
||
|
|
"github.com/google/uuid"
|
||
|
|
"github.com/rs/zerolog/log"
|
||
|
|
|
||
|
|
"gitlab.com/technofab/go-copilot-proxy/internal/auth"
|
||
|
|
"gitlab.com/technofab/go-copilot-proxy/internal/config"
|
||
|
|
)
|
||
|
|
|
||
|
|
func NewRouter(accessToken string, copilotAuth *auth.CopilotAuth) chi.Router {
|
||
|
|
r := chi.NewRouter()
|
||
|
|
r.Use(middleware.Recoverer)
|
||
|
|
r.Use(middleware.RealIP)
|
||
|
|
r.Use(middleware.RequestID)
|
||
|
|
r.Use(middleware.Heartbeat("/healthz"))
|
||
|
|
r.Use(corsHandler)
|
||
|
|
|
||
|
|
r.Route("/v1", func(r chi.Router) {
|
||
|
|
r.Use(authMiddleware(accessToken))
|
||
|
|
r.Post("/chat/completions", createChatHandler(copilotAuth))
|
||
|
|
r.Post("/embeddings", createProxyHandler("https://api.githubcopilot.com/embeddings", copilotAuth))
|
||
|
|
r.Get("/models", createProxyHandler("https://api.githubcopilot.com/models", copilotAuth))
|
||
|
|
})
|
||
|
|
|
||
|
|
return r
|
||
|
|
}
|
||
|
|
|
||
|
|
func authMiddleware(accessToken string) func(http.Handler) http.Handler {
|
||
|
|
return func(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
authHeader := r.Header.Get("Authorization")
|
||
|
|
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||
|
|
if token != accessToken {
|
||
|
|
http.Error(w, "Invalid access token", http.StatusForbidden)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func corsHandler(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
|
|
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
||
|
|
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
|
||
|
|
w.Header().Set("Access-Control-Expose-Headers", "Content-Type, Content-Length, Transfer-Encoding")
|
||
|
|
if r.Method == "OPTIONS" {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func createChatHandler(copilotAuth *auth.CopilotAuth) http.HandlerFunc {
|
||
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
log.Debug().
|
||
|
|
Str("method", r.Method).
|
||
|
|
Str("path", r.URL.Path).
|
||
|
|
Msg("Handling chat completions request")
|
||
|
|
|
||
|
|
token := copilotAuth.GetToken()
|
||
|
|
if token == "" {
|
||
|
|
http.Error(w, "No Copilot token available", http.StatusServiceUnavailable)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
body, err := io.ReadAll(r.Body)
|
||
|
|
if err != nil {
|
||
|
|
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
r.Body.Close()
|
||
|
|
|
||
|
|
req, err := http.NewRequest("POST", "https://api.githubcopilot.com/chat/completions", bytes.NewBuffer(body))
|
||
|
|
if err != nil {
|
||
|
|
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
headers := createRequestHeaders(token, "api.githubcopilot.com")
|
||
|
|
for key, value := range headers {
|
||
|
|
req.Header.Set(key, value)
|
||
|
|
}
|
||
|
|
|
||
|
|
client := &http.Client{}
|
||
|
|
resp, err := client.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
log.Error().Err(err).Msg("Failed to make request to GitHub Copilot API")
|
||
|
|
http.Error(w, "Failed to proxy request", http.StatusBadGateway)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
copyHeaders(w.Header(), resp.Header)
|
||
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
|
|
|
||
|
|
isStreaming := resp.Header.Get("Transfer-Encoding") == "chunked" ||
|
||
|
|
strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
|
||
|
|
|
||
|
|
if resp.StatusCode == http.StatusOK && isStreaming {
|
||
|
|
handleStreamingResponse(w, resp, r)
|
||
|
|
} else if resp.StatusCode == http.StatusOK {
|
||
|
|
handleNonStreamingResponse(w, resp, r)
|
||
|
|
} else {
|
||
|
|
w.WriteHeader(resp.StatusCode)
|
||
|
|
io.Copy(w, resp.Body)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func createRequestHeaders(token, host string) map[string]string {
|
||
|
|
hash := sha256.Sum256([]byte(token))
|
||
|
|
machineID := hex.EncodeToString(hash[:])
|
||
|
|
|
||
|
|
return map[string]string{
|
||
|
|
"Authorization": "Bearer " + token,
|
||
|
|
"Host": host,
|
||
|
|
"X-Request-Id": uuid.New().String(),
|
||
|
|
"X-Github-Api-Version": "2025-04-01",
|
||
|
|
"Vscode-Sessionid": uuid.New().String() + fmt.Sprintf("%d", time.Now().Unix()),
|
||
|
|
"Vscode-Machineid": machineID,
|
||
|
|
"Editor-Version": config.EditorVersion,
|
||
|
|
"Editor-Plugin-Version": config.EditorPluginVersion,
|
||
|
|
"Openai-Organization": "github-copilot",
|
||
|
|
"Copilot-Integration-Id": config.CopilotIntegrationID,
|
||
|
|
"Openai-Intent": "conversation-panel",
|
||
|
|
"Content-Type": "application/json",
|
||
|
|
"User-Agent": config.UserAgent,
|
||
|
|
"Accept": "*/*",
|
||
|
|
"Accept-Encoding": "gzip, deflate, br",
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func copyHeaders(dst, src http.Header) {
|
||
|
|
for key, values := range src {
|
||
|
|
for _, value := range values {
|
||
|
|
dst.Add(key, value)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func handleStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) {
|
||
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
||
|
|
w.Header().Set("Cache-Control", "no-cache")
|
||
|
|
w.Header().Set("Connection", "keep-alive")
|
||
|
|
|
||
|
|
flusher, ok := w.(http.Flusher)
|
||
|
|
if !ok {
|
||
|
|
log.Error().Msg("Response writer does not support flushing")
|
||
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
scanner := bufio.NewScanner(resp.Body)
|
||
|
|
buffer := ""
|
||
|
|
delimiter := "\n\n"
|
||
|
|
|
||
|
|
for scanner.Scan() {
|
||
|
|
line := scanner.Text()
|
||
|
|
buffer += line + "\n"
|
||
|
|
|
||
|
|
if strings.HasSuffix(buffer, delimiter) {
|
||
|
|
lines := strings.Split(strings.TrimSuffix(buffer, delimiter), delimiter)
|
||
|
|
for _, chunk := range lines {
|
||
|
|
if cleanedChunk := cleanStreamLine(chunk); cleanedChunk != "" {
|
||
|
|
fmt.Fprint(w, cleanedChunk+delimiter)
|
||
|
|
flusher.Flush()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
buffer = ""
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if buffer != "" {
|
||
|
|
if cleanedChunk := cleanStreamLine(strings.TrimSuffix(buffer, "\n")); cleanedChunk != "" {
|
||
|
|
fmt.Fprint(w, cleanedChunk+"\n")
|
||
|
|
flusher.Flush()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := scanner.Err(); err != nil {
|
||
|
|
log.Error().Err(err).Msg("Error reading streaming response")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) {
|
||
|
|
body, err := io.ReadAll(resp.Body)
|
||
|
|
if err != nil {
|
||
|
|
log.Error().Err(err).Msg("Failed to read response body")
|
||
|
|
http.Error(w, "Failed to read response", http.StatusInternalServerError)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
cleanedBody := cleanResponse(string(body))
|
||
|
|
w.WriteHeader(resp.StatusCode)
|
||
|
|
fmt.Fprint(w, cleanedBody)
|
||
|
|
}
|
||
|
|
|
||
|
|
func cleanResponse(responseBody string) string {
|
||
|
|
var data map[string]interface{}
|
||
|
|
if err := json.Unmarshal([]byte(responseBody), &data); err != nil {
|
||
|
|
log.Error().Err(err).Msg("Failed to parse response JSON")
|
||
|
|
return responseBody
|
||
|
|
}
|
||
|
|
|
||
|
|
data["object"] = "chat.completion"
|
||
|
|
delete(data, "prompt_filter_results")
|
||
|
|
|
||
|
|
if choices, ok := data["choices"].([]interface{}); ok {
|
||
|
|
for _, choice := range choices {
|
||
|
|
if choiceMap, ok := choice.(map[string]interface{}); ok {
|
||
|
|
delete(choiceMap, "content_filter_results")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
cleanedBytes, err := json.Marshal(data)
|
||
|
|
if err != nil {
|
||
|
|
log.Error().Err(err).Msg("Failed to marshal cleaned response")
|
||
|
|
return responseBody
|
||
|
|
}
|
||
|
|
|
||
|
|
return string(cleanedBytes)
|
||
|
|
}
|
||
|
|
|
||
|
|
func cleanStreamLine(line string) string {
|
||
|
|
if !strings.HasPrefix(line, "data: ") {
|
||
|
|
return line
|
||
|
|
}
|
||
|
|
|
||
|
|
dataContent := strings.TrimPrefix(line, "data: ")
|
||
|
|
if dataContent == "[DONE]" {
|
||
|
|
return line
|
||
|
|
}
|
||
|
|
|
||
|
|
var data map[string]interface{}
|
||
|
|
if err := json.Unmarshal([]byte(dataContent), &data); err != nil {
|
||
|
|
log.Debug().Err(err).Msg("Failed to parse stream line JSON")
|
||
|
|
return line
|
||
|
|
}
|
||
|
|
|
||
|
|
if choices, ok := data["choices"].([]interface{}); ok {
|
||
|
|
if len(choices) == 0 {
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
|
||
|
|
data["object"] = "chat.completion.chunk"
|
||
|
|
|
||
|
|
for _, choice := range choices {
|
||
|
|
if choiceMap, ok := choice.(map[string]interface{}); ok {
|
||
|
|
delete(choiceMap, "content_filter_offsets")
|
||
|
|
delete(choiceMap, "content_filter_results")
|
||
|
|
|
||
|
|
if delta, ok := choiceMap["delta"].(map[string]interface{}); ok {
|
||
|
|
for key, value := range delta {
|
||
|
|
if value == nil {
|
||
|
|
delete(delta, key)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
cleanedBytes, err := json.Marshal(data)
|
||
|
|
if err != nil {
|
||
|
|
log.Error().Err(err).Msg("Failed to marshal cleaned stream data")
|
||
|
|
return line
|
||
|
|
}
|
||
|
|
|
||
|
|
return "data: " + string(cleanedBytes)
|
||
|
|
}
|
||
|
|
|
||
|
|
func createProxyHandler(targetURL string, copilotAuth *auth.CopilotAuth) http.HandlerFunc {
|
||
|
|
target, err := url.Parse(targetURL)
|
||
|
|
if err != nil {
|
||
|
|
log.Fatal().Err(err).Str("url", targetURL).Msg("Invalid target URL")
|
||
|
|
}
|
||
|
|
|
||
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
log.Debug().
|
||
|
|
Str("method", r.Method).
|
||
|
|
Str("path", r.URL.Path).
|
||
|
|
Str("target", targetURL).
|
||
|
|
Msg("Proxying request")
|
||
|
|
|
||
|
|
token := copilotAuth.GetToken()
|
||
|
|
if token == "" {
|
||
|
|
http.Error(w, "No Copilot token available", http.StatusServiceUnavailable)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
body, err := io.ReadAll(r.Body)
|
||
|
|
if err != nil {
|
||
|
|
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
r.Body.Close()
|
||
|
|
|
||
|
|
req, err := http.NewRequest(r.Method, target.String(), bytes.NewBuffer(body))
|
||
|
|
if err != nil {
|
||
|
|
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
headers := createRequestHeaders(token, target.Host)
|
||
|
|
for key, value := range headers {
|
||
|
|
req.Header.Set(key, value)
|
||
|
|
}
|
||
|
|
|
||
|
|
client := &http.Client{}
|
||
|
|
resp, err := client.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
log.Error().Err(err).Msg("Failed to make request")
|
||
|
|
http.Error(w, "Failed to proxy request", http.StatusBadGateway)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
copyHeaders(w.Header(), resp.Header)
|
||
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
|
|
w.WriteHeader(resp.StatusCode)
|
||
|
|
io.Copy(w, resp.Body)
|
||
|
|
}
|
||
|
|
}
|