package service import ( "fmt" "io" "io/fs" "io/ioutil" "log" "net" "os" "path" "path/filepath" "runtime" "sync" "time" "github.com/pkg/errors" "github.com/pkg/sftp" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) var sftpCli *sftp.Client var sftpOne sync.Once type SSHService struct { Client *ssh.Client LastResult string //执行结果 } func NewSSHService() *SSHService { ss := &SSHService{} return ss } func (ss *SSHService) Connect(user, password, host, key string, port int, cipherList []string) error { var ( auth []ssh.AuthMethod //认证方式 addr string clientConfig *ssh.ClientConfig config ssh.Config err error ) auth = make([]ssh.AuthMethod, 0) if key == "" { // 密码认证 auth = append(auth, ssh.Password(password)) } else { // 秘钥认证 pemBytes, err := ioutil.ReadFile(key) if err != nil { return err } var signer ssh.Signer if password == "" { signer, err = ssh.ParsePrivateKey(pemBytes) } else { signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(password)) } if err != nil { return err } // 加载秘钥 auth = append(auth, ssh.PublicKeys(signer)) } // 设置ssh 的配置参数 if len(cipherList) == 0 { config = ssh.Config{ // 连接所允许的加密算法, go的SSH包允许的算法 Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "aes192-cbc", "aes256-cbc"}, } } else { config = ssh.Config{ Ciphers: cipherList, } } clientConfig = &ssh.ClientConfig{ User: user, Auth: auth, Timeout: time.Second * 30, Config: config, // 默认密钥不受信任时,Go 的 ssh 包会在 HostKeyCallback 里把连接干掉(1.8 之后加的应该)。但是我们使用用户名密码连接的时候,这个太正常了,所以让他 return nil 就好了 HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, } addr = fmt.Sprintf("%s:%d", host, port) if ss.Client, err = ssh.Dial("tcp", addr, clientConfig); err != nil { return err } sftpCli, err = sftp.NewClient(ss.Client) return nil } func (ss *SSHService) getSftp() (*sftp.Client, error) { if ss.Client == nil { return nil, errors.New("ssh client is nil") } var err error sftpOne.Do(func() { if sftpCli, err = sftp.NewClient(ss.Client); err != nil { return } }) return sftpCli, nil } func (ss *SSHService) Close() { if ss.Client != nil { if err := ss.Client.Close(); err != nil { logrus.WithField("err", err).Error("close ssh err") } } // if sftpCli != nil { // if err := sftpCli.Close(); err != nil { // logrus.WithField("err", err).Error("close sftp err") // } // } } func (ss *SSHService) RunShell(shell string) { var ( session *ssh.Session err error ) //获取session,这个session是用来远程执行操作的 if session, err = ss.Client.NewSession(); err != nil { logrus.Errorf("error newsession:%v", err) } // 使用 session.Shell() 模拟终端时,所建立的终端参数 modes := ssh.TerminalModes{ ssh.ECHO: 0, //disable echoing ssh.TTY_OP_ISPEED: 14400, //input speed=14.4kbaud ssh.TTY_OP_OSPEED: 14400, } if err := session.RequestPty("xterm", 80, 40, modes); err != nil { logrus.Error(err) } //执行shell if output, err := session.CombinedOutput(shell); err != nil { logrus.Errorf("error CombinedOutput:%v", err) } else { ss.LastResult = string(output) } } //单个copy func (ss *SSHService) ScpCopy(localFilePath, remoteDir string) error { var ( err error ) fi, err := os.Stat(localFilePath) if err != nil { return err } if fi.IsDir() { return errors.New(localFilePath + " is not file") } sftpCli, err = ss.getSftp() if err != nil { return fmt.Errorf("new sftp client error: %w", err) } // defer sftpCli.Close() srcFile, err := os.Open(localFilePath) if err != nil { log.Println("scpCopy:", err) return err } defer srcFile.Close() var remoteFileName string sysTyep := runtime.GOOS if sysTyep == "windows" { remoteFileName = path.Base(filepath.ToSlash(localFilePath)) } else { remoteFileName = path.Base(localFilePath) } dstFile, err := sftpCli.Create(path.Join(remoteDir, remoteFileName)) if err != nil { log.Println("scpCopy:", err) return err } defer dstFile.Close() fileByte, err := ioutil.ReadAll(srcFile) if nil != err { return fmt.Errorf("read local file failed: %w", err) } if _, err := dstFile.Write(fileByte); err != nil { return err } return nil } //批量copy func (ss *SSHService) BatchScpCoy(cfs []CopyFiles, remoteDir string) error { var err error sftpCli, err = ss.getSftp() if err != nil { return fmt.Errorf("new sftp client error: %w", err) } for _, f := range cfs { srcFile, err := os.Open(path.Join(f.Dir, f.FileName)) if err != nil { return err } defer srcFile.Close() dstFile, err := sftpCli.Create(path.Join(remoteDir, f.FileName)) if err != nil { logrus.Error("scpCopy:", err) return err } defer dstFile.Close() fileByte, err := ioutil.ReadAll(srcFile) if nil != err { return fmt.Errorf("read local file failed: %w", err) } if _, err := dstFile.Write(fileByte); err != nil { return err } } return nil } //Deprecated Scp func (ss *SSHService) Scp(targetDir, srcFileName string) (int, error) { sftpClient, err := sftp.NewClient(ss.Client) if err != nil { return 0, fmt.Errorf("new sftp client error: %w", err) } defer sftpClient.Close() // source, err := sftpClient.Open(srcFileName) // if err != nil { // return 0, fmt.Errorf("sftp client open src file error: %w", err) // } // defer source.Close() srcFile, err := os.Open(srcFileName) if err != nil { return 0, fmt.Errorf("open local file error: %w", err) } defer srcFile.Close() var remoteFileName = path.Base(srcFileName) dstFile, err := sftpClient.Create(path.Join(targetDir, remoteFileName)) if err != nil { fmt.Errorf("scpCopy:%v", err) return 0, err } // n, err := io.Copy(target, source) if err != nil { return 0, fmt.Errorf("copy file error: %w", err) } defer dstFile.Close() buf := make([]byte, 1024) n := 0 for { n, _ = srcFile.Read(buf) if n == 0 { break } dstFile.Write(buf[0:n]) } return n, nil } //Download func (ss *SSHService) ScpDownload(localDir, remoteFilePath string) error { var err error sftpCli, err = ss.getSftp() if err != nil { return fmt.Errorf("new sftp client error: %w", err) } remoteFile, err := sftpCli.Open(remoteFilePath) if err != nil { log.Println("scpCopy:", err) return err } defer remoteFile.Close() fileName := path.Base(remoteFile.Name()) if err := os.MkdirAll(localDir, fs.ModePerm); err != nil { logrus.Error(err) return err } target, err := os.OpenFile(localDir+fileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fs.ModePerm) if err != nil { return fmt.Errorf("open local file error: %w", err) } defer target.Close() _, err = io.Copy(target, remoteFile) if err != nil { return fmt.Errorf("write file error: %w", err) } return nil } func (ss *SSHService) GetRemoteDir(remoteDir string) (files []File, err error) { sftpCli, err = ss.getSftp() if err != nil { return nil, fmt.Errorf("new sftp client error: %w", err) } remoteFiles, err := sftpCli.ReadDir(remoteDir) if err != nil { log.Println("read remote Dir:", err) return nil, err } for _, f := range remoteFiles { fi := File{ FileName: f.Name(), FilePath: filepath.Join(remoteDir, f.Name()), Size: f.Size(), } files = append(files, fi) // logrus.WithFields(logrus.Fields{"name": f.Name(), "size": f.Size()}).Debug("远程日志文件") } return } func (ss *SSHService) BatchScpDownload(localDir, remoteDir string) error { var err error sftpCli, err = ss.getSftp() if err != nil { return fmt.Errorf("new sftp client error: %w", err) } remoteFiles, err := sftpCli.ReadDir(remoteDir) if err != nil { log.Println("read remote Dir:", err) return err } for _, f := range remoteFiles { if err := ss.ScpDownload(localDir, filepath.Join(remoteDir, f.Name())); err != nil { return err } } return nil } type CopyFiles struct { Dir string FileName string } type File struct { FileName string FilePath string Size int64 }