diff --git a/README.md b/README.md index 0b89e65ff..9b4e352c8 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Status"> - provides **end-to-end encryption** (using PAKE) - enables easy **cross-platform** transfers (Windows, Linux, Mac) - allows **multiple file** transfers +- allows **multiple sequential** transfers - allows **resuming transfers** that are interrupted - local server or port-forwarding **not needed** - **ipv6-first** with ipv4 fallback @@ -134,6 +135,33 @@ The code phrase is used to establish password-authenticated key agreement ([PAKE There are a number of configurable options (see `--help`). A set of options (like custom relay, ports, and code phrase) can be set using `--remember`. +### Transfer on LAN only + +You can transfer files using only local connections. + +``` +croc --local send [file(s)-or-folder] +``` + +### Allow multiple sequential transfers + +By default, after a transfer is done, the program stops. +You can allow more than one transfer to happen (one after another) by using the `--multiple` flag, which requires a value >= 1. + +``` +croc send --multiple [nr-of-transfers] [file(s)-or-folder] +``` + +After all `[nr-of-transfers]` were done, the program will stop. To prevent keeping the program running forever if not all transfers +possibilities are used, a timeout is set on the connection with the relay. By default, this `timeout` is set to `30 seconds`, which is +likely not enough. If you want to keep the connection alive for more time you can use the `--timeout` flag like this: + +``` +croc send --timeout [nr-of-seconds] (--multiple [nr-of-transfers]) [file(s)-or-folder] +``` + +*NOTE*: You can't keep the connection alive for more than `1 hour`. + ### Custom code phrase You can send with your own code phrase (must be more than 6 characters). diff --git a/go.mod b/go.mod index b90716245..61ccb071f 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/OneOfOne/xxhash v1.2.8 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/uuid v1.3.1 github.com/magisterquis/connectproxy v0.0.0-20200725203833-3582e84f0c9b github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index fcd625ec9..fb6cc7791 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/denisbrodbeck/machineid v1.0.1 h1:geKr9qtkB876mXguW2X6TU4ZynleN6ezuMSRhl4D7AQ= github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbjJCrnectwCyxcUSI= +github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/kalafut/imohash v1.0.2 h1:j/cUPa15YvXv7abJlM+kdJIycbBMpmO7WqhPl4YB76I= github.com/kalafut/imohash v1.0.2/go.mod h1:PjHBF0vpo1q7zMqiTn0qwSTQU2wDn5QIe8S8sFQuZS8= diff --git a/src/cli/cli.go b/src/cli/cli.go index 5237c62a8..4853e3d92 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -44,7 +44,7 @@ func Run() (err error) { app.UsageText = `Send a file: croc send file.txt - -git to respect your .gitignore + -git to respect your .gitignore Send multiple files: croc send file1.txt file2.txt file3.txt or @@ -66,6 +66,8 @@ func Run() (err error) { ArgsUsage: "[filename(s) or folder]", Flags: []cli.Flag{ &cli.BoolFlag{Name: "zip", Usage: "zip folder before sending"}, + &cli.IntFlag{Name: "timelimit", Value: 30, Usage: "timelimit in secods for sender to allow all transfers"}, + &cli.IntFlag{Name: "multiple", Value: 1, Usage: "maximum number of transfers"}, &cli.StringFlag{Name: "code", Aliases: []string{"c"}, Usage: "codephrase used to connect to relay"}, &cli.StringFlag{Name: "hash", Value: "xxhash", Usage: "hash algorithm (xxhash, imohash, md5)"}, &cli.StringFlag{Name: "text", Aliases: []string{"t"}, Usage: "send some text"}, @@ -181,6 +183,8 @@ func send(c *cli.Context) (err error) { crocOptions := croc.Options{ SharedSecret: c.String("code"), IsSender: true, + TimeLimit: c.Int("timelimit"), + MaxTransfers: c.Int("multiple"), Debug: c.Bool("debug"), NoPrompt: c.Bool("yes"), RelayAddress: c.String("relay"), @@ -202,11 +206,30 @@ func send(c *cli.Context) (err error) { ZipFolder: c.Bool("zip"), GitIgnore: c.Bool("git"), } + + if crocOptions.TimeLimit <= 0 { + fmt.Println("timelimit must be greater than 0. Defaulting to 30 seconds.") + crocOptions.TimeLimit = 30 + } else if crocOptions.TimeLimit > 3600 { + fmt.Println("timelimit must be less than 3600. Defaulting to 30 seconds.") + crocOptions.TimeLimit = 30 + } + + if crocOptions.MaxTransfers <= 0 { + fmt.Println("multiple must be greater than 0. Defaulting to 1 transfers.") + crocOptions.MaxTransfers = 1 + } else if crocOptions.MaxTransfers > 1 { + fmt.Println("Allowing multiple transfers.") + fmt.Println("The connection will stay open until all transfers are complete, or the timelimit is reached.") + fmt.Printf("The current timelimit is %d seconds.\n", crocOptions.TimeLimit) + } + if crocOptions.RelayAddress != models.DEFAULT_RELAY { crocOptions.RelayAddress6 = "" } else if crocOptions.RelayAddress6 != models.DEFAULT_RELAY6 { crocOptions.RelayAddress = "" } + b, errOpen := os.ReadFile(getConfigFile()) if errOpen == nil && !c.Bool("remember") { var rememberedOptions croc.Options diff --git a/src/comm/comm.go b/src/comm/comm.go index d578ac14e..8cf648bd3 100644 --- a/src/comm/comm.go +++ b/src/comm/comm.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/google/uuid" "github.com/magisterquis/connectproxy" "github.com/schollz/croc/v9/src/utils" log "github.com/schollz/logger" @@ -23,6 +24,7 @@ var MAGIC_BYTES = []byte("croc") // Comm is some basic TCP communication type Comm struct { + id string connection net.Conn } @@ -100,6 +102,7 @@ func New(c net.Conn) *Comm { log.Errorf("error setting write deadline: %v", err) } comm := new(Comm) + comm.id = uuid.New().String() comm.connection = c return comm } @@ -109,6 +112,10 @@ func (c *Comm) Connection() net.Conn { return c.connection } +func (c *Comm) ID() string { + return c.id +} + // Close closes the connection func (c *Comm) Close() { if err := c.connection.Close(); err != nil { diff --git a/src/comm/comm_test.go b/src/comm/comm_test.go index afba91f80..ea89dcb07 100644 --- a/src/comm/comm_test.go +++ b/src/comm/comm_test.go @@ -51,6 +51,7 @@ func TestComm(t *testing.T) { time.Sleep(300 * time.Millisecond) a, err := NewConnection("127.0.0.1:"+port, 10*time.Minute) assert.Nil(t, err) + assert.NotNil(t, a.id) data, err := a.Receive() assert.Equal(t, []byte("hello, world"), data) assert.Nil(t, err) diff --git a/src/croc/croc.go b/src/croc/croc.go index a4e23a195..9beff521d 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -38,6 +38,7 @@ import ( var ( ipRequest = []byte("ips?") handshakeRequest = []byte("handshake") + wgTransfer sync.WaitGroup ) func init() { @@ -56,6 +57,8 @@ func Debug(debug bool) { // Options specifies user specific options type Options struct { IsSender bool + TimeLimit int + MaxTransfers int SharedSecret string Debug bool RelayAddress string @@ -578,11 +581,24 @@ func (c *Client) broadcastOnLocalNetwork(useipv6 bool) { } } -func (c *Client) transferOverLocalRelay(errchan chan<- error) { +func (c *Client) resetFlagsAndRealeaseLock() { + // reset flags and keys for next transfer + c.Key = nil + c.Step1ChannelSecured = false + c.Step2FileInfoTransferred = false + c.Step3RecipientRequestFile = false + c.Step4FileTransferred = false + c.Step5CloseChannels = false + c.SuccessfulTransfer = false + + wgTransfer.Done() +} + +func (c *Client) transferOverLocalRelay(errchan chan error) { time.Sleep(500 * time.Millisecond) log.Debug("establishing connection") var banner string - conn, banner, ipaddr, err := tcp.ConnectToTCPServer("127.0.0.1:"+c.Options.RelayPorts[0], c.Options.RelayPassword, c.Options.SharedSecret[:3]) + conn, banner, ipaddr, err := tcp.ConnectToTCPServer("127.0.0.1:"+c.Options.RelayPorts[0], c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, true, c.Options.MaxTransfers) log.Debugf("banner: %s", banner) if err != nil { err = fmt.Errorf("could not connect to 127.0.0.1:%s: %w", c.Options.RelayPorts[0], err) @@ -590,17 +606,33 @@ func (c *Client) transferOverLocalRelay(errchan chan<- error) { // not really an error because it will try to connect over the actual relay return } - log.Debugf("local connection established: %+v", conn) + log.Debugf("local sender connection established: %+v", conn) + err = nil for { - data, _ := conn.Receive() - if bytes.Equal(data, handshakeRequest) { + data, errConn := conn.Receive() + if errConn != nil { + log.Debugf("[%+v] had error: %s", conn, errConn.Error()) break + } + if bytes.Equal(data, ipRequest) { + log.Debug("Got ip request, sending nil since we are local") + if err = conn.Send(nil); err != nil { + log.Errorf("error sending: %v", err) + } + } else if bytes.Equal(data, handshakeRequest) { + wgTransfer.Add(1) + go c.makeLocalTransfer(conn, ipaddr, banner, errchan) + wgTransfer.Wait() } else if bytes.Equal(data, []byte{1}) { log.Debug("got ping") } else { log.Debugf("instead of handshake got: %s", data) } } + errchan <- err +} + +func (c *Client) makeLocalTransfer(conn *comm.Comm, ipaddr, banner string, errchan chan error) (err error) { c.conn[0] = conn log.Debug("exchanged header message") c.Options.RelayAddress = "127.0.0.1" @@ -610,7 +642,112 @@ func (c *Client) transferOverLocalRelay(errchan chan<- error) { c.Options.RelayPorts = []string{c.Options.RelayPorts[0]} } c.ExternalIP = ipaddr - errchan <- c.transfer() + + err = c.transfer() + if err != nil { + errchan <- err + fmt.Print("Did not transfer successfully locally\n") + } else { + fmt.Print("Local transfer was successful!\n") + } + + c.resetFlagsAndRealeaseLock() + + return +} + +func (c *Client) establishSecureConnectionWithTCPServer(errchan chan error) (conn *comm.Comm, ipaddr, banner string, err error) { + for _, address := range []string{c.Options.RelayAddress6, c.Options.RelayAddress} { + if address == "" { + continue + } + host, port, _ := net.SplitHostPort(address) + log.Debugf("host: '%s', port: '%s'", host, port) + // Default port to :9009 + if port == "" { + host = address + port = models.DEFAULT_PORT + } + log.Debugf("got host '%v' and port '%v'", host, port) + address = net.JoinHostPort(host, port) + log.Debugf("trying connection to %s", address) + conn, banner, ipaddr, err = tcp.ConnectToTCPServer(address, c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, true, c.Options.MaxTransfers, time.Duration(c.Options.TimeLimit)*time.Second) + if err == nil { + c.Options.RelayAddress = address + break + } + log.Debugf("could not establish '%s'", address) + } + + return +} + +func (c *Client) listenToMainConn(conn *comm.Comm, ipaddr, banner string, errchan chan error) { + var err error + err = nil + for { + log.Debug("waiting for bytes") + data, errConn := conn.Receive() + if errConn != nil { + log.Debugf("[%+v] had error: %s", conn, errConn.Error()) + break + } + if bytes.Equal(data, ipRequest) { + log.Debug("Got ip request") + // recipient wants to try to connect to local ips + var ips []string + // only get local ips if the local is enabled + if !c.Options.DisableLocal { + // get list of local ips + ips, err = utils.GetLocalIPs() + if err != nil { + log.Debugf("error getting local ips: %v", err) + } + // prepend the port that is being listened to + ips = append([]string{c.Options.RelayPorts[0]}, ips...) + } + bips, _ := json.Marshal(ips) + if err = conn.Send(bips); err != nil { + log.Errorf("error sending: %v", err) + } + } else if bytes.Equal(data, handshakeRequest) { + wgTransfer.Add(1) + go c.makeTheTransfer(conn, ipaddr, banner, errchan) + wgTransfer.Wait() + } else if bytes.Equal(data, []byte{1}) { + log.Debug("got ping") + continue + } else { + log.Debugf("[%+v] got weird bytes: %+v", conn, data) + // throttle the reading + errchan <- fmt.Errorf("gracefully refusing using the public relay") + return + } + } + errchan <- err +} + +func (c *Client) makeTheTransfer(conn *comm.Comm, ipaddr, banner string, errchan chan error) (err error) { + c.conn[0] = conn + c.Options.RelayPorts = strings.Split(banner, ",") + if c.Options.NoMultiplexing { + log.Debug("no multiplexing") + c.Options.RelayPorts = []string{c.Options.RelayPorts[0]} + } + c.ExternalIP = ipaddr + log.Debug("exchanged header message") + + err = c.transfer() + if err != nil { + errchan <- err + fmt.Print("Did not transfer successfully\n") + } else { + fmt.Print("Transfer successful!\n") + } + + c.resetFlagsAndRealeaseLock() + + return } // Send will send the specified file @@ -652,88 +789,21 @@ func (c *Client) Send(filesInfo []FileInfo, emptyFoldersToTransfer []FileInfo, t } if !c.Options.OnlyLocal { - go func() { - var ipaddr, banner string - var conn *comm.Comm - durations := []time.Duration{100 * time.Millisecond, 5 * time.Second} - for i, address := range []string{c.Options.RelayAddress6, c.Options.RelayAddress} { - if address == "" { - continue - } - host, port, _ := net.SplitHostPort(address) - log.Debugf("host: '%s', port: '%s'", host, port) - // Default port to :9009 - if port == "" { - host = address - port = models.DEFAULT_PORT - } - log.Debugf("got host '%v' and port '%v'", host, port) - address = net.JoinHostPort(host, port) - log.Debugf("trying connection to %s", address) - conn, banner, ipaddr, err = tcp.ConnectToTCPServer(address, c.Options.RelayPassword, c.Options.SharedSecret[:3], durations[i]) - if err == nil { - c.Options.RelayAddress = address - break - } - log.Debugf("could not establish '%s'", address) - } - if conn == nil && err == nil { - err = fmt.Errorf("could not connect") - } - if err != nil { - err = fmt.Errorf("could not connect to %s: %w", c.Options.RelayAddress, err) - log.Debug(err) - errchan <- err - return - } + conn, ipaddr, banner, err := c.establishSecureConnectionWithTCPServer(errchan) + + if conn == nil && err == nil { + err = fmt.Errorf("could not connect") + } + if err != nil { + err = fmt.Errorf("could not connect to %s: %w", c.Options.RelayAddress, err) + log.Debug(err) + errchan <- err + } else { log.Debugf("banner: %s", banner) - log.Debugf("connection established: %+v", conn) - for { - log.Debug("waiting for bytes") - data, errConn := conn.Receive() - if errConn != nil { - log.Debugf("[%+v] had error: %s", conn, errConn.Error()) - } - if bytes.Equal(data, ipRequest) { - // recipient wants to try to connect to local ips - var ips []string - // only get local ips if the local is enabled - if !c.Options.DisableLocal { - // get list of local ips - ips, err = utils.GetLocalIPs() - if err != nil { - log.Debugf("error getting local ips: %v", err) - } - // prepend the port that is being listened to - ips = append([]string{c.Options.RelayPorts[0]}, ips...) - } - bips, _ := json.Marshal(ips) - if err = conn.Send(bips); err != nil { - log.Errorf("error sending: %v", err) - } - } else if bytes.Equal(data, handshakeRequest) { - break - } else if bytes.Equal(data, []byte{1}) { - log.Debug("got ping") - continue - } else { - log.Debugf("[%+v] got weird bytes: %+v", conn, data) - // throttle the reading - errchan <- fmt.Errorf("gracefully refusing using the public relay") - return - } - } + log.Debugf("sender connection established: %+v", conn) - c.conn[0] = conn - c.Options.RelayPorts = strings.Split(banner, ",") - if c.Options.NoMultiplexing { - log.Debug("no multiplexing") - c.Options.RelayPorts = []string{c.Options.RelayPorts[0]} - } - c.ExternalIP = ipaddr - log.Debug("exchanged header message") - errchan <- c.transfer() - }() + c.listenToMainConn(conn, ipaddr, banner, errchan) + } } err = <-errchan @@ -851,9 +921,8 @@ func (c *Client) Receive() (err error) { log.Debug("establishing connection") } var banner string - durations := []time.Duration{200 * time.Millisecond, 5 * time.Second} err = fmt.Errorf("found no addresses to connect") - for i, address := range []string{c.Options.RelayAddress6, c.Options.RelayAddress} { + for _, address := range []string{c.Options.RelayAddress6, c.Options.RelayAddress} { if address == "" { continue } @@ -867,9 +936,10 @@ func (c *Client) Receive() (err error) { log.Debugf("got host '%v' and port '%v'", host, port) address = net.JoinHostPort(host, port) log.Debugf("trying connection to %s", address) - c.conn[0], banner, c.ExternalIP, err = tcp.ConnectToTCPServer(address, c.Options.RelayPassword, c.Options.SharedSecret[:3], durations[i]) + c.conn[0], banner, c.ExternalIP, err = tcp.ConnectToTCPServer(address, c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, true, 1, time.Duration(c.Options.TimeLimit)*time.Second) if err == nil { c.Options.RelayAddress = address + log.Debug("receiver connection established") break } log.Debugf("could not establish '%s'", address) @@ -925,7 +995,7 @@ func (c *Client) Receive() (err error) { } serverTry := net.JoinHostPort(ip, port) - conn, banner2, externalIP, errConn := tcp.ConnectToTCPServer(serverTry, c.Options.RelayPassword, c.Options.SharedSecret[:3], 500*time.Millisecond) + conn, banner2, externalIP, errConn := tcp.ConnectToTCPServer(serverTry, c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, false, 1, 500*time.Millisecond) if errConn != nil { log.Debug(errConn) log.Debugf("could not connect to " + serverTry) @@ -945,6 +1015,7 @@ func (c *Client) Receive() (err error) { } } + log.Debug("sending handshake message") if err = c.conn[0].Send(handshakeRequest); err != nil { log.Errorf("handshake send error: %v", err) } @@ -973,6 +1044,7 @@ func (c *Client) transfer() (err error) { // if recipient, initialize with sending pake information log.Debug("ready") if !c.Options.IsSender && !c.Step1ChannelSecured { + log.Debug("sending pake information") err = message.Send(c.conn[0], c.Key, message.Message{ Type: message.TypePAKE, Bytes: c.Pake.Bytes(), @@ -1030,7 +1102,7 @@ func (c *Client) transfer() (err error) { } } - if c.Options.Stdout && !c.Options.IsSender { + if c.Options.Stdout && !c.Options.IsSender && c.FilesToTransfer != nil && len(c.FilesToTransfer) > 0 { pathToFile := path.Join( c.FilesToTransfer[c.FilesToTransferCurrentNum].FolderRemote, c.FilesToTransfer[c.FilesToTransferCurrentNum].Name, @@ -1276,6 +1348,9 @@ func (c *Client) processMessagePake(m message.Message) (err error) { server, c.Options.RelayPassword, fmt.Sprintf("%s-%d", utils.SHA256(c.Options.SharedSecret[:5])[:6], j), + c.Options.IsSender, + false, + 2, ) if err != nil { panic(err) @@ -1338,9 +1413,12 @@ func (c *Client) processMessage(payload []byte) (done bool, err error) { switch m.Type { case message.TypeFinished: - err = message.Send(c.conn[0], c.Key, message.Message{ - Type: message.TypeFinished, - }) + // only senders should respond to "finished" messages + if c.Options.IsSender { + err = message.Send(c.conn[0], c.Key, message.Message{ + Type: message.TypeFinished, + }) + } done = true c.SuccessfulTransfer = true return @@ -1516,6 +1594,10 @@ func (c *Client) recipientGetFileReady(finished bool) (err error) { } c.SuccessfulTransfer = true c.FilesHasFinished[c.FilesToTransferCurrentNum] = struct{}{} + + if !c.Options.IsSender { + return + } } err = c.recipientInitializeFile() diff --git a/src/croc/croc_test.go b/src/croc/croc_test.go index 445ec5afe..71b80d1c8 100644 --- a/src/croc/croc_test.go +++ b/src/croc/croc_test.go @@ -32,6 +32,8 @@ func TestCrocReadme(t *testing.T) { log.Debug("setting up sender") sender, err := New(Options{ IsSender: true, + TimeLimit: 30, + MaxTransfers: 1, SharedSecret: "8123-testingthecroc", Debug: true, RelayAddress: "127.0.0.1:8281", @@ -99,6 +101,8 @@ func TestCrocEmptyFolder(t *testing.T) { log.Debug("setting up sender") sender, err := New(Options{ IsSender: true, + TimeLimit: 30, + MaxTransfers: 1, SharedSecret: "8123-testingthecroc", Debug: true, RelayAddress: "127.0.0.1:8281", @@ -166,6 +170,8 @@ func TestCrocSymlink(t *testing.T) { log.Debug("setting up sender") sender, err := New(Options{ IsSender: true, + TimeLimit: 30, + MaxTransfers: 1, SharedSecret: "8124-testingthecroc", Debug: true, RelayAddress: "127.0.0.1:8281", @@ -232,7 +238,8 @@ func TestCrocSymlink(t *testing.T) { t.Errorf("symlink transfer failed: %s", err.Error()) } } -func testCrocIgnoreGit(t *testing.T) { + +func TestCrocIgnoreGit(t *testing.T) { log.SetLevel("trace") defer os.Remove(".gitignore") time.Sleep(300 * time.Millisecond) @@ -268,7 +275,9 @@ func TestCrocLocal(t *testing.T) { log.Debug("setting up sender") sender, err := New(Options{ IsSender: true, - SharedSecret: "8123-testingthecroc", + TimeLimit: 30, + MaxTransfers: 1, + SharedSecret: "2813-testingthecroc", Debug: true, RelayAddress: "127.0.0.1:8181", RelayPorts: []string{"8181", "8182"}, @@ -288,7 +297,7 @@ func TestCrocLocal(t *testing.T) { log.Debug("setting up receiver") receiver, err := New(Options{ IsSender: false, - SharedSecret: "8123-testingthecroc", + SharedSecret: "2813-testingthecroc", Debug: true, RelayAddress: "127.0.0.1:8181", RelayPassword: "pass123", @@ -316,7 +325,7 @@ func TestCrocLocal(t *testing.T) { } wg.Done() }() - time.Sleep(100 * time.Millisecond) + time.Sleep(300 * time.Millisecond) go func() { err := receiver.Receive() if err != nil { @@ -348,6 +357,8 @@ func TestCrocError(t *testing.T) { log.SetLevel("warn") sender, _ := New(Options{ IsSender: true, + TimeLimit: 30, + MaxTransfers: 1, SharedSecret: "8123-testingthecroc2", Debug: true, RelayAddress: "doesntexistok.com:8381", @@ -380,6 +391,14 @@ func TestCleanUp(t *testing.T) { log.Debug("Full cleanup") var err error + for _, file := range []string{".gitignore", ".gitignore"} { + err = os.Remove(file) + if err == nil { + log.Debugf("Successfully purged %s", file) + } else { + log.Debugf("%s was already purged.", file) + } + } for _, file := range []string{"README.md", "./README.md"} { err = os.Remove(file) if err == nil { diff --git a/src/message/message.go b/src/message/message.go index 1d94b80ae..00433e620 100644 --- a/src/message/message.go +++ b/src/message/message.go @@ -1,6 +1,7 @@ package message import ( + "bytes" "encoding/json" "github.com/schollz/croc/v9/src/comm" @@ -56,7 +57,9 @@ func Encode(key []byte, m Message) (b []byte, err error) { b = compress.Compress(b) if key != nil { log.Debugf("writing %s message (encrypted)", m.Type) - b, err = crypt.Encrypt(b, key) + if m.Type != TypeFinished { + b, err = crypt.Encrypt(b, key) + } } else { log.Debugf("writing %s message (unencrypted)", m.Type) } @@ -65,7 +68,7 @@ func Encode(key []byte, m Message) (b []byte, err error) { // Decode will convert from bytes func Decode(key []byte, b []byte) (m Message, err error) { - if key != nil { + if key != nil && !bytes.Contains(b, []byte(TypeFinished)) { b, err = crypt.Decrypt(b, key) if err != nil { return diff --git a/src/message/message_test.go b/src/message/message_test.go index eae58f5f0..b0f3a67c8 100644 --- a/src/message/message_test.go +++ b/src/message/message_test.go @@ -1,6 +1,7 @@ package message import ( + "bytes" "crypto/rand" "fmt" "net" @@ -15,6 +16,18 @@ import ( var TypeMessage Type = "message" +// Test that the finished message is not encrypted +func TestMessageDontEncryptFinished(t *testing.T) { + log.SetLevel("debug") + m := Message{Type: TypeFinished, Message: "hello, world"} + e, salt, err := crypt.New([]byte("pass"), nil) + assert.Nil(t, err) + fmt.Println(string(salt)) + b, err := Encode(e, m) + assert.Nil(t, err) + assert.True(t, bytes.Contains(b, []byte("hello, world"))) +} + func TestMessage(t *testing.T) { log.SetLevel("debug") m := Message{Type: TypeMessage, Message: "hello, world"} diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index b7803b000..a9f49498b 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -2,8 +2,10 @@ package tcp import ( "bytes" + "container/list" "fmt" "net" + "strconv" "strings" "sync" "time" @@ -26,20 +28,30 @@ type server struct { } type roomInfo struct { - first *comm.Comm - second *comm.Comm - opened time.Time - full bool + sender *comm.Comm + receiver *comm.Comm + queue *list.List + isMainRoom bool + maxTransfers int + doneTransfers int + opened time.Time } type roomMap struct { - rooms map[string]roomInfo + rooms map[string]roomInfo + roomLocks map[string]*sync.Mutex sync.Mutex } const pingRoom = "pinglkasjdlfjsaldjf" -var timeToRoomDeletion = 10 * time.Minute +var ( + fullRoom = []byte("room_full") + senderGone = []byte("sender_gone") + noRoom = []byte("room_non_existent") +) + +var timeToRoomDeletion = 60 * time.Minute // Run starts a tcp listener, run async func Run(debugLevel, host, port, password string, banner ...string) (err error) { @@ -59,6 +71,7 @@ func (s *server) start() (err error) { log.Debugf("starting with password '%s'", s.password) s.rooms.Lock() s.rooms.rooms = make(map[string]roomInfo) + s.rooms.roomLocks = make(map[string]*sync.Mutex) s.rooms.Unlock() // delete old rooms @@ -143,19 +156,19 @@ func (s *server) run() (err error) { log.Debugf("checking connection of room %s for %+v", room, c) deleteIt := false s.rooms.Lock() - if _, ok := s.rooms.rooms[room]; !ok { - log.Debug("room is gone") + if _, ok := s.rooms.rooms[room]; !ok || (s.rooms.rooms[room].sender == nil && s.rooms.rooms[room].receiver == nil) { + log.Debugf("room %s is gone", room) s.rooms.Unlock() return } log.Debugf("room: %+v", s.rooms.rooms[room]) - if s.rooms.rooms[room].first != nil && s.rooms.rooms[room].second != nil { + if s.rooms.rooms[room].sender != nil && s.rooms.rooms[room].receiver != nil { log.Debug("rooms ready") s.rooms.Unlock() break } else { - if s.rooms.rooms[room].first != nil { - errSend := s.rooms.rooms[room].first.Send([]byte{1}) + if s.rooms.rooms[room].sender != nil { + errSend := s.rooms.rooms[room].sender.Send([]byte{1}) if errSend != nil { log.Debug(errSend) deleteIt = true @@ -249,64 +262,281 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er return } - // wait for client to tell me which room they want - log.Debug("waiting for answer") + isSender := false + log.Debug("wait for client to tell if they want to send or receive") enc, err := c.Receive() if err != nil { return } + data, err := crypt.Decrypt(enc, strongKeyForEncryption) + if err != nil { + return + } + if !bytes.Equal(data, []byte("send")) && !bytes.Equal(data, []byte("receive")) { + err = fmt.Errorf("got bad response: %s", data) + return + } else if bytes.Equal(data, []byte("send")) { + log.Debug("client wants to send") + isSender = true + } + + // wait for client to tell me which room they want + log.Debug("waiting for room") + enc, err = c.Receive() + if err != nil { + return + } roomBytes, err := crypt.Decrypt(enc, strongKeyForEncryption) if err != nil { return } room = string(roomBytes) + log.Debug("Check if this is a main room") + enc, err = c.Receive() + if err != nil { + return + } + data, err = crypt.Decrypt(enc, strongKeyForEncryption) + if err != nil { + return + } + if !bytes.Equal(data, []byte("main")) && !bytes.Equal(data, []byte("secondary")) { + err = fmt.Errorf("got bad response: %s", data) + return + } + isMainRoom := bytes.Equal(data, []byte("main")) + log.Debugf("isMainRoom: %v", isMainRoom) + s.rooms.Lock() + _, roomExists := s.rooms.rooms[room] // create the room if it is new - if _, ok := s.rooms.rooms[room]; !ok { + if !roomExists || isSender { + if roomExists && isSender && s.rooms.rooms[room].sender != nil { + // if the room exists and the sender is already connected + // then signal to the client that the room is full + err = s.sendRoomIsFull(c, strongKeyForEncryption) + return + } + err = s.createOrUpdateRoom(c, room, strongKeyForEncryption, isMainRoom, isSender, roomExists) + if err != nil { + log.Error(err) + } + + // if the room is new then return + if !roomExists { + return + } + } else if s.rooms.rooms[room].receiver != nil { + // if the room has a transfer going on + if s.rooms.rooms[room].maxTransfers > 1 { + // if the room is a multi-transfer room then add to queue + var keepGoing bool + err, keepGoing = s.handleWaitingRoomForReceivers(c, room, strongKeyForEncryption) + if err != nil { + log.Error(err) + } + if !keepGoing { + s.rooms.roomLocks[room].Unlock() + return + } + } else { + // otherwise, tell the client that the room is full + err = s.sendRoomIsFull(c, strongKeyForEncryption) + return + } + } else { + log.Debugf("room %s has 2", room) s.rooms.rooms[room] = roomInfo{ - first: c, - opened: time.Now(), + sender: s.rooms.rooms[room].sender, + receiver: c, + queue: s.rooms.rooms[room].queue, + isMainRoom: s.rooms.rooms[room].isMainRoom, + maxTransfers: s.rooms.rooms[room].maxTransfers, + doneTransfers: s.rooms.rooms[room].doneTransfers, + opened: s.rooms.rooms[room].opened, } - s.rooms.Unlock() - // tell the client that they got the room + s.rooms.roomLocks[room].Lock() + } - bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption) + err = s.beginTransfer(c, room, strongKeyForEncryption) + if err != nil { + log.Error(err) + } + + return +} + +func (s *server) sendRoomIsFull(c *comm.Comm, strongKeyForEncryption []byte) (err error) { + s.rooms.Unlock() + bSend, err := crypt.Encrypt([]byte(fullRoom), strongKeyForEncryption) + if err != nil { + return + } + err = c.Send(bSend) + if err != nil { + log.Error(err) + return + } + return +} + +func (s *server) createOrUpdateRoom(c *comm.Comm, room string, strongKeyForEncryption []byte, isMainRoom, isSender, updateRoom bool) (err error) { + var enc, data, bSend []byte + + if !updateRoom { + log.Debugf("Creating room %s", room) + } else { + log.Debugf("Updating room %s", room) + } + + maxTransfers := 1 + if isMainRoom && isSender { + log.Debug("Wait for maxTransfers") + enc, err = c.Receive() if err != nil { return } - err = c.Send(bSend) + data, err = crypt.Decrypt(enc, strongKeyForEncryption) if err != nil { - log.Error(err) - s.deleteRoom(room) return } - log.Debugf("room %s has 1", room) - return + + maxTransfers, err = strconv.Atoi(string(data)) + if err != nil { + return + } + log.Debugf("maxTransfers: %v", maxTransfers) + } + + var sender, receiver *comm.Comm + var queue *list.List + opened := time.Now() + if isSender { + sender = c + if updateRoom { + receiver = s.rooms.rooms[room].receiver + queue = s.rooms.rooms[room].queue + opened = s.rooms.rooms[room].opened + } + } else { + receiver = c + if updateRoom { + sender = s.rooms.rooms[room].sender + queue = s.rooms.rooms[room].queue + opened = s.rooms.rooms[room].opened + } } - if s.rooms.rooms[room].full { - s.rooms.Unlock() - bSend, err = crypt.Encrypt([]byte("room full"), strongKeyForEncryption) + + s.rooms.rooms[room] = roomInfo{ + sender: sender, + receiver: receiver, + queue: queue, + isMainRoom: isMainRoom, + maxTransfers: maxTransfers, + doneTransfers: 0, + opened: opened, + } + + if !updateRoom { + log.Debugf("Client crated main room %s, %v", room, isSender) + s.rooms.roomLocks[room] = &sync.Mutex{} + // tell the client that they got the room + bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption) if err != nil { return } err = c.Send(bSend) if err != nil { log.Error(err) + s.deleteRoom(room) return } - return + log.Debugf("room %s has 1", room) + s.rooms.Unlock() } - log.Debugf("room %s has 2", room) + + return +} + +func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error, keepGoing bool) { + var bSend []byte + log.Debugf("room %s is full, adding to queue", room) + queue := s.rooms.rooms[room].queue + if queue == nil { + queue = list.New() + } + queue.PushBack(c.ID()) s.rooms.rooms[room] = roomInfo{ - first: s.rooms.rooms[room].first, - second: c, - opened: s.rooms.rooms[room].opened, - full: true, + sender: s.rooms.rooms[room].sender, + receiver: s.rooms.rooms[room].receiver, + isMainRoom: s.rooms.rooms[room].isMainRoom, + opened: s.rooms.rooms[room].opened, + maxTransfers: s.rooms.rooms[room].maxTransfers, + doneTransfers: s.rooms.rooms[room].doneTransfers, + queue: queue, + } + s.rooms.Unlock() + + keepGoing = false + for { + s.rooms.roomLocks[room].Lock() + + if s.rooms.rooms[room].receiver != nil || s.rooms.rooms[room].queue.Front().Value.(string) != c.ID() { + time.Sleep(1 * time.Second) + // tell the client that they need to wait + bSend, err = crypt.Encrypt([]byte("wait"), strongKeyForEncryption) + if err != nil { + return + } + err = c.Send(bSend) + if err != nil { + log.Error(err) + return + } + s.rooms.roomLocks[room].Unlock() + } else if s.rooms.rooms[room].doneTransfers >= s.rooms.rooms[room].maxTransfers { + // tell the client that the sender is no longer available + bSend, err = crypt.Encrypt([]byte(senderGone), strongKeyForEncryption) + if err != nil { + return + } + err = c.Send(bSend) + if err != nil { + log.Error(err) + return + } + break + } else { + s.rooms.Lock() + // remove the client from the queue + newQueue := s.rooms.rooms[room].queue + newQueue.Remove(newQueue.Front()) + s.rooms.rooms[room] = roomInfo{ + sender: s.rooms.rooms[room].sender, + receiver: c, + queue: newQueue, + isMainRoom: s.rooms.rooms[room].isMainRoom, + maxTransfers: s.rooms.rooms[room].maxTransfers, + doneTransfers: s.rooms.rooms[room].doneTransfers, + opened: s.rooms.rooms[room].opened, + } + keepGoing = true + break + } } - otherConnection := s.rooms.rooms[room].first + return +} + +func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error) { s.rooms.Unlock() + // safety check (it should never happen) + if s.rooms.rooms[room].sender == nil || s.rooms.rooms[room].receiver == nil { + err = fmt.Errorf("sender or receiver is nil") + return + } + // second connection is the sender, time to staple connections var wg sync.WaitGroup wg.Add(1) @@ -317,22 +547,49 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er pipe(com1.Connection(), com2.Connection()) wg.Done() log.Debug("done piping") - }(otherConnection, c, &wg) + }(s.rooms.rooms[room].sender, s.rooms.rooms[room].receiver, &wg) - // tell the sender everything is ready - bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption) + // tell the client everything is ready + log.Debug("sending ok to client") + bSend, err := crypt.Encrypt([]byte("ok"), strongKeyForEncryption) if err != nil { return } err = c.Send(bSend) if err != nil { - s.deleteRoom(room) return } wg.Wait() - // delete room - s.deleteRoom(room) + // check if room is done and delete it if so + newDoneTransfers := s.rooms.rooms[room].doneTransfers + 1 + + // update the room info + s.rooms.Lock() + lengthOfQueue := 0 + if s.rooms.rooms[room].queue != nil { + lengthOfQueue = s.rooms.rooms[room].queue.Len() + } + log.Debugf("room %s has %d left in queue", room, lengthOfQueue) + s.rooms.rooms[room] = roomInfo{ + sender: s.rooms.rooms[room].sender, + receiver: nil, + queue: s.rooms.rooms[room].queue, + isMainRoom: s.rooms.rooms[room].isMainRoom, + maxTransfers: s.rooms.rooms[room].maxTransfers, + doneTransfers: newDoneTransfers, + opened: s.rooms.rooms[room].opened, + } + s.rooms.Unlock() + + // delete the room if it is done or unlock it if it is not + if newDoneTransfers >= s.rooms.rooms[room].maxTransfers { + log.Debugf("room %s is done, deleting it", room) + s.deleteRoom(room) + } else { + log.Debugf("room %s has %d done", room, newDoneTransfers) + s.rooms.roomLocks[room].Unlock() + } return } @@ -342,21 +599,45 @@ func (s *server) deleteRoom(room string) { if _, ok := s.rooms.rooms[room]; !ok { return } + if s.rooms.rooms[room].queue != nil && s.rooms.rooms[room].queue.Len() > 0 && s.rooms.roomLocks[room] != nil { + // signal to all waiting that the room will be deleted + for { + s.rooms.roomLocks[room].Unlock() + time.Sleep(250 * time.Millisecond) + s.rooms.roomLocks[room].Lock() + // remove the client from the queue + newQueue := s.rooms.rooms[room].queue + newQueue.Remove(newQueue.Front()) + s.rooms.rooms[room] = roomInfo{ + sender: s.rooms.rooms[room].sender, + receiver: s.rooms.rooms[room].receiver, + isMainRoom: s.rooms.rooms[room].isMainRoom, + opened: s.rooms.rooms[room].opened, + maxTransfers: s.rooms.rooms[room].maxTransfers, + doneTransfers: s.rooms.rooms[room].doneTransfers, + queue: newQueue, + } + if s.rooms.rooms[room].queue.Len() == 0 { + break + } + } + delete(s.rooms.roomLocks, room) + } log.Debugf("deleting room: %s", room) - if s.rooms.rooms[room].first != nil { - s.rooms.rooms[room].first.Close() + if s.rooms.rooms[room].sender != nil { + s.rooms.rooms[room].sender.Close() } - if s.rooms.rooms[room].second != nil { - s.rooms.rooms[room].second.Close() + if s.rooms.rooms[room].receiver != nil { + s.rooms.rooms[room].receiver.Close() } - s.rooms.rooms[room] = roomInfo{first: nil, second: nil} + s.rooms.rooms[room] = roomInfo{sender: nil, receiver: nil} delete(s.rooms.rooms, room) } // chanFromConn creates a channel from a Conn object, and sends everything it // // Read()s from the socket to the channel. -func chanFromConn(conn net.Conn) chan []byte { +func chanFromConn(conn net.Conn, isSender bool) chan []byte { c := make(chan []byte, 1) if err := conn.SetReadDeadline(time.Now().Add(3 * time.Hour)); err != nil { log.Warnf("can't set read deadline: %v", err) @@ -371,14 +652,19 @@ func chanFromConn(conn net.Conn) chan []byte { // Copy the buffer so it doesn't get changed while read by the recipient. copy(res, b[:n]) c <- res + // if finished, then we must exit in order to prevent zombie listeners + if bytes.Contains(res, []byte("finished")) && isSender { + log.Debugf("closing sender channel for %s", conn.RemoteAddr().String()) + close(c) + break + } } if err != nil { - log.Debug(err) + log.Debugf("closing channel for %s: %v", conn.RemoteAddr().String(), err) c <- nil break } } - log.Debug("exiting") }() return c @@ -387,27 +673,34 @@ func chanFromConn(conn net.Conn) chan []byte { // pipe creates a full-duplex pipe between the two sockets and // transfers data from one to the other. func pipe(conn1 net.Conn, conn2 net.Conn) { - chan1 := chanFromConn(conn1) - chan2 := chanFromConn(conn2) + chan1 := chanFromConn(conn1, true) + chan2 := chanFromConn(conn2, false) for { + log.Debugf("running in pipe %v - %v", conn1.RemoteAddr().String(), conn2.RemoteAddr().String()) select { - case b1 := <-chan1: - if b1 == nil { + case b1, ok := <-chan1: + if b1 == nil || !ok { return } + log.Debugf("got %s bytes from conn 1, sending it to conn 2", b1) if _, err := conn2.Write(b1); err != nil { log.Errorf("write error on channel 1: %v", err) } - case b2 := <-chan2: - if b2 == nil { + case b2, ok := <-chan2: + if b2 == nil || !ok { return } + log.Debugf("got %s bytes from conn 2, sending it to conn 1", b2) if _, err := conn1.Write(b2); err != nil { log.Errorf("write error on channel 2: %v", err) } } + + if chan1 == nil || chan2 == nil { + break + } } } @@ -436,7 +729,7 @@ func PingServer(address string) (err error) { // ConnectToTCPServer will initiate a new connection // to the specified address, room with optional time limit -func ConnectToTCPServer(address, password, room string, timelimit ...time.Duration) (c *comm.Comm, banner string, ipaddr string, err error) { +func ConnectToTCPServer(address, password, room string, isSender, isMainRoom bool, maxTransfers int, timelimit ...time.Duration) (c *comm.Comm, banner string, ipaddr string, err error) { if len(timelimit) > 0 { c, err = comm.NewConnection(address, timelimit[0]) } else { @@ -498,7 +791,7 @@ func ConnectToTCPServer(address, password, room string, timelimit ...time.Durati log.Debug(err) return } - log.Debug("waiting for first ok") + log.Debug("waiting for sender ok") enc, err := c.Receive() if err != nil { log.Debug(err) @@ -516,8 +809,13 @@ func ConnectToTCPServer(address, password, room string, timelimit ...time.Durati } banner = strings.Split(string(data), "|||")[0] ipaddr = strings.Split(string(data), "|||")[1] - log.Debug("sending room") - bSend, err = crypt.Encrypt([]byte(room), strongKeyForEncryption) + + log.Debug("tell server if you want to send or receive") + clientType := "receive" + if isSender { + clientType = "send" + } + bSend, err = crypt.Encrypt([]byte(clientType), strongKeyForEncryption) if err != nil { log.Debug(err) return @@ -527,22 +825,93 @@ func ConnectToTCPServer(address, password, room string, timelimit ...time.Durati log.Debug(err) return } - log.Debug("waiting for room confirmation") - enc, err = c.Receive() + + log.Debug("sending room") + bSend, err = crypt.Encrypt([]byte(room), strongKeyForEncryption) if err != nil { log.Debug(err) return } - data, err = crypt.Decrypt(enc, strongKeyForEncryption) + err = c.Send(bSend) if err != nil { log.Debug(err) return } - if !bytes.Equal(data, []byte("ok")) { - err = fmt.Errorf("got bad response: %s", data) + + log.Debug("tell server if this is a main room") + roomType := "secondary" + if isMainRoom { + roomType = "main" + } + bSend, err = crypt.Encrypt([]byte(roomType), strongKeyForEncryption) + if err != nil { log.Debug(err) return } + err = c.Send(bSend) + if err != nil { + log.Debug(err) + return + } + + if isMainRoom && isSender { + log.Debug("tell server maxTransfers") + bSend, err = crypt.Encrypt([]byte(strconv.Itoa(maxTransfers)), strongKeyForEncryption) + if err != nil { + log.Debug(err) + return + } + err = c.Send(bSend) + if err != nil { + log.Debug(err) + return + } + } + + log.Debug("waiting for room confirmation") + for { + enc, err = c.Receive() + if err != nil { + log.Debug(err) + return + } + data, err = crypt.Decrypt(enc, strongKeyForEncryption) + if err != nil { + log.Debug(err) + return + } + + if bytes.Equal(data, []byte(fullRoom)) { + err = fmt.Errorf("room is full") + c = nil + return + } + + if !isSender { + if bytes.Equal(data, []byte("wait")) { + log.Debug("waiting for sender to be free") + time.Sleep(1 * time.Second) + continue + } else if bytes.Equal(data, []byte(senderGone)) { + err = fmt.Errorf("sender is gone") + c = nil + return + } else if bytes.Equal(data, []byte(noRoom)) { + err = fmt.Errorf("room does not exist") + c = nil + return + } + } + + if !bytes.Equal(data, []byte("ok")) { + err = fmt.Errorf("got bad response: %s", data) + log.Debug(err) + return + } else { + break + } + } + log.Debug("all set") return } diff --git a/src/tcp/tcp_test.go b/src/tcp/tcp_test.go index 769d999eb..d3a26ad8c 100644 --- a/src/tcp/tcp_test.go +++ b/src/tcp/tcp_test.go @@ -3,9 +3,11 @@ package tcp import ( "bytes" "fmt" + "strings" "testing" "time" + "github.com/schollz/croc/v9/src/comm" log "github.com/schollz/logger" "github.com/stretchr/testify/assert" ) @@ -16,35 +18,48 @@ func BenchmarkConnection(b *testing.B) { time.Sleep(100 * time.Millisecond) b.ResetTimer() for i := 0; i < b.N; i++ { - c, _, _, _ := ConnectToTCPServer("127.0.0.1:8283", "pass123", fmt.Sprintf("testroom%d", i), 1*time.Minute) + c, _, _, _ := ConnectToTCPServer("127.0.0.1:8283", "pass123", fmt.Sprintf("testroom%d", i), true, true, 1, 1*time.Minute) c.Close() } } -func TestTCP(t *testing.T) { +func TestTCPServerPing(t *testing.T) { log.SetLevel("error") - timeToRoomDeletion = 100 * time.Millisecond go Run("debug", "127.0.0.1", "8381", "pass123", "8382") - time.Sleep(timeToRoomDeletion) + time.Sleep(100 * time.Millisecond) err := PingServer("127.0.0.1:8381") assert.Nil(t, err) err = PingServer("127.0.0.1:8333") assert.NotNil(t, err) +} - time.Sleep(timeToRoomDeletion) - c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", 1*time.Minute) - assert.Equal(t, banner, "8382") - assert.Nil(t, err) - c2, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom") +func TestOnlyOneSenderPerRoom(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8280", "pass123", "8382") + time.Sleep(100 * time.Millisecond) + + c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8280", "pass123", "testRoom", true, true, 1, 1*time.Minute) assert.Nil(t, err) - _, _, _, err = ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom") - assert.NotNil(t, err) - _, _, _, err = ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", 1*time.Nanosecond) + assert.NotNil(t, c1) + assert.Equal(t, banner, "8382") + + c2, _, _, err := ConnectToTCPServer("127.0.0.1:8280", "pass123", "testRoom", true, true, 1, 1*time.Minute) assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "room is full")) + assert.Nil(t, c2) - // try sending data - assert.Nil(t, c1.Send([]byte("hello, c2"))) + c1.Close() + time.Sleep(300 * time.Millisecond) +} + +// This is helper function to test that a mocks a transfer +// between two clients connected to the server, +// and checks that the data is transferred correctly +func mockTransfer(c1, c2 *comm.Comm, t *testing.T) { + // try sending data to check the pipe is working properly var data []byte + var err error + assert.Nil(t, c1.Send([]byte("hello, c2"))) for { data, err = c2.Receive() if bytes.Equal(data, []byte{1}) { @@ -65,7 +80,216 @@ func TestTCP(t *testing.T) { } assert.Nil(t, err) assert.Equal(t, []byte("hello, c1"), data) +} + +// Test that a successful transfer can be made +func TestTCPServerSingleConnectionTransfer(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8381", "pass123", "8382") + time.Sleep(100 * time.Millisecond) + + c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", true, true, 1, 1*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c1) + assert.Equal(t, banner, "8382") + + c2, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", false, true, 1) + assert.Nil(t, err) + assert.NotNil(t, c2) + + mockTransfer(c1, c2, t) + + c2.Close() + c1.Close() + time.Sleep(300 * time.Millisecond) +} + +// Test that a receiver can connect before a sender +func TestTCPRecieverFirst(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8383", "pass123", "8382") + time.Sleep(100 * time.Millisecond) + + receiver, banner, _, err := ConnectToTCPServer("127.0.0.1:8383", "pass123", "testRoom", false, true, 1, 1*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, receiver) + assert.Equal(t, banner, "8382") + + sender, _, _, err := ConnectToTCPServer("127.0.0.1:8383", "pass123", "testRoom", true, true, 1, 1*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, sender) + + mockTransfer(receiver, sender, t) + + receiver.Close() + sender.Close() + time.Sleep(300 * time.Millisecond) +} + +// Test that a third client cannot connect +// to a room that already has two clients +// connected to it with maxTransfers=1 +func TestTCPSingleConnectionOnly2Clients(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8384", "pass123", "8382") + time.Sleep(100 * time.Millisecond) + + c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8384", "pass123", "testRoom", true, true, 1, 10*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c1) + assert.Equal(t, banner, "8382") + + c2, _, _, err := ConnectToTCPServer("127.0.0.1:8384", "pass123", "testRoom", false, true, 1, 10*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c2) + closeChan := make(chan int) + + // we need to run this transfer in a goroutine because + // otherwise connections will be idle and the server will + // close them when we try to connect a third client + go func() { + for { + select { + case <-closeChan: + fmt.Println("Closing go routine") + return + default: + mockTransfer(c1, c2, t) + } + } + }() + c3, _, _, err := ConnectToTCPServer("127.0.0.1:8384", "pass123", "testRoom", false, true, 1, 5*time.Minute) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "room is full")) + assert.Nil(t, c3) + closeChan <- 1 + + c1.Close() + c2.Close() + time.Sleep(300 * time.Millisecond) +} + +// Test that the server can handle multiple +// successive transfers from the same sender +func TestTCPMultipleConnectionTransfer(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8385", "pass123", "8382") + time.Sleep(100 * time.Millisecond) + + c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8385", "pass123", "testRoom", true, true, 2, 10*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c1) + assert.Equal(t, banner, "8382") + + c2, _, _, err := ConnectToTCPServer("127.0.0.1:8385", "pass123", "testRoom", false, true, 1, 10*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c2) + + mockTransfer(c1, c2, t) + c2.Close() + // tell c1 to close pipe listener + c1.Send([]byte("finished")) + time.Sleep(100 * time.Millisecond) + + c3, _, _, err := ConnectToTCPServer("127.0.0.1:8385", "pass123", "testRoom", false, true, 1, 5*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c3) + + mockTransfer(c1, c3, t) + + c1.Close() + c3.Close() + time.Sleep(300 * time.Millisecond) +} + +// Test that for a room with maxTransfers>=2, +// the receivers are queued if there is a transfer +// in progress already, and the receiver is allowed +// to connect when the transfer is finished +func TestTCPMultipleConnectionWaitingRoom(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8386", "pass123", "8382") + time.Sleep(100 * time.Millisecond) + + c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8386", "pass123", "testRoom", true, true, 2, 10*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c1) + assert.Equal(t, banner, "8382") + + c2, _, _, err := ConnectToTCPServer("127.0.0.1:8386", "pass123", "testRoom", false, true, 1, 10*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c2) + + // we need to run this transfer in a goroutine because + // otherwise connections will be idle and the server will + // close them when we try to connect a third client + go func() { + counter := 1 + time.Sleep(100 * time.Millisecond) + for { + mockTransfer(c1, c2, t) + if counter == 5 { + c2.Close() + // tell c1 to close pipe listener + c1.Send([]byte("finished")) + break + } + counter++ + } + }() + + c3, _, _, err := ConnectToTCPServer("127.0.0.1:8386", "pass123", "testRoom", false, true, 1, 5*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c3) + + mockTransfer(c1, c3, t) + + c1.Close() + c2.Close() + c3.Close() + time.Sleep(300 * time.Millisecond) +} + +// Test that for a room with maxTransfers>=2, +// if there are receivers queued they will get a +// nottification that the room is no longer available +// when the sender the maxTransfers limit is reached +func TestTCPMultipleConnectionWaitingRoomCloses(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8387", "pass123", "8382") + time.Sleep(100 * time.Millisecond) + + c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8387", "pass123", "testRoom", true, true, 2, 10*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c1) + assert.Equal(t, banner, "8382") + + c2, _, _, err := ConnectToTCPServer("127.0.0.1:8387", "pass123", "testRoom", false, true, 1, 10*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c2) + + // one transfer + mockTransfer(c1, c2, t) + c2.Close() + // tell c1 to close pipe listener + c1.Send([]byte("finished")) + + c2, _, _, err = ConnectToTCPServer("127.0.0.1:8387", "pass123", "testRoom", false, true, 1, 10*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, c2) + + go func() { + c3, _, _, err := ConnectToTCPServer("127.0.0.1:8387", "pass123", "testRoom", false, true, 1, 5*time.Minute) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "sender is gone")) + assert.Nil(t, c3) + }() + + time.Sleep(100 * time.Millisecond) + c2.Close() + // tell c1 to close pipe listener + c1.Send([]byte("finished")) c1.Close() time.Sleep(300 * time.Millisecond) }