mirror of
https://gitlab.com/TECHNOFAB/go-copilot-proxy.git
synced 2025-12-11 22:10:06 +01:00
chore: initial commit
This commit is contained in:
commit
595200836c
16 changed files with 1571 additions and 0 deletions
375
internal/auth/copilot.go
Normal file
375
internal/auth/copilot.go
Normal file
|
|
@ -0,0 +1,375 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"gitlab.com/technofab/go-copilot-proxy/internal/config"
|
||||
)
|
||||
|
||||
type GithubToken struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
type CopilotAuth struct {
|
||||
oauthToken string
|
||||
stateDir string
|
||||
tokenFile string
|
||||
tokenFileLock string
|
||||
isSelfWriting atomic.Bool
|
||||
mu sync.RWMutex
|
||||
githubToken *GithubToken
|
||||
}
|
||||
|
||||
func NewCopilotAuth() (*CopilotAuth, error) {
|
||||
stateDir, err := getStatePath("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(stateDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("could not create state dir: %w", err)
|
||||
}
|
||||
|
||||
return &CopilotAuth{
|
||||
stateDir: stateDir,
|
||||
tokenFile: filepath.Join(stateDir, "token.json"),
|
||||
tokenFileLock: filepath.Join(stateDir, "token.json.lock"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) GetToken() string {
|
||||
ca.mu.RLock()
|
||||
defer ca.mu.RUnlock()
|
||||
if ca.githubToken != nil {
|
||||
return ca.githubToken.Token
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) Start(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
var err error
|
||||
ca.oauthToken, err = ca.getOAuthToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get oauth token: %w", err)
|
||||
}
|
||||
|
||||
if err := ca.loadTokenFromFile(); err != nil {
|
||||
log.Warn().Err(err).Msg("Initial token load failed, will attempt refresh")
|
||||
}
|
||||
|
||||
if !ca.RefreshToken(true) {
|
||||
return errors.New("initial token refresh failed")
|
||||
}
|
||||
|
||||
wg.Add(3)
|
||||
go ca.refreshTokenLoop(ctx, wg)
|
||||
go ca.watchTokenFile(ctx, wg)
|
||||
go ca.checkStaleLocks(ctx, wg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) RefreshToken(force bool) bool {
|
||||
if !force {
|
||||
if ca.isTokenValid() {
|
||||
log.Debug().Msg("Token still valid, skipping refresh")
|
||||
return true
|
||||
}
|
||||
if err := ca.loadTokenFromFile(); err != nil {
|
||||
log.Warn().Err(err).Msg("Could not reload token from file before refresh check")
|
||||
}
|
||||
if ca.isTokenValid() {
|
||||
log.Debug().Msg("Valid token loaded from file, skipping refresh")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if !ca.acquireLock() {
|
||||
return ca.waitForTokenRefresh()
|
||||
}
|
||||
defer ca.releaseLock()
|
||||
|
||||
req, err := http.NewRequest("GET", config.AuthURL, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create token request")
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+ca.oauthToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", config.UserAgent)
|
||||
req.Header.Set("Editor-Version", config.EditorVersion)
|
||||
req.Header.Set("Editor-Plugin-Version", config.EditorPluginVersion)
|
||||
req.Header.Set("Copilot-Integration-Id", config.CopilotIntegrationID)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("HTTP error during token refresh")
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Error().Int("status", resp.StatusCode).Str("body", string(body)).Msg("Token refresh failed")
|
||||
return false
|
||||
}
|
||||
|
||||
var token GithubToken
|
||||
if err := json.NewDecoder(resp.Body).Decode(&token); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to decode token response")
|
||||
return false
|
||||
}
|
||||
|
||||
ca.mu.Lock()
|
||||
ca.githubToken = &token
|
||||
ca.mu.Unlock()
|
||||
|
||||
if err := ca.saveTokenToFile(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to save new token")
|
||||
return false
|
||||
}
|
||||
log.Debug().Msg("Token successfully refreshed")
|
||||
return true
|
||||
}
|
||||
|
||||
func getConfigPath(filename string) (string, error) {
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not find user config dir: %w", err)
|
||||
}
|
||||
return filepath.Join(configDir, "go-copilot-proxy", filename), nil
|
||||
}
|
||||
|
||||
func getStatePath(filename string) (string, error) {
|
||||
stateDir := os.Getenv("XDG_STATE_HOME")
|
||||
if stateDir == "" {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not find user home dir: %w", err)
|
||||
}
|
||||
stateDir = filepath.Join(homeDir, ".local", "state")
|
||||
}
|
||||
return filepath.Join(stateDir, "go-copilot-proxy", filename), nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) getOAuthToken() (string, error) {
|
||||
path, err := getConfigPath("config.json")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return "", errors.New("config.json not found, please run the 'auth' command first")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read %s: %w", path, err)
|
||||
}
|
||||
|
||||
var configData map[string]string
|
||||
if err := json.Unmarshal(data, &configData); err != nil {
|
||||
return "", fmt.Errorf("failed to parse %s: %w", path, err)
|
||||
}
|
||||
|
||||
if token, ok := configData["oauth_token"]; ok && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
return "", errors.New("oauth_token not found in config.json")
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) acquireLock() bool {
|
||||
f, err := os.OpenFile(ca.tokenFileLock, os.O_CREATE|os.O_EXCL, 0644)
|
||||
if err != nil {
|
||||
if os.IsExist(err) {
|
||||
log.Debug().Msg("Lock file already exists")
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Error acquiring lock")
|
||||
}
|
||||
return false
|
||||
}
|
||||
f.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) releaseLock() {
|
||||
if err := os.Remove(ca.tokenFileLock); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
log.Error().Err(err).Msg("Error releasing lock file")
|
||||
}
|
||||
} else {
|
||||
log.Debug().Msg("Lock file released successfully")
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) saveTokenToFile() error {
|
||||
ca.mu.RLock()
|
||||
tokenData, err := json.Marshal(ca.githubToken)
|
||||
ca.mu.RUnlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
|
||||
ca.isSelfWriting.Store(true)
|
||||
defer ca.isSelfWriting.Store(false)
|
||||
|
||||
tempFile := ca.tokenFile + ".tmp"
|
||||
if err := os.WriteFile(tempFile, tokenData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write to temporary token file: %w", err)
|
||||
}
|
||||
if err := os.Rename(tempFile, ca.tokenFile); err != nil {
|
||||
return fmt.Errorf("failed to atomically move token file: %w", err)
|
||||
}
|
||||
log.Debug().Msg("Token successfully saved to file")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) loadTokenFromFile() error {
|
||||
data, err := os.ReadFile(ca.tokenFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var token GithubToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
ca.mu.Lock()
|
||||
ca.githubToken = &token
|
||||
ca.mu.Unlock()
|
||||
log.Debug().Msg("Token loaded from file")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) isTokenValid() bool {
|
||||
ca.mu.RLock()
|
||||
defer ca.mu.RUnlock()
|
||||
if ca.githubToken == nil {
|
||||
return false
|
||||
}
|
||||
return ca.githubToken.ExpiresAt > time.Now().Add(config.TokenRefreshBuffer).Unix()
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) waitForTokenRefresh() bool {
|
||||
log.Debug().Msg("Waiting for another process to refresh the token...")
|
||||
time.Sleep(5 * time.Second)
|
||||
if err := ca.loadTokenFromFile(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reload token after waiting")
|
||||
return false
|
||||
}
|
||||
return ca.isTokenValid()
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) refreshTokenLoop(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
var sleepDuration time.Duration
|
||||
ca.mu.RLock()
|
||||
if ca.githubToken != nil {
|
||||
expiryTime := time.Unix(ca.githubToken.ExpiresAt, 0)
|
||||
refreshTime := expiryTime.Add(-config.TokenRefreshBuffer)
|
||||
sleepDuration = time.Until(refreshTime)
|
||||
}
|
||||
ca.mu.RUnlock()
|
||||
|
||||
if sleepDuration <= 0 {
|
||||
sleepDuration = config.RetryInterval
|
||||
}
|
||||
|
||||
log.Debug().Dur("duration", sleepDuration).Msg("Scheduling next token refresh")
|
||||
select {
|
||||
case <-time.After(sleepDuration):
|
||||
ca.RefreshToken(false)
|
||||
case <-ctx.Done():
|
||||
log.Debug().Msg("Refresh token loop shutting down.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) watchTokenFile(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create file watcher")
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
if err := watcher.Add(ca.stateDir); err != nil {
|
||||
log.Error().Err(err).Str("path", ca.stateDir).Msg("Failed to watch token directory")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Name == ca.tokenFile && (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) {
|
||||
if ca.isSelfWriting.Load() {
|
||||
continue
|
||||
}
|
||||
log.Debug().Msg("Token file changed externally, reloading.")
|
||||
if err := ca.loadTokenFromFile(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reload token from watched file")
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Error().Err(err).Msg("File watcher error")
|
||||
case <-ctx.Done():
|
||||
log.Debug().Msg("File watcher shutting down.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *CopilotAuth) checkStaleLocks(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
ticker := time.NewTicker(config.StaleLockCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
info, err := os.Stat(ca.tokenFileLock)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
log.Error().Err(err).Msg("Error checking stale lock")
|
||||
continue
|
||||
}
|
||||
if time.Since(info.ModTime()) > config.StaleLockTimeout {
|
||||
log.Warn().Msg("Removing stale lock file")
|
||||
ca.releaseLock()
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.Debug().Msg("Stale lock checker shutting down.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue