379 lines
8.2 KiB
Go
379 lines
8.2 KiB
Go
package service
|
||
|
||
import (
|
||
"fmt"
|
||
"io"
|
||
"io/fs"
|
||
"io/ioutil"
|
||
"log"
|
||
"net"
|
||
"os"
|
||
"path"
|
||
"path/filepath"
|
||
"runtime"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/pkg/errors"
|
||
"github.com/pkg/sftp"
|
||
"github.com/sirupsen/logrus"
|
||
"golang.org/x/crypto/ssh"
|
||
)
|
||
|
||
var sftpCli *sftp.Client
|
||
var sftpOne sync.Once
|
||
|
||
type SSHService struct {
|
||
Client *ssh.Client
|
||
LastResult string //执行结果
|
||
}
|
||
|
||
func NewSSHService() *SSHService {
|
||
ss := &SSHService{}
|
||
return ss
|
||
}
|
||
|
||
func (ss *SSHService) Connect(user, password, host, key string, port int, cipherList []string) error {
|
||
var (
|
||
auth []ssh.AuthMethod //认证方式
|
||
addr string
|
||
clientConfig *ssh.ClientConfig
|
||
config ssh.Config
|
||
err error
|
||
)
|
||
auth = make([]ssh.AuthMethod, 0)
|
||
if key == "" {
|
||
// 密码认证
|
||
auth = append(auth, ssh.Password(password))
|
||
} else {
|
||
// 秘钥认证
|
||
pemBytes, err := ioutil.ReadFile(key)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
var signer ssh.Signer
|
||
if password == "" {
|
||
signer, err = ssh.ParsePrivateKey(pemBytes)
|
||
} else {
|
||
signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(password))
|
||
}
|
||
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 加载秘钥
|
||
auth = append(auth, ssh.PublicKeys(signer))
|
||
}
|
||
|
||
// 设置ssh 的配置参数
|
||
if len(cipherList) == 0 {
|
||
config = ssh.Config{
|
||
// 连接所允许的加密算法, go的SSH包允许的算法
|
||
Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "aes192-cbc", "aes256-cbc"},
|
||
}
|
||
} else {
|
||
config = ssh.Config{
|
||
Ciphers: cipherList,
|
||
}
|
||
}
|
||
|
||
clientConfig = &ssh.ClientConfig{
|
||
User: user,
|
||
Auth: auth,
|
||
Timeout: time.Second * 30,
|
||
Config: config,
|
||
// 默认密钥不受信任时,Go 的 ssh 包会在 HostKeyCallback 里把连接干掉(1.8 之后加的应该)。但是我们使用用户名密码连接的时候,这个太正常了,所以让他 return nil 就好了
|
||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||
return nil
|
||
},
|
||
}
|
||
addr = fmt.Sprintf("%s:%d", host, port)
|
||
|
||
if ss.Client, err = ssh.Dial("tcp", addr, clientConfig); err != nil {
|
||
return err
|
||
}
|
||
|
||
sftpCli, err = sftp.NewClient(ss.Client)
|
||
return nil
|
||
}
|
||
|
||
func (ss *SSHService) getSftp() (*sftp.Client, error) {
|
||
if ss.Client == nil {
|
||
return nil, errors.New("ssh client is nil")
|
||
}
|
||
var err error
|
||
sftpOne.Do(func() {
|
||
if sftpCli, err = sftp.NewClient(ss.Client); err != nil {
|
||
return
|
||
}
|
||
})
|
||
return sftpCli, nil
|
||
}
|
||
|
||
func (ss *SSHService) Close() {
|
||
if ss.Client != nil {
|
||
if err := ss.Client.Close(); err != nil {
|
||
logrus.WithField("err", err).Error("close ssh err")
|
||
}
|
||
}
|
||
// if sftpCli != nil {
|
||
// if err := sftpCli.Close(); err != nil {
|
||
// logrus.WithField("err", err).Error("close sftp err")
|
||
// }
|
||
// }
|
||
}
|
||
|
||
func (ss *SSHService) RunShell(shell string) {
|
||
var (
|
||
session *ssh.Session
|
||
err error
|
||
)
|
||
|
||
//获取session,这个session是用来远程执行操作的
|
||
if session, err = ss.Client.NewSession(); err != nil {
|
||
logrus.Errorf("error newsession:%v", err)
|
||
}
|
||
// 使用 session.Shell() 模拟终端时,所建立的终端参数
|
||
modes := ssh.TerminalModes{
|
||
ssh.ECHO: 0, //disable echoing
|
||
ssh.TTY_OP_ISPEED: 14400, //input speed=14.4kbaud
|
||
ssh.TTY_OP_OSPEED: 14400,
|
||
}
|
||
|
||
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
|
||
logrus.Error(err)
|
||
}
|
||
|
||
//执行shell
|
||
if output, err := session.CombinedOutput(shell); err != nil {
|
||
logrus.Errorf("error CombinedOutput:%v", err)
|
||
} else {
|
||
ss.LastResult = string(output)
|
||
}
|
||
}
|
||
|
||
//单个copy
|
||
func (ss *SSHService) ScpCopy(localFilePath, remoteDir string) error {
|
||
var (
|
||
err error
|
||
)
|
||
|
||
fi, err := os.Stat(localFilePath)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if fi.IsDir() {
|
||
return errors.New(localFilePath + " is not file")
|
||
}
|
||
|
||
sftpCli, err = ss.getSftp()
|
||
if err != nil {
|
||
return fmt.Errorf("new sftp client error: %w", err)
|
||
}
|
||
|
||
// defer sftpCli.Close()
|
||
|
||
srcFile, err := os.Open(localFilePath)
|
||
if err != nil {
|
||
log.Println("scpCopy:", err)
|
||
return err
|
||
}
|
||
defer srcFile.Close()
|
||
|
||
var remoteFileName string
|
||
sysTyep := runtime.GOOS
|
||
if sysTyep == "windows" {
|
||
remoteFileName = path.Base(filepath.ToSlash(localFilePath))
|
||
} else {
|
||
remoteFileName = path.Base(localFilePath)
|
||
}
|
||
|
||
dstFile, err := sftpCli.Create(path.Join(remoteDir, remoteFileName))
|
||
if err != nil {
|
||
log.Println("scpCopy:", err)
|
||
return err
|
||
}
|
||
defer dstFile.Close()
|
||
|
||
fileByte, err := ioutil.ReadAll(srcFile)
|
||
if nil != err {
|
||
return fmt.Errorf("read local file failed: %w", err)
|
||
}
|
||
|
||
if _, err := dstFile.Write(fileByte); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
//批量copy
|
||
func (ss *SSHService) BatchScpCoy(cfs []CopyFiles, remoteDir string) error {
|
||
var err error
|
||
sftpCli, err = ss.getSftp()
|
||
if err != nil {
|
||
return fmt.Errorf("new sftp client error: %w", err)
|
||
}
|
||
|
||
for _, f := range cfs {
|
||
srcFile, err := os.Open(path.Join(f.Dir, f.FileName))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer srcFile.Close()
|
||
|
||
dstFile, err := sftpCli.Create(path.Join(remoteDir, f.FileName))
|
||
if err != nil {
|
||
logrus.Error("scpCopy:", err)
|
||
return err
|
||
}
|
||
defer dstFile.Close()
|
||
|
||
fileByte, err := ioutil.ReadAll(srcFile)
|
||
if nil != err {
|
||
return fmt.Errorf("read local file failed: %w", err)
|
||
}
|
||
|
||
if _, err := dstFile.Write(fileByte); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
//Deprecated Scp
|
||
func (ss *SSHService) Scp(targetDir, srcFileName string) (int, error) {
|
||
sftpClient, err := sftp.NewClient(ss.Client)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("new sftp client error: %w", err)
|
||
}
|
||
defer sftpClient.Close()
|
||
// source, err := sftpClient.Open(srcFileName)
|
||
// if err != nil {
|
||
// return 0, fmt.Errorf("sftp client open src file error: %w", err)
|
||
// }
|
||
// defer source.Close()
|
||
srcFile, err := os.Open(srcFileName)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("open local file error: %w", err)
|
||
}
|
||
defer srcFile.Close()
|
||
|
||
var remoteFileName = path.Base(srcFileName)
|
||
dstFile, err := sftpClient.Create(path.Join(targetDir, remoteFileName))
|
||
if err != nil {
|
||
fmt.Errorf("scpCopy:%v", err)
|
||
return 0, err
|
||
}
|
||
// n, err := io.Copy(target, source)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("copy file error: %w", err)
|
||
}
|
||
defer dstFile.Close()
|
||
|
||
buf := make([]byte, 1024)
|
||
n := 0
|
||
for {
|
||
n, _ = srcFile.Read(buf)
|
||
if n == 0 {
|
||
break
|
||
}
|
||
dstFile.Write(buf[0:n])
|
||
}
|
||
return n, nil
|
||
}
|
||
|
||
//Download
|
||
func (ss *SSHService) ScpDownload(localDir, remoteFilePath string) error {
|
||
var err error
|
||
sftpCli, err = ss.getSftp()
|
||
if err != nil {
|
||
return fmt.Errorf("new sftp client error: %w", err)
|
||
}
|
||
|
||
remoteFile, err := sftpCli.Open(remoteFilePath)
|
||
if err != nil {
|
||
log.Println("scpCopy:", err)
|
||
return err
|
||
}
|
||
defer remoteFile.Close()
|
||
|
||
fileName := path.Base(remoteFile.Name())
|
||
|
||
if err := os.MkdirAll(localDir, fs.ModePerm); err != nil {
|
||
logrus.Error(err)
|
||
return err
|
||
}
|
||
|
||
target, err := os.OpenFile(localDir+fileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fs.ModePerm)
|
||
if err != nil {
|
||
return fmt.Errorf("open local file error: %w", err)
|
||
}
|
||
defer target.Close()
|
||
|
||
_, err = io.Copy(target, remoteFile)
|
||
if err != nil {
|
||
return fmt.Errorf("write file error: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (ss *SSHService) GetRemoteDir(remoteDir string) (files []File, err error) {
|
||
sftpCli, err = ss.getSftp()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("new sftp client error: %w", err)
|
||
}
|
||
|
||
remoteFiles, err := sftpCli.ReadDir(remoteDir)
|
||
if err != nil {
|
||
log.Println("read remote Dir:", err)
|
||
return nil, err
|
||
}
|
||
|
||
for _, f := range remoteFiles {
|
||
fi := File{
|
||
FileName: f.Name(),
|
||
FilePath: filepath.Join(remoteDir, f.Name()),
|
||
Size: f.Size(),
|
||
}
|
||
files = append(files, fi)
|
||
// logrus.WithFields(logrus.Fields{"name": f.Name(), "size": f.Size()}).Debug("远程日志文件")
|
||
}
|
||
return
|
||
}
|
||
|
||
func (ss *SSHService) BatchScpDownload(localDir, remoteDir string) error {
|
||
var err error
|
||
sftpCli, err = ss.getSftp()
|
||
if err != nil {
|
||
return fmt.Errorf("new sftp client error: %w", err)
|
||
}
|
||
|
||
remoteFiles, err := sftpCli.ReadDir(remoteDir)
|
||
if err != nil {
|
||
log.Println("read remote Dir:", err)
|
||
return err
|
||
}
|
||
|
||
for _, f := range remoteFiles {
|
||
if err := ss.ScpDownload(localDir, filepath.Join(remoteDir, f.Name())); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type CopyFiles struct {
|
||
Dir string
|
||
FileName string
|
||
}
|
||
|
||
type File struct {
|
||
FileName string
|
||
FilePath string
|
||
Size int64
|
||
}
|