package main

import (
	"context"
	"flag"
	"fmt"
	"log"
	"net"
	"sync"
	"time"

	"golang.org/x/crypto/ssh"
	"github.com/armon/go-socks5"
	"ssh-client/tunnel"
)

var (
	sshClient *ssh.Client
	clientMux sync.Mutex
)

func main() {
	// Define command-line flags
	server := flag.String("server", " ", "SSH server address")
	port := flag.String("port", " ", "SSH server port")
	user := flag.String("user", " ", "SSH username")
	pass := flag.String("pass", " ", "SSH password")
	listen := flag.String("listen", "localhost:1080", "Local address to listen on")
	proxyHost := flag.String("proxyHost", " ", "HTTP proxy host")
	proxyPort := flag.Int("proxyPort",  , "HTTP proxy port")

	// Parse command-line flags
	flag.Parse()

	sshConfig := &ssh.ClientConfig{
		User: *user,
		Auth: []ssh.AuthMethod{
			ssh.Password(*pass),
		},
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
	}

	sshHost := fmt.Sprintf("%s:%s", *server, *port)
	var err error

	// Create HTTP proxy
	httpProxy, err := tunnel.NewHttpProxy(*proxyHost, *proxyPort)
	if err != nil {
		log.Fatalf("Failed to create HTTP proxy: %s", err)
	}

	// Initial SSH connection
	sshClient, err = connectToSSH(sshHost, sshConfig, httpProxy)
	if err != nil {
		log.Fatalf("Failed to establish initial SSH connection: %s", err)
	}
	defer httpProxy.Close()

	listener, err := net.Listen("tcp", *listen)
	if err != nil {
		log.Fatalf("Failed to listen on %s: %s", *listen, err)
	}
	defer listener.Close()

	log.Printf("SOCKS5 proxy listening on %s", *listen)

	socksConf := &socks5.Config{
		Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
			return dialThroughSSH(ctx, network, addr, sshHost, sshConfig, httpProxy)
		},
	}
	socksServer, err := socks5.New(socksConf)
	if err != nil {
		log.Fatalf("Failed to create SOCKS5 server: %s", err)
	}

	for {
		conn, err := listener.Accept()
		if err != nil {
			log.Printf("Failed to accept connection: %s", err)
			continue
		}
		go func() {
			defer conn.Close()
			if err := socksServer.ServeConn(conn); err != nil {
				log.Printf("Failed to serve connection: %s", err)
			}
		}()
	}
}

func connectToSSH(sshHost string, sshConfig *ssh.ClientConfig, proxy *tunnel.HttpProxy) (*ssh.Client, error) {
	for {
		proxyConn, err := proxy.OpenConnection(sshHost, 0, 10000, 10000)
		if err != nil {
			log.Printf("Failed to connect to SSH server via proxy: %s. Retrying...", err)
			time.Sleep(2 * time.Second)
			continue
		}

		clientConn, chans, reqs, err := ssh.NewClientConn(proxyConn, sshHost, sshConfig)
		if err == nil {
			log.Printf("Successfully connected to SSH server at %s via proxy", sshHost)
			return ssh.NewClient(clientConn, chans, reqs), nil
		}

		log.Printf("Failed to dial SSH: %s. Retrying...", err)
		proxyConn.Close()
		//time.Sleep(2 * time.Second)
	}
}

func dialThroughSSH(ctx context.Context, network, addr, sshHost string, sshConfig *ssh.ClientConfig, proxy *tunnel.HttpProxy) (net.Conn, error) {
	clientMux.Lock()
	defer clientMux.Unlock()

	conn, err := sshClient.Dial(network, addr)
	if err == nil {
		return conn, nil
	}

	// Reconnect if the connection is lost
	log.Printf("SSH connection lost: %s. Reconnecting...", err)
	sshClient, err = connectToSSH(sshHost, sshConfig, proxy)
	if err != nil {
		return nil, err
	}

	return sshClient.Dial(network, addr)
}