From 595200836c266f1f36d716dd400ec2aa4ff4ba9e Mon Sep 17 00:00:00 2001 From: technofab Date: Tue, 5 Aug 2025 11:08:53 +0200 Subject: [PATCH] chore: initial commit --- .envrc | 1 + .gitignore | 5 + README.md | 19 ++ cmd/go-copilot-proxy/main.go | 13 ++ flake.lock | 325 ++++++++++++++++++++++++++++++ flake.nix | 91 +++++++++ go.mod | 20 ++ go.sum | 34 ++++ internal/auth/copilot.go | 375 +++++++++++++++++++++++++++++++++++ internal/auth/oauth.go | 142 +++++++++++++ internal/cmd/auth.go | 38 ++++ internal/cmd/root.go | 23 +++ internal/cmd/serve.go | 86 ++++++++ internal/config/config.go | 34 ++++ internal/server/server.go | 342 ++++++++++++++++++++++++++++++++ package.nix | 23 +++ 16 files changed, 1571 insertions(+) create mode 100644 .envrc create mode 100644 .gitignore create mode 100644 README.md create mode 100644 cmd/go-copilot-proxy/main.go create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/auth/copilot.go create mode 100644 internal/auth/oauth.go create mode 100644 internal/cmd/auth.go create mode 100644 internal/cmd/root.go create mode 100644 internal/cmd/serve.go create mode 100644 internal/config/config.go create mode 100644 internal/server/server.go create mode 100644 package.nix diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..f990172 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake . --impure --accept-flake-config diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8076b5a --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.direnv +.devenv +result +.pre-commit-config.yaml +.crush/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..c41bd89 --- /dev/null +++ b/README.md @@ -0,0 +1,19 @@ +# Go Copilot Proxy + +A simple, single binary proxy for GitHub Copilot. + +Based on [copilot-openai-api](https://github.com/yuchanns/copilot-openai-api) +and [openai-github-copilot](https://gitea.com/PublicAffairs/openai-github-copilot). + +## Usage + +Run the proxy server: + +```sh +go-copilot-proxy auth # to log in with oauth +go-copilot-proxy serve +``` + +Use `http://localhost:8080` as the OpenAI endpoint. + +Run `go-copilot-proxy --help` for more. diff --git a/cmd/go-copilot-proxy/main.go b/cmd/go-copilot-proxy/main.go new file mode 100644 index 0000000..0448add --- /dev/null +++ b/cmd/go-copilot-proxy/main.go @@ -0,0 +1,13 @@ +package main + +import ( + "os" + + "gitlab.com/technofab/go-copilot-proxy/internal/cmd" +) + +func main() { + if err := cmd.Execute(); err != nil { + os.Exit(1) + } +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..524cf70 --- /dev/null +++ b/flake.lock @@ -0,0 +1,325 @@ +{ + "nodes": { + "cachix": { + "inputs": { + "devenv": [ + "devenv" + ], + "flake-compat": [ + "devenv" + ], + "git-hooks": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1748883665, + "narHash": "sha256-R0W7uAg+BLoHjMRMQ8+oiSbTq8nkGz5RDpQ+ZfxxP3A=", + "owner": "cachix", + "repo": "cachix", + "rev": "f707778d902af4d62d8dd92c269f8e70de09acbe", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "devenv": { + "inputs": { + "cachix": "cachix", + "flake-compat": "flake-compat", + "git-hooks": "git-hooks", + "nix": "nix", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1754158015, + "narHash": "sha256-B/o0XiDj06Knm7t/9KmLKnkrpI9s5O13qU+SNL/4Gp8=", + "owner": "cachix", + "repo": "devenv", + "rev": "062f3f42de2f6bb7382f88f6dbcbbbaa118a3791", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "devenv", + "nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1733312601, + "narHash": "sha256-4pDvzqnegAfRkPwO3wmwBhVi/Sye1mzps0zHWYnP88c=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "205b12d8b7cd4802fbcb8e8ef6a0f1408781a4f9", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_2": { + "inputs": { + "nixpkgs-lib": "nixpkgs-lib" + }, + "locked": { + "lastModified": 1754091436, + "narHash": "sha256-XKqDMN1/Qj1DKivQvscI4vmHfDfvYR2pfuFOJiCeewM=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "67df8c627c2c39c41dbec76a1f201929929ab0bd", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "git-hooks": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "gitignore": "gitignore", + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1750779888, + "narHash": "sha256-wibppH3g/E2lxU43ZQHC5yA/7kIKLGxVEnsnVK1BtRg=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "16ec914f6fb6f599ce988427d9d94efddf25fe6d", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "devenv", + "git-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nix": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "flake-parts": "flake-parts", + "git-hooks-nix": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-23-11": [ + "devenv" + ], + "nixpkgs-regression": [ + "devenv" + ] + }, + "locked": { + "lastModified": 1752773918, + "narHash": "sha256-dOi/M6yNeuJlj88exI+7k154z+hAhFcuB8tZktiW7rg=", + "owner": "cachix", + "repo": "nix", + "rev": "031c3cf42d2e9391eee373507d8c12e0f9606779", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "devenv-2.30", + "repo": "nix", + "type": "github" + } + }, + "nix-gitlab-ci": { + "locked": { + "dir": "lib", + "lastModified": 1749124633, + "narHash": "sha256-vgYHrbAFRfgNYysW74Eam/S7KruYWMLCHG4U32xgHKY=", + "owner": "technofab", + "repo": "nix-gitlab-ci", + "rev": "f121b10dc9a7417906a886154e3065410a72462d", + "type": "gitlab" + }, + "original": { + "dir": "lib", + "owner": "technofab", + "ref": "2.1.0", + "repo": "nix-gitlab-ci", + "type": "gitlab" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1750441195, + "narHash": "sha256-yke+pm+MdgRb6c0dPt8MgDhv7fcBbdjmv1ZceNTyzKg=", + "owner": "cachix", + "repo": "devenv-nixpkgs", + "rev": "0ceffe312871b443929ff3006960d29b120dc627", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "rolling", + "repo": "devenv-nixpkgs", + "type": "github" + } + }, + "nixpkgs-lib": { + "locked": { + "lastModified": 1753579242, + "narHash": "sha256-zvaMGVn14/Zz8hnp4VWT9xVnhc8vuL3TStRqwk22biA=", + "owner": "nix-community", + "repo": "nixpkgs.lib", + "rev": "0f36c44e01a6129be94e3ade315a5883f0228a6e", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nixpkgs.lib", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1754278406, + "narHash": "sha256-jvIQTMN5EzoOP5RaGztpVese8a3wqy0M/h6tNzycW28=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "6a489c9482ca676ce23c0bcd7f2e1795383325fa", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_3": { + "locked": { + "lastModified": 1747958103, + "narHash": "sha256-qmmFCrfBwSHoWw7cVK4Aj+fns+c54EBP8cGqp/yK410=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "fe51d34885f7b5e3e7b59572796e1bcb427eccb1", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "flake-parts": "flake-parts_2", + "nix-gitlab-ci": "nix-gitlab-ci", + "nixpkgs": "nixpkgs_2", + "systems": "systems", + "treefmt-nix": "treefmt-nix" + } + }, + "systems": { + "locked": { + "lastModified": 1689347949, + "narHash": "sha256-12tWmuL2zgBgZkdoB6qXZsgJEH9LR3oUgpaQq2RbI80=", + "owner": "nix-systems", + "repo": "default-linux", + "rev": "31732fcf5e8fea42e59c2488ad31a0e651500f68", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default-linux", + "type": "github" + } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": "nixpkgs_3" + }, + "locked": { + "lastModified": 1754061284, + "narHash": "sha256-ONcNxdSiPyJ9qavMPJYAXDNBzYobHRxw0WbT38lKbwU=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "58bd4da459f0a39e506847109a2a5cfceb837796", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..2b29ec4 --- /dev/null +++ b/flake.nix @@ -0,0 +1,91 @@ +{ + outputs = { + flake-parts, + systems, + ... + } @ inputs: + flake-parts.lib.mkFlake {inherit inputs;} { + imports = [ + inputs.devenv.flakeModule + inputs.treefmt-nix.flakeModule + inputs.nix-gitlab-ci.flakeModule + ]; + systems = import systems; + flake = {}; + perSystem = { + pkgs, + config, + ... + }: { + treefmt = { + projectRootFile = "flake.nix"; + programs = { + alejandra.enable = true; + mdformat.enable = true; + gofmt.enable = true; + }; + }; + devenv.shells.default = { + containers = pkgs.lib.mkForce {}; + packages = []; + + languages.go = { + enable = true; + enableHardeningWorkaround = true; + }; + + git-hooks.hooks = { + treefmt = { + enable = true; + packageOverrides.treefmt = config.treefmt.build.wrapper; + }; + convco.enable = true; + }; + }; + + ci = { + stages = ["build"]; + jobs = { + "build" = { + stage = "build"; + script = [ + # sh + '' + nix build .#default + '' + ]; + }; + }; + }; + + packages = { + default = pkgs.callPackage ./package.nix {}; + }; + }; + }; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + + # flake & devenv related + flake-parts.url = "github:hercules-ci/flake-parts"; + systems.url = "github:nix-systems/default-linux"; + devenv.url = "github:cachix/devenv"; + treefmt-nix.url = "github:numtide/treefmt-nix"; + nix-gitlab-ci.url = "gitlab:technofab/nix-gitlab-ci/2.1.0?dir=lib"; + }; + + nixConfig = { + extra-substituters = [ + "https://cache.nixos.org/" + "https://nix-community.cachix.org" + "https://devenv.cachix.org" + ]; + + extra-trusted-public-keys = [ + "cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY=" + "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=" + "devenv.cachix.org-1:w1cLUi8dv3hnoSPGAuibQv+f9TZLr6cv/Hm9XgU50cw=" + ]; + }; +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e339d1c --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module gitlab.com/technofab/go-copilot-proxy + +go 1.24.2 + +require ( + github.com/atotto/clipboard v0.1.4 + github.com/fsnotify/fsnotify v1.9.0 + github.com/go-chi/chi/v5 v5.2.2 + github.com/google/uuid v1.6.0 + github.com/rs/zerolog v1.34.0 + github.com/spf13/cobra v1.9.1 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/spf13/pflag v1.0.6 // indirect + golang.org/x/sys v0.13.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8bb634f --- /dev/null +++ b/go.sum @@ -0,0 +1,34 @@ +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= +github.com/go-chi/chi/v5 v5.2.2/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= +github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/auth/copilot.go b/internal/auth/copilot.go new file mode 100644 index 0000000..5779f43 --- /dev/null +++ b/internal/auth/copilot.go @@ -0,0 +1,375 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "sync" + "sync/atomic" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog/log" + + "gitlab.com/technofab/go-copilot-proxy/internal/config" +) + +type GithubToken struct { + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` +} + +type CopilotAuth struct { + oauthToken string + stateDir string + tokenFile string + tokenFileLock string + isSelfWriting atomic.Bool + mu sync.RWMutex + githubToken *GithubToken +} + +func NewCopilotAuth() (*CopilotAuth, error) { + stateDir, err := getStatePath("") + if err != nil { + return nil, err + } + + if err := os.MkdirAll(stateDir, 0755); err != nil { + return nil, fmt.Errorf("could not create state dir: %w", err) + } + + return &CopilotAuth{ + stateDir: stateDir, + tokenFile: filepath.Join(stateDir, "token.json"), + tokenFileLock: filepath.Join(stateDir, "token.json.lock"), + }, nil +} + +func (ca *CopilotAuth) GetToken() string { + ca.mu.RLock() + defer ca.mu.RUnlock() + if ca.githubToken != nil { + return ca.githubToken.Token + } + return "" +} + +func (ca *CopilotAuth) Start(ctx context.Context, wg *sync.WaitGroup) error { + var err error + ca.oauthToken, err = ca.getOAuthToken() + if err != nil { + return fmt.Errorf("could not get oauth token: %w", err) + } + + if err := ca.loadTokenFromFile(); err != nil { + log.Warn().Err(err).Msg("Initial token load failed, will attempt refresh") + } + + if !ca.RefreshToken(true) { + return errors.New("initial token refresh failed") + } + + wg.Add(3) + go ca.refreshTokenLoop(ctx, wg) + go ca.watchTokenFile(ctx, wg) + go ca.checkStaleLocks(ctx, wg) + + return nil +} + +func (ca *CopilotAuth) RefreshToken(force bool) bool { + if !force { + if ca.isTokenValid() { + log.Debug().Msg("Token still valid, skipping refresh") + return true + } + if err := ca.loadTokenFromFile(); err != nil { + log.Warn().Err(err).Msg("Could not reload token from file before refresh check") + } + if ca.isTokenValid() { + log.Debug().Msg("Valid token loaded from file, skipping refresh") + return true + } + } + + if !ca.acquireLock() { + return ca.waitForTokenRefresh() + } + defer ca.releaseLock() + + req, err := http.NewRequest("GET", config.AuthURL, nil) + if err != nil { + log.Error().Err(err).Msg("Failed to create token request") + return false + } + + req.Header.Set("Authorization", "Bearer "+ca.oauthToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", config.UserAgent) + req.Header.Set("Editor-Version", config.EditorVersion) + req.Header.Set("Editor-Plugin-Version", config.EditorPluginVersion) + req.Header.Set("Copilot-Integration-Id", config.CopilotIntegrationID) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + log.Error().Err(err).Msg("HTTP error during token refresh") + return false + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + log.Error().Int("status", resp.StatusCode).Str("body", string(body)).Msg("Token refresh failed") + return false + } + + var token GithubToken + if err := json.NewDecoder(resp.Body).Decode(&token); err != nil { + log.Error().Err(err).Msg("Failed to decode token response") + return false + } + + ca.mu.Lock() + ca.githubToken = &token + ca.mu.Unlock() + + if err := ca.saveTokenToFile(); err != nil { + log.Error().Err(err).Msg("Failed to save new token") + return false + } + log.Debug().Msg("Token successfully refreshed") + return true +} + +func getConfigPath(filename string) (string, error) { + configDir, err := os.UserConfigDir() + if err != nil { + return "", fmt.Errorf("could not find user config dir: %w", err) + } + return filepath.Join(configDir, "go-copilot-proxy", filename), nil +} + +func getStatePath(filename string) (string, error) { + stateDir := os.Getenv("XDG_STATE_HOME") + if stateDir == "" { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("could not find user home dir: %w", err) + } + stateDir = filepath.Join(homeDir, ".local", "state") + } + return filepath.Join(stateDir, "go-copilot-proxy", filename), nil +} + +func (ca *CopilotAuth) getOAuthToken() (string, error) { + path, err := getConfigPath("config.json") + if err != nil { + return "", err + } + if _, err := os.Stat(path); os.IsNotExist(err) { + return "", errors.New("config.json not found, please run the 'auth' command first") + } + + data, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read %s: %w", path, err) + } + + var configData map[string]string + if err := json.Unmarshal(data, &configData); err != nil { + return "", fmt.Errorf("failed to parse %s: %w", path, err) + } + + if token, ok := configData["oauth_token"]; ok && token != "" { + return token, nil + } + return "", errors.New("oauth_token not found in config.json") +} + +func (ca *CopilotAuth) acquireLock() bool { + f, err := os.OpenFile(ca.tokenFileLock, os.O_CREATE|os.O_EXCL, 0644) + if err != nil { + if os.IsExist(err) { + log.Debug().Msg("Lock file already exists") + } else { + log.Error().Err(err).Msg("Error acquiring lock") + } + return false + } + f.Close() + return true +} + +func (ca *CopilotAuth) releaseLock() { + if err := os.Remove(ca.tokenFileLock); err != nil { + if !os.IsNotExist(err) { + log.Error().Err(err).Msg("Error releasing lock file") + } + } else { + log.Debug().Msg("Lock file released successfully") + } +} + +func (ca *CopilotAuth) saveTokenToFile() error { + ca.mu.RLock() + tokenData, err := json.Marshal(ca.githubToken) + ca.mu.RUnlock() + if err != nil { + return fmt.Errorf("failed to marshal token: %w", err) + } + + ca.isSelfWriting.Store(true) + defer ca.isSelfWriting.Store(false) + + tempFile := ca.tokenFile + ".tmp" + if err := os.WriteFile(tempFile, tokenData, 0644); err != nil { + return fmt.Errorf("failed to write to temporary token file: %w", err) + } + if err := os.Rename(tempFile, ca.tokenFile); err != nil { + return fmt.Errorf("failed to atomically move token file: %w", err) + } + log.Debug().Msg("Token successfully saved to file") + return nil +} + +func (ca *CopilotAuth) loadTokenFromFile() error { + data, err := os.ReadFile(ca.tokenFile) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to read token file: %w", err) + } + + var token GithubToken + if err := json.Unmarshal(data, &token); err != nil { + return fmt.Errorf("failed to parse token file: %w", err) + } + + ca.mu.Lock() + ca.githubToken = &token + ca.mu.Unlock() + log.Debug().Msg("Token loaded from file") + return nil +} + +func (ca *CopilotAuth) isTokenValid() bool { + ca.mu.RLock() + defer ca.mu.RUnlock() + if ca.githubToken == nil { + return false + } + return ca.githubToken.ExpiresAt > time.Now().Add(config.TokenRefreshBuffer).Unix() +} + +func (ca *CopilotAuth) waitForTokenRefresh() bool { + log.Debug().Msg("Waiting for another process to refresh the token...") + time.Sleep(5 * time.Second) + if err := ca.loadTokenFromFile(); err != nil { + log.Error().Err(err).Msg("Failed to reload token after waiting") + return false + } + return ca.isTokenValid() +} + +func (ca *CopilotAuth) refreshTokenLoop(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + for { + var sleepDuration time.Duration + ca.mu.RLock() + if ca.githubToken != nil { + expiryTime := time.Unix(ca.githubToken.ExpiresAt, 0) + refreshTime := expiryTime.Add(-config.TokenRefreshBuffer) + sleepDuration = time.Until(refreshTime) + } + ca.mu.RUnlock() + + if sleepDuration <= 0 { + sleepDuration = config.RetryInterval + } + + log.Debug().Dur("duration", sleepDuration).Msg("Scheduling next token refresh") + select { + case <-time.After(sleepDuration): + ca.RefreshToken(false) + case <-ctx.Done(): + log.Debug().Msg("Refresh token loop shutting down.") + return + } + } +} + +func (ca *CopilotAuth) watchTokenFile(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Error().Err(err).Msg("Failed to create file watcher") + return + } + defer watcher.Close() + + if err := watcher.Add(ca.stateDir); err != nil { + log.Error().Err(err).Str("path", ca.stateDir).Msg("Failed to watch token directory") + return + } + + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Name == ca.tokenFile && (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) { + if ca.isSelfWriting.Load() { + continue + } + log.Debug().Msg("Token file changed externally, reloading.") + if err := ca.loadTokenFromFile(); err != nil { + log.Error().Err(err).Msg("Failed to reload token from watched file") + } + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Error().Err(err).Msg("File watcher error") + case <-ctx.Done(): + log.Debug().Msg("File watcher shutting down.") + return + } + } +} + +func (ca *CopilotAuth) checkStaleLocks(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + ticker := time.NewTicker(config.StaleLockCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + info, err := os.Stat(ca.tokenFileLock) + if err != nil { + if os.IsNotExist(err) { + continue + } + log.Error().Err(err).Msg("Error checking stale lock") + continue + } + if time.Since(info.ModTime()) > config.StaleLockTimeout { + log.Warn().Msg("Removing stale lock file") + ca.releaseLock() + } + case <-ctx.Done(): + log.Debug().Msg("Stale lock checker shutting down.") + return + } + } +} diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go new file mode 100644 index 0000000..bca4e80 --- /dev/null +++ b/internal/auth/oauth.go @@ -0,0 +1,142 @@ +package auth + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/atotto/clipboard" + "github.com/rs/zerolog/log" + + "gitlab.com/technofab/go-copilot-proxy/internal/config" +) + +type DeviceCodeResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +type AccessTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +func RequestDeviceCode() (*DeviceCodeResponse, error) { + payload := strings.NewReader(fmt.Sprintf(`{"client_id":"%s","scope":"%s"}`, config.GHClientID, config.GHScope)) + req, err := http.NewRequest("POST", config.GHDeviceCodeURL, payload) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("User-Agent", config.UserAgent) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad response from GitHub: %s", resp.Status) + } + + var data DeviceCodeResponse + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return nil, err + } + return &data, nil +} + +func PollForAccessToken(deviceCodeInfo *DeviceCodeResponse) (string, error) { + interval := time.Duration(deviceCodeInfo.Interval) * time.Second + if interval == 0 { + interval = 5 * time.Second + } + + log.Info().Msg("Waiting for you to authorize in the browser...") + + for { + time.Sleep(interval) + fmt.Print(".") + + payload := strings.NewReader(fmt.Sprintf( + `{"client_id":"%s","device_code":"%s","grant_type":"urn:ietf:params:oauth:grant-type:device_code"}`, + config.GHClientID, deviceCodeInfo.DeviceCode, + )) + req, err := http.NewRequest("POST", config.GHOauthTokenURL, payload) + if err != nil { + return "", err + } + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("User-Agent", config.UserAgent) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + + var data AccessTokenResponse + if resp.StatusCode == http.StatusOK { + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + resp.Body.Close() + return "", err + } + } + resp.Body.Close() + + if data.AccessToken != "" { + fmt.Println() + return data.AccessToken, nil + } + + if data.Error == "authorization_pending" { + continue + } + + if data.Error != "" { + return "", fmt.Errorf("authentication failed: %s - %s", data.Error, data.ErrorDescription) + } + } +} +func SaveOAuthToken(token string) error { + path, err := getConfigPath("config.json") + if err != nil { + return err + } + + tokenData := map[string]string{ + "oauth_token": token, + } + + jsonData, err := json.MarshalIndent(tokenData, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal token data: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + return os.WriteFile(path, jsonData, 0644) +} + +func PromptUserForAuth(deviceCodeResp *DeviceCodeResponse) { + log.Info().Msgf("Please open this URL in your browser: %s", deviceCodeResp.VerificationURI) + log.Info().Msgf("And enter this code: %s", deviceCodeResp.UserCode) + + if err := clipboard.WriteAll(deviceCodeResp.UserCode); err == nil { + log.Info().Msg("(The code has been copied to your clipboard!)") + } +} diff --git a/internal/cmd/auth.go b/internal/cmd/auth.go new file mode 100644 index 0000000..69c2be5 --- /dev/null +++ b/internal/cmd/auth.go @@ -0,0 +1,38 @@ +package cmd + +import ( + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + + "gitlab.com/technofab/go-copilot-proxy/internal/auth" +) + +var authCmd = &cobra.Command{ + Use: "auth", + Short: "Authenticates with GitHub to get the initial OAuth token", + Long: "Initiates the GitHub OAuth device flow to retrieve an OAuth token required for this proxy to work.", + Run: runAuth, +} + +func runAuth(cmd *cobra.Command, args []string) { + log.Info().Msg("Starting GitHub authentication for Copilot...") + + deviceCodeResp, err := auth.RequestDeviceCode() + if err != nil { + log.Fatal().Err(err).Msg("Failed to request device code") + } + + auth.PromptUserForAuth(deviceCodeResp) + + accessToken, err := auth.PollForAccessToken(deviceCodeResp) + if err != nil { + log.Fatal().Err(err).Msg("Failed to obtain OAuth token") + } + + if err := auth.SaveOAuthToken(accessToken); err != nil { + log.Fatal().Err(err).Msg("Failed to save OAuth token") + } + + log.Info().Msg("✅ Authentication successful! The OAuth token has been saved.") + log.Info().Msg("You can now run the 'serve' command.") +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go new file mode 100644 index 0000000..0949423 --- /dev/null +++ b/internal/cmd/root.go @@ -0,0 +1,23 @@ +package cmd + +import ( + "github.com/spf13/cobra" + + "gitlab.com/technofab/go-copilot-proxy/internal/config" +) + +var rootCmd = &cobra.Command{ + Use: "go-copilot-proxy", + Short: "A Go-based proxy for GitHub Copilot", + Long: "go-copilot-proxy provides a local proxy server for GitHub Copilot API requests with automatic token management.", +} + +func Execute() error { + return rootCmd.Execute() +} + +func init() { + config.InitLogging() + rootCmd.AddCommand(serveCmd) + rootCmd.AddCommand(authCmd) +} diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go new file mode 100644 index 0000000..d9f6fd7 --- /dev/null +++ b/internal/cmd/serve.go @@ -0,0 +1,86 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + + "gitlab.com/technofab/go-copilot-proxy/internal/auth" + "gitlab.com/technofab/go-copilot-proxy/internal/server" +) + +var serveCmd = &cobra.Command{ + Use: "serve", + Short: "Starts the proxy server", + Long: "Starts the HTTP proxy server that forwards requests to the GitHub Copilot API with automatic token management.", + Run: runServe, +} + +func init() { + serveCmd.Flags().String("host", "127.0.0.1", "Host to bind the server to") + serveCmd.Flags().Int("port", 8080, "Port to run the server on") +} + +func runServe(cmd *cobra.Command, args []string) { + accessToken := os.Getenv("GO_COPILOT_PROXY_TOKEN") + if accessToken == "" { + log.Fatal().Msg("GO_COPILOT_PROXY_TOKEN environment variable is not set. Please set it to a secure token to protect your proxy.") + } + log.Debug().Msg("Proxy access token is configured.") + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + + copilotAuth, err := auth.NewCopilotAuth() + if err != nil { + log.Fatal().Err(err).Msg("Failed to initialize Copilot Auth") + } + + if err := copilotAuth.Start(ctx, &wg); err != nil { + log.Fatal().Err(err).Msg("Failed to start Copilot auth background tasks. Did you run the 'auth' command first?") + } + + r := server.NewRouter(accessToken, copilotAuth) + + host, _ := cmd.Flags().GetString("host") + port, _ := cmd.Flags().GetInt("port") + addr := fmt.Sprintf("%s:%d", host, port) + httpServer := &http.Server{ + Addr: addr, + Handler: r, + } + + go func() { + log.Info().Str("address", addr).Msg("Starting server...") + if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatal().Err(err).Msg("Server failed to start") + } + }() + + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Info().Msg("Shutting down server...") + + cancel() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer shutdownCancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + log.Error().Err(err).Msg("Server forced to shutdown") + } + + log.Debug().Msg("Waiting for background tasks to complete...") + wg.Wait() + log.Info().Msg("Server exiting.") +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..0b2a396 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,34 @@ +package config + +import ( + "os" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +const ( + AuthURL = "https://api.github.com/copilot_internal/v2/token" + StaleLockTimeout = 5 * time.Minute + StaleLockCheckInterval = 1 * time.Minute + TokenRefreshBuffer = 2 * time.Minute + RetryInterval = 1 * time.Minute + + GHDeviceCodeURL = "https://github.com/login/device/code" + GHOauthTokenURL = "https://github.com/login/oauth/access_token" + GHClientID = "Iv1.b507a08c87ecfe98" + GHScope = "read:user" +) + +const ( + UserAgent = "GitHubCopilotChat/0.26.7" + EditorVersion = "vscode/1.99.3" + EditorPluginVersion = "copilot-chat/0.26.7" + CopilotIntegrationID = "vscode-chat" +) + +func InitLogging() { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}) + zerolog.SetGlobalLevel(zerolog.InfoLevel) +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..539c192 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,342 @@ +package server + +import ( + "bufio" + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + + "gitlab.com/technofab/go-copilot-proxy/internal/auth" + "gitlab.com/technofab/go-copilot-proxy/internal/config" +) + +func NewRouter(accessToken string, copilotAuth *auth.CopilotAuth) chi.Router { + r := chi.NewRouter() + r.Use(middleware.Recoverer) + r.Use(middleware.RealIP) + r.Use(middleware.RequestID) + r.Use(middleware.Heartbeat("/healthz")) + r.Use(corsHandler) + + r.Route("/v1", func(r chi.Router) { + r.Use(authMiddleware(accessToken)) + r.Post("/chat/completions", createChatHandler(copilotAuth)) + r.Post("/embeddings", createProxyHandler("https://api.githubcopilot.com/embeddings", copilotAuth)) + r.Get("/models", createProxyHandler("https://api.githubcopilot.com/models", copilotAuth)) + }) + + return r +} + +func authMiddleware(accessToken string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + token := strings.TrimPrefix(authHeader, "Bearer ") + if token != accessToken { + http.Error(w, "Invalid access token", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) + } +} + +func corsHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") + w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") + w.Header().Set("Access-Control-Expose-Headers", "Content-Type, Content-Length, Transfer-Encoding") + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + next.ServeHTTP(w, r) + }) +} + +func createChatHandler(copilotAuth *auth.CopilotAuth) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Msg("Handling chat completions request") + + token := copilotAuth.GetToken() + if token == "" { + http.Error(w, "No Copilot token available", http.StatusServiceUnavailable) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + r.Body.Close() + + req, err := http.NewRequest("POST", "https://api.githubcopilot.com/chat/completions", bytes.NewBuffer(body)) + if err != nil { + http.Error(w, "Failed to create request", http.StatusInternalServerError) + return + } + + headers := createRequestHeaders(token, "api.githubcopilot.com") + for key, value := range headers { + req.Header.Set(key, value) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + log.Error().Err(err).Msg("Failed to make request to GitHub Copilot API") + http.Error(w, "Failed to proxy request", http.StatusBadGateway) + return + } + defer resp.Body.Close() + + copyHeaders(w.Header(), resp.Header) + w.Header().Set("Access-Control-Allow-Origin", "*") + + isStreaming := resp.Header.Get("Transfer-Encoding") == "chunked" || + strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") + + if resp.StatusCode == http.StatusOK && isStreaming { + handleStreamingResponse(w, resp, r) + } else if resp.StatusCode == http.StatusOK { + handleNonStreamingResponse(w, resp, r) + } else { + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) + } + } +} + +func createRequestHeaders(token, host string) map[string]string { + hash := sha256.Sum256([]byte(token)) + machineID := hex.EncodeToString(hash[:]) + + return map[string]string{ + "Authorization": "Bearer " + token, + "Host": host, + "X-Request-Id": uuid.New().String(), + "X-Github-Api-Version": "2025-04-01", + "Vscode-Sessionid": uuid.New().String() + fmt.Sprintf("%d", time.Now().Unix()), + "Vscode-Machineid": machineID, + "Editor-Version": config.EditorVersion, + "Editor-Plugin-Version": config.EditorPluginVersion, + "Openai-Organization": "github-copilot", + "Copilot-Integration-Id": config.CopilotIntegrationID, + "Openai-Intent": "conversation-panel", + "Content-Type": "application/json", + "User-Agent": config.UserAgent, + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", + } +} + +func copyHeaders(dst, src http.Header) { + for key, values := range src { + for _, value := range values { + dst.Add(key, value) + } + } +} + +func handleStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + log.Error().Msg("Response writer does not support flushing") + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + scanner := bufio.NewScanner(resp.Body) + buffer := "" + delimiter := "\n\n" + + for scanner.Scan() { + line := scanner.Text() + buffer += line + "\n" + + if strings.HasSuffix(buffer, delimiter) { + lines := strings.Split(strings.TrimSuffix(buffer, delimiter), delimiter) + for _, chunk := range lines { + if cleanedChunk := cleanStreamLine(chunk); cleanedChunk != "" { + fmt.Fprint(w, cleanedChunk+delimiter) + flusher.Flush() + } + } + buffer = "" + } + } + + if buffer != "" { + if cleanedChunk := cleanStreamLine(strings.TrimSuffix(buffer, "\n")); cleanedChunk != "" { + fmt.Fprint(w, cleanedChunk+"\n") + flusher.Flush() + } + } + + if err := scanner.Err(); err != nil { + log.Error().Err(err).Msg("Error reading streaming response") + } +} + +func handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, r *http.Request) { + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Error().Err(err).Msg("Failed to read response body") + http.Error(w, "Failed to read response", http.StatusInternalServerError) + return + } + + cleanedBody := cleanResponse(string(body)) + w.WriteHeader(resp.StatusCode) + fmt.Fprint(w, cleanedBody) +} + +func cleanResponse(responseBody string) string { + var data map[string]interface{} + if err := json.Unmarshal([]byte(responseBody), &data); err != nil { + log.Error().Err(err).Msg("Failed to parse response JSON") + return responseBody + } + + data["object"] = "chat.completion" + delete(data, "prompt_filter_results") + + if choices, ok := data["choices"].([]interface{}); ok { + for _, choice := range choices { + if choiceMap, ok := choice.(map[string]interface{}); ok { + delete(choiceMap, "content_filter_results") + } + } + } + + cleanedBytes, err := json.Marshal(data) + if err != nil { + log.Error().Err(err).Msg("Failed to marshal cleaned response") + return responseBody + } + + return string(cleanedBytes) +} + +func cleanStreamLine(line string) string { + if !strings.HasPrefix(line, "data: ") { + return line + } + + dataContent := strings.TrimPrefix(line, "data: ") + if dataContent == "[DONE]" { + return line + } + + var data map[string]interface{} + if err := json.Unmarshal([]byte(dataContent), &data); err != nil { + log.Debug().Err(err).Msg("Failed to parse stream line JSON") + return line + } + + if choices, ok := data["choices"].([]interface{}); ok { + if len(choices) == 0 { + return "" + } + + data["object"] = "chat.completion.chunk" + + for _, choice := range choices { + if choiceMap, ok := choice.(map[string]interface{}); ok { + delete(choiceMap, "content_filter_offsets") + delete(choiceMap, "content_filter_results") + + if delta, ok := choiceMap["delta"].(map[string]interface{}); ok { + for key, value := range delta { + if value == nil { + delete(delta, key) + } + } + } + } + } + } + + cleanedBytes, err := json.Marshal(data) + if err != nil { + log.Error().Err(err).Msg("Failed to marshal cleaned stream data") + return line + } + + return "data: " + string(cleanedBytes) +} + +func createProxyHandler(targetURL string, copilotAuth *auth.CopilotAuth) http.HandlerFunc { + target, err := url.Parse(targetURL) + if err != nil { + log.Fatal().Err(err).Str("url", targetURL).Msg("Invalid target URL") + } + + return func(w http.ResponseWriter, r *http.Request) { + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("target", targetURL). + Msg("Proxying request") + + token := copilotAuth.GetToken() + if token == "" { + http.Error(w, "No Copilot token available", http.StatusServiceUnavailable) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + r.Body.Close() + + req, err := http.NewRequest(r.Method, target.String(), bytes.NewBuffer(body)) + if err != nil { + http.Error(w, "Failed to create request", http.StatusInternalServerError) + return + } + + headers := createRequestHeaders(token, target.Host) + for key, value := range headers { + req.Header.Set(key, value) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + log.Error().Err(err).Msg("Failed to make request") + http.Error(w, "Failed to proxy request", http.StatusBadGateway) + return + } + defer resp.Body.Close() + + copyHeaders(w.Header(), resp.Header) + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) + } +} diff --git a/package.nix b/package.nix new file mode 100644 index 0000000..1c68bd1 --- /dev/null +++ b/package.nix @@ -0,0 +1,23 @@ +{ + lib, + buildGoModule, + ... +}: +buildGoModule { + name = "go-copilot-proxy"; + src = + # filter everything except for cmd/, internal/ and go.mod, go.sum + with lib.fileset; + toSource { + root = ./.; + fileset = unions [ + ./cmd + ./internal + ./go.mod + ./go.sum + ]; + }; + subPackages = ["cmd/go-copilot-proxy"]; + vendorHash = "sha256-/+6NnofnE3IrtUbHJrgDE2VFK6Gj40rodtE/42LvPMk="; + meta.mainProgram = "go-copilot-proxy"; +}