package server import ( "bufio" "bytes" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/google/uuid" "github.com/rs/zerolog/log" "gitlab.com/technofab/go-copilot-proxy/internal/auth" "gitlab.com/technofab/go-copilot-proxy/internal/config" ) func NewRouter(accessToken string, copilotAuth *auth.CopilotAuth) chi.Router { r := chi.NewRouter() r.Use(middleware.Recoverer) r.Use(middleware.RealIP) r.Use(middleware.RequestID) r.Use(middleware.Heartbeat("/healthz")) r.Use(corsHandler) r.Route("/v1", func(r chi.Router) { r.Use(authMiddleware(accessToken)) r.Post("/chat/completions", createChatHandler(copilotAuth)) r.Post("/embeddings", createProxyHandler("https://api.githubcopilot.com/embeddings", copilotAuth)) r.Get("/models", createProxyHandler("https://api.githubcopilot.com/models", copilotAuth)) }) return r } func authMiddleware(accessToken string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") token := strings.TrimPrefix(authHeader, "Bearer ") if token != accessToken { http.Error(w, "Invalid access token", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } } func corsHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") w.Header().Set("Access-Control-Expose-Headers", "Content-Type, Content-Length, Transfer-Encoding") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } func createChatHandler(copilotAuth *auth.CopilotAuth) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { log.Debug(). Str("method", r.Method). Str("path", r.URL.Path). Msg("Handling chat completions request") token := copilotAuth.GetToken() if token == "" { http.Error(w, "No Copilot token available", http.StatusServiceUnavailable) return } body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Failed to read request body", http.StatusBadRequest) return } r.Body.Close() req, err := http.NewRequest("POST", "https://api.githubcopilot.com/chat/completions", bytes.NewBuffer(body)) if err != nil { http.Error(w, "Failed to create request", http.StatusInternalServerError) return } headers := createRequestHeaders(token, "api.githubcopilot.com") for key, value := range headers { req.Header.Set(key, value) } client := &http.Client{} resp, err := client.Do(req) if err != nil { log.Error().Err(err).Msg("Failed to make request to GitHub Copilot API") http.Error(w, "Failed to proxy request", http.StatusBadGateway) return } defer resp.Body.Close() copyHeaders(w.Header(), resp.Header) w.Header().Set("Access-Control-Allow-Origin", "*") isStreaming := resp.Header.Get("Transfer-Encoding") == "chunked" || strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode == http.StatusOK && isStreaming { handleStreamingResponse(w, resp, r) } else if resp.StatusCode == http.StatusOK { handleNonStreamingResponse(w, resp, r) } else { w.WriteHeader(resp.StatusCode) io.Copy(w, resp.Body) } } } func createRequestHeaders(token, host string) map[string]string { hash := sha256.Sum256([]byte(token)) machineID := hex.EncodeToString(hash[:]) return map[string]string{ "Authorization": "Bearer " + token, "Host": host, "X-Request-Id": uuid.New().String(), "X-Github-Api-Version": "2025-04-01", "Vscode-Sessionid": uuid.New().String() + fmt.Sprintf("%d", time.Now().Unix()), "Vscode-Machineid": machineID, "Editor-Version": config.EditorVersion, "Editor-Plugin-Version": config.EditorPluginVersion, "Openai-Organization": "github-copilot", "Copilot-Integration-Id": config.CopilotIntegrationID, "Openai-Intent": "conversation-panel", "Content-Type": "application/json", "User-Agent": config.UserAgent, "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", } } func copyHeaders(dst, src http.Header) { for key, values := range src { for _, value := range values { dst.Add(key, value) } } } func handleStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") flusher, ok := w.(http.Flusher) if !ok { log.Error().Msg("Response writer does not support flushing") http.Error(w, "Streaming not supported", http.StatusInternalServerError) return } scanner := bufio.NewScanner(resp.Body) buffer := "" delimiter := "\n\n" for scanner.Scan() { line := scanner.Text() buffer += line + "\n" if strings.HasSuffix(buffer, delimiter) { lines := strings.Split(strings.TrimSuffix(buffer, delimiter), delimiter) for _, chunk := range lines { if cleanedChunk := cleanStreamLine(chunk); cleanedChunk != "" { fmt.Fprint(w, cleanedChunk+delimiter) flusher.Flush() } } buffer = "" } } if buffer != "" { if cleanedChunk := cleanStreamLine(strings.TrimSuffix(buffer, "\n")); cleanedChunk != "" { fmt.Fprint(w, cleanedChunk+"\n") flusher.Flush() } } if err := scanner.Err(); err != nil { log.Error().Err(err).Msg("Error reading streaming response") } } func handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) { body, err := io.ReadAll(resp.Body) if err != nil { log.Error().Err(err).Msg("Failed to read response body") http.Error(w, "Failed to read response", http.StatusInternalServerError) return } cleanedBody := cleanResponse(string(body)) w.WriteHeader(resp.StatusCode) fmt.Fprint(w, cleanedBody) } func cleanResponse(responseBody string) string { var data map[string]interface{} if err := json.Unmarshal([]byte(responseBody), &data); err != nil { log.Error().Err(err).Msg("Failed to parse response JSON") return responseBody } data["object"] = "chat.completion" delete(data, "prompt_filter_results") if choices, ok := data["choices"].([]interface{}); ok { for _, choice := range choices { if choiceMap, ok := choice.(map[string]interface{}); ok { delete(choiceMap, "content_filter_results") } } } cleanedBytes, err := json.Marshal(data) if err != nil { log.Error().Err(err).Msg("Failed to marshal cleaned response") return responseBody } return string(cleanedBytes) } func cleanStreamLine(line string) string { if !strings.HasPrefix(line, "data: ") { return line } dataContent := strings.TrimPrefix(line, "data: ") if dataContent == "[DONE]" { return line } var data map[string]interface{} if err := json.Unmarshal([]byte(dataContent), &data); err != nil { log.Debug().Err(err).Msg("Failed to parse stream line JSON") return line } if choices, ok := data["choices"].([]interface{}); ok { if len(choices) == 0 { return "" } data["object"] = "chat.completion.chunk" for _, choice := range choices { if choiceMap, ok := choice.(map[string]interface{}); ok { delete(choiceMap, "content_filter_offsets") delete(choiceMap, "content_filter_results") if delta, ok := choiceMap["delta"].(map[string]interface{}); ok { for key, value := range delta { if value == nil { delete(delta, key) } } } } } } cleanedBytes, err := json.Marshal(data) if err != nil { log.Error().Err(err).Msg("Failed to marshal cleaned stream data") return line } return "data: " + string(cleanedBytes) } func createProxyHandler(targetURL string, copilotAuth *auth.CopilotAuth) http.HandlerFunc { target, err := url.Parse(targetURL) if err != nil { log.Fatal().Err(err).Str("url", targetURL).Msg("Invalid target URL") } return func(w http.ResponseWriter, r *http.Request) { log.Debug(). Str("method", r.Method). Str("path", r.URL.Path). Str("target", targetURL). Msg("Proxying request") token := copilotAuth.GetToken() if token == "" { http.Error(w, "No Copilot token available", http.StatusServiceUnavailable) return } body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Failed to read request body", http.StatusBadRequest) return } r.Body.Close() req, err := http.NewRequest(r.Method, target.String(), bytes.NewBuffer(body)) if err != nil { http.Error(w, "Failed to create request", http.StatusInternalServerError) return } headers := createRequestHeaders(token, target.Host) for key, value := range headers { req.Header.Set(key, value) } client := &http.Client{} resp, err := client.Do(req) if err != nil { log.Error().Err(err).Msg("Failed to make request") http.Error(w, "Failed to proxy request", http.StatusBadGateway) return } defer resp.Body.Close() copyHeaders(w.Header(), resp.Header) w.Header().Set("Access-Control-Allow-Origin", "*") w.WriteHeader(resp.StatusCode) io.Copy(w, resp.Body) } }