dreamfactory_cmd/cmd/v2/service/sshService.go
2023-06-09 21:58:02 +08:00

379 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"fmt"
"io"
"io/fs"
"io/ioutil"
"log"
"net"
"os"
"path"
"path/filepath"
"runtime"
"sync"
"time"
"github.com/pkg/errors"
"github.com/pkg/sftp"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
var sftpCli *sftp.Client
var sftpOne sync.Once
type SSHService struct {
Client *ssh.Client
LastResult string //执行结果
}
func NewSSHService() *SSHService {
ss := &SSHService{}
return ss
}
func (ss *SSHService) Connect(user, password, host, key string, port int, cipherList []string) error {
var (
auth []ssh.AuthMethod //认证方式
addr string
clientConfig *ssh.ClientConfig
config ssh.Config
err error
)
auth = make([]ssh.AuthMethod, 0)
if key == "" {
// 密码认证
auth = append(auth, ssh.Password(password))
} else {
// 秘钥认证
pemBytes, err := ioutil.ReadFile(key)
if err != nil {
return err
}
var signer ssh.Signer
if password == "" {
signer, err = ssh.ParsePrivateKey(pemBytes)
} else {
signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(password))
}
if err != nil {
return err
}
// 加载秘钥
auth = append(auth, ssh.PublicKeys(signer))
}
// 设置ssh 的配置参数
if len(cipherList) == 0 {
config = ssh.Config{
// 连接所允许的加密算法, go的SSH包允许的算法
Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "aes192-cbc", "aes256-cbc"},
}
} else {
config = ssh.Config{
Ciphers: cipherList,
}
}
clientConfig = &ssh.ClientConfig{
User: user,
Auth: auth,
Timeout: time.Second * 30,
Config: config,
// 默认密钥不受信任时Go 的 ssh 包会在 HostKeyCallback 里把连接干掉1.8 之后加的应该)。但是我们使用用户名密码连接的时候,这个太正常了,所以让他 return nil 就好了
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}
addr = fmt.Sprintf("%s:%d", host, port)
if ss.Client, err = ssh.Dial("tcp", addr, clientConfig); err != nil {
return err
}
sftpCli, err = sftp.NewClient(ss.Client)
return nil
}
func (ss *SSHService) getSftp() (*sftp.Client, error) {
if ss.Client == nil {
return nil, errors.New("ssh client is nil")
}
var err error
sftpOne.Do(func() {
if sftpCli, err = sftp.NewClient(ss.Client); err != nil {
return
}
})
return sftpCli, nil
}
func (ss *SSHService) Close() {
if ss.Client != nil {
if err := ss.Client.Close(); err != nil {
logrus.WithField("err", err).Error("close ssh err")
}
}
// if sftpCli != nil {
// if err := sftpCli.Close(); err != nil {
// logrus.WithField("err", err).Error("close sftp err")
// }
// }
}
func (ss *SSHService) RunShell(shell string) {
var (
session *ssh.Session
err error
)
//获取session这个session是用来远程执行操作的
if session, err = ss.Client.NewSession(); err != nil {
logrus.Errorf("error newsession:%v", err)
}
// 使用 session.Shell() 模拟终端时,所建立的终端参数
modes := ssh.TerminalModes{
ssh.ECHO: 0, //disable echoing
ssh.TTY_OP_ISPEED: 14400, //input speed=14.4kbaud
ssh.TTY_OP_OSPEED: 14400,
}
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
logrus.Error(err)
}
//执行shell
if output, err := session.CombinedOutput(shell); err != nil {
logrus.Errorf("error CombinedOutput:%v", err)
} else {
ss.LastResult = string(output)
}
}
//单个copy
func (ss *SSHService) ScpCopy(localFilePath, remoteDir string) error {
var (
err error
)
fi, err := os.Stat(localFilePath)
if err != nil {
return err
}
if fi.IsDir() {
return errors.New(localFilePath + " is not file")
}
sftpCli, err = ss.getSftp()
if err != nil {
return fmt.Errorf("new sftp client error: %w", err)
}
// defer sftpCli.Close()
srcFile, err := os.Open(localFilePath)
if err != nil {
log.Println("scpCopy:", err)
return err
}
defer srcFile.Close()
var remoteFileName string
sysTyep := runtime.GOOS
if sysTyep == "windows" {
remoteFileName = path.Base(filepath.ToSlash(localFilePath))
} else {
remoteFileName = path.Base(localFilePath)
}
dstFile, err := sftpCli.Create(path.Join(remoteDir, remoteFileName))
if err != nil {
log.Println("scpCopy:", err)
return err
}
defer dstFile.Close()
fileByte, err := ioutil.ReadAll(srcFile)
if nil != err {
return fmt.Errorf("read local file failed: %w", err)
}
if _, err := dstFile.Write(fileByte); err != nil {
return err
}
return nil
}
//批量copy
func (ss *SSHService) BatchScpCoy(cfs []CopyFiles, remoteDir string) error {
var err error
sftpCli, err = ss.getSftp()
if err != nil {
return fmt.Errorf("new sftp client error: %w", err)
}
for _, f := range cfs {
srcFile, err := os.Open(path.Join(f.Dir, f.FileName))
if err != nil {
return err
}
defer srcFile.Close()
dstFile, err := sftpCli.Create(path.Join(remoteDir, f.FileName))
if err != nil {
logrus.Error("scpCopy:", err)
return err
}
defer dstFile.Close()
fileByte, err := ioutil.ReadAll(srcFile)
if nil != err {
return fmt.Errorf("read local file failed: %w", err)
}
if _, err := dstFile.Write(fileByte); err != nil {
return err
}
}
return nil
}
//Deprecated Scp
func (ss *SSHService) Scp(targetDir, srcFileName string) (int, error) {
sftpClient, err := sftp.NewClient(ss.Client)
if err != nil {
return 0, fmt.Errorf("new sftp client error: %w", err)
}
defer sftpClient.Close()
// source, err := sftpClient.Open(srcFileName)
// if err != nil {
// return 0, fmt.Errorf("sftp client open src file error: %w", err)
// }
// defer source.Close()
srcFile, err := os.Open(srcFileName)
if err != nil {
return 0, fmt.Errorf("open local file error: %w", err)
}
defer srcFile.Close()
var remoteFileName = path.Base(srcFileName)
dstFile, err := sftpClient.Create(path.Join(targetDir, remoteFileName))
if err != nil {
fmt.Errorf("scpCopy:%v", err)
return 0, err
}
// n, err := io.Copy(target, source)
if err != nil {
return 0, fmt.Errorf("copy file error: %w", err)
}
defer dstFile.Close()
buf := make([]byte, 1024)
n := 0
for {
n, _ = srcFile.Read(buf)
if n == 0 {
break
}
dstFile.Write(buf[0:n])
}
return n, nil
}
//Download
func (ss *SSHService) ScpDownload(localDir, remoteFilePath string) error {
var err error
sftpCli, err = ss.getSftp()
if err != nil {
return fmt.Errorf("new sftp client error: %w", err)
}
remoteFile, err := sftpCli.Open(remoteFilePath)
if err != nil {
log.Println("scpCopy:", err)
return err
}
defer remoteFile.Close()
fileName := path.Base(remoteFile.Name())
if err := os.MkdirAll(localDir, fs.ModePerm); err != nil {
logrus.Error(err)
return err
}
target, err := os.OpenFile(localDir+fileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fs.ModePerm)
if err != nil {
return fmt.Errorf("open local file error: %w", err)
}
defer target.Close()
_, err = io.Copy(target, remoteFile)
if err != nil {
return fmt.Errorf("write file error: %w", err)
}
return nil
}
func (ss *SSHService) GetRemoteDir(remoteDir string) (files []File, err error) {
sftpCli, err = ss.getSftp()
if err != nil {
return nil, fmt.Errorf("new sftp client error: %w", err)
}
remoteFiles, err := sftpCli.ReadDir(remoteDir)
if err != nil {
log.Println("read remote Dir:", err)
return nil, err
}
for _, f := range remoteFiles {
fi := File{
FileName: f.Name(),
FilePath: filepath.Join(remoteDir, f.Name()),
Size: f.Size(),
}
files = append(files, fi)
// logrus.WithFields(logrus.Fields{"name": f.Name(), "size": f.Size()}).Debug("远程日志文件")
}
return
}
func (ss *SSHService) BatchScpDownload(localDir, remoteDir string) error {
var err error
sftpCli, err = ss.getSftp()
if err != nil {
return fmt.Errorf("new sftp client error: %w", err)
}
remoteFiles, err := sftpCli.ReadDir(remoteDir)
if err != nil {
log.Println("read remote Dir:", err)
return err
}
for _, f := range remoteFiles {
if err := ss.ScpDownload(localDir, filepath.Join(remoteDir, f.Name())); err != nil {
return err
}
}
return nil
}
type CopyFiles struct {
Dir string
FileName string
}
type File struct {
FileName string
FilePath string
Size int64
}