-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathauth.go
149 lines (129 loc) · 4.16 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package sshtun
import (
"fmt"
"io"
"net"
"os"
"os/user"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var defaultSSHKeys = []string{"id_rsa", "id_dsa", "id_ecdsa", "id_ecdsa_sk", "id_ed25519", "id_ed25519_sk"}
// AuthType is the type of authentication to use for SSH.
type AuthType int
const (
// AuthTypeKeyFile uses the keys from a SSH key file read from the system.
AuthTypeKeyFile AuthType = iota
// AuthTypeEncryptedKeyFile uses the keys from an encrypted SSH key file read from the system.
AuthTypeEncryptedKeyFile
// AuthTypeKeyReader uses the keys from a SSH key reader.
AuthTypeKeyReader
// AuthTypeEncryptedKeyReader uses the keys from an encrypted SSH key reader.
AuthTypeEncryptedKeyReader
// AuthTypePassword uses a password directly.
AuthTypePassword
// AuthTypeSSHAgent will use registered users in the ssh-agent.
AuthTypeSSHAgent
// AuthTypeAuto tries to get the authentication method automatically. See SSHTun.Start for details on
// this.
AuthTypeAuto
)
func (tun *SSHTun) getSSHAuthMethod() (ssh.AuthMethod, error) {
switch tun.authType {
case AuthTypeKeyFile:
return tun.getSSHAuthMethodForKeyFile(false)
case AuthTypeEncryptedKeyFile:
return tun.getSSHAuthMethodForKeyFile(true)
case AuthTypeKeyReader:
return tun.getSSHAuthMethodForKeyReader(false)
case AuthTypeEncryptedKeyReader:
return tun.getSSHAuthMethodForKeyReader(true)
case AuthTypePassword:
return ssh.Password(tun.authPassword), nil
case AuthTypeSSHAgent:
return tun.getSSHAuthMethodForSSHAgent()
case AuthTypeAuto:
method, errFile := tun.getSSHAuthMethodForKeyFile(false)
if errFile == nil {
return method, nil
}
method, errAgent := tun.getSSHAuthMethodForSSHAgent()
if errAgent == nil {
return method, nil
}
return nil, fmt.Errorf("auto auth failed (file based: %v) (ssh-agent: %v)", errFile, errAgent)
default:
return nil, fmt.Errorf("unknown auth type: %d", tun.authType)
}
}
func (tun *SSHTun) getSSHAuthMethodForKeyFile(encrypted bool) (ssh.AuthMethod, error) {
if tun.authKeyFile != "" {
return tun.readPrivateKey(tun.authKeyFile, encrypted)
}
homeDir := "/root"
usr, _ := user.Current()
if usr != nil {
homeDir = usr.HomeDir
}
for _, keyName := range defaultSSHKeys {
keyFile := fmt.Sprintf("%s/.ssh/%s", homeDir, keyName)
authMethod, err := tun.readPrivateKey(keyFile, encrypted)
if err == nil {
return authMethod, nil
}
}
return nil, fmt.Errorf("could not read any default SSH key (%v)", defaultSSHKeys)
}
func (tun *SSHTun) readPrivateKey(keyFile string, encrypted bool) (ssh.AuthMethod, error) {
buf, err := os.ReadFile(keyFile)
if err != nil {
return nil, fmt.Errorf("reading SSH key file %s: %w", keyFile, err)
}
key, err := tun.parsePrivateKey(buf, encrypted)
if err != nil {
return nil, fmt.Errorf("parsing SSH key file %s: %w", keyFile, err)
}
return key, nil
}
func (tun *SSHTun) getSSHAuthMethodForKeyReader(encrypted bool) (ssh.AuthMethod, error) {
buf, err := io.ReadAll(tun.authKeyReader)
if err != nil {
return nil, fmt.Errorf("reading from SSH key reader: %w", err)
}
key, err := tun.parsePrivateKey(buf, encrypted)
if err != nil {
return nil, fmt.Errorf("reading from SSH key reader: %w", err)
}
return key, nil
}
func (tun *SSHTun) parsePrivateKey(buf []byte, encrypted bool) (ssh.AuthMethod, error) {
var key ssh.Signer
var err error
if encrypted {
key, err = ssh.ParsePrivateKeyWithPassphrase(buf, []byte(tun.authPassword))
if err != nil {
return nil, fmt.Errorf("parsing encrypted key: %w", err)
}
} else {
key, err = ssh.ParsePrivateKey(buf)
if err != nil {
return nil, fmt.Errorf("error parsing key: %w", err)
}
}
return ssh.PublicKeys(key), nil
}
func (tun *SSHTun) getSSHAuthMethodForSSHAgent() (ssh.AuthMethod, error) {
conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
return nil, fmt.Errorf("opening unix socket: %w", err)
}
agentClient := agent.NewClient(conn)
signers, err := agentClient.Signers()
if err != nil {
return nil, fmt.Errorf("getting ssh-agent signers: %w", err)
}
if len(signers) == 0 {
return nil, fmt.Errorf("no signers from ssh-agent (use 'ssh-add' to add keys to agent)")
}
return ssh.PublicKeys(signers...), nil
}