Files
kport/port_forwarder.go

323 lines
8.8 KiB
Go
Raw Normal View History

package main
import (
"fmt"
"io"
"net"
"os"
"path/filepath"
"strconv"
"sync"
"time"
tea "github.com/charmbracelet/bubbletea"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// ForwardingStartedMsg is sent when port forwarding starts
type ForwardingStartedMsg struct {
LocalPort int
RemotePort int
}
// PortForwarder manages SSH port forwarding
type PortForwarder struct {
sshClient *ssh.Client
localPort int
remotePort int
listener net.Listener
stopChan chan struct{}
wg sync.WaitGroup
isRunning bool
mu sync.Mutex
}
// NewPortForwarder creates a new port forwarder
func NewPortForwarder(sshClient *ssh.Client, localPort, remotePort int) *PortForwarder {
return &PortForwarder{
sshClient: sshClient,
localPort: localPort,
remotePort: remotePort,
stopChan: make(chan struct{}),
}
}
// Start starts the port forwarding
func (pf *PortForwarder) Start() error {
pf.mu.Lock()
defer pf.mu.Unlock()
if pf.isRunning {
return fmt.Errorf("port forwarding already running")
}
// Create local listener
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", pf.localPort))
if err != nil {
return fmt.Errorf("failed to create local listener: %w", err)
}
pf.listener = listener
pf.isRunning = true
// Start accepting connections
pf.wg.Add(1)
go pf.acceptConnections()
return nil
}
// Stop stops the port forwarding
func (pf *PortForwarder) Stop() {
pf.mu.Lock()
defer pf.mu.Unlock()
if !pf.isRunning {
return
}
pf.isRunning = false
close(pf.stopChan)
if pf.listener != nil {
pf.listener.Close()
}
pf.wg.Wait()
}
// acceptConnections accepts and handles incoming connections
func (pf *PortForwarder) acceptConnections() {
defer pf.wg.Done()
for {
select {
case <-pf.stopChan:
return
default:
// Set a timeout for Accept to avoid blocking indefinitely
if tcpListener, ok := pf.listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
}
conn, err := pf.listener.Accept()
if err != nil {
// Check if it's a timeout error and continue
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
// If we're stopping, ignore the error
select {
case <-pf.stopChan:
return
default:
continue
}
}
// Handle the connection in a separate goroutine
pf.wg.Add(1)
go pf.handleConnection(conn)
}
}
}
// handleConnection handles a single connection
func (pf *PortForwarder) handleConnection(localConn net.Conn) {
defer pf.wg.Done()
defer localConn.Close()
// Create connection to remote host through SSH
remoteConn, err := pf.sshClient.Dial("tcp", fmt.Sprintf("localhost:%d", pf.remotePort))
if err != nil {
return
}
defer remoteConn.Close()
// Copy data between connections
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(localConn, remoteConn)
}()
go func() {
defer wg.Done()
io.Copy(remoteConn, localConn)
}()
wg.Wait()
}
// StartPortForwarding starts port forwarding for a specific port
func StartPortForwarding(host SSHHost, remotePort int) tea.Cmd {
return func() tea.Msg {
fmt.Fprintf(os.Stderr, "Debug: Starting port forwarding for %s:%d\n", host.Name, remotePort)
// Find an available local port
localPort, err := findAvailablePort()
if err != nil {
fmt.Fprintf(os.Stderr, "Debug: Failed to find available port: %v\n", err)
return ErrorMsg{Error: fmt.Errorf("failed to find available local port: %w", err)}
}
fmt.Fprintf(os.Stderr, "Debug: Found available local port: %d\n", localPort)
// Create SSH client
client, err := createSSHClient(host)
if err != nil {
fmt.Fprintf(os.Stderr, "Debug: Failed to create SSH client: %v\n", err)
return ErrorMsg{Error: fmt.Errorf("failed to connect to %s: %w", host.Name, err)}
}
fmt.Fprintf(os.Stderr, "Debug: SSH client created successfully\n")
// Create and start port forwarder
forwarder := NewPortForwarder(client, localPort, remotePort)
if err := forwarder.Start(); err != nil {
client.Close()
fmt.Fprintf(os.Stderr, "Debug: Failed to start port forwarder: %v\n", err)
return ErrorMsg{Error: fmt.Errorf("failed to start port forwarding: %w", err)}
}
fmt.Fprintf(os.Stderr, "Debug: Port forwarder started successfully\n")
return ForwardingStartedMsg{
LocalPort: localPort,
RemotePort: remotePort,
}
}
}
// StartManualPortForwarding starts port forwarding for a manually entered port
func StartManualPortForwarding(host SSHHost, portStr string) tea.Cmd {
return func() tea.Msg {
fmt.Fprintf(os.Stderr, "Debug: Manual port forwarding requested for %s:%s\n", host.Name, portStr)
remotePort, err := strconv.Atoi(portStr)
if err != nil {
fmt.Fprintf(os.Stderr, "Debug: Invalid port number: %s\n", portStr)
return ErrorMsg{Error: fmt.Errorf("invalid port number: %s", portStr)}
}
if remotePort <= 0 || remotePort > 65535 {
fmt.Fprintf(os.Stderr, "Debug: Port number out of range: %d\n", remotePort)
return ErrorMsg{Error: fmt.Errorf("port number must be between 1 and 65535")}
}
// Find an available local port
localPort, err := findAvailablePort()
if err != nil {
fmt.Fprintf(os.Stderr, "Debug: Failed to find available port: %v\n", err)
return ErrorMsg{Error: fmt.Errorf("failed to find available local port: %w", err)}
}
fmt.Fprintf(os.Stderr, "Debug: Found available local port: %d\n", localPort)
// Create SSH client
client, err := createSSHClient(host)
if err != nil {
fmt.Fprintf(os.Stderr, "Debug: Failed to create SSH client: %v\n", err)
return ErrorMsg{Error: fmt.Errorf("failed to connect to %s: %w", host.Name, err)}
}
fmt.Fprintf(os.Stderr, "Debug: SSH client created successfully\n")
// Create and start port forwarder
forwarder := NewPortForwarder(client, localPort, remotePort)
if err := forwarder.Start(); err != nil {
client.Close()
fmt.Fprintf(os.Stderr, "Debug: Failed to start port forwarder: %v\n", err)
return ErrorMsg{Error: fmt.Errorf("failed to start port forwarding: %w", err)}
}
fmt.Fprintf(os.Stderr, "Debug: Port forwarder started successfully\n")
return ForwardingStartedMsg{
LocalPort: localPort,
RemotePort: remotePort,
}
}
}
// createSSHClient creates an SSH client for the given host
func createSSHClient(host SSHHost) (*ssh.Client, error) {
config := &ssh.ClientConfig{
User: host.User,
Auth: []ssh.AuthMethod{},
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // In production, use proper host key verification
Timeout: 10 * time.Second, // Longer timeout for better reliability
}
// Add key-based authentication if identity file is specified
if host.Identity != "" {
fmt.Fprintf(os.Stderr, "Debug: Trying identity file: %s\n", host.Identity)
key, err := loadPrivateKey(host.Identity)
if err == nil {
config.Auth = append(config.Auth, ssh.PublicKeys(key))
fmt.Fprintf(os.Stderr, "Debug: Added key-based auth\n")
} else {
fmt.Fprintf(os.Stderr, "Debug: Failed to load identity file: %v\n", err)
}
}
// Add SSH agent authentication
if agentAuth, err := sshAgentAuth(); err == nil {
config.Auth = append(config.Auth, agentAuth)
fmt.Fprintf(os.Stderr, "Debug: Added SSH agent auth\n")
} else {
fmt.Fprintf(os.Stderr, "Debug: SSH agent not available: %v\n", err)
}
// Try to load default SSH keys if no specific identity is set
if host.Identity == "" {
defaultKeys := []string{"id_rsa", "id_ecdsa", "id_ed25519"}
homeDir, err := os.UserHomeDir()
if err == nil {
for _, keyName := range defaultKeys {
keyPath := filepath.Join(homeDir, ".ssh", keyName)
if key, err := loadPrivateKey(keyPath); err == nil {
config.Auth = append(config.Auth, ssh.PublicKeys(key))
fmt.Fprintf(os.Stderr, "Debug: Added default key: %s\n", keyName)
}
}
}
}
// If no auth methods available, provide helpful error
if len(config.Auth) == 0 {
return nil, fmt.Errorf("no SSH authentication methods available - please set up SSH keys or SSH agent")
}
// Connect to the remote host
addr := net.JoinHostPort(host.Hostname, host.Port)
fmt.Fprintf(os.Stderr, "Debug: Connecting to %s\n", addr)
client, err := ssh.Dial("tcp", addr, config)
if err != nil {
return nil, fmt.Errorf("failed to connect to %s (%s): %w", host.Name, addr, err)
}
fmt.Fprintf(os.Stderr, "Debug: Successfully connected to %s\n", host.Name)
return client, nil
}
// sshAgentAuth returns SSH agent authentication method
func sshAgentAuth() (ssh.AuthMethod, error) {
// Try to connect to SSH agent
agentConn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
return nil, err
}
sshAgent := agent.NewClient(agentConn)
return ssh.PublicKeysCallback(sshAgent.Signers), nil
}
// findAvailablePort finds an available local port
func findAvailablePort() (int, error) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
return 0, err
}
defer listener.Close()
addr := listener.Addr().(*net.TCPAddr)
return addr.Port, nil
}