go-copilot-proxy/internal/server/server.go

343 lines
9.3 KiB
Go
Raw Normal View History

2025-08-05 11:08:53 +02:00
package server
import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"gitlab.com/technofab/go-copilot-proxy/internal/auth"
"gitlab.com/technofab/go-copilot-proxy/internal/config"
)
func NewRouter(accessToken string, copilotAuth *auth.CopilotAuth) chi.Router {
r := chi.NewRouter()
r.Use(middleware.Recoverer)
r.Use(middleware.RealIP)
r.Use(middleware.RequestID)
r.Use(middleware.Heartbeat("/healthz"))
r.Use(corsHandler)
r.Route("/v1", func(r chi.Router) {
r.Use(authMiddleware(accessToken))
r.Post("/chat/completions", createChatHandler(copilotAuth))
r.Post("/embeddings", createProxyHandler("https://api.githubcopilot.com/embeddings", copilotAuth))
r.Get("/models", createProxyHandler("https://api.githubcopilot.com/models", copilotAuth))
})
return r
}
func authMiddleware(accessToken string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
token := strings.TrimPrefix(authHeader, "Bearer ")
if token != accessToken {
http.Error(w, "Invalid access token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}
func corsHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
w.Header().Set("Access-Control-Expose-Headers", "Content-Type, Content-Length, Transfer-Encoding")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
func createChatHandler(copilotAuth *auth.CopilotAuth) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
log.Debug().
Str("method", r.Method).
Str("path", r.URL.Path).
Msg("Handling chat completions request")
token := copilotAuth.GetToken()
if token == "" {
http.Error(w, "No Copilot token available", http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
r.Body.Close()
req, err := http.NewRequest("POST", "https://api.githubcopilot.com/chat/completions", bytes.NewBuffer(body))
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
headers := createRequestHeaders(token, "api.githubcopilot.com")
for key, value := range headers {
req.Header.Set(key, value)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
log.Error().Err(err).Msg("Failed to make request to GitHub Copilot API")
http.Error(w, "Failed to proxy request", http.StatusBadGateway)
return
}
defer resp.Body.Close()
copyHeaders(w.Header(), resp.Header)
w.Header().Set("Access-Control-Allow-Origin", "*")
isStreaming := resp.Header.Get("Transfer-Encoding") == "chunked" ||
strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode == http.StatusOK && isStreaming {
handleStreamingResponse(w, resp, r)
} else if resp.StatusCode == http.StatusOK {
handleNonStreamingResponse(w, resp, r)
} else {
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
}
}
func createRequestHeaders(token, host string) map[string]string {
hash := sha256.Sum256([]byte(token))
machineID := hex.EncodeToString(hash[:])
return map[string]string{
"Authorization": "Bearer " + token,
"Host": host,
"X-Request-Id": uuid.New().String(),
"X-Github-Api-Version": "2025-04-01",
"Vscode-Sessionid": uuid.New().String() + fmt.Sprintf("%d", time.Now().Unix()),
"Vscode-Machineid": machineID,
"Editor-Version": config.EditorVersion,
"Editor-Plugin-Version": config.EditorPluginVersion,
"Openai-Organization": "github-copilot",
"Copilot-Integration-Id": config.CopilotIntegrationID,
"Openai-Intent": "conversation-panel",
"Content-Type": "application/json",
"User-Agent": config.UserAgent,
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
}
func copyHeaders(dst, src http.Header) {
for key, values := range src {
for _, value := range values {
dst.Add(key, value)
}
}
}
func handleStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
log.Error().Msg("Response writer does not support flushing")
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
scanner := bufio.NewScanner(resp.Body)
buffer := ""
delimiter := "\n\n"
for scanner.Scan() {
line := scanner.Text()
buffer += line + "\n"
if strings.HasSuffix(buffer, delimiter) {
lines := strings.Split(strings.TrimSuffix(buffer, delimiter), delimiter)
for _, chunk := range lines {
if cleanedChunk := cleanStreamLine(chunk); cleanedChunk != "" {
fmt.Fprint(w, cleanedChunk+delimiter)
flusher.Flush()
}
}
buffer = ""
}
}
if buffer != "" {
if cleanedChunk := cleanStreamLine(strings.TrimSuffix(buffer, "\n")); cleanedChunk != "" {
fmt.Fprint(w, cleanedChunk+"\n")
flusher.Flush()
}
}
if err := scanner.Err(); err != nil {
log.Error().Err(err).Msg("Error reading streaming response")
}
}
func handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) {
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Error().Err(err).Msg("Failed to read response body")
http.Error(w, "Failed to read response", http.StatusInternalServerError)
return
}
cleanedBody := cleanResponse(string(body))
w.WriteHeader(resp.StatusCode)
fmt.Fprint(w, cleanedBody)
}
func cleanResponse(responseBody string) string {
var data map[string]interface{}
if err := json.Unmarshal([]byte(responseBody), &data); err != nil {
log.Error().Err(err).Msg("Failed to parse response JSON")
return responseBody
}
data["object"] = "chat.completion"
delete(data, "prompt_filter_results")
if choices, ok := data["choices"].([]interface{}); ok {
for _, choice := range choices {
if choiceMap, ok := choice.(map[string]interface{}); ok {
delete(choiceMap, "content_filter_results")
}
}
}
cleanedBytes, err := json.Marshal(data)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal cleaned response")
return responseBody
}
return string(cleanedBytes)
}
func cleanStreamLine(line string) string {
if !strings.HasPrefix(line, "data: ") {
return line
}
dataContent := strings.TrimPrefix(line, "data: ")
if dataContent == "[DONE]" {
return line
}
var data map[string]interface{}
if err := json.Unmarshal([]byte(dataContent), &data); err != nil {
log.Debug().Err(err).Msg("Failed to parse stream line JSON")
return line
}
if choices, ok := data["choices"].([]interface{}); ok {
if len(choices) == 0 {
return ""
}
data["object"] = "chat.completion.chunk"
for _, choice := range choices {
if choiceMap, ok := choice.(map[string]interface{}); ok {
delete(choiceMap, "content_filter_offsets")
delete(choiceMap, "content_filter_results")
if delta, ok := choiceMap["delta"].(map[string]interface{}); ok {
for key, value := range delta {
if value == nil {
delete(delta, key)
}
}
}
}
}
}
cleanedBytes, err := json.Marshal(data)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal cleaned stream data")
return line
}
return "data: " + string(cleanedBytes)
}
func createProxyHandler(targetURL string, copilotAuth *auth.CopilotAuth) http.HandlerFunc {
target, err := url.Parse(targetURL)
if err != nil {
log.Fatal().Err(err).Str("url", targetURL).Msg("Invalid target URL")
}
return func(w http.ResponseWriter, r *http.Request) {
log.Debug().
Str("method", r.Method).
Str("path", r.URL.Path).
Str("target", targetURL).
Msg("Proxying request")
token := copilotAuth.GetToken()
if token == "" {
http.Error(w, "No Copilot token available", http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
r.Body.Close()
req, err := http.NewRequest(r.Method, target.String(), bytes.NewBuffer(body))
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
headers := createRequestHeaders(token, target.Host)
for key, value := range headers {
req.Header.Set(key, value)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
log.Error().Err(err).Msg("Failed to make request")
http.Error(w, "Failed to proxy request", http.StatusBadGateway)
return
}
defer resp.Body.Close()
copyHeaders(w.Header(), resp.Header)
w.Header().Set("Access-Control-Allow-Origin", "*")
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
}