mirror of
https://gitlab.com/TECHNOFAB/go-copilot-proxy.git
synced 2025-12-11 22:10:06 +01:00
375 lines
9.2 KiB
Go
375 lines
9.2 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/fsnotify/fsnotify"
|
|
"github.com/rs/zerolog/log"
|
|
|
|
"gitlab.com/technofab/go-copilot-proxy/internal/config"
|
|
)
|
|
|
|
type GithubToken struct {
|
|
Token string `json:"token"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
}
|
|
|
|
type CopilotAuth struct {
|
|
oauthToken string
|
|
stateDir string
|
|
tokenFile string
|
|
tokenFileLock string
|
|
isSelfWriting atomic.Bool
|
|
mu sync.RWMutex
|
|
githubToken *GithubToken
|
|
}
|
|
|
|
func NewCopilotAuth() (*CopilotAuth, error) {
|
|
stateDir, err := getStatePath("")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := os.MkdirAll(stateDir, 0755); err != nil {
|
|
return nil, fmt.Errorf("could not create state dir: %w", err)
|
|
}
|
|
|
|
return &CopilotAuth{
|
|
stateDir: stateDir,
|
|
tokenFile: filepath.Join(stateDir, "token.json"),
|
|
tokenFileLock: filepath.Join(stateDir, "token.json.lock"),
|
|
}, nil
|
|
}
|
|
|
|
func (ca *CopilotAuth) GetToken() string {
|
|
ca.mu.RLock()
|
|
defer ca.mu.RUnlock()
|
|
if ca.githubToken != nil {
|
|
return ca.githubToken.Token
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (ca *CopilotAuth) Start(ctx context.Context, wg *sync.WaitGroup) error {
|
|
var err error
|
|
ca.oauthToken, err = ca.getOAuthToken()
|
|
if err != nil {
|
|
return fmt.Errorf("could not get oauth token: %w", err)
|
|
}
|
|
|
|
if err := ca.loadTokenFromFile(); err != nil {
|
|
log.Warn().Err(err).Msg("Initial token load failed, will attempt refresh")
|
|
}
|
|
|
|
if !ca.RefreshToken(true) {
|
|
return errors.New("initial token refresh failed")
|
|
}
|
|
|
|
wg.Add(3)
|
|
go ca.refreshTokenLoop(ctx, wg)
|
|
go ca.watchTokenFile(ctx, wg)
|
|
go ca.checkStaleLocks(ctx, wg)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ca *CopilotAuth) RefreshToken(force bool) bool {
|
|
if !force {
|
|
if ca.isTokenValid() {
|
|
log.Debug().Msg("Token still valid, skipping refresh")
|
|
return true
|
|
}
|
|
if err := ca.loadTokenFromFile(); err != nil {
|
|
log.Warn().Err(err).Msg("Could not reload token from file before refresh check")
|
|
}
|
|
if ca.isTokenValid() {
|
|
log.Debug().Msg("Valid token loaded from file, skipping refresh")
|
|
return true
|
|
}
|
|
}
|
|
|
|
if !ca.acquireLock() {
|
|
return ca.waitForTokenRefresh()
|
|
}
|
|
defer ca.releaseLock()
|
|
|
|
req, err := http.NewRequest("GET", config.AuthURL, nil)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to create token request")
|
|
return false
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+ca.oauthToken)
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("User-Agent", config.UserAgent)
|
|
req.Header.Set("Editor-Version", config.EditorVersion)
|
|
req.Header.Set("Editor-Plugin-Version", config.EditorPluginVersion)
|
|
req.Header.Set("Copilot-Integration-Id", config.CopilotIntegrationID)
|
|
|
|
client := &http.Client{Timeout: 10 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("HTTP error during token refresh")
|
|
return false
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
log.Error().Int("status", resp.StatusCode).Str("body", string(body)).Msg("Token refresh failed")
|
|
return false
|
|
}
|
|
|
|
var token GithubToken
|
|
if err := json.NewDecoder(resp.Body).Decode(&token); err != nil {
|
|
log.Error().Err(err).Msg("Failed to decode token response")
|
|
return false
|
|
}
|
|
|
|
ca.mu.Lock()
|
|
ca.githubToken = &token
|
|
ca.mu.Unlock()
|
|
|
|
if err := ca.saveTokenToFile(); err != nil {
|
|
log.Error().Err(err).Msg("Failed to save new token")
|
|
return false
|
|
}
|
|
log.Debug().Msg("Token successfully refreshed")
|
|
return true
|
|
}
|
|
|
|
func getConfigPath(filename string) (string, error) {
|
|
configDir, err := os.UserConfigDir()
|
|
if err != nil {
|
|
return "", fmt.Errorf("could not find user config dir: %w", err)
|
|
}
|
|
return filepath.Join(configDir, "go-copilot-proxy", filename), nil
|
|
}
|
|
|
|
func getStatePath(filename string) (string, error) {
|
|
stateDir := os.Getenv("XDG_STATE_HOME")
|
|
if stateDir == "" {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return "", fmt.Errorf("could not find user home dir: %w", err)
|
|
}
|
|
stateDir = filepath.Join(homeDir, ".local", "state")
|
|
}
|
|
return filepath.Join(stateDir, "go-copilot-proxy", filename), nil
|
|
}
|
|
|
|
func (ca *CopilotAuth) getOAuthToken() (string, error) {
|
|
path, err := getConfigPath("config.json")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
|
return "", errors.New("config.json not found, please run the 'auth' command first")
|
|
}
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to read %s: %w", path, err)
|
|
}
|
|
|
|
var configData map[string]string
|
|
if err := json.Unmarshal(data, &configData); err != nil {
|
|
return "", fmt.Errorf("failed to parse %s: %w", path, err)
|
|
}
|
|
|
|
if token, ok := configData["oauth_token"]; ok && token != "" {
|
|
return token, nil
|
|
}
|
|
return "", errors.New("oauth_token not found in config.json")
|
|
}
|
|
|
|
func (ca *CopilotAuth) acquireLock() bool {
|
|
f, err := os.OpenFile(ca.tokenFileLock, os.O_CREATE|os.O_EXCL, 0644)
|
|
if err != nil {
|
|
if os.IsExist(err) {
|
|
log.Debug().Msg("Lock file already exists")
|
|
} else {
|
|
log.Error().Err(err).Msg("Error acquiring lock")
|
|
}
|
|
return false
|
|
}
|
|
f.Close()
|
|
return true
|
|
}
|
|
|
|
func (ca *CopilotAuth) releaseLock() {
|
|
if err := os.Remove(ca.tokenFileLock); err != nil {
|
|
if !os.IsNotExist(err) {
|
|
log.Error().Err(err).Msg("Error releasing lock file")
|
|
}
|
|
} else {
|
|
log.Debug().Msg("Lock file released successfully")
|
|
}
|
|
}
|
|
|
|
func (ca *CopilotAuth) saveTokenToFile() error {
|
|
ca.mu.RLock()
|
|
tokenData, err := json.Marshal(ca.githubToken)
|
|
ca.mu.RUnlock()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal token: %w", err)
|
|
}
|
|
|
|
ca.isSelfWriting.Store(true)
|
|
defer ca.isSelfWriting.Store(false)
|
|
|
|
tempFile := ca.tokenFile + ".tmp"
|
|
if err := os.WriteFile(tempFile, tokenData, 0644); err != nil {
|
|
return fmt.Errorf("failed to write to temporary token file: %w", err)
|
|
}
|
|
if err := os.Rename(tempFile, ca.tokenFile); err != nil {
|
|
return fmt.Errorf("failed to atomically move token file: %w", err)
|
|
}
|
|
log.Debug().Msg("Token successfully saved to file")
|
|
return nil
|
|
}
|
|
|
|
func (ca *CopilotAuth) loadTokenFromFile() error {
|
|
data, err := os.ReadFile(ca.tokenFile)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("failed to read token file: %w", err)
|
|
}
|
|
|
|
var token GithubToken
|
|
if err := json.Unmarshal(data, &token); err != nil {
|
|
return fmt.Errorf("failed to parse token file: %w", err)
|
|
}
|
|
|
|
ca.mu.Lock()
|
|
ca.githubToken = &token
|
|
ca.mu.Unlock()
|
|
log.Debug().Msg("Token loaded from file")
|
|
return nil
|
|
}
|
|
|
|
func (ca *CopilotAuth) isTokenValid() bool {
|
|
ca.mu.RLock()
|
|
defer ca.mu.RUnlock()
|
|
if ca.githubToken == nil {
|
|
return false
|
|
}
|
|
return ca.githubToken.ExpiresAt > time.Now().Add(config.TokenRefreshBuffer).Unix()
|
|
}
|
|
|
|
func (ca *CopilotAuth) waitForTokenRefresh() bool {
|
|
log.Debug().Msg("Waiting for another process to refresh the token...")
|
|
time.Sleep(5 * time.Second)
|
|
if err := ca.loadTokenFromFile(); err != nil {
|
|
log.Error().Err(err).Msg("Failed to reload token after waiting")
|
|
return false
|
|
}
|
|
return ca.isTokenValid()
|
|
}
|
|
|
|
func (ca *CopilotAuth) refreshTokenLoop(ctx context.Context, wg *sync.WaitGroup) {
|
|
defer wg.Done()
|
|
for {
|
|
var sleepDuration time.Duration
|
|
ca.mu.RLock()
|
|
if ca.githubToken != nil {
|
|
expiryTime := time.Unix(ca.githubToken.ExpiresAt, 0)
|
|
refreshTime := expiryTime.Add(-config.TokenRefreshBuffer)
|
|
sleepDuration = time.Until(refreshTime)
|
|
}
|
|
ca.mu.RUnlock()
|
|
|
|
if sleepDuration <= 0 {
|
|
sleepDuration = config.RetryInterval
|
|
}
|
|
|
|
log.Debug().Dur("duration", sleepDuration).Msg("Scheduling next token refresh")
|
|
select {
|
|
case <-time.After(sleepDuration):
|
|
ca.RefreshToken(false)
|
|
case <-ctx.Done():
|
|
log.Debug().Msg("Refresh token loop shutting down.")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ca *CopilotAuth) watchTokenFile(ctx context.Context, wg *sync.WaitGroup) {
|
|
defer wg.Done()
|
|
watcher, err := fsnotify.NewWatcher()
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to create file watcher")
|
|
return
|
|
}
|
|
defer watcher.Close()
|
|
|
|
if err := watcher.Add(ca.stateDir); err != nil {
|
|
log.Error().Err(err).Str("path", ca.stateDir).Msg("Failed to watch token directory")
|
|
return
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case event, ok := <-watcher.Events:
|
|
if !ok {
|
|
return
|
|
}
|
|
if event.Name == ca.tokenFile && (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) {
|
|
if ca.isSelfWriting.Load() {
|
|
continue
|
|
}
|
|
log.Debug().Msg("Token file changed externally, reloading.")
|
|
if err := ca.loadTokenFromFile(); err != nil {
|
|
log.Error().Err(err).Msg("Failed to reload token from watched file")
|
|
}
|
|
}
|
|
case err, ok := <-watcher.Errors:
|
|
if !ok {
|
|
return
|
|
}
|
|
log.Error().Err(err).Msg("File watcher error")
|
|
case <-ctx.Done():
|
|
log.Debug().Msg("File watcher shutting down.")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ca *CopilotAuth) checkStaleLocks(ctx context.Context, wg *sync.WaitGroup) {
|
|
defer wg.Done()
|
|
ticker := time.NewTicker(config.StaleLockCheckInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
info, err := os.Stat(ca.tokenFileLock)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
continue
|
|
}
|
|
log.Error().Err(err).Msg("Error checking stale lock")
|
|
continue
|
|
}
|
|
if time.Since(info.ModTime()) > config.StaleLockTimeout {
|
|
log.Warn().Msg("Removing stale lock file")
|
|
ca.releaseLock()
|
|
}
|
|
case <-ctx.Done():
|
|
log.Debug().Msg("Stale lock checker shutting down.")
|
|
return
|
|
}
|
|
}
|
|
}
|