diff --git a/internal/sshproxy/proxy_connection.go b/internal/sshproxy/proxy_connection.go index c0cd6eb..f2a5ba2 100644 --- a/internal/sshproxy/proxy_connection.go +++ b/internal/sshproxy/proxy_connection.go @@ -11,9 +11,12 @@ type proxyConnection struct { internalPort int externalPort int listener net.Listener + + // closedChan is used to notify that the proxy connection is closed + closedChan chan<- *proxyConnection } -func newProxyConnection(toPort int) (*proxyConnection, error) { +func newProxyConnection(toPort int, closedChan chan *proxyConnection) (*proxyConnection, error) { l, err := net.Listen("tcp", ":0") if err != nil { return nil, err @@ -25,6 +28,7 @@ func newProxyConnection(toPort int) (*proxyConnection, error) { internalPort: toPort, externalPort: externalPort, listener: l, + closedChan: closedChan, }, nil } @@ -44,6 +48,9 @@ func (c *proxyConnection) forwardConnectionToSSH(conn net.Conn) { fmt.Printf("error connecting to container ssh at port %d\n", c.internalPort) return } + defer func() { + c.closedChan <- c + }() defer containerConn.Close() var wg sync.WaitGroup diff --git a/internal/sshproxy/sshproxy.go b/internal/sshproxy/sshproxy.go index 1fc87d0..a47552a 100644 --- a/internal/sshproxy/sshproxy.go +++ b/internal/sshproxy/sshproxy.go @@ -6,17 +6,24 @@ type SSHProxy struct { internalPorts map[int]int connections map[int]*proxyConnection + + closedConnections chan *proxyConnection } func New() *SSHProxy { - return &SSHProxy{ - internalPorts: map[int]int{}, - connections: map[int]*proxyConnection{}, + p := &SSHProxy{ + internalPorts: map[int]int{}, + connections: map[int]*proxyConnection{}, + closedConnections: make(chan *proxyConnection), } + + go p.handleClosedConnections() + + return p } func (p *SSHProxy) NewProxyEntryTo(toPort int) error { - c, err := newProxyConnection(toPort) + c, err := newProxyConnection(toPort, p.closedConnections) if err != nil { return err } @@ -35,3 +42,10 @@ func (p *SSHProxy) FindExternalPort(internalPort int) int { } return -1 } + +func (p *SSHProxy) handleClosedConnections() { + for c := range p.closedConnections { + delete(p.internalPorts, c.internalPort) + delete(p.connections, c.internalPort) + } +} diff --git a/internal/workspace/http_handlers.go b/internal/workspace/http_handlers.go index bf18307..3157df4 100644 --- a/internal/workspace/http_handlers.go +++ b/internal/workspace/http_handlers.go @@ -11,7 +11,7 @@ import ( "net/http" "strconv" "sync" - docker2 "tesseract/internal/docker" + "tesseract/internal/docker" "tesseract/internal/service" "tesseract/internal/template" "time" @@ -42,7 +42,7 @@ func fetchAllWorkspaces(c echo.Context) error { return c.JSON(http.StatusOK, make([]workspace, 0)) } - docker := service.DockerClient(c) + dockerClient := service.DockerClient(c) sshProxy := service.SSHProxy(c) var wg sync.WaitGroup @@ -54,7 +54,7 @@ func fetchAllWorkspaces(c echo.Context) error { go func() { defer wg.Done() - inspect, err := docker.ContainerInspect(ctx, w.ContainerID) + inspect, err := dockerClient.ContainerInspect(ctx, w.ContainerID) if err != nil { mu.Lock() errs = append(errs, err) @@ -73,7 +73,7 @@ func fetchAllWorkspaces(c echo.Context) error { workspaces[i].Status = statusUnknown } - if internalPort := docker2.ContainerSSHHostPort(ctx, inspect); internalPort > 0 { + if internalPort := docker.ContainerSSHHostPort(ctx, inspect); internalPort > 0 { if port := sshProxy.FindExternalPort(internalPort); port > 0 { workspaces[i].SSHPort = port } @@ -120,20 +120,36 @@ func updateOrCreateWorkspace(c echo.Context) error { return err } - docker := service.DockerClient(c) + dockerClient := service.DockerClient(c) + sshProxy := service.SSHProxy(c) switch status(body.Status) { case statusStopped: - if err = stopContainer(ctx, docker, workspaceName); err != nil { + if err = stopContainer(ctx, dockerClient, workspaceName); err != nil { return err } w.Status = statusStopped break + case statusRunning: - if err = startContainer(ctx, docker, workspaceName); err != nil { + if err = startContainer(ctx, dockerClient, workspaceName); err != nil { return err } + + inspect, err := dockerClient.ContainerInspect(ctx, w.ContainerID) + if err != nil { + return err + } + + sshPort := docker.ContainerSSHHostPort(ctx, inspect) + if sshPort > 0 { + if err = sshProxy.NewProxyEntryTo(sshPort); err != nil { + return err + } + } + w.Status = statusRunning + break } @@ -165,7 +181,7 @@ func createWorkspace(c echo.Context, workspaceName string) error { return err } - docker := service.DockerClient(c) + dockerClient := service.DockerClient(c) containerSSHPort := nat.Port("22/tcp") containerConfig := &container.Config{ @@ -184,17 +200,17 @@ func createWorkspace(c echo.Context, workspaceName string) error { }, } - res, err := docker.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, workspaceName) + res, err := dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, workspaceName) if err != nil { return err } - err = docker.ContainerStart(ctx, res.ID, container.StartOptions{}) + err = dockerClient.ContainerStart(ctx, res.ID, container.StartOptions{}) if err != nil { return err } - inspect, err := docker.ContainerInspect(ctx, res.ID) + inspect, err := dockerClient.ContainerInspect(ctx, res.ID) if err != nil { return err } @@ -248,7 +264,7 @@ func deleteWorkspace(c echo.Context) error { } db := service.Database(c) - docker := service.DockerClient(c) + dockerClient := service.DockerClient(c) ctx := c.Request().Context() tx, err := db.BeginTx(ctx, nil) @@ -262,20 +278,20 @@ func deleteWorkspace(c echo.Context) error { return echo.NewHTTPError(http.StatusNotFound) } - inspect, err := inspectContainer(ctx, docker, w.ContainerID) + inspect, err := inspectContainer(ctx, dockerClient, w.ContainerID) if err != nil { _ = tx.Rollback() return err } if inspect.State.Running { - if err = stopContainer(ctx, docker, w.ContainerID); err != nil { + if err = stopContainer(ctx, dockerClient, w.ContainerID); err != nil { _ = tx.Rollback() return err } } - if err = deleteContainer(ctx, docker, w.ContainerID); err != nil { + if err = deleteContainer(ctx, dockerClient, w.ContainerID); err != nil { return err }