diff --git a/cmd/rterm/main.go b/cmd/rterm/main.go index 0ad49bb..829a18c 100644 --- a/cmd/rterm/main.go +++ b/cmd/rterm/main.go @@ -2,10 +2,12 @@ package main import ( "log" + "net/http" + _ "net/http/pprof" + "strings" + "github.com/dev6699/rterm" "github.com/dev6699/rterm/command" - "github.com/dev6699/rterm/server" - "github.com/dev6699/rterm/ui" ) func main() { @@ -16,24 +18,43 @@ func main() { } func run() error { - assets, err := ui.Assets() - if err != nil { - return err - } + rterm.SetPrefix("/") + mux := http.NewServeMux() - srv, err := server.New( - assets, - func() (*command.Command, error) { - return command.New("bash", nil) + rterm.Register( + mux, + rterm.Command{ + Factory: func() (*command.Command, error) { + return command.New("bash", nil) + }, + Name: "bash", + Description: "Bash (Unix shell)", + Writable: true, + }, + rterm.Command{ + Factory: func() (*command.Command, error) { + return command.New("htop", nil) + }, + Name: "htop", + Description: "Interactive system monitor process viewer and process manager", + Writable: false, + }, + rterm.Command{ + Factory: func() (*command.Command, error) { + return command.New("nvidia-smi", strings.Split("--query-gpu=utilization.gpu --format=csv -l 1", " ")) + }, + Name: "nvidia-smi", + Description: "Monitors and outputs the GPU utilization percentage every second", + Writable: false, }, ) - if err != nil { - return err - } addr := ":5000" + server := &http.Server{ + Addr: addr, + Handler: mux, + } log.Println("⚠️ CAUTION USE AT YOUR OWN RISK!!! ⚠️") log.Printf("Server listening on http://0.0.0.0%s", addr) - - return srv.Run(addr) + return server.ListenAndServe() } diff --git a/rterm.go b/rterm.go new file mode 100644 index 0000000..55ecb99 --- /dev/null +++ b/rterm.go @@ -0,0 +1,166 @@ +package rterm + +import ( + "bytes" + "fmt" + "html" + "io" + "io/fs" + "log" + "net/http" + "net/url" + "path/filepath" + "sort" + "strings" + + "github.com/dev6699/rterm/server" + "github.com/dev6699/rterm/ui" + "github.com/gorilla/websocket" +) + +var ( + defaultPrefix = "/rterm" + wsUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + + assets fs.FS + registeredCommands []Command +) + +func init() { + var err error + assets, err = ui.Assets() + if err != nil { + log.Fatalf("rterm: failed to load assets; err = %v", err) + } +} + +// SetPrefix to override default url prefix +func SetPrefix(prefix string) { + // Check if the prefix starts with "/" + if !strings.HasPrefix(prefix, "/") { + prefix = "/" + prefix + } + + // Check if the prefix ends with "/" + if strings.HasSuffix(prefix, "/") { + prefix = strings.TrimSuffix(prefix, "/") + } + + defaultPrefix = prefix +} + +// SetWSUpgrader to override default websocket upgrader +func SetWSUpgrader(u websocket.Upgrader) { + wsUpgrader = u +} + +type Command struct { + Factory server.CommandFactory + // Name of the command, will be used as the url to execute the command + Name string + // Description of the command + Description string + // Writable indicate whether server should process inputs from clients. + Writable bool +} + +// Register binds all command handlers to the http mux. +// GET / -> commands listing index page. +// GET /{command} -> command page. +// GET /{command}/ws -> websocket for command inputs handling. +func Register(mux *http.ServeMux, commands ...Command) { + commandsMap := map[string]Command{} + for _, cmd := range commands { + commandsMap[cmd.Name] = cmd + registeredCommands = append(registeredCommands, cmd) + log.Printf("server: command[%s] -> %s", cmd.Name, defaultPrefix+"/"+cmd.Name) + } + + sort.Slice(commands, func(i, j int) bool { + return commands[i].Name < commands[j].Name + }) + + baseURL := "GET " + defaultPrefix + if baseURL == "GET " { + baseURL = "GET /" + } + mux.HandleFunc(baseURL, index) + mux.Handle("GET "+defaultPrefix+"/{command}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ext := filepath.Ext(r.URL.String()) + stripPrefix := r.URL.String() + if ext != "" { + stripPrefix = defaultPrefix + } + http.StripPrefix(stripPrefix, http.FileServer(http.FS(assets))).ServeHTTP(w, r) + })) + mux.HandleFunc("GET "+defaultPrefix+"/{command}/ws", func(w http.ResponseWriter, r *http.Request) { + c := r.PathValue("command") + cmd, ok := commandsMap[c] + if !ok { + http.NotFound(w, r) + return + } + server.HandleWebSocket(&wsUpgrader, cmd.Factory, cmd.Writable)(w, r) + }) +} + +// index responds with an HTML page listing the available commands. +func index(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Content-Type", "text/html; charset=utf-8") + + err := indexTmplExecute(w) + if err != nil { + log.Printf("rterm: failed to serve index; err = %v", err) + } +} + +func indexTmplExecute(w io.Writer) error { + var b bytes.Buffer + fmt.Fprintf(&b, ` + +%s + + + +%s +
+
+Types of commands available: + + +`, defaultPrefix, defaultPrefix) + + for _, command := range registeredCommands { + link := &url.URL{Path: defaultPrefix + "/" + command.Name} + fmt.Fprintf(&b, "\n", link, html.EscapeString(command.Name)) + } + + b.WriteString(`
Command
%s
+
+

+Command Descriptions: +

    +`) + for _, command := range registeredCommands { + fmt.Fprintf(&b, "
  • %s:
    %s
  • \n", html.EscapeString(command.Name), html.EscapeString(command.Description)) + } + b.WriteString(`
+

+ +`) + + _, err := w.Write(b.Bytes()) + return err +} diff --git a/server/server.go b/server/server.go index 456f834..27d4ae6 100644 --- a/server/server.go +++ b/server/server.go @@ -1,7 +1,6 @@ package server import ( - "io/fs" "log" "net/http" @@ -12,50 +11,27 @@ import ( type CommandFactory = func() (*command.Command, error) -type Server struct { - wsUpgrader *websocket.Upgrader - cmdFac CommandFactory -} - -func New(assets fs.FS, cmdFac CommandFactory) (*Server, error) { - s := &Server{ - wsUpgrader: &websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true - }, - }, - cmdFac: cmdFac, - } - - http.Handle("/", http.FileServer(http.FS(assets))) - http.HandleFunc("/ws", s.handleWebSocket) - - return s, nil -} - -func (s *Server) Run(addr string) error { - return http.ListenAndServe(addr, nil) -} - -func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { - conn, err := s.wsUpgrader.Upgrade(w, r, nil) - if err != nil { - log.Printf("server: failed to upgrade websocket; err = %v", err) - return +func HandleWebSocket(wsUpgrader *websocket.Upgrader, cmdFac CommandFactory, writable bool) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + + conn, err := wsUpgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("server: failed to upgrade websocket; err = %v", err) + return + } + defer conn.Close() + + cmd, err := cmdFac() + if err != nil { + log.Printf("server: failed to start command; err = %v", err) + return + } + + t := tty.New(WSController{Conn: conn}, cmd, writable) + err = t.Run(r.Context()) + if err != nil { + log.Printf("server: socket connection closed; err = %v", err) + } } - defer conn.Close() - cmd, err := s.cmdFac() - if err != nil { - log.Printf("server: failed to start command; err = %v", err) - return - } - - t := tty.New(WSController{Conn: conn}, cmd) - err = t.Run(r.Context()) - if err != nil { - log.Printf("server: socket connection closed; err = %v", err) - } } diff --git a/tty/tty.go b/tty/tty.go index 7b6d300..6d0dd12 100644 --- a/tty/tty.go +++ b/tty/tty.go @@ -16,12 +16,12 @@ type TTY struct { writable bool } -func New(controller Controller, agent Agent) *TTY { +func New(controller Controller, agent Agent, writable bool) *TTY { return &TTY{ controller: controller, agent: agent, bufferSize: 1024, - writable: true, + writable: writable, } } diff --git a/ui/src/script.js b/ui/src/script.js index f08006b..82d549f 100644 --- a/ui/src/script.js +++ b/ui/src/script.js @@ -13,7 +13,7 @@ fitAddon.fit(); const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; const wsHost = window.location.hostname; const wsPort = window.location.port ? ':' + window.location.port : ''; -const wsPath = '/ws'; +const wsPath = window.location.pathname + '/ws'; const wsURL = wsProtocol + wsHost + wsPort + wsPath; const socket = new WebSocket(wsURL);