Skip to content

Commit

Permalink
enhance: add functions for daemon servers for mTLS (#87)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville authored Dec 16, 2024
1 parent 2ddfb8e commit 79a6682
Showing 1 changed file with 102 additions and 0 deletions.
102 changes: 102 additions & 0 deletions pkg/daemon/daemon.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package daemon

import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"net/http"
"os"
)

type Server struct {
mux *http.ServeMux
tlsConfig *tls.Config
}

// CreateServer creates a new HTTP server with TLS configured for GPTScript.
// This function should be used when creating a new server for a daemon tool.
// The server should then be started with the StartServer function.
func CreateServer() (*Server, error) {
return CreateServerWithMux(http.DefaultServeMux)
}

// CreateServerWithMux creates a new HTTP server with TLS configured for GPTScript.
// This function should be used when creating a new server for a daemon tool with a custom ServeMux.
// The server should then be started with the StartServer function.
func CreateServerWithMux(mux *http.ServeMux) (*Server, error) {
tlsConfig, err := getTLSConfig()
if err != nil {
return nil, fmt.Errorf("failed to get TLS config: %v", err)
}

return &Server{
mux: mux,
tlsConfig: tlsConfig,
}, nil
}

// Start starts an HTTP server created by the CreateServer function.
// This is for use with daemon tools.
func (s *Server) Start() error {
server := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%s", os.Getenv("PORT")),
TLSConfig: s.tlsConfig,
Handler: s.mux,
}

if err := server.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("stopped serving: %v", err)
}
return nil
}

func (s *Server) HandleFunc(pattern string, handler http.HandlerFunc) {
s.mux.HandleFunc(pattern, handler)
}

func getTLSConfig() (*tls.Config, error) {
certB64 := os.Getenv("CERT")
privateKeyB64 := os.Getenv("PRIVATE_KEY")
gptscriptCertB64 := os.Getenv("GPTSCRIPT_CERT")

if certB64 == "" {
return nil, fmt.Errorf("CERT not set")
} else if privateKeyB64 == "" {
return nil, fmt.Errorf("PRIVATE_KEY not set")
} else if gptscriptCertB64 == "" {
return nil, fmt.Errorf("GPTSCRIPT_CERT not set")
}

certBytes, err := base64.StdEncoding.DecodeString(certB64)
if err != nil {
return nil, fmt.Errorf("failed to decode cert base64: %v", err)
}

privateKeyBytes, err := base64.StdEncoding.DecodeString(privateKeyB64)
if err != nil {
return nil, fmt.Errorf("failed to decode private key base64: %v", err)
}

gptscriptCertBytes, err := base64.StdEncoding.DecodeString(gptscriptCertB64)
if err != nil {
return nil, fmt.Errorf("failed to decode gptscript cert base64: %v", err)
}

cert, err := tls.X509KeyPair(certBytes, privateKeyBytes)
if err != nil {
return nil, fmt.Errorf("failed to create X509 key pair: %v", err)
}

pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(gptscriptCertBytes) {
return nil, fmt.Errorf("failed to append gptscript cert to pool")
}

return &tls.Config{
Certificates: []tls.Certificate{cert},
ClientCAs: pool,
ClientAuth: tls.RequireAndVerifyClientCert,
}, nil
}

0 comments on commit 79a6682

Please sign in to comment.