go-copilot-proxy/internal/auth/oauth.go
2025-08-05 11:08:53 +02:00

142 lines
3.6 KiB
Go

package auth
import (
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/atotto/clipboard"
"github.com/rs/zerolog/log"
"gitlab.com/technofab/go-copilot-proxy/internal/config"
)
type DeviceCodeResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
func RequestDeviceCode() (*DeviceCodeResponse, error) {
payload := strings.NewReader(fmt.Sprintf(`{"client_id":"%s","scope":"%s"}`, config.GHClientID, config.GHScope))
req, err := http.NewRequest("POST", config.GHDeviceCodeURL, payload)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("User-Agent", config.UserAgent)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("bad response from GitHub: %s", resp.Status)
}
var data DeviceCodeResponse
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
return &data, nil
}
func PollForAccessToken(deviceCodeInfo *DeviceCodeResponse) (string, error) {
interval := time.Duration(deviceCodeInfo.Interval) * time.Second
if interval == 0 {
interval = 5 * time.Second
}
log.Info().Msg("Waiting for you to authorize in the browser...")
for {
time.Sleep(interval)
fmt.Print(".")
payload := strings.NewReader(fmt.Sprintf(
`{"client_id":"%s","device_code":"%s","grant_type":"urn:ietf:params:oauth:grant-type:device_code"}`,
config.GHClientID, deviceCodeInfo.DeviceCode,
))
req, err := http.NewRequest("POST", config.GHOauthTokenURL, payload)
if err != nil {
return "", err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("User-Agent", config.UserAgent)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
var data AccessTokenResponse
if resp.StatusCode == http.StatusOK {
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
resp.Body.Close()
return "", err
}
}
resp.Body.Close()
if data.AccessToken != "" {
fmt.Println()
return data.AccessToken, nil
}
if data.Error == "authorization_pending" {
continue
}
if data.Error != "" {
return "", fmt.Errorf("authentication failed: %s - %s", data.Error, data.ErrorDescription)
}
}
}
func SaveOAuthToken(token string) error {
path, err := getConfigPath("config.json")
if err != nil {
return err
}
tokenData := map[string]string{
"oauth_token": token,
}
jsonData, err := json.MarshalIndent(tokenData, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal token data: %w", err)
}
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
return os.WriteFile(path, jsonData, 0644)
}
func PromptUserForAuth(deviceCodeResp *DeviceCodeResponse) {
log.Info().Msgf("Please open this URL in your browser: %s", deviceCodeResp.VerificationURI)
log.Info().Msgf("And enter this code: %s", deviceCodeResp.UserCode)
if err := clipboard.WriteAll(deviceCodeResp.UserCode); err == nil {
log.Info().Msg("(The code has been copied to your clipboard!)")
}
}