package auth import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "os" "path/filepath" "sync" "sync/atomic" "time" "github.com/fsnotify/fsnotify" "github.com/rs/zerolog/log" "gitlab.com/technofab/go-copilot-proxy/internal/config" ) type GithubToken struct { Token string `json:"token"` ExpiresAt int64 `json:"expires_at"` } type CopilotAuth struct { oauthToken string stateDir string tokenFile string tokenFileLock string isSelfWriting atomic.Bool mu sync.RWMutex githubToken *GithubToken } func NewCopilotAuth() (*CopilotAuth, error) { stateDir, err := getStatePath("") if err != nil { return nil, err } if err := os.MkdirAll(stateDir, 0755); err != nil { return nil, fmt.Errorf("could not create state dir: %w", err) } return &CopilotAuth{ stateDir: stateDir, tokenFile: filepath.Join(stateDir, "token.json"), tokenFileLock: filepath.Join(stateDir, "token.json.lock"), }, nil } func (ca *CopilotAuth) GetToken() string { ca.mu.RLock() defer ca.mu.RUnlock() if ca.githubToken != nil { return ca.githubToken.Token } return "" } func (ca *CopilotAuth) Start(ctx context.Context, wg *sync.WaitGroup) error { var err error ca.oauthToken, err = ca.getOAuthToken() if err != nil { return fmt.Errorf("could not get oauth token: %w", err) } if err := ca.loadTokenFromFile(); err != nil { log.Warn().Err(err).Msg("Initial token load failed, will attempt refresh") } if !ca.RefreshToken(true) { return errors.New("initial token refresh failed") } wg.Add(3) go ca.refreshTokenLoop(ctx, wg) go ca.watchTokenFile(ctx, wg) go ca.checkStaleLocks(ctx, wg) return nil } func (ca *CopilotAuth) RefreshToken(force bool) bool { if !force { if ca.isTokenValid() { log.Debug().Msg("Token still valid, skipping refresh") return true } if err := ca.loadTokenFromFile(); err != nil { log.Warn().Err(err).Msg("Could not reload token from file before refresh check") } if ca.isTokenValid() { log.Debug().Msg("Valid token loaded from file, skipping refresh") return true } } if !ca.acquireLock() { return ca.waitForTokenRefresh() } defer ca.releaseLock() req, err := http.NewRequest("GET", config.AuthURL, nil) if err != nil { log.Error().Err(err).Msg("Failed to create token request") return false } req.Header.Set("Authorization", "Bearer "+ca.oauthToken) req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", config.UserAgent) req.Header.Set("Editor-Version", config.EditorVersion) req.Header.Set("Editor-Plugin-Version", config.EditorPluginVersion) req.Header.Set("Copilot-Integration-Id", config.CopilotIntegrationID) client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Do(req) if err != nil { log.Error().Err(err).Msg("HTTP error during token refresh") return false } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) log.Error().Int("status", resp.StatusCode).Str("body", string(body)).Msg("Token refresh failed") return false } var token GithubToken if err := json.NewDecoder(resp.Body).Decode(&token); err != nil { log.Error().Err(err).Msg("Failed to decode token response") return false } ca.mu.Lock() ca.githubToken = &token ca.mu.Unlock() if err := ca.saveTokenToFile(); err != nil { log.Error().Err(err).Msg("Failed to save new token") return false } log.Debug().Msg("Token successfully refreshed") return true } func getConfigPath(filename string) (string, error) { configDir, err := os.UserConfigDir() if err != nil { return "", fmt.Errorf("could not find user config dir: %w", err) } return filepath.Join(configDir, "go-copilot-proxy", filename), nil } func getStatePath(filename string) (string, error) { stateDir := os.Getenv("XDG_STATE_HOME") if stateDir == "" { homeDir, err := os.UserHomeDir() if err != nil { return "", fmt.Errorf("could not find user home dir: %w", err) } stateDir = filepath.Join(homeDir, ".local", "state") } return filepath.Join(stateDir, "go-copilot-proxy", filename), nil } func (ca *CopilotAuth) getOAuthToken() (string, error) { path, err := getConfigPath("config.json") if err != nil { return "", err } if _, err := os.Stat(path); os.IsNotExist(err) { return "", errors.New("config.json not found, please run the 'auth' command first") } data, err := os.ReadFile(path) if err != nil { return "", fmt.Errorf("failed to read %s: %w", path, err) } var configData map[string]string if err := json.Unmarshal(data, &configData); err != nil { return "", fmt.Errorf("failed to parse %s: %w", path, err) } if token, ok := configData["oauth_token"]; ok && token != "" { return token, nil } return "", errors.New("oauth_token not found in config.json") } func (ca *CopilotAuth) acquireLock() bool { f, err := os.OpenFile(ca.tokenFileLock, os.O_CREATE|os.O_EXCL, 0644) if err != nil { if os.IsExist(err) { log.Debug().Msg("Lock file already exists") } else { log.Error().Err(err).Msg("Error acquiring lock") } return false } f.Close() return true } func (ca *CopilotAuth) releaseLock() { if err := os.Remove(ca.tokenFileLock); err != nil { if !os.IsNotExist(err) { log.Error().Err(err).Msg("Error releasing lock file") } } else { log.Debug().Msg("Lock file released successfully") } } func (ca *CopilotAuth) saveTokenToFile() error { ca.mu.RLock() tokenData, err := json.Marshal(ca.githubToken) ca.mu.RUnlock() if err != nil { return fmt.Errorf("failed to marshal token: %w", err) } ca.isSelfWriting.Store(true) defer ca.isSelfWriting.Store(false) tempFile := ca.tokenFile + ".tmp" if err := os.WriteFile(tempFile, tokenData, 0644); err != nil { return fmt.Errorf("failed to write to temporary token file: %w", err) } if err := os.Rename(tempFile, ca.tokenFile); err != nil { return fmt.Errorf("failed to atomically move token file: %w", err) } log.Debug().Msg("Token successfully saved to file") return nil } func (ca *CopilotAuth) loadTokenFromFile() error { data, err := os.ReadFile(ca.tokenFile) if err != nil { if os.IsNotExist(err) { return nil } return fmt.Errorf("failed to read token file: %w", err) } var token GithubToken if err := json.Unmarshal(data, &token); err != nil { return fmt.Errorf("failed to parse token file: %w", err) } ca.mu.Lock() ca.githubToken = &token ca.mu.Unlock() log.Debug().Msg("Token loaded from file") return nil } func (ca *CopilotAuth) isTokenValid() bool { ca.mu.RLock() defer ca.mu.RUnlock() if ca.githubToken == nil { return false } return ca.githubToken.ExpiresAt > time.Now().Add(config.TokenRefreshBuffer).Unix() } func (ca *CopilotAuth) waitForTokenRefresh() bool { log.Debug().Msg("Waiting for another process to refresh the token...") time.Sleep(5 * time.Second) if err := ca.loadTokenFromFile(); err != nil { log.Error().Err(err).Msg("Failed to reload token after waiting") return false } return ca.isTokenValid() } func (ca *CopilotAuth) refreshTokenLoop(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() for { var sleepDuration time.Duration ca.mu.RLock() if ca.githubToken != nil { expiryTime := time.Unix(ca.githubToken.ExpiresAt, 0) refreshTime := expiryTime.Add(-config.TokenRefreshBuffer) sleepDuration = time.Until(refreshTime) } ca.mu.RUnlock() if sleepDuration <= 0 { sleepDuration = config.RetryInterval } log.Debug().Dur("duration", sleepDuration).Msg("Scheduling next token refresh") select { case <-time.After(sleepDuration): ca.RefreshToken(false) case <-ctx.Done(): log.Debug().Msg("Refresh token loop shutting down.") return } } } func (ca *CopilotAuth) watchTokenFile(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() watcher, err := fsnotify.NewWatcher() if err != nil { log.Error().Err(err).Msg("Failed to create file watcher") return } defer watcher.Close() if err := watcher.Add(ca.stateDir); err != nil { log.Error().Err(err).Str("path", ca.stateDir).Msg("Failed to watch token directory") return } for { select { case event, ok := <-watcher.Events: if !ok { return } if event.Name == ca.tokenFile && (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) { if ca.isSelfWriting.Load() { continue } log.Debug().Msg("Token file changed externally, reloading.") if err := ca.loadTokenFromFile(); err != nil { log.Error().Err(err).Msg("Failed to reload token from watched file") } } case err, ok := <-watcher.Errors: if !ok { return } log.Error().Err(err).Msg("File watcher error") case <-ctx.Done(): log.Debug().Msg("File watcher shutting down.") return } } } func (ca *CopilotAuth) checkStaleLocks(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() ticker := time.NewTicker(config.StaleLockCheckInterval) defer ticker.Stop() for { select { case <-ticker.C: info, err := os.Stat(ca.tokenFileLock) if err != nil { if os.IsNotExist(err) { continue } log.Error().Err(err).Msg("Error checking stale lock") continue } if time.Since(info.ModTime()) > config.StaleLockTimeout { log.Warn().Msg("Removing stale lock file") ca.releaseLock() } case <-ctx.Done(): log.Debug().Msg("Stale lock checker shutting down.") return } } }