mirror of
https://gitlab.com/TECHNOFAB/go-copilot-proxy.git
synced 2025-12-12 06:20:05 +01:00
chore: initial commit
This commit is contained in:
commit
595200836c
16 changed files with 1571 additions and 0 deletions
375
internal/auth/copilot.go
Normal file
375
internal/auth/copilot.go
Normal file
|
|
@ -0,0 +1,375 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"gitlab.com/technofab/go-copilot-proxy/internal/config"
|
||||
)
|
||||
|
||||
type GithubToken struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
type CopilotAuth struct {
|
||||
oauthToken string
|
||||
stateDir string
|
||||
tokenFile string
|
||||
tokenFileLock string
|
||||
isSelfWriting atomic.Bool
|
||||
mu sync.RWMutex
|
||||
githubToken *GithubToken
|
||||
}
|
||||
|
||||
func NewCopilotAuth() (*CopilotAuth, error) {
|
||||
stateDir, err := getStatePath("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(stateDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("could not create state dir: %w", err)
|
||||
}
|
||||
|
||||
return &CopilotAuth{
|
||||
stateDir: stateDir,
|
||||
tokenFile: filepath.Join(stateDir, "token.json"),
|
||||
tokenFileLock: filepath.Join(stateDir, "token.json.lock"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) GetToken() string {
|
||||
ca.mu.RLock()
|
||||
defer ca.mu.RUnlock()
|
||||
if ca.githubToken != nil {
|
||||
return ca.githubToken.Token
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) Start(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
var err error
|
||||
ca.oauthToken, err = ca.getOAuthToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get oauth token: %w", err)
|
||||
}
|
||||
|
||||
if err := ca.loadTokenFromFile(); err != nil {
|
||||
log.Warn().Err(err).Msg("Initial token load failed, will attempt refresh")
|
||||
}
|
||||
|
||||
if !ca.RefreshToken(true) {
|
||||
return errors.New("initial token refresh failed")
|
||||
}
|
||||
|
||||
wg.Add(3)
|
||||
go ca.refreshTokenLoop(ctx, wg)
|
||||
go ca.watchTokenFile(ctx, wg)
|
||||
go ca.checkStaleLocks(ctx, wg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) RefreshToken(force bool) bool {
|
||||
if !force {
|
||||
if ca.isTokenValid() {
|
||||
log.Debug().Msg("Token still valid, skipping refresh")
|
||||
return true
|
||||
}
|
||||
if err := ca.loadTokenFromFile(); err != nil {
|
||||
log.Warn().Err(err).Msg("Could not reload token from file before refresh check")
|
||||
}
|
||||
if ca.isTokenValid() {
|
||||
log.Debug().Msg("Valid token loaded from file, skipping refresh")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if !ca.acquireLock() {
|
||||
return ca.waitForTokenRefresh()
|
||||
}
|
||||
defer ca.releaseLock()
|
||||
|
||||
req, err := http.NewRequest("GET", config.AuthURL, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create token request")
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+ca.oauthToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", config.UserAgent)
|
||||
req.Header.Set("Editor-Version", config.EditorVersion)
|
||||
req.Header.Set("Editor-Plugin-Version", config.EditorPluginVersion)
|
||||
req.Header.Set("Copilot-Integration-Id", config.CopilotIntegrationID)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("HTTP error during token refresh")
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Error().Int("status", resp.StatusCode).Str("body", string(body)).Msg("Token refresh failed")
|
||||
return false
|
||||
}
|
||||
|
||||
var token GithubToken
|
||||
if err := json.NewDecoder(resp.Body).Decode(&token); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to decode token response")
|
||||
return false
|
||||
}
|
||||
|
||||
ca.mu.Lock()
|
||||
ca.githubToken = &token
|
||||
ca.mu.Unlock()
|
||||
|
||||
if err := ca.saveTokenToFile(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to save new token")
|
||||
return false
|
||||
}
|
||||
log.Debug().Msg("Token successfully refreshed")
|
||||
return true
|
||||
}
|
||||
|
||||
func getConfigPath(filename string) (string, error) {
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not find user config dir: %w", err)
|
||||
}
|
||||
return filepath.Join(configDir, "go-copilot-proxy", filename), nil
|
||||
}
|
||||
|
||||
func getStatePath(filename string) (string, error) {
|
||||
stateDir := os.Getenv("XDG_STATE_HOME")
|
||||
if stateDir == "" {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not find user home dir: %w", err)
|
||||
}
|
||||
stateDir = filepath.Join(homeDir, ".local", "state")
|
||||
}
|
||||
return filepath.Join(stateDir, "go-copilot-proxy", filename), nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) getOAuthToken() (string, error) {
|
||||
path, err := getConfigPath("config.json")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return "", errors.New("config.json not found, please run the 'auth' command first")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read %s: %w", path, err)
|
||||
}
|
||||
|
||||
var configData map[string]string
|
||||
if err := json.Unmarshal(data, &configData); err != nil {
|
||||
return "", fmt.Errorf("failed to parse %s: %w", path, err)
|
||||
}
|
||||
|
||||
if token, ok := configData["oauth_token"]; ok && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
return "", errors.New("oauth_token not found in config.json")
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) acquireLock() bool {
|
||||
f, err := os.OpenFile(ca.tokenFileLock, os.O_CREATE|os.O_EXCL, 0644)
|
||||
if err != nil {
|
||||
if os.IsExist(err) {
|
||||
log.Debug().Msg("Lock file already exists")
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Error acquiring lock")
|
||||
}
|
||||
return false
|
||||
}
|
||||
f.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) releaseLock() {
|
||||
if err := os.Remove(ca.tokenFileLock); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
log.Error().Err(err).Msg("Error releasing lock file")
|
||||
}
|
||||
} else {
|
||||
log.Debug().Msg("Lock file released successfully")
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) saveTokenToFile() error {
|
||||
ca.mu.RLock()
|
||||
tokenData, err := json.Marshal(ca.githubToken)
|
||||
ca.mu.RUnlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
|
||||
ca.isSelfWriting.Store(true)
|
||||
defer ca.isSelfWriting.Store(false)
|
||||
|
||||
tempFile := ca.tokenFile + ".tmp"
|
||||
if err := os.WriteFile(tempFile, tokenData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write to temporary token file: %w", err)
|
||||
}
|
||||
if err := os.Rename(tempFile, ca.tokenFile); err != nil {
|
||||
return fmt.Errorf("failed to atomically move token file: %w", err)
|
||||
}
|
||||
log.Debug().Msg("Token successfully saved to file")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) loadTokenFromFile() error {
|
||||
data, err := os.ReadFile(ca.tokenFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var token GithubToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
ca.mu.Lock()
|
||||
ca.githubToken = &token
|
||||
ca.mu.Unlock()
|
||||
log.Debug().Msg("Token loaded from file")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) isTokenValid() bool {
|
||||
ca.mu.RLock()
|
||||
defer ca.mu.RUnlock()
|
||||
if ca.githubToken == nil {
|
||||
return false
|
||||
}
|
||||
return ca.githubToken.ExpiresAt > time.Now().Add(config.TokenRefreshBuffer).Unix()
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) waitForTokenRefresh() bool {
|
||||
log.Debug().Msg("Waiting for another process to refresh the token...")
|
||||
time.Sleep(5 * time.Second)
|
||||
if err := ca.loadTokenFromFile(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reload token after waiting")
|
||||
return false
|
||||
}
|
||||
return ca.isTokenValid()
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) refreshTokenLoop(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
var sleepDuration time.Duration
|
||||
ca.mu.RLock()
|
||||
if ca.githubToken != nil {
|
||||
expiryTime := time.Unix(ca.githubToken.ExpiresAt, 0)
|
||||
refreshTime := expiryTime.Add(-config.TokenRefreshBuffer)
|
||||
sleepDuration = time.Until(refreshTime)
|
||||
}
|
||||
ca.mu.RUnlock()
|
||||
|
||||
if sleepDuration <= 0 {
|
||||
sleepDuration = config.RetryInterval
|
||||
}
|
||||
|
||||
log.Debug().Dur("duration", sleepDuration).Msg("Scheduling next token refresh")
|
||||
select {
|
||||
case <-time.After(sleepDuration):
|
||||
ca.RefreshToken(false)
|
||||
case <-ctx.Done():
|
||||
log.Debug().Msg("Refresh token loop shutting down.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) watchTokenFile(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create file watcher")
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
if err := watcher.Add(ca.stateDir); err != nil {
|
||||
log.Error().Err(err).Str("path", ca.stateDir).Msg("Failed to watch token directory")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Name == ca.tokenFile && (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) {
|
||||
if ca.isSelfWriting.Load() {
|
||||
continue
|
||||
}
|
||||
log.Debug().Msg("Token file changed externally, reloading.")
|
||||
if err := ca.loadTokenFromFile(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reload token from watched file")
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Error().Err(err).Msg("File watcher error")
|
||||
case <-ctx.Done():
|
||||
log.Debug().Msg("File watcher shutting down.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) checkStaleLocks(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
ticker := time.NewTicker(config.StaleLockCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
info, err := os.Stat(ca.tokenFileLock)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
log.Error().Err(err).Msg("Error checking stale lock")
|
||||
continue
|
||||
}
|
||||
if time.Since(info.ModTime()) > config.StaleLockTimeout {
|
||||
log.Warn().Msg("Removing stale lock file")
|
||||
ca.releaseLock()
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.Debug().Msg("Stale lock checker shutting down.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
142
internal/auth/oauth.go
Normal file
142
internal/auth/oauth.go
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/atotto/clipboard"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"gitlab.com/technofab/go-copilot-proxy/internal/config"
|
||||
)
|
||||
|
||||
type DeviceCodeResponse struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
type AccessTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
}
|
||||
|
||||
func RequestDeviceCode() (*DeviceCodeResponse, error) {
|
||||
payload := strings.NewReader(fmt.Sprintf(`{"client_id":"%s","scope":"%s"}`, config.GHClientID, config.GHScope))
|
||||
req, err := http.NewRequest("POST", config.GHDeviceCodeURL, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("User-Agent", config.UserAgent)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("bad response from GitHub: %s", resp.Status)
|
||||
}
|
||||
|
||||
var data DeviceCodeResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func PollForAccessToken(deviceCodeInfo *DeviceCodeResponse) (string, error) {
|
||||
interval := time.Duration(deviceCodeInfo.Interval) * time.Second
|
||||
if interval == 0 {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
|
||||
log.Info().Msg("Waiting for you to authorize in the browser...")
|
||||
|
||||
for {
|
||||
time.Sleep(interval)
|
||||
fmt.Print(".")
|
||||
|
||||
payload := strings.NewReader(fmt.Sprintf(
|
||||
`{"client_id":"%s","device_code":"%s","grant_type":"urn:ietf:params:oauth:grant-type:device_code"}`,
|
||||
config.GHClientID, deviceCodeInfo.DeviceCode,
|
||||
))
|
||||
req, err := http.NewRequest("POST", config.GHOauthTokenURL, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("User-Agent", config.UserAgent)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var data AccessTokenResponse
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
resp.Body.Close()
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if data.AccessToken != "" {
|
||||
fmt.Println()
|
||||
return data.AccessToken, nil
|
||||
}
|
||||
|
||||
if data.Error == "authorization_pending" {
|
||||
continue
|
||||
}
|
||||
|
||||
if data.Error != "" {
|
||||
return "", fmt.Errorf("authentication failed: %s - %s", data.Error, data.ErrorDescription)
|
||||
}
|
||||
}
|
||||
}
|
||||
func SaveOAuthToken(token string) error {
|
||||
path, err := getConfigPath("config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenData := map[string]string{
|
||||
"oauth_token": token,
|
||||
}
|
||||
|
||||
jsonData, err := json.MarshalIndent(tokenData, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token data: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
return os.WriteFile(path, jsonData, 0644)
|
||||
}
|
||||
|
||||
func PromptUserForAuth(deviceCodeResp *DeviceCodeResponse) {
|
||||
log.Info().Msgf("Please open this URL in your browser: %s", deviceCodeResp.VerificationURI)
|
||||
log.Info().Msgf("And enter this code: %s", deviceCodeResp.UserCode)
|
||||
|
||||
if err := clipboard.WriteAll(deviceCodeResp.UserCode); err == nil {
|
||||
log.Info().Msg("(The code has been copied to your clipboard!)")
|
||||
}
|
||||
}
|
||||
38
internal/cmd/auth.go
Normal file
38
internal/cmd/auth.go
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"gitlab.com/technofab/go-copilot-proxy/internal/auth"
|
||||
)
|
||||
|
||||
var authCmd = &cobra.Command{
|
||||
Use: "auth",
|
||||
Short: "Authenticates with GitHub to get the initial OAuth token",
|
||||
Long: "Initiates the GitHub OAuth device flow to retrieve an OAuth token required for this proxy to work.",
|
||||
Run: runAuth,
|
||||
}
|
||||
|
||||
func runAuth(cmd *cobra.Command, args []string) {
|
||||
log.Info().Msg("Starting GitHub authentication for Copilot...")
|
||||
|
||||
deviceCodeResp, err := auth.RequestDeviceCode()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to request device code")
|
||||
}
|
||||
|
||||
auth.PromptUserForAuth(deviceCodeResp)
|
||||
|
||||
accessToken, err := auth.PollForAccessToken(deviceCodeResp)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to obtain OAuth token")
|
||||
}
|
||||
|
||||
if err := auth.SaveOAuthToken(accessToken); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to save OAuth token")
|
||||
}
|
||||
|
||||
log.Info().Msg("✅ Authentication successful! The OAuth token has been saved.")
|
||||
log.Info().Msg("You can now run the 'serve' command.")
|
||||
}
|
||||
23
internal/cmd/root.go
Normal file
23
internal/cmd/root.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"gitlab.com/technofab/go-copilot-proxy/internal/config"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "go-copilot-proxy",
|
||||
Short: "A Go-based proxy for GitHub Copilot",
|
||||
Long: "go-copilot-proxy provides a local proxy server for GitHub Copilot API requests with automatic token management.",
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func init() {
|
||||
config.InitLogging()
|
||||
rootCmd.AddCommand(serveCmd)
|
||||
rootCmd.AddCommand(authCmd)
|
||||
}
|
||||
86
internal/cmd/serve.go
Normal file
86
internal/cmd/serve.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"gitlab.com/technofab/go-copilot-proxy/internal/auth"
|
||||
"gitlab.com/technofab/go-copilot-proxy/internal/server"
|
||||
)
|
||||
|
||||
var serveCmd = &cobra.Command{
|
||||
Use: "serve",
|
||||
Short: "Starts the proxy server",
|
||||
Long: "Starts the HTTP proxy server that forwards requests to the GitHub Copilot API with automatic token management.",
|
||||
Run: runServe,
|
||||
}
|
||||
|
||||
func init() {
|
||||
serveCmd.Flags().String("host", "127.0.0.1", "Host to bind the server to")
|
||||
serveCmd.Flags().Int("port", 8080, "Port to run the server on")
|
||||
}
|
||||
|
||||
func runServe(cmd *cobra.Command, args []string) {
|
||||
accessToken := os.Getenv("GO_COPILOT_PROXY_TOKEN")
|
||||
if accessToken == "" {
|
||||
log.Fatal().Msg("GO_COPILOT_PROXY_TOKEN environment variable is not set. Please set it to a secure token to protect your proxy.")
|
||||
}
|
||||
log.Debug().Msg("Proxy access token is configured.")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
copilotAuth, err := auth.NewCopilotAuth()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to initialize Copilot Auth")
|
||||
}
|
||||
|
||||
if err := copilotAuth.Start(ctx, &wg); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to start Copilot auth background tasks. Did you run the 'auth' command first?")
|
||||
}
|
||||
|
||||
r := server.NewRouter(accessToken, copilotAuth)
|
||||
|
||||
host, _ := cmd.Flags().GetString("host")
|
||||
port, _ := cmd.Flags().GetInt("port")
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
httpServer := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Info().Str("address", addr).Msg("Starting server...")
|
||||
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatal().Err(err).Msg("Server failed to start")
|
||||
}
|
||||
}()
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Info().Msg("Shutting down server...")
|
||||
|
||||
cancel()
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
log.Error().Err(err).Msg("Server forced to shutdown")
|
||||
}
|
||||
|
||||
log.Debug().Msg("Waiting for background tasks to complete...")
|
||||
wg.Wait()
|
||||
log.Info().Msg("Server exiting.")
|
||||
}
|
||||
34
internal/config/config.go
Normal file
34
internal/config/config.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthURL = "https://api.github.com/copilot_internal/v2/token"
|
||||
StaleLockTimeout = 5 * time.Minute
|
||||
StaleLockCheckInterval = 1 * time.Minute
|
||||
TokenRefreshBuffer = 2 * time.Minute
|
||||
RetryInterval = 1 * time.Minute
|
||||
|
||||
GHDeviceCodeURL = "https://github.com/login/device/code"
|
||||
GHOauthTokenURL = "https://github.com/login/oauth/access_token"
|
||||
GHClientID = "Iv1.b507a08c87ecfe98"
|
||||
GHScope = "read:user"
|
||||
)
|
||||
|
||||
const (
|
||||
UserAgent = "GitHubCopilotChat/0.26.7"
|
||||
EditorVersion = "vscode/1.99.3"
|
||||
EditorPluginVersion = "copilot-chat/0.26.7"
|
||||
CopilotIntegrationID = "vscode-chat"
|
||||
)
|
||||
|
||||
func InitLogging() {
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339})
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
}
|
||||
342
internal/server/server.go
Normal file
342
internal/server/server.go
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"gitlab.com/technofab/go-copilot-proxy/internal/auth"
|
||||
"gitlab.com/technofab/go-copilot-proxy/internal/config"
|
||||
)
|
||||
|
||||
func NewRouter(accessToken string, copilotAuth *auth.CopilotAuth) chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.Heartbeat("/healthz"))
|
||||
r.Use(corsHandler)
|
||||
|
||||
r.Route("/v1", func(r chi.Router) {
|
||||
r.Use(authMiddleware(accessToken))
|
||||
r.Post("/chat/completions", createChatHandler(copilotAuth))
|
||||
r.Post("/embeddings", createProxyHandler("https://api.githubcopilot.com/embeddings", copilotAuth))
|
||||
r.Get("/models", createProxyHandler("https://api.githubcopilot.com/models", copilotAuth))
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func authMiddleware(accessToken string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if token != accessToken {
|
||||
http.Error(w, "Invalid access token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func corsHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "Content-Type, Content-Length, Transfer-Encoding")
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func createChatHandler(copilotAuth *auth.CopilotAuth) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Debug().
|
||||
Str("method", r.Method).
|
||||
Str("path", r.URL.Path).
|
||||
Msg("Handling chat completions request")
|
||||
|
||||
token := copilotAuth.GetToken()
|
||||
if token == "" {
|
||||
http.Error(w, "No Copilot token available", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
r.Body.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", "https://api.githubcopilot.com/chat/completions", bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headers := createRequestHeaders(token, "api.githubcopilot.com")
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to make request to GitHub Copilot API")
|
||||
http.Error(w, "Failed to proxy request", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
copyHeaders(w.Header(), resp.Header)
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
isStreaming := resp.Header.Get("Transfer-Encoding") == "chunked" ||
|
||||
strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
if resp.StatusCode == http.StatusOK && isStreaming {
|
||||
handleStreamingResponse(w, resp, r)
|
||||
} else if resp.StatusCode == http.StatusOK {
|
||||
handleNonStreamingResponse(w, resp, r)
|
||||
} else {
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createRequestHeaders(token, host string) map[string]string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
machineID := hex.EncodeToString(hash[:])
|
||||
|
||||
return map[string]string{
|
||||
"Authorization": "Bearer " + token,
|
||||
"Host": host,
|
||||
"X-Request-Id": uuid.New().String(),
|
||||
"X-Github-Api-Version": "2025-04-01",
|
||||
"Vscode-Sessionid": uuid.New().String() + fmt.Sprintf("%d", time.Now().Unix()),
|
||||
"Vscode-Machineid": machineID,
|
||||
"Editor-Version": config.EditorVersion,
|
||||
"Editor-Plugin-Version": config.EditorPluginVersion,
|
||||
"Openai-Organization": "github-copilot",
|
||||
"Copilot-Integration-Id": config.CopilotIntegrationID,
|
||||
"Openai-Intent": "conversation-panel",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": config.UserAgent,
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
}
|
||||
|
||||
func copyHeaders(dst, src http.Header) {
|
||||
for key, values := range src {
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
log.Error().Msg("Response writer does not support flushing")
|
||||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buffer := ""
|
||||
delimiter := "\n\n"
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
buffer += line + "\n"
|
||||
|
||||
if strings.HasSuffix(buffer, delimiter) {
|
||||
lines := strings.Split(strings.TrimSuffix(buffer, delimiter), delimiter)
|
||||
for _, chunk := range lines {
|
||||
if cleanedChunk := cleanStreamLine(chunk); cleanedChunk != "" {
|
||||
fmt.Fprint(w, cleanedChunk+delimiter)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
buffer = ""
|
||||
}
|
||||
}
|
||||
|
||||
if buffer != "" {
|
||||
if cleanedChunk := cleanStreamLine(strings.TrimSuffix(buffer, "\n")); cleanedChunk != "" {
|
||||
fmt.Fprint(w, cleanedChunk+"\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Error().Err(err).Msg("Error reading streaming response")
|
||||
}
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to read response body")
|
||||
http.Error(w, "Failed to read response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
cleanedBody := cleanResponse(string(body))
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
fmt.Fprint(w, cleanedBody)
|
||||
}
|
||||
|
||||
func cleanResponse(responseBody string) string {
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(responseBody), &data); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to parse response JSON")
|
||||
return responseBody
|
||||
}
|
||||
|
||||
data["object"] = "chat.completion"
|
||||
delete(data, "prompt_filter_results")
|
||||
|
||||
if choices, ok := data["choices"].([]interface{}); ok {
|
||||
for _, choice := range choices {
|
||||
if choiceMap, ok := choice.(map[string]interface{}); ok {
|
||||
delete(choiceMap, "content_filter_results")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanedBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal cleaned response")
|
||||
return responseBody
|
||||
}
|
||||
|
||||
return string(cleanedBytes)
|
||||
}
|
||||
|
||||
func cleanStreamLine(line string) string {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
return line
|
||||
}
|
||||
|
||||
dataContent := strings.TrimPrefix(line, "data: ")
|
||||
if dataContent == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataContent), &data); err != nil {
|
||||
log.Debug().Err(err).Msg("Failed to parse stream line JSON")
|
||||
return line
|
||||
}
|
||||
|
||||
if choices, ok := data["choices"].([]interface{}); ok {
|
||||
if len(choices) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
data["object"] = "chat.completion.chunk"
|
||||
|
||||
for _, choice := range choices {
|
||||
if choiceMap, ok := choice.(map[string]interface{}); ok {
|
||||
delete(choiceMap, "content_filter_offsets")
|
||||
delete(choiceMap, "content_filter_results")
|
||||
|
||||
if delta, ok := choiceMap["delta"].(map[string]interface{}); ok {
|
||||
for key, value := range delta {
|
||||
if value == nil {
|
||||
delete(delta, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanedBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal cleaned stream data")
|
||||
return line
|
||||
}
|
||||
|
||||
return "data: " + string(cleanedBytes)
|
||||
}
|
||||
|
||||
func createProxyHandler(targetURL string, copilotAuth *auth.CopilotAuth) http.HandlerFunc {
|
||||
target, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Str("url", targetURL).Msg("Invalid target URL")
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Debug().
|
||||
Str("method", r.Method).
|
||||
Str("path", r.URL.Path).
|
||||
Str("target", targetURL).
|
||||
Msg("Proxying request")
|
||||
|
||||
token := copilotAuth.GetToken()
|
||||
if token == "" {
|
||||
http.Error(w, "No Copilot token available", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
r.Body.Close()
|
||||
|
||||
req, err := http.NewRequest(r.Method, target.String(), bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headers := createRequestHeaders(token, target.Host)
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to make request")
|
||||
http.Error(w, "Failed to proxy request", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
copyHeaders(w.Header(), resp.Header)
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue