mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 07:26:39 -04:00
chore: do not use sentinel errors when unneeded
- main reason being it's a burden to always define sentinel errors at global scope, wrap them with `%w` instead of using a string directly - only use sentinel errors when it has to be checked using `errors.Is` - replace all usage of these sentinel errors in `fmt.Errorf` with direct strings that were in the sentinel error - exclude the sentinel error definition requirement from .golangci.yml - update unit tests to use ContainersError instead of ErrorIs so it stays as a "not a change detector test" without requiring a sentinel error
This commit is contained in:
@@ -68,6 +68,9 @@ linters:
|
|||||||
- err113
|
- err113
|
||||||
- mnd
|
- mnd
|
||||||
path: ci\/.+\.go
|
path: ci\/.+\.go
|
||||||
|
- linters:
|
||||||
|
- err113
|
||||||
|
text: "do not define dynamic errors, use wrapped static errors instead"
|
||||||
|
|
||||||
paths:
|
paths:
|
||||||
- third_party$
|
- third_party$
|
||||||
|
|||||||
+1
-3
@@ -142,8 +142,6 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errCommandUnknown = errors.New("command is unknown")
|
|
||||||
|
|
||||||
//nolint:gocognit,gocyclo,maintidx
|
//nolint:gocognit,gocyclo,maintidx
|
||||||
func _main(ctx context.Context, buildInfo models.BuildInformation,
|
func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||||
args []string, logger log.LoggerInterface, reader *reader.Reader,
|
args []string, logger log.LoggerInterface, reader *reader.Reader,
|
||||||
@@ -165,7 +163,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
case "genkey":
|
case "genkey":
|
||||||
return cli.GenKey(args[2:])
|
return cli.GenKey(args[2:])
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
return fmt.Errorf("command is unknown: %s", args[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package alpine
|
package alpine
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"os"
|
"os"
|
||||||
@@ -9,8 +8,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrUserAlreadyExists = errors.New("user already exists")
|
|
||||||
|
|
||||||
// CreateUser creates a user in Alpine with the given UID.
|
// CreateUser creates a user in Alpine with the given UID.
|
||||||
func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, err error) {
|
func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, err error) {
|
||||||
UIDStr := strconv.Itoa(uid)
|
UIDStr := strconv.Itoa(uid)
|
||||||
@@ -34,8 +31,8 @@ func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
if u != nil {
|
if u != nil {
|
||||||
return "", fmt.Errorf("%w: with name %s for ID %s instead of %d",
|
return "", fmt.Errorf("user already exists: with name %s for ID %s instead of %d",
|
||||||
ErrUserAlreadyExists, username, u.Uid, uid)
|
username, u.Uid, uid)
|
||||||
}
|
}
|
||||||
|
|
||||||
const permission = fs.FileMode(0o644)
|
const permission = fs.FileMode(0o644)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package amneziawg
|
package amneziawg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ func Test_New(t *testing.T) {
|
|||||||
PrivateKey: "",
|
PrivateKey: "",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
err: wireguard.ErrPrivateKeyMissing,
|
err: errors.New("private key is missing"),
|
||||||
},
|
},
|
||||||
"minimal valid settings": {
|
"minimal valid settings": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
|
|||||||
@@ -13,11 +13,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/wireguard"
|
"github.com/qdm12/gluetun/internal/wireguard"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errTunNameMismatch = errors.New("TUN device name is mismatching")
|
|
||||||
errDeviceWaited = errors.New("device waited for")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Run runs the amneziawg interface and waits until the context is done, then it cleans up the
|
// Run runs the amneziawg interface and waits until the context is done, then it cleans up the
|
||||||
// interface and returns any error that occurred during setup or waiting. It sends an error to
|
// interface and returns any error that occurred during setup or waiting. It sends an error to
|
||||||
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
|
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
|
||||||
@@ -52,8 +47,7 @@ func setupUserspace(ctx context.Context,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
|
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
|
||||||
} else if tunName != interfaceName {
|
} else if tunName != interfaceName {
|
||||||
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
|
return 0, nil, fmt.Errorf("TUN device name is mismatching: expected %q and got %q", interfaceName, tunName)
|
||||||
errTunNameMismatch, interfaceName, tunName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
link, err := netLinker.LinkByName(interfaceName)
|
link, err := netLinker.LinkByName(interfaceName)
|
||||||
@@ -106,7 +100,7 @@ func setupUserspace(ctx context.Context,
|
|||||||
case err = <-uapiAcceptErrorCh:
|
case err = <-uapiAcceptErrorCh:
|
||||||
close(uapiAcceptErrorCh)
|
close(uapiAcceptErrorCh)
|
||||||
case <-device.Wait():
|
case <-device.Wait():
|
||||||
err = errDeviceWaited
|
err = errors.New("device waited for")
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanups.Cleanup(logger)
|
cleanups.Cleanup(logger)
|
||||||
|
|||||||
@@ -16,11 +16,6 @@ import (
|
|||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
|
|
||||||
ErrMultipleProvidersToFormat = errors.New("more than one VPN provider to format were specified")
|
|
||||||
)
|
|
||||||
|
|
||||||
func addProviderFlag(flagSet *flag.FlagSet, providerToFormat map[string]*bool,
|
func addProviderFlag(flagSet *flag.FlagSet, providerToFormat map[string]*bool,
|
||||||
provider string, titleCaser cases.Caser,
|
provider string, titleCaser cases.Caser,
|
||||||
) {
|
) {
|
||||||
@@ -65,11 +60,10 @@ func (c *CLI) FormatServers(args []string) error {
|
|||||||
}
|
}
|
||||||
switch len(providers) {
|
switch len(providers) {
|
||||||
case 0:
|
case 0:
|
||||||
return fmt.Errorf("%w", ErrProviderUnspecified)
|
return errors.New("VPN provider to format was not specified")
|
||||||
case 1:
|
case 1:
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %d specified: %s",
|
return fmt.Errorf("more than one VPN provider to format were specified: %d specified: %s", len(providers),
|
||||||
ErrMultipleProvidersToFormat, len(providers),
|
|
||||||
strings.Join(providers, ", "))
|
strings.Join(providers, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,13 +24,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/updater/unzip"
|
"github.com/qdm12/gluetun/internal/updater/unzip"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified")
|
|
||||||
ErrNoProviderSpecified = errors.New("no provider was specified")
|
|
||||||
ErrUsernameMissing = errors.New("username is required for this provider")
|
|
||||||
ErrPasswordMissing = errors.New("password is required for this provider")
|
|
||||||
)
|
|
||||||
|
|
||||||
type UpdaterLogger interface {
|
type UpdaterLogger interface {
|
||||||
Info(s string)
|
Info(s string)
|
||||||
Warn(s string)
|
Warn(s string)
|
||||||
@@ -65,14 +58,14 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !endUserMode && !maintainerMode {
|
if !endUserMode && !maintainerMode {
|
||||||
return fmt.Errorf("%w", ErrModeUnspecified)
|
return errors.New("at least one of -enduser or -maintainer must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
if updateAll {
|
if updateAll {
|
||||||
options.Providers = providers.All()
|
options.Providers = providers.All()
|
||||||
} else {
|
} else {
|
||||||
if csvProviders == "" {
|
if csvProviders == "" {
|
||||||
return fmt.Errorf("%w", ErrNoProviderSpecified)
|
return errors.New("no provider was specified")
|
||||||
}
|
}
|
||||||
options.Providers = strings.Split(csvProviders, ",")
|
options.Providers = strings.Split(csvProviders, ",")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,13 +8,6 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errCommandEmpty = errors.New("command is empty")
|
|
||||||
errSingleQuoteUnterminated = errors.New("unterminated single-quoted string")
|
|
||||||
errDoubleQuoteUnterminated = errors.New("unterminated double-quoted string")
|
|
||||||
errEscapeUnterminated = errors.New("unterminated backslash-escape")
|
|
||||||
)
|
|
||||||
|
|
||||||
// split splits a command string into a slice of arguments.
|
// split splits a command string into a slice of arguments.
|
||||||
// This is especially important for commands such as:
|
// This is especially important for commands such as:
|
||||||
// /bin/sh -c "echo hello"
|
// /bin/sh -c "echo hello"
|
||||||
@@ -25,7 +18,7 @@ var (
|
|||||||
// - expansion (brace, shell or pathname).
|
// - expansion (brace, shell or pathname).
|
||||||
func split(command string) (words []string, err error) {
|
func split(command string) (words []string, err error) {
|
||||||
if command == "" {
|
if command == "" {
|
||||||
return nil, fmt.Errorf("%w", errCommandEmpty)
|
return nil, errors.New("command is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
const bufferSize = 1024
|
const bufferSize = 1024
|
||||||
@@ -42,7 +35,7 @@ func split(command string) (words []string, err error) {
|
|||||||
case character == '\\':
|
case character == '\\':
|
||||||
// Look ahead to eventually skip an escaped newline
|
// Look ahead to eventually skip an escaped newline
|
||||||
if command[startIndex+runeSize:] == "" {
|
if command[startIndex+runeSize:] == "" {
|
||||||
return nil, fmt.Errorf("%w: %q", errEscapeUnterminated, command)
|
return nil, fmt.Errorf("unterminated backslash-escape: %q", command)
|
||||||
}
|
}
|
||||||
character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:])
|
character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:])
|
||||||
if character == '\n' {
|
if character == '\n' {
|
||||||
@@ -119,7 +112,7 @@ func handleDoubleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
|||||||
startIndex = cursor
|
startIndex = cursor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return "", 0, fmt.Errorf("%w", errDoubleQuoteUnterminated)
|
return "", 0, errors.New("unterminated double-quoted string")
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
||||||
@@ -127,7 +120,7 @@ func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
|||||||
) {
|
) {
|
||||||
closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'')
|
closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'')
|
||||||
if closingQuoteIndex == -1 {
|
if closingQuoteIndex == -1 {
|
||||||
return "", 0, fmt.Errorf("%w", errSingleQuoteUnterminated)
|
return "", 0, errors.New("unterminated single-quoted string")
|
||||||
}
|
}
|
||||||
buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex])
|
buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex])
|
||||||
const singleQuoteRuneLength = 1
|
const singleQuoteRuneLength = 1
|
||||||
@@ -139,7 +132,7 @@ func handleEscaped(input string, startIndex int, buffer *bytes.Buffer) (
|
|||||||
word string, newStartIndex int, err error,
|
word string, newStartIndex int, err error,
|
||||||
) {
|
) {
|
||||||
if input[startIndex:] == "" {
|
if input[startIndex:] == "" {
|
||||||
return "", 0, fmt.Errorf("%w", errEscapeUnterminated)
|
return "", 0, errors.New("unterminated backslash-escape")
|
||||||
}
|
}
|
||||||
character, runeLength := utf8.DecodeRuneInString(input[startIndex:])
|
character, runeLength := utf8.DecodeRuneInString(input[startIndex:])
|
||||||
if character != '\n' { // backslash-escaped newline is ignored
|
if character != '\n' { // backslash-escaped newline is ignored
|
||||||
|
|||||||
@@ -12,12 +12,10 @@ func Test_split(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
command string
|
command string
|
||||||
words []string
|
words []string
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"empty": {
|
"empty": {
|
||||||
command: "",
|
command: "",
|
||||||
errWrapped: errCommandEmpty,
|
|
||||||
errMessage: "command is empty",
|
errMessage: "command is empty",
|
||||||
},
|
},
|
||||||
"concrete_sh_command": {
|
"concrete_sh_command": {
|
||||||
@@ -74,22 +72,18 @@ func Test_split(t *testing.T) {
|
|||||||
},
|
},
|
||||||
"unterminated_single_quote": {
|
"unterminated_single_quote": {
|
||||||
command: "'abc'\\''def",
|
command: "'abc'\\''def",
|
||||||
errWrapped: errSingleQuoteUnterminated,
|
|
||||||
errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`,
|
errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`,
|
||||||
},
|
},
|
||||||
"unterminated_double_quote": {
|
"unterminated_double_quote": {
|
||||||
command: "\"abc'def",
|
command: "\"abc'def",
|
||||||
errWrapped: errDoubleQuoteUnterminated,
|
|
||||||
errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`,
|
errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`,
|
||||||
},
|
},
|
||||||
"unterminated_escape": {
|
"unterminated_escape": {
|
||||||
command: "abc\\",
|
command: "abc\\",
|
||||||
errWrapped: errEscapeUnterminated,
|
|
||||||
errMessage: `splitting word in "abc\\": unterminated backslash-escape`,
|
errMessage: `splitting word in "abc\\": unterminated backslash-escape`,
|
||||||
},
|
},
|
||||||
"unterminated_escape_only": {
|
"unterminated_escape_only": {
|
||||||
command: " \\",
|
command: " \\",
|
||||||
errWrapped: errEscapeUnterminated,
|
|
||||||
errMessage: `unterminated backslash-escape: " \\"`,
|
errMessage: `unterminated backslash-escape: " \\"`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -101,9 +95,10 @@ func Test_split(t *testing.T) {
|
|||||||
words, err := split(testCase.command)
|
words, err := split(testCase.command)
|
||||||
|
|
||||||
assert.Equal(t, testCase.words, words)
|
assert.Equal(t, testCase.words, words)
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errWrapped != nil {
|
assert.ErrorContains(t, err, testCase.errMessage)
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -177,14 +176,6 @@ func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
|
|||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
|
|
||||||
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
|
|
||||||
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
|
|
||||||
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
|
|
||||||
ErrHeaderRangeMalformed = errors.New("header range is malformed")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
||||||
const amneziaWG = true
|
const amneziaWG = true
|
||||||
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
|
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
|
||||||
@@ -194,16 +185,16 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
|||||||
|
|
||||||
if *a.JunkPacketCount == 0 {
|
if *a.JunkPacketCount == 0 {
|
||||||
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
|
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
|
||||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
return fmt.Errorf("junk packet count must be set when junk packet min or max is set: "+
|
||||||
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
|
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
|
||||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
return fmt.Errorf("junk packet min and max must be set when junk packet count is set: "+
|
||||||
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||||
} else if *a.JunkPacketMin > *a.JunkPacketMax {
|
} else if *a.JunkPacketMin > *a.JunkPacketMax {
|
||||||
return fmt.Errorf("%w: jmin=%d and jmax=%d",
|
return fmt.Errorf("junk packet minimum must be lower than or equal to maximum: "+
|
||||||
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax)
|
"jmin=%d and jmax=%d", *a.JunkPacketMin, *a.JunkPacketMax)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,20 +213,20 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
|||||||
case 1:
|
case 1:
|
||||||
_, err := strconv.Atoi(fields[0])
|
_, err := strconv.Atoi(fields[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s value %s is not a number",
|
return fmt.Errorf("header range is malformed: "+
|
||||||
ErrHeaderRangeMalformed, name, headerRange)
|
"%s value %s is not a number", name, headerRange)
|
||||||
}
|
}
|
||||||
case 2: //nolint:mnd
|
case 2: //nolint:mnd
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
_, err := strconv.Atoi(field)
|
_, err := strconv.Atoi(field)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s value %s is not a valid range",
|
return fmt.Errorf("header range is malformed: "+
|
||||||
ErrHeaderRangeMalformed, name, headerRange)
|
"%s value %s is not a valid range", name, headerRange)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s value %s must be in the form n or n-m",
|
return fmt.Errorf("header range is malformed: "+
|
||||||
ErrHeaderRangeMalformed, name, headerRange)
|
"%s value %s must be in the form n or n-m", name, headerRange)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
@@ -48,22 +47,15 @@ type DNS struct {
|
|||||||
UpstreamPlainAddresses []netip.AddrPort
|
UpstreamPlainAddresses []netip.AddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid")
|
|
||||||
ErrDNSUpdatePeriodTooShort = errors.New("update period is too short")
|
|
||||||
ErrDNSUpstreamPlainNoIPv6 = errors.New("upstream plain addresses do not contain any IPv6 address")
|
|
||||||
ErrDNSUpstreamPlainNoIPv4 = errors.New("upstream plain addresses do not contain any IPv4 address")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d DNS) validate() (err error) {
|
func (d DNS) validate() (err error) {
|
||||||
if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) {
|
if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) {
|
||||||
return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType)
|
return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType)
|
||||||
}
|
}
|
||||||
|
|
||||||
const minUpdatePeriod = 30 * time.Second
|
const minUpdatePeriod = 30 * time.Second
|
||||||
if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod {
|
if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod {
|
||||||
return fmt.Errorf("%w: %s must be bigger than %s",
|
return fmt.Errorf("update period is too short: %s must be bigger than %s",
|
||||||
ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod)
|
*d.UpdatePeriod, minUpdatePeriod)
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.UpstreamType == DNSUpstreamTypePlain {
|
if d.UpstreamType == DNSUpstreamTypePlain {
|
||||||
@@ -81,9 +73,11 @@ func (d DNS) validate() (err error) {
|
|||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
case *d.IPv6 && !selectedHasPlainIPv6:
|
case *d.IPv6 && !selectedHasPlainIPv6:
|
||||||
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses))
|
return fmt.Errorf("upstream plain addresses do not contain any IPv6 address: "+
|
||||||
|
"in %d addresses", len(d.UpstreamPlainAddresses))
|
||||||
case !*d.IPv6 && !selectedHasPlainIPv4:
|
case !*d.IPv6 && !selectedHasPlainIPv4:
|
||||||
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses))
|
return fmt.Errorf("upstream plain addresses do not contain any IPv4 address: "+
|
||||||
|
"in %d addresses", len(d.UpstreamPlainAddresses))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes
|
// Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -37,22 +36,16 @@ func (b *DNSBlacklist) setDefaults() {
|
|||||||
|
|
||||||
var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll
|
var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll
|
||||||
|
|
||||||
var (
|
|
||||||
ErrAllowedHostNotValid = errors.New("allowed host is not valid")
|
|
||||||
ErrBlockedHostNotValid = errors.New("blocked host is not valid")
|
|
||||||
ErrRebindingProtectionExemptHostNotValid = errors.New("rebinding protection exempt host is not valid")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (b DNSBlacklist) validate() (err error) {
|
func (b DNSBlacklist) validate() (err error) {
|
||||||
for _, host := range b.AllowedHosts {
|
for _, host := range b.AllowedHosts {
|
||||||
if !hostRegex.MatchString(host) {
|
if !hostRegex.MatchString(host) {
|
||||||
return fmt.Errorf("%w: %s", ErrAllowedHostNotValid, host)
|
return fmt.Errorf("allowed host is not valid: %s", host)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, host := range b.AddBlockedHosts {
|
for _, host := range b.AddBlockedHosts {
|
||||||
if !hostRegex.MatchString(host) {
|
if !hostRegex.MatchString(host) {
|
||||||
return fmt.Errorf("%w: %s", ErrBlockedHostNotValid, host)
|
return fmt.Errorf("blocked host is not valid: %s", host)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,7 +54,7 @@ func (b DNSBlacklist) validate() (err error) {
|
|||||||
host = host[2:]
|
host = host[2:]
|
||||||
}
|
}
|
||||||
if !hostRegex.MatchString(host) {
|
if !hostRegex.MatchString(host) {
|
||||||
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
|
return fmt.Errorf("rebinding protection exempt host is not valid: %s", host)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,8 +202,6 @@ func readDNSBlockedIPs(r *reader.Reader) (ips []netip.Addr,
|
|||||||
return ips, ipPrefixes, nil
|
return ips, ipPrefixes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrPrivateAddressNotValid = errors.New("private address is not a valid IP or CIDR range")
|
|
||||||
|
|
||||||
func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
|
func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
|
||||||
ipPrefixes []netip.Prefix, err error,
|
ipPrefixes []netip.Prefix, err error,
|
||||||
) {
|
) {
|
||||||
@@ -236,8 +227,9 @@ func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil, fmt.Errorf(
|
return nil, nil, fmt.Errorf(
|
||||||
"environment variable DOT_PRIVATE_ADDRESS: %w: %s",
|
"environment variable DOT_PRIVATE_ADDRESS: "+
|
||||||
ErrPrivateAddressNotValid, privateAddress)
|
"private address is not a valid IP or CIDR range: %s",
|
||||||
|
privateAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ips, ipPrefixes, nil
|
return ips, ipPrefixes, nil
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
package settings
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrValueUnknown = errors.New("value is unknown")
|
|
||||||
ErrCityNotValid = errors.New("the city specified is not valid")
|
|
||||||
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
|
|
||||||
ErrCategoryNotValid = errors.New("the category specified is not valid")
|
|
||||||
ErrCountryNotValid = errors.New("the country specified is not valid")
|
|
||||||
ErrFilepathMissing = errors.New("filepath is missing")
|
|
||||||
ErrFirewallZeroPort = errors.New("cannot have a zero port")
|
|
||||||
ErrFirewallPublicOutboundSubnet = errors.New("outbound subnet has an unspecified address")
|
|
||||||
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
|
|
||||||
ErrISPNotValid = errors.New("the ISP specified is not valid")
|
|
||||||
ErrMinRatioNotValid = errors.New("minimum ratio is not valid")
|
|
||||||
ErrMissingValue = errors.New("missing value")
|
|
||||||
ErrNameNotValid = errors.New("the server name specified is not valid")
|
|
||||||
ErrOpenVPNClientKeyMissing = errors.New("client key is missing")
|
|
||||||
ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed")
|
|
||||||
ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid")
|
|
||||||
ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid")
|
|
||||||
ErrOpenVPNKeyPassphraseIsEmpty = errors.New("key passphrase is empty")
|
|
||||||
ErrOpenVPNMSSFixIsTooHigh = errors.New("mssfix option value is too high")
|
|
||||||
ErrOpenVPNPasswordIsEmpty = errors.New("password is empty")
|
|
||||||
ErrOpenVPNTCPNotSupported = errors.New("TCP protocol is not supported")
|
|
||||||
ErrOpenVPNUserIsEmpty = errors.New("user is empty")
|
|
||||||
ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds")
|
|
||||||
ErrOpenVPNVersionIsNotValid = errors.New("version is not valid")
|
|
||||||
ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled")
|
|
||||||
ErrPortForwardingUserEmpty = errors.New("port forwarding username is empty")
|
|
||||||
ErrPortForwardingPasswordEmpty = errors.New("port forwarding password is empty")
|
|
||||||
ErrRegionNotValid = errors.New("the region specified is not valid")
|
|
||||||
ErrServerAddressNotValid = errors.New("server listening address is not valid")
|
|
||||||
ErrSystemPGIDNotValid = errors.New("process group id is not valid")
|
|
||||||
ErrSystemPUIDNotValid = errors.New("process user id is not valid")
|
|
||||||
ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
|
|
||||||
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
|
|
||||||
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
|
|
||||||
ErrUpdaterProtonEmailMissing = errors.New("proton email is missing")
|
|
||||||
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
|
|
||||||
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
|
|
||||||
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
|
|
||||||
ErrWireguardAllowedIPsNotSet = errors.New("allowed IPs is not set")
|
|
||||||
ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set")
|
|
||||||
ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed")
|
|
||||||
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
|
|
||||||
ErrWireguardEndpointPortSet = errors.New("endpoint port is set")
|
|
||||||
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
|
|
||||||
ErrWireguardInterfaceAddressIPv6 = errors.New("interface address is IPv6 but IPv6 is not supported")
|
|
||||||
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
|
|
||||||
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
|
|
||||||
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
|
|
||||||
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
|
|
||||||
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
|
|
||||||
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
|
|
||||||
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
|
|
||||||
)
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
@@ -20,16 +21,16 @@ type Firewall struct {
|
|||||||
|
|
||||||
func (f Firewall) validate() (err error) {
|
func (f Firewall) validate() (err error) {
|
||||||
if hasZeroPort(f.VPNInputPorts) {
|
if hasZeroPort(f.VPNInputPorts) {
|
||||||
return fmt.Errorf("VPN input ports: %w", ErrFirewallZeroPort)
|
return errors.New("VPN input ports: cannot have a zero port")
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasZeroPort(f.InputPorts) {
|
if hasZeroPort(f.InputPorts) {
|
||||||
return fmt.Errorf("input ports: %w", ErrFirewallZeroPort)
|
return errors.New("input ports: cannot have a zero port")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, subnet := range f.OutboundSubnets {
|
for _, subnet := range f.OutboundSubnets {
|
||||||
if subnet.Addr().IsUnspecified() {
|
if subnet.Addr().IsUnspecified() {
|
||||||
return fmt.Errorf("%w: %s", ErrFirewallPublicOutboundSubnet, subnet)
|
return fmt.Errorf("outbound subnet has an unspecified address: %s", subnet)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,25 +13,21 @@ func Test_Firewall_validate(t *testing.T) {
|
|||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
firewall Firewall
|
firewall Firewall
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"empty": {
|
"empty": {
|
||||||
errWrapped: log.ErrLevelNotRecognized,
|
|
||||||
errMessage: "iptables settings: log level: level is not recognized: ",
|
errMessage: "iptables settings: log level: level is not recognized: ",
|
||||||
},
|
},
|
||||||
"zero_vpn_input_port": {
|
"zero_vpn_input_port": {
|
||||||
firewall: Firewall{
|
firewall: Firewall{
|
||||||
VPNInputPorts: []uint16{0},
|
VPNInputPorts: []uint16{0},
|
||||||
},
|
},
|
||||||
errWrapped: ErrFirewallZeroPort,
|
|
||||||
errMessage: "VPN input ports: cannot have a zero port",
|
errMessage: "VPN input ports: cannot have a zero port",
|
||||||
},
|
},
|
||||||
"zero_input_port": {
|
"zero_input_port": {
|
||||||
firewall: Firewall{
|
firewall: Firewall{
|
||||||
InputPorts: []uint16{0},
|
InputPorts: []uint16{0},
|
||||||
},
|
},
|
||||||
errWrapped: ErrFirewallZeroPort,
|
|
||||||
errMessage: "input ports: cannot have a zero port",
|
errMessage: "input ports: cannot have a zero port",
|
||||||
},
|
},
|
||||||
"unspecified_outbound_subnet": {
|
"unspecified_outbound_subnet": {
|
||||||
@@ -40,7 +36,6 @@ func Test_Firewall_validate(t *testing.T) {
|
|||||||
netip.MustParsePrefix("0.0.0.0/0"),
|
netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
errWrapped: ErrFirewallPublicOutboundSubnet,
|
|
||||||
errMessage: "outbound subnet has an unspecified address: 0.0.0.0/0",
|
errMessage: "outbound subnet has an unspecified address: 0.0.0.0/0",
|
||||||
},
|
},
|
||||||
"public_outbound_subnet": {
|
"public_outbound_subnet": {
|
||||||
@@ -70,9 +65,10 @@ func Test_Firewall_validate(t *testing.T) {
|
|||||||
|
|
||||||
err := testCase.firewall.validate()
|
err := testCase.firewall.validate()
|
||||||
|
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errWrapped != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,12 +38,6 @@ type Health struct {
|
|||||||
RestartVPN *bool
|
RestartVPN *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid")
|
|
||||||
ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible")
|
|
||||||
ErrSmallCheckTypeNotValid = errors.New("small check type is not valid")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (h Health) Validate() (err error) {
|
func (h Health) Validate() (err error) {
|
||||||
err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
|
err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -53,16 +47,16 @@ func (h Health) Validate() (err error) {
|
|||||||
for _, ip := range h.ICMPTargetIPs {
|
for _, ip := range h.ICMPTargetIPs {
|
||||||
switch {
|
switch {
|
||||||
case !ip.IsValid():
|
case !ip.IsValid():
|
||||||
return fmt.Errorf("%w: %s", ErrICMPTargetIPNotValid, ip)
|
return fmt.Errorf("ICMP target IP address is not valid: %s", ip)
|
||||||
case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1:
|
case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1:
|
||||||
return fmt.Errorf("%w: only a single IP address must be set if it is to be unspecified",
|
return errors.New("ICMP target IP addresses are not compatible: " +
|
||||||
ErrICMPTargetIPsNotCompatible)
|
"only a single IP address must be set if it is to be unspecified")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
|
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err)
|
return fmt.Errorf("small check type is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (h HTTPProxy) validate() (err error) {
|
|||||||
// Do not validate user and password
|
// Do not validate user and password
|
||||||
err = validate.ListeningAddress(h.ListeningAddress, os.Getuid())
|
err = validate.ListeningAddress(h.ListeningAddress, os.Getuid())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrServerAddressNotValid, h.ListeningAddress)
|
return fmt.Errorf("server listening address is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -176,7 +176,6 @@ func readHTTProxyLog(r *reader.Reader) (enabled *bool, err error) {
|
|||||||
case "disabled", "no", "off":
|
case "disabled", "no", "off":
|
||||||
return ptrTo(false), nil
|
return ptrTo(false), nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: %w: %s",
|
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: value is unknown: %s", value)
|
||||||
ErrValueUnknown, value)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package settings
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -92,7 +93,7 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
|||||||
// Validate version
|
// Validate version
|
||||||
validVersions := []string{openvpn.Openvpn25, openvpn.Openvpn26}
|
validVersions := []string{openvpn.Openvpn25, openvpn.Openvpn26}
|
||||||
if err = validate.IsOneOf(o.Version, validVersions...); err != nil {
|
if err = validate.IsOneOf(o.Version, validVersions...); err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrOpenVPNVersionIsNotValid, err)
|
return fmt.Errorf("version is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
isCustom := vpnProvider == providers.Custom
|
isCustom := vpnProvider == providers.Custom
|
||||||
@@ -101,14 +102,14 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
|||||||
vpnProvider != providers.VPNSecure
|
vpnProvider != providers.VPNSecure
|
||||||
|
|
||||||
if isUserRequired && *o.User == "" {
|
if isUserRequired && *o.User == "" {
|
||||||
return fmt.Errorf("%w", ErrOpenVPNUserIsEmpty)
|
return errors.New("user is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
passwordRequired := isUserRequired &&
|
passwordRequired := isUserRequired &&
|
||||||
(vpnProvider != providers.Ivpn || !ivpnAccountID.MatchString(*o.User))
|
(vpnProvider != providers.Ivpn || !ivpnAccountID.MatchString(*o.User))
|
||||||
|
|
||||||
if passwordRequired && *o.Password == "" {
|
if passwordRequired && *o.Password == "" {
|
||||||
return fmt.Errorf("%w", ErrOpenVPNPasswordIsEmpty)
|
return errors.New("password is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
|
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
|
||||||
@@ -132,23 +133,20 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if *o.EncryptedKey != "" && *o.KeyPassphrase == "" {
|
if *o.EncryptedKey != "" && *o.KeyPassphrase == "" {
|
||||||
return fmt.Errorf("%w", ErrOpenVPNKeyPassphraseIsEmpty)
|
return errors.New("key passphrase is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxMSSFix = 10000
|
const maxMSSFix = 10000
|
||||||
if *o.MSSFix > maxMSSFix {
|
if *o.MSSFix > maxMSSFix {
|
||||||
return fmt.Errorf("%w: %d is over the maximum value of %d",
|
return fmt.Errorf("mssfix option value is too high: %d is over the maximum value of %d", *o.MSSFix, maxMSSFix)
|
||||||
ErrOpenVPNMSSFixIsTooHigh, *o.MSSFix, maxMSSFix)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !regexpInterfaceName.MatchString(o.Interface) {
|
if !regexpInterfaceName.MatchString(o.Interface) {
|
||||||
return fmt.Errorf("%w: '%s' does not match regex '%s'",
|
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", o.Interface, regexpInterfaceName)
|
||||||
ErrOpenVPNInterfaceNotValid, o.Interface, regexpInterfaceName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if *o.Verbosity < 0 || *o.Verbosity > 6 {
|
if *o.Verbosity < 0 || *o.Verbosity > 6 {
|
||||||
return fmt.Errorf("%w: %d can only be between 0 and 5",
|
return fmt.Errorf("verbosity value is out of bounds: %d can only be between 0 and 5", o.Verbosity)
|
||||||
ErrOpenVPNVerbosityIsOutOfBounds, o.Verbosity)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -162,7 +160,7 @@ func validateOpenVPNConfigFilepath(isCustom bool,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if confFile == "" {
|
if confFile == "" {
|
||||||
return fmt.Errorf("%w", ErrFilepathMissing)
|
return errors.New("filepath is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validate.FileExists(confFile)
|
err = validate.FileExists(confFile)
|
||||||
@@ -189,7 +187,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
|
|||||||
providers.VPNSecure,
|
providers.VPNSecure,
|
||||||
providers.VPNUnlimited:
|
providers.VPNUnlimited:
|
||||||
if clientCert == "" {
|
if clientCert == "" {
|
||||||
return fmt.Errorf("%w", ErrMissingValue)
|
return errors.New("missing value")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,7 +209,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
|
|||||||
providers.Cyberghost,
|
providers.Cyberghost,
|
||||||
providers.VPNUnlimited:
|
providers.VPNUnlimited:
|
||||||
if clientKey == "" {
|
if clientKey == "" {
|
||||||
return fmt.Errorf("%w", ErrMissingValue)
|
return errors.New("missing value")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +228,7 @@ func validateOpenVPNEncryptedKey(vpnProvider,
|
|||||||
encryptedPrivateKey string,
|
encryptedPrivateKey string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
if vpnProvider == providers.VPNSecure && encryptedPrivateKey == "" {
|
if vpnProvider == providers.VPNSecure && encryptedPrivateKey == "" {
|
||||||
return fmt.Errorf("%w", ErrMissingValue)
|
return errors.New("missing value")
|
||||||
}
|
}
|
||||||
|
|
||||||
if encryptedPrivateKey == "" {
|
if encryptedPrivateKey == "" {
|
||||||
|
|||||||
@@ -62,8 +62,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
|||||||
providers.Perfectprivacy,
|
providers.Perfectprivacy,
|
||||||
providers.Vyprvpn,
|
providers.Vyprvpn,
|
||||||
) {
|
) {
|
||||||
return fmt.Errorf("%w: for VPN service provider %s",
|
return fmt.Errorf("TCP protocol is not supported: for VPN service provider %s", vpnProvider)
|
||||||
ErrOpenVPNTCPNotSupported, vpnProvider)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate CustomPort
|
// Validate CustomPort
|
||||||
@@ -78,8 +77,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
|||||||
providers.Nordvpn, providers.Purevpn,
|
providers.Nordvpn, providers.Purevpn,
|
||||||
providers.Surfshark, providers.VPNSecure,
|
providers.Surfshark, providers.VPNSecure,
|
||||||
providers.VPNUnlimited, providers.Vyprvpn:
|
providers.VPNUnlimited, providers.Vyprvpn:
|
||||||
return fmt.Errorf("%w: for VPN service provider %s",
|
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s", vpnProvider)
|
||||||
ErrOpenVPNCustomPortNotAllowed, vpnProvider)
|
|
||||||
default:
|
default:
|
||||||
var allowedTCP, allowedUDP []uint16
|
var allowedTCP, allowedUDP []uint16
|
||||||
switch vpnProvider {
|
switch vpnProvider {
|
||||||
@@ -123,8 +121,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
|||||||
}
|
}
|
||||||
err = validate.IsOneOf(*o.CustomPort, allowedPorts...)
|
err = validate.IsOneOf(*o.CustomPort, allowedPorts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: for VPN service provider %s: %w",
|
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
|
||||||
ErrOpenVPNCustomPortNotAllowed, vpnProvider, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -136,7 +133,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
|||||||
presets.Strong,
|
presets.Strong,
|
||||||
}
|
}
|
||||||
if err = validate.IsOneOf(*o.PIAEncPreset, validEncryptionPresets...); err != nil {
|
if err = validate.IsOneOf(*o.PIAEncPreset, validEncryptionPresets...); err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrOpenVPNEncryptionPresetNotValid, err)
|
return fmt.Errorf("PIA encryption preset is not valid: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
@@ -24,21 +23,16 @@ type PMTUD struct {
|
|||||||
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
|
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
|
|
||||||
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Validate validates PMTUD settings.
|
// Validate validates PMTUD settings.
|
||||||
func (p PMTUD) validate() (err error) {
|
func (p PMTUD) validate() (err error) {
|
||||||
for i, addr := range p.ICMPAddresses {
|
for i, addr := range p.ICMPAddresses {
|
||||||
if !addr.IsValid() {
|
if !addr.IsValid() {
|
||||||
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
|
return fmt.Errorf("PMTUD ICMP address is not valid: at index %d", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i, addr := range p.TCPAddresses {
|
for i, addr := range p.TCPAddresses {
|
||||||
if !addr.IsValid() {
|
if !addr.IsValid() {
|
||||||
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
|
return fmt.Errorf("PMTUD TCP address is not valid: at index %d", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -55,12 +55,6 @@ type PortForwarding struct {
|
|||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrPortsCountTooHigh = errors.New("ports count too high")
|
|
||||||
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
|
|
||||||
ErrListeningPortZero = errors.New("listening port cannot be 0")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
||||||
if !*p.Enabled {
|
if !*p.Enabled {
|
||||||
return nil
|
return nil
|
||||||
@@ -78,7 +72,7 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
|||||||
providers.Protonvpn,
|
providers.Protonvpn,
|
||||||
}
|
}
|
||||||
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
|
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err)
|
return fmt.Errorf("port forwarding cannot be enabled: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Filepath
|
// Validate Filepath
|
||||||
@@ -94,30 +88,31 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
|||||||
const maxPortsCount = 1
|
const maxPortsCount = 1
|
||||||
switch {
|
switch {
|
||||||
case p.PortsCount > maxPortsCount:
|
case p.PortsCount > maxPortsCount:
|
||||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
|
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
|
||||||
case p.Username == "":
|
case p.Username == "":
|
||||||
return fmt.Errorf("%w", ErrPortForwardingUserEmpty)
|
return errors.New("port forwarding username is empty")
|
||||||
case p.Password == "":
|
case p.Password == "":
|
||||||
return fmt.Errorf("%w", ErrPortForwardingPasswordEmpty)
|
return errors.New("port forwarding password is empty")
|
||||||
}
|
}
|
||||||
case providers.Protonvpn:
|
case providers.Protonvpn:
|
||||||
const maxPortsCount = 4
|
const maxPortsCount = 4
|
||||||
if p.PortsCount > maxPortsCount {
|
if p.PortsCount > maxPortsCount {
|
||||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
|
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
const maxPortsCount = 1
|
const maxPortsCount = 1
|
||||||
if p.PortsCount > maxPortsCount {
|
if p.PortsCount > maxPortsCount {
|
||||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
|
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !slices.Equal(p.ListeningPorts, []uint16{0}) {
|
if !slices.Equal(p.ListeningPorts, []uint16{0}) {
|
||||||
switch {
|
switch {
|
||||||
case len(p.ListeningPorts) != int(p.PortsCount):
|
case len(p.ListeningPorts) != int(p.PortsCount):
|
||||||
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(p.ListeningPorts), p.PortsCount)
|
return fmt.Errorf("listening ports length must be equal to ports count: "+
|
||||||
|
"%d != %d", len(p.ListeningPorts), p.PortsCount)
|
||||||
case slices.Contains(p.ListeningPorts, 0):
|
case slices.Contains(p.ListeningPorts, 0):
|
||||||
return fmt.Errorf("%w: in %v", ErrListeningPortZero, p.ListeningPorts)
|
return fmt.Errorf("listening port cannot be 0: in %v", p.ListeningPorts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
|
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
|
||||||
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err)
|
return fmt.Errorf("VPN provider name is not valid for %s: %w", vpnType, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ func Test_PublicIP_read(t *testing.T) {
|
|||||||
makeReader func(ctrl *gomock.Controller) *reader.Reader
|
makeReader func(ctrl *gomock.Controller) *reader.Reader
|
||||||
makeWarner func(ctrl *gomock.Controller) Warner
|
makeWarner func(ctrl *gomock.Controller) Warner
|
||||||
settings PublicIP
|
settings PublicIP
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"nothing_read": {
|
"nothing_read": {
|
||||||
@@ -152,9 +151,10 @@ func Test_PublicIP_read(t *testing.T) {
|
|||||||
err := settings.read(reader, warner)
|
err := settings.read(reader, warner)
|
||||||
|
|
||||||
assert.Equal(t, testCase.settings, settings)
|
assert.Equal(t, testCase.settings, settings)
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errWrapped != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,8 +46,7 @@ func (c ControlServer) validate() (err error) {
|
|||||||
uid := os.Getuid()
|
uid := os.Getuid()
|
||||||
const maxPrivilegedPort = 1023
|
const maxPrivilegedPort = 1023
|
||||||
if uid != 0 && port != 0 && port <= maxPrivilegedPort {
|
if uid != 0 && port != 0 && port <= maxPrivilegedPort {
|
||||||
return fmt.Errorf("%w: %d when running with user ID %d",
|
return fmt.Errorf("cannot use privileged port without running as root: %d when running with user ID %d", port, uid)
|
||||||
ErrControlServerPrivilegedPort, port, uid)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonDecoder := json.NewDecoder(bytes.NewBufferString(c.AuthDefaultRole))
|
jsonDecoder := json.NewDecoder(bytes.NewBufferString(c.AuthDefaultRole))
|
||||||
|
|||||||
@@ -71,25 +71,13 @@ type ServerSelection struct {
|
|||||||
Wireguard WireguardSelection `json:"wireguard"`
|
Wireguard WireguardSelection `json:"wireguard"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrOwnedOnlyNotSupported = errors.New("owned only filter is not supported")
|
|
||||||
ErrFreeOnlyNotSupported = errors.New("free only filter is not supported")
|
|
||||||
ErrPremiumOnlyNotSupported = errors.New("premium only filter is not supported")
|
|
||||||
ErrStreamOnlyNotSupported = errors.New("stream only filter is not supported")
|
|
||||||
ErrMultiHopOnlyNotSupported = errors.New("multi hop only filter is not supported")
|
|
||||||
ErrPortForwardOnlyNotSupported = errors.New("port forwarding only filter is not supported")
|
|
||||||
ErrFreePremiumBothSet = errors.New("free only and premium only filters are both set")
|
|
||||||
ErrSecureCoreOnlyNotSupported = errors.New("secure core only filter is not supported")
|
|
||||||
ErrTorOnlyNotSupported = errors.New("tor only filter is not supported")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (ss *ServerSelection) validate(vpnServiceProvider string,
|
func (ss *ServerSelection) validate(vpnServiceProvider string,
|
||||||
filterChoicesGetter FilterChoicesGetter, warner Warner,
|
filterChoicesGetter FilterChoicesGetter, warner Warner,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
switch ss.VPN {
|
switch ss.VPN {
|
||||||
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
|
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
|
return fmt.Errorf("VPN type is not valid: %s", ss.VPN)
|
||||||
}
|
}
|
||||||
|
|
||||||
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, filterChoicesGetter, warner)
|
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, filterChoicesGetter, warner)
|
||||||
@@ -150,7 +138,7 @@ func getLocationFilterChoices(vpnServiceProvider string,
|
|||||||
// Only return error comparing with newer regions, we don't want to confuse the user
|
// Only return error comparing with newer regions, we don't want to confuse the user
|
||||||
// with the retro regions in the error message.
|
// with the retro regions in the error message.
|
||||||
err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner)
|
err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner)
|
||||||
return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
return models.FilterChoices{}, fmt.Errorf("the region specified is not valid: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,27 +152,27 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
|||||||
) (err error) {
|
) (err error) {
|
||||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrCountryNotValid, err)
|
return fmt.Errorf("the country specified is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
return fmt.Errorf("the region specified is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrCityNotValid, err)
|
return fmt.Errorf("the city specified is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrISPNotValid, err)
|
return fmt.Errorf("the ISP specified is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
return fmt.Errorf("the hostname specified is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if vpnServiceProvider == providers.Custom {
|
if vpnServiceProvider == providers.Custom {
|
||||||
@@ -196,19 +184,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
|||||||
// which requires a server name for TLS verification.
|
// which requires a server name for TLS verification.
|
||||||
filterChoices.Names = settings.Names
|
filterChoices.Names = settings.Names
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %d names specified instead of "+
|
return fmt.Errorf("name is not valid: "+
|
||||||
"0 or 1 for the custom provider",
|
"%d names specified instead of 0 or 1 for the custom provider",
|
||||||
ErrNameNotValid, len(settings.Names))
|
len(settings.Names))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrNameNotValid, err)
|
return fmt.Errorf("the server name specified is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrCategoryNotValid, err)
|
return fmt.Errorf("the category specified is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -255,12 +243,12 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
|
|||||||
switch {
|
switch {
|
||||||
case *settings.FreeOnly &&
|
case *settings.FreeOnly &&
|
||||||
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
||||||
return fmt.Errorf("%w", ErrFreeOnlyNotSupported)
|
return errors.New("free only filter is not supported")
|
||||||
case *settings.PremiumOnly &&
|
case *settings.PremiumOnly &&
|
||||||
!helpers.IsOneOf(vpnServiceProvider, providers.VPNSecure):
|
!helpers.IsOneOf(vpnServiceProvider, providers.VPNSecure):
|
||||||
return fmt.Errorf("%w", ErrPremiumOnlyNotSupported)
|
return errors.New("premium only filter is not supported")
|
||||||
case *settings.FreeOnly && *settings.PremiumOnly:
|
case *settings.FreeOnly && *settings.PremiumOnly:
|
||||||
return fmt.Errorf("%w", ErrFreePremiumBothSet)
|
return errors.New("free only and premium only filters are both set")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -269,21 +257,21 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
|
|||||||
func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string) error {
|
func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string) error {
|
||||||
switch {
|
switch {
|
||||||
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
|
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
|
||||||
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
|
return errors.New("owned only filter is not supported")
|
||||||
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
|
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
|
||||||
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
|
return errors.New("port forwarding only filter is not supported: together with free only filter")
|
||||||
case *settings.StreamOnly &&
|
case *settings.StreamOnly &&
|
||||||
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
||||||
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)
|
return errors.New("stream only filter is not supported")
|
||||||
case *settings.MultiHopOnly && vpnServiceProvider != providers.Surfshark:
|
case *settings.MultiHopOnly && vpnServiceProvider != providers.Surfshark:
|
||||||
return fmt.Errorf("%w", ErrMultiHopOnlyNotSupported)
|
return errors.New("multi hop only filter is not supported")
|
||||||
case *settings.PortForwardOnly &&
|
case *settings.PortForwardOnly &&
|
||||||
!helpers.IsOneOf(vpnServiceProvider, providers.PrivateInternetAccess, providers.Protonvpn):
|
!helpers.IsOneOf(vpnServiceProvider, providers.PrivateInternetAccess, providers.Protonvpn):
|
||||||
return fmt.Errorf("%w", ErrPortForwardOnlyNotSupported)
|
return errors.New("port forwarding only filter is not supported")
|
||||||
case *settings.SecureCoreOnly && vpnServiceProvider != providers.Protonvpn:
|
case *settings.SecureCoreOnly && vpnServiceProvider != providers.Protonvpn:
|
||||||
return fmt.Errorf("%w", ErrSecureCoreOnlyNotSupported)
|
return errors.New("secure core only filter is not supported")
|
||||||
case *settings.TorOnly && vpnServiceProvider != providers.Protonvpn:
|
case *settings.TorOnly && vpnServiceProvider != providers.Protonvpn:
|
||||||
return fmt.Errorf("%w", ErrTorOnlyNotSupported)
|
return errors.New("tor only filter is not supported")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -37,20 +38,20 @@ type Updater struct {
|
|||||||
func (u Updater) Validate() (err error) {
|
func (u Updater) Validate() (err error) {
|
||||||
const minPeriod = time.Minute
|
const minPeriod = time.Minute
|
||||||
if *u.Period > 0 && *u.Period < minPeriod {
|
if *u.Period > 0 && *u.Period < minPeriod {
|
||||||
return fmt.Errorf("%w: %d must be larger than %s",
|
return fmt.Errorf("VPN server data updater period is too small: "+
|
||||||
ErrUpdaterPeriodTooSmall, *u.Period, minPeriod)
|
"%d must be larger than %s", *u.Period, minPeriod)
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.MinRatio <= 0 || u.MinRatio > 1 {
|
if u.MinRatio <= 0 || u.MinRatio > 1 {
|
||||||
return fmt.Errorf("%w: %.2f must be between 0+ and 1",
|
return fmt.Errorf("minimum ratio is not valid: "+
|
||||||
ErrMinRatioNotValid, u.MinRatio)
|
"%.2f must be between 0+ and 1", u.MinRatio)
|
||||||
}
|
}
|
||||||
|
|
||||||
validProviders := providers.All()
|
validProviders := providers.All()
|
||||||
for _, provider := range u.Providers {
|
for _, provider := range u.Providers {
|
||||||
err = validate.IsOneOf(provider, validProviders...)
|
err = validate.IsOneOf(provider, validProviders...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrVPNProviderNameNotValid, err)
|
return fmt.Errorf("VPN provider name is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if provider == providers.Protonvpn {
|
if provider == providers.Protonvpn {
|
||||||
@@ -58,9 +59,9 @@ func (u Updater) Validate() (err error) {
|
|||||||
if authenticatedAPI {
|
if authenticatedAPI {
|
||||||
switch {
|
switch {
|
||||||
case *u.ProtonEmail == "":
|
case *u.ProtonEmail == "":
|
||||||
return fmt.Errorf("%w", ErrUpdaterProtonEmailMissing)
|
return errors.New("proton email is missing")
|
||||||
case *u.ProtonPassword == "":
|
case *u.ProtonPassword == "":
|
||||||
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing)
|
return errors.New("proton password is missing")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
|
|||||||
// Validate Type
|
// Validate Type
|
||||||
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
|
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
|
||||||
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
|
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
|
return fmt.Errorf("VPN type is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = v.Provider.validate(v.Type, filterChoicesGetter, warner)
|
err = v.Provider.validate(v.Type, filterChoicesGetter, warner)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -54,7 +55,7 @@ var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
|||||||
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
|
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
|
||||||
// Validate PrivateKey
|
// Validate PrivateKey
|
||||||
if *w.PrivateKey == "" {
|
if *w.PrivateKey == "" {
|
||||||
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
|
return errors.New("private key is not set")
|
||||||
}
|
}
|
||||||
_, err = wgtypes.ParseKey(*w.PrivateKey)
|
_, err = wgtypes.ParseKey(*w.PrivateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -68,7 +69,7 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
|
|||||||
|
|
||||||
if vpnProvider == providers.Airvpn {
|
if vpnProvider == providers.Airvpn {
|
||||||
if *w.PreSharedKey == "" {
|
if *w.PreSharedKey == "" {
|
||||||
return fmt.Errorf("%w", ErrWireguardPreSharedKeyNotSet)
|
return errors.New("pre-shared key is not set")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,17 +83,15 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
|
|||||||
|
|
||||||
// Validate Addresses
|
// Validate Addresses
|
||||||
if len(w.Addresses) == 0 {
|
if len(w.Addresses) == 0 {
|
||||||
return fmt.Errorf("%w", ErrWireguardInterfaceAddressNotSet)
|
return errors.New("interface address is not set")
|
||||||
}
|
}
|
||||||
for i, ipNet := range w.Addresses {
|
for i, ipNet := range w.Addresses {
|
||||||
if !ipNet.IsValid() {
|
if !ipNet.IsValid() {
|
||||||
return fmt.Errorf("%w: for address at index %d",
|
return fmt.Errorf("interface address is not set: for address at index %d", i)
|
||||||
ErrWireguardInterfaceAddressNotSet, i)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ipv6Supported && ipNet.Addr().Is6() {
|
if !ipv6Supported && ipNet.Addr().Is6() {
|
||||||
return fmt.Errorf("%w: address %s",
|
return fmt.Errorf("interface address is IPv6 but IPv6 is not supported: address %s", ipNet.String())
|
||||||
ErrWireguardInterfaceAddressIPv6, ipNet.String())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,30 +99,27 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
|
|||||||
// WARNING: do not check for IPv6 networks in the allowed IPs,
|
// WARNING: do not check for IPv6 networks in the allowed IPs,
|
||||||
// the wireguard code will take care to ignore it.
|
// the wireguard code will take care to ignore it.
|
||||||
if len(w.AllowedIPs) == 0 {
|
if len(w.AllowedIPs) == 0 {
|
||||||
return fmt.Errorf("%w", ErrWireguardAllowedIPsNotSet)
|
return errors.New("allowed IPs is not set")
|
||||||
}
|
}
|
||||||
for i, allowedIP := range w.AllowedIPs {
|
for i, allowedIP := range w.AllowedIPs {
|
||||||
if !allowedIP.IsValid() {
|
if !allowedIP.IsValid() {
|
||||||
return fmt.Errorf("%w: for allowed ip %d of %d",
|
return fmt.Errorf("allowed IP is not set: for allowed ip %d of %d", i+1, len(w.AllowedIPs))
|
||||||
ErrWireguardAllowedIPNotSet, i+1, len(w.AllowedIPs))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if *w.PersistentKeepaliveInterval < 0 {
|
if *w.PersistentKeepaliveInterval < 0 {
|
||||||
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative,
|
return fmt.Errorf("persistent keep alive interval is negative: %s", *w.PersistentKeepaliveInterval)
|
||||||
*w.PersistentKeepaliveInterval)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate interface
|
// Validate interface
|
||||||
if !regexpInterfaceName.MatchString(w.Interface) {
|
if !regexpInterfaceName.MatchString(w.Interface) {
|
||||||
return fmt.Errorf("%w: '%s' does not match regex '%s'",
|
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", w.Interface, regexpInterfaceName)
|
||||||
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
|
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
|
||||||
validImplementations := []string{"auto", "userspace", "kernelspace"}
|
validImplementations := []string{"auto", "userspace", "kernelspace"}
|
||||||
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
|
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
|
return fmt.Errorf("implementation is not valid: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
@@ -44,7 +45,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
|||||||
// endpoint IP addresses are baked in
|
// endpoint IP addresses are baked in
|
||||||
case providers.Custom:
|
case providers.Custom:
|
||||||
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
|
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
|
||||||
return fmt.Errorf("%w", ErrWireguardEndpointIPNotSet)
|
return errors.New("endpoint IP is not set")
|
||||||
}
|
}
|
||||||
default: // Providers not supporting Wireguard
|
default: // Providers not supporting Wireguard
|
||||||
}
|
}
|
||||||
@@ -54,13 +55,13 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
|||||||
// EndpointPort is required
|
// EndpointPort is required
|
||||||
case providers.Custom:
|
case providers.Custom:
|
||||||
if *w.EndpointPort == 0 {
|
if *w.EndpointPort == 0 {
|
||||||
return fmt.Errorf("%w", ErrWireguardEndpointPortNotSet)
|
return errors.New("endpoint port is not set")
|
||||||
}
|
}
|
||||||
// EndpointPort cannot be set
|
// EndpointPort cannot be set
|
||||||
case providers.Fastestvpn, providers.Nordvpn,
|
case providers.Fastestvpn, providers.Nordvpn,
|
||||||
providers.Protonvpn, providers.Surfshark:
|
providers.Protonvpn, providers.Surfshark:
|
||||||
if *w.EndpointPort != 0 {
|
if *w.EndpointPort != 0 {
|
||||||
return fmt.Errorf("%w", ErrWireguardEndpointPortSet)
|
return errors.New("endpoint port is set")
|
||||||
}
|
}
|
||||||
case providers.Airvpn, providers.Ivpn, providers.Mullvad, providers.Windscribe:
|
case providers.Airvpn, providers.Ivpn, providers.Mullvad, providers.Windscribe:
|
||||||
// EndpointPort is optional and can be 0
|
// EndpointPort is optional and can be 0
|
||||||
@@ -84,8 +85,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%w: for VPN service provider %s: %w",
|
return fmt.Errorf("endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
|
||||||
ErrWireguardEndpointPortNotAllowed, vpnProvider, err)
|
|
||||||
default: // Providers not supporting Wireguard
|
default: // Providers not supporting Wireguard
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,15 +96,14 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
|||||||
// public keys are baked in
|
// public keys are baked in
|
||||||
case providers.Custom:
|
case providers.Custom:
|
||||||
if w.PublicKey == "" {
|
if w.PublicKey == "" {
|
||||||
return fmt.Errorf("%w", ErrWireguardPublicKeyNotSet)
|
return errors.New("public key is not set")
|
||||||
}
|
}
|
||||||
default: // Providers not supporting Wireguard
|
default: // Providers not supporting Wireguard
|
||||||
}
|
}
|
||||||
if w.PublicKey != "" {
|
if w.PublicKey != "" {
|
||||||
_, err := wgtypes.ParseKey(w.PublicKey)
|
_, err := wgtypes.ParseKey(w.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s: %s",
|
return fmt.Errorf("public key is not valid: %s: %s", w.PublicKey, err)
|
||||||
ErrWireguardPublicKeyNotValid, w.PublicKey, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -74,8 +74,6 @@ func parseWireguardInterfaceSection(interfaceSection *ini.Section) (
|
|||||||
return privateKey, addresses
|
return privateKey, addresses
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrEndpointHostNotIP = errors.New("endpoint host is not an IP")
|
|
||||||
|
|
||||||
func parseWireguardPeerSection(peerSection *ini.Section) (
|
func parseWireguardPeerSection(peerSection *ini.Section) (
|
||||||
preSharedKey, publicKey, endpointIP, endpointPort *string,
|
preSharedKey, publicKey, endpointIP, endpointPort *string,
|
||||||
) {
|
) {
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
@@ -63,8 +62,6 @@ func generateRandomString(length uint) string {
|
|||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errIPLeakSessionMismatch = errors.New("ipleak.net session mismatch")
|
|
||||||
|
|
||||||
func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
|
func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
|
||||||
dnsToCount map[string]uint, err error,
|
dnsToCount map[string]uint, err error,
|
||||||
) {
|
) {
|
||||||
@@ -93,7 +90,7 @@ func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decoding response: %w", err)
|
return nil, fmt.Errorf("decoding response: %w", err)
|
||||||
} else if data.Session != session {
|
} else if data.Session != session {
|
||||||
return nil, fmt.Errorf("%w: expected %s, got %s", errIPLeakSessionMismatch, session, data.Session)
|
return nil, fmt.Errorf("ipleak.net session mismatch: expected %s, got %s", session, data.Session)
|
||||||
}
|
}
|
||||||
|
|
||||||
return data.IP, nil
|
return data.IP, nil
|
||||||
|
|||||||
@@ -57,18 +57,15 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
const iptablesBinary = "/sbin/iptables"
|
const iptablesBinary = "/sbin/iptables"
|
||||||
errTest := errors.New("test error")
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
instruction string
|
instruction string
|
||||||
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
|
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
|
||||||
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"invalid_instruction": {
|
"invalid_instruction": {
|
||||||
instruction: "invalid",
|
instruction: "invalid",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
|
||||||
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||||
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||||
},
|
},
|
||||||
@@ -78,7 +75,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
|||||||
runner := NewMockCmdRunner(ctrl)
|
runner := NewMockCmdRunner(ctrl)
|
||||||
runner.EXPECT().
|
runner.EXPECT().
|
||||||
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||||
Return("", errTest)
|
Return("", errors.New("test error"))
|
||||||
return runner
|
return runner
|
||||||
},
|
},
|
||||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||||
@@ -86,7 +83,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
|||||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||||
return logger
|
return logger
|
||||||
},
|
},
|
||||||
errWrapped: errTest,
|
|
||||||
errMessage: `finding iptables chain rule line number: command failed: ` +
|
errMessage: `finding iptables chain rule line number: command failed: ` +
|
||||||
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
||||||
},
|
},
|
||||||
@@ -120,7 +116,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
|||||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||||
nil)
|
nil)
|
||||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
|
"^-D$", "^PREROUTING$", "^2$")).Return("details", errors.New("test error"))
|
||||||
return runner
|
return runner
|
||||||
},
|
},
|
||||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||||
@@ -131,7 +127,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
|||||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||||
return logger
|
return logger
|
||||||
},
|
},
|
||||||
errWrapped: errTest,
|
|
||||||
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
||||||
},
|
},
|
||||||
"rule_found_delete_success": {
|
"rule_found_delete_success": {
|
||||||
@@ -177,9 +172,10 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
|||||||
|
|
||||||
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
||||||
|
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errWrapped != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,13 +82,11 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrPolicyNotValid = errors.New("policy is not valid")
|
|
||||||
|
|
||||||
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
|
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||||
switch policy {
|
switch policy {
|
||||||
case "ACCEPT", "DROP":
|
case "ACCEPT", "DROP":
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
return fmt.Errorf("policy is not valid: %s", policy)
|
||||||
}
|
}
|
||||||
return c.runIP6tablesInstructions(ctx, []string{
|
return c.runIP6tablesInstructions(ctx, []string{
|
||||||
"--policy INPUT " + policy,
|
"--policy INPUT " + policy,
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package iptables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -13,10 +12,8 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const (
|
||||||
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
needIP6Tables = "ip6tables is required, please upgrade your kernel"
|
||||||
ErrPolicyUnknown = errors.New("unknown policy")
|
|
||||||
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func appendOrDelete(remove bool) string {
|
func appendOrDelete(remove bool) string {
|
||||||
@@ -36,7 +33,7 @@ func (c *Config) Version(ctx context.Context) (string, error) {
|
|||||||
words := strings.Fields(output)
|
words := strings.Fields(output)
|
||||||
const minWords = 2
|
const minWords = 2
|
||||||
if len(words) < minWords {
|
if len(words) < minWords {
|
||||||
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
|
return "", fmt.Errorf("iptables version string is too short: %s", output)
|
||||||
}
|
}
|
||||||
return "iptables " + words[1], nil
|
return "iptables " + words[1], nil
|
||||||
}
|
}
|
||||||
@@ -102,7 +99,7 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
|
|||||||
switch policy {
|
switch policy {
|
||||||
case "ACCEPT", "DROP":
|
case "ACCEPT", "DROP":
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
return fmt.Errorf("unknown policy: %s", policy)
|
||||||
}
|
}
|
||||||
return c.runIptablesInstructions(ctx, []string{
|
return c.runIptablesInstructions(ctx, []string{
|
||||||
"--policy INPUT " + policy,
|
"--policy INPUT " + policy,
|
||||||
@@ -129,7 +126,7 @@ func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destinati
|
|||||||
return c.runIptablesInstruction(ctx, instruction)
|
return c.runIptablesInstruction(ctx, instruction)
|
||||||
}
|
}
|
||||||
if c.ip6Tables == "" {
|
if c.ip6Tables == "" {
|
||||||
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
|
return fmt.Errorf("accept input to subnet %s: %s", destination, needIP6Tables)
|
||||||
}
|
}
|
||||||
return c.runIP6tablesInstruction(ctx, instruction)
|
return c.runIP6tablesInstruction(ctx, instruction)
|
||||||
}
|
}
|
||||||
@@ -157,7 +154,7 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
|
|||||||
if connection.IP.Is4() {
|
if connection.IP.Is4() {
|
||||||
return c.runIptablesInstruction(ctx, instruction)
|
return c.runIptablesInstruction(ctx, instruction)
|
||||||
} else if c.ip6Tables == "" {
|
} else if c.ip6Tables == "" {
|
||||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
return fmt.Errorf("accept output to VPN server %s: %s", connection.IP, needIP6Tables)
|
||||||
}
|
}
|
||||||
return c.runIP6tablesInstruction(ctx, instruction)
|
return c.runIP6tablesInstruction(ctx, instruction)
|
||||||
}
|
}
|
||||||
@@ -175,7 +172,7 @@ func (c *Config) AcceptOutput(ctx context.Context,
|
|||||||
if ip.Is4() {
|
if ip.Is4() {
|
||||||
return c.runIptablesInstruction(ctx, instruction)
|
return c.runIptablesInstruction(ctx, instruction)
|
||||||
} else if c.ip6Tables == "" {
|
} else if c.ip6Tables == "" {
|
||||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
return fmt.Errorf("accept output to VPN server %s: %s", ip, needIP6Tables)
|
||||||
}
|
}
|
||||||
return c.runIP6tablesInstruction(ctx, instruction)
|
return c.runIP6tablesInstruction(ctx, instruction)
|
||||||
}
|
}
|
||||||
@@ -200,7 +197,7 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
|
|||||||
if doIPv4 {
|
if doIPv4 {
|
||||||
return c.runIptablesInstruction(ctx, instruction)
|
return c.runIptablesInstruction(ctx, instruction)
|
||||||
} else if c.ip6Tables == "" {
|
} else if c.ip6Tables == "" {
|
||||||
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
|
return fmt.Errorf("accept output from %s to %s: %s", sourceIP, destinationSubnet, needIP6Tables)
|
||||||
}
|
}
|
||||||
return c.runIP6tablesInstruction(ctx, instruction)
|
return c.runIP6tablesInstruction(ctx, instruction)
|
||||||
}
|
}
|
||||||
@@ -350,7 +347,7 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
|
|||||||
case ipv4:
|
case ipv4:
|
||||||
err = c.runIptablesInstructionNoSave(ctx, rule)
|
err = c.runIptablesInstructionNoSave(ctx, rule)
|
||||||
case c.ip6Tables == "":
|
case c.ip6Tables == "":
|
||||||
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
|
err = fmt.Errorf("running user ip6tables rule: %s", needIP6Tables)
|
||||||
default: // ipv6
|
default: // ipv6
|
||||||
err = c.runIP6tablesInstructionNoSave(ctx, rule)
|
err = c.runIP6tablesInstructionNoSave(ctx, rule)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,8 +40,6 @@ type mark struct {
|
|||||||
value uint
|
value uint
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
|
||||||
|
|
||||||
func parseChain(iptablesOutput string) (c chain, err error) {
|
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||||
// Text example:
|
// Text example:
|
||||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||||
@@ -63,8 +61,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
|
|||||||
|
|
||||||
const minLines = 2 // chain general information line + legend line
|
const minLines = 2 // chain general information line + legend line
|
||||||
if len(lines) < minLines {
|
if len(lines) < minLines {
|
||||||
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
return chain{}, fmt.Errorf("iptables chain list output is malformed: not enough lines to process in: %s",
|
||||||
ErrChainListMalformed, iptablesOutput)
|
iptablesOutput)
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err = parseChainGeneralDataLine(lines[0])
|
c, err = parseChainGeneralDataLine(lines[0])
|
||||||
@@ -77,8 +75,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
|
|||||||
legendLine := strings.TrimSpace(lines[1])
|
legendLine := strings.TrimSpace(lines[1])
|
||||||
legendFields := strings.Fields(legendLine)
|
legendFields := strings.Fields(legendLine)
|
||||||
if !slices.Equal(expectedLegendFields, legendFields) {
|
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||||
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
return chain{}, fmt.Errorf("iptables chain list output is malformed: legend %q is not the expected %q",
|
||||||
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
legendLine, strings.Join(expectedLegendFields, " "))
|
||||||
}
|
}
|
||||||
|
|
||||||
lines = lines[2:] // remove chain general information line and legend line
|
lines = lines[2:] // remove chain general information line and legend line
|
||||||
@@ -111,8 +109,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
|||||||
fields := strings.Fields(line)
|
fields := strings.Fields(line)
|
||||||
const expectedNumberOfFields = 8
|
const expectedNumberOfFields = 8
|
||||||
if len(fields) != expectedNumberOfFields {
|
if len(fields) != expectedNumberOfFields {
|
||||||
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %d fields in %q",
|
||||||
ErrChainListMalformed, expectedNumberOfFields, line)
|
expectedNumberOfFields, line)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sanity checks
|
// Sanity checks
|
||||||
@@ -126,8 +124,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
|||||||
if fields[index] == expectedValue {
|
if fields[index] == expectedValue {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %q for field %d in %q",
|
||||||
ErrChainListMalformed, expectedValue, index, line)
|
expectedValue, index, line)
|
||||||
}
|
}
|
||||||
|
|
||||||
base.name = fields[1] // chain name could be custom
|
base.name = fields[1] // chain name could be custom
|
||||||
@@ -152,19 +150,17 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
|||||||
return base, nil
|
return base, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
|
||||||
|
|
||||||
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
if line == "" {
|
if line == "" {
|
||||||
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
return chainRule{}, errors.New("chain rule is malformed: empty line")
|
||||||
}
|
}
|
||||||
|
|
||||||
fields := strings.Fields(line)
|
fields := strings.Fields(line)
|
||||||
|
|
||||||
const minFields = 10
|
const minFields = 10
|
||||||
if len(fields) < minFields {
|
if len(fields) < minFields {
|
||||||
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
return chainRule{}, errors.New("chain rule is malformed: not enough fields")
|
||||||
}
|
}
|
||||||
|
|
||||||
for fieldIndex, field := range fields[:minFields] {
|
for fieldIndex, field := range fields[:minFields] {
|
||||||
@@ -186,7 +182,7 @@ func parseChainRuleLine(line string) (rule chainRule, err error) {
|
|||||||
|
|
||||||
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||||
if field == "" {
|
if field == "" {
|
||||||
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
return fmt.Errorf("chain rule is malformed: empty field at index %d", fieldIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -278,8 +274,8 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
|||||||
rule.redirPorts = ports
|
rule.redirPorts = ports
|
||||||
i++
|
i++
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: unexpected %q after redir",
|
return fmt.Errorf("chain rule is malformed: unexpected %q after redir",
|
||||||
ErrChainRuleMalformed, optionalFields[1])
|
optionalFields[1])
|
||||||
}
|
}
|
||||||
case "ctstate":
|
case "ctstate":
|
||||||
i++
|
i++
|
||||||
@@ -294,15 +290,13 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
|||||||
rule.mark = mark
|
rule.mark = mark
|
||||||
i += consumed
|
i += consumed
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
return fmt.Errorf("chain rule is malformed: unexpected optional field: %s",
|
||||||
ErrChainRuleMalformed, optionalFields[i])
|
optionalFields[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
|
|
||||||
|
|
||||||
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||||
for _, value := range optionalFields {
|
for _, value := range optionalFields {
|
||||||
if !strings.ContainsRune(value, ':') {
|
if !strings.ContainsRune(value, ':') {
|
||||||
@@ -323,14 +317,12 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
|||||||
}
|
}
|
||||||
consumed++
|
consumed++
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
return 0, fmt.Errorf("unknown UDP optional field: %s", value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return consumed, nil
|
return consumed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
|
|
||||||
|
|
||||||
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||||
for _, value := range optionalFields {
|
for _, value := range optionalFields {
|
||||||
if !strings.ContainsRune(value, ':') {
|
if !strings.ContainsRune(value, ':') {
|
||||||
@@ -357,7 +349,7 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
|||||||
}
|
}
|
||||||
consumed++
|
consumed++
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
|
return 0, fmt.Errorf("unknown TCP optional field: %s", value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return consumed, nil
|
return consumed, nil
|
||||||
@@ -373,15 +365,13 @@ func parseSourcePort(value string) (port uint16, err error) {
|
|||||||
return parsePort(value)
|
return parsePort(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
|
||||||
|
|
||||||
func parseTCPFlags(value string) (tcpFlags, error) {
|
func parseTCPFlags(value string) (tcpFlags, error) {
|
||||||
value = strings.TrimPrefix(value, "flags:")
|
value = strings.TrimPrefix(value, "flags:")
|
||||||
fields := strings.Split(value, "/")
|
fields := strings.Split(value, "/")
|
||||||
const expectedFields = 2
|
const expectedFields = 2
|
||||||
if len(fields) != expectedFields {
|
if len(fields) != expectedFields {
|
||||||
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
|
return tcpFlags{}, fmt.Errorf("TCP flags are malformed: expected format 'flags:<mask>/<comparison>' in %q",
|
||||||
errTCPFlagsMalformed, value)
|
value)
|
||||||
}
|
}
|
||||||
maskFlags := strings.Split(fields[0], ",")
|
maskFlags := strings.Split(fields[0], ",")
|
||||||
mask := make([]tcpFlag, len(maskFlags))
|
mask := make([]tcpFlag, len(maskFlags))
|
||||||
@@ -422,8 +412,6 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
|||||||
return ports, nil
|
return ports, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var errMarkValueMalformed = errors.New("mark value is malformed")
|
|
||||||
|
|
||||||
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||||
switch optionalFields[consumed] {
|
switch optionalFields[consumed] {
|
||||||
case "match":
|
case "match":
|
||||||
@@ -437,42 +425,36 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
|||||||
const bits = 32
|
const bits = 32
|
||||||
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
|
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
|
return mark{}, 0, fmt.Errorf("mark value is malformed: %s", optionalFields[consumed])
|
||||||
}
|
}
|
||||||
m.value = uint(value)
|
m.value = uint(value)
|
||||||
consumed++
|
consumed++
|
||||||
default:
|
default:
|
||||||
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
|
return mark{}, 0, fmt.Errorf("chain rule is malformed: unexpected mark mode field: %s",
|
||||||
ErrChainRuleMalformed, optionalFields[consumed])
|
optionalFields[consumed])
|
||||||
}
|
}
|
||||||
return m, consumed, nil
|
return m, consumed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrLineNumberIsZero = errors.New("line number is zero")
|
|
||||||
|
|
||||||
func parseLineNumber(s string) (n uint16, err error) {
|
func parseLineNumber(s string) (n uint16, err error) {
|
||||||
const base, bitLength = 10, 16
|
const base, bitLength = 10, 16
|
||||||
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
} else if lineNumber == 0 {
|
} else if lineNumber == 0 {
|
||||||
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
return 0, errors.New("line number is zero")
|
||||||
}
|
}
|
||||||
return uint16(lineNumber), nil
|
return uint16(lineNumber), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrTargetUnknown = errors.New("unknown target")
|
|
||||||
|
|
||||||
func checkTarget(target string) (err error) {
|
func checkTarget(target string) (err error) {
|
||||||
switch target {
|
switch target {
|
||||||
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
return fmt.Errorf("unknown target: %s", target)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrProtocolUnknown = errors.New("unknown protocol")
|
|
||||||
|
|
||||||
func parseProtocol(s string) (protocol string, err error) {
|
func parseProtocol(s string) (protocol string, err error) {
|
||||||
switch s {
|
switch s {
|
||||||
case "0", "all":
|
case "0", "all":
|
||||||
@@ -483,18 +465,16 @@ func parseProtocol(s string) (protocol string, err error) {
|
|||||||
case "17", "udp":
|
case "17", "udp":
|
||||||
protocol = "udp"
|
protocol = "udp"
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
return "", fmt.Errorf("unknown protocol: %s", s)
|
||||||
}
|
}
|
||||||
return protocol, nil
|
return protocol, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
|
||||||
|
|
||||||
// parseMetricSize parses a metric size string like 140K or 226M and
|
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||||
// returns the raw integer matching it.
|
// returns the raw integer matching it.
|
||||||
func parseMetricSize(size string) (n uint64, err error) {
|
func parseMetricSize(size string) (n uint64, err error) {
|
||||||
if size == "" {
|
if size == "" {
|
||||||
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
return n, errors.New("metric size is malformed: empty string")
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:mnd
|
//nolint:mnd
|
||||||
@@ -516,7 +496,7 @@ func parseMetricSize(size string) (n uint64, err error) {
|
|||||||
const base, bitLength = 10, 64
|
const base, bitLength = 10, 64
|
||||||
n, err = strconv.ParseUint(size, base, bitLength)
|
n, err = strconv.ParseUint(size, base, bitLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
return n, fmt.Errorf("metric size is malformed: %w", err)
|
||||||
}
|
}
|
||||||
n *= multiplier
|
n *= multiplier
|
||||||
return n, nil
|
return n, nil
|
||||||
|
|||||||
@@ -13,30 +13,25 @@ func Test_parseChain(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
iptablesOutput string
|
iptablesOutput string
|
||||||
table chain
|
table chain
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"no_output": {
|
"no_output": {
|
||||||
errWrapped: ErrChainListMalformed,
|
|
||||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
||||||
},
|
},
|
||||||
"single_line_only": {
|
"single_line_only": {
|
||||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
||||||
errWrapped: ErrChainListMalformed,
|
|
||||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
||||||
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
||||||
},
|
},
|
||||||
"malformed_general_data_line": {
|
"malformed_general_data_line": {
|
||||||
iptablesOutput: `Chain INPUT
|
iptablesOutput: `Chain INPUT
|
||||||
num pkts bytes target prot opt in out source destination`,
|
num pkts bytes target prot opt in out source destination`,
|
||||||
errWrapped: ErrChainListMalformed,
|
|
||||||
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
||||||
"expected 8 fields in \"Chain INPUT\"",
|
"expected 8 fields in \"Chain INPUT\"",
|
||||||
},
|
},
|
||||||
"malformed_legend": {
|
"malformed_legend": {
|
||||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||||
num pkts bytes target prot opt in out source`,
|
num pkts bytes target prot opt in out source`,
|
||||||
errWrapped: ErrChainListMalformed,
|
|
||||||
errMessage: "iptables chain list output is malformed: legend " +
|
errMessage: "iptables chain list output is malformed: legend " +
|
||||||
"\"num pkts bytes target prot opt in out source\" " +
|
"\"num pkts bytes target prot opt in out source\" " +
|
||||||
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
||||||
@@ -135,9 +130,10 @@ num pkts bytes target prot opt in out source destinati
|
|||||||
table, err := parseChain(testCase.iptablesOutput)
|
table, err := parseChain(testCase.iptablesOutput)
|
||||||
|
|
||||||
assert.Equal(t, testCase.table, table)
|
assert.Equal(t, testCase.table, table)
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errWrapped != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,11 +80,9 @@ func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
|||||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
|
||||||
|
|
||||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
return iptablesInstruction{}, errors.New("iptables command is malformed: empty instruction")
|
||||||
}
|
}
|
||||||
fields := strings.Fields(s)
|
fields := strings.Fields(s)
|
||||||
|
|
||||||
@@ -173,7 +171,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
|||||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
return 0, fmt.Errorf("iptables command is malformed: unknown key %q", flag)
|
||||||
}
|
}
|
||||||
return consumed, nil
|
return consumed, nil
|
||||||
}
|
}
|
||||||
@@ -185,15 +183,15 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
|||||||
case "--tcp-flags": // -m can have 1 or 2 values
|
case "--tcp-flags": // -m can have 1 or 2 values
|
||||||
const expected = 3
|
const expected = 3
|
||||||
if len(fields) < expected {
|
if len(fields) < expected {
|
||||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
return 0, fmt.Errorf("iptables command is malformed: flag %q requires at least 2 values, but got %s",
|
||||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
flag, strings.Join(fields, " "))
|
||||||
}
|
}
|
||||||
return expected, nil
|
return expected, nil
|
||||||
default:
|
default:
|
||||||
const expected = 2
|
const expected = 2
|
||||||
if len(fields) < expected {
|
if len(fields) < expected {
|
||||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
return 0, fmt.Errorf("iptables command is malformed: flag %q requires a value, but got none",
|
||||||
ErrIptablesCommandMalformed, flag)
|
flag)
|
||||||
}
|
}
|
||||||
return expected, nil
|
return expected, nil
|
||||||
}
|
}
|
||||||
@@ -239,12 +237,12 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
|||||||
consumed++
|
consumed++
|
||||||
instruction.mark.invert = true
|
instruction.mark.invert = true
|
||||||
default:
|
default:
|
||||||
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
return consumed, fmt.Errorf("iptables command is malformed: unsupported match mark with value: %s",
|
||||||
ErrIptablesCommandMalformed, fields[2])
|
fields[2])
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: unknown match value: %s",
|
return 0, fmt.Errorf("iptables command is malformed: unknown match value: %s",
|
||||||
ErrIptablesCommandMalformed, fields[consumed])
|
fields[consumed])
|
||||||
}
|
}
|
||||||
return consumed, nil
|
return consumed, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,21 +13,17 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
s string
|
s string
|
||||||
instruction iptablesInstruction
|
instruction iptablesInstruction
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"no_instruction": {
|
"no_instruction": {
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
|
||||||
errMessage: "iptables command is malformed: empty instruction",
|
errMessage: "iptables command is malformed: empty instruction",
|
||||||
},
|
},
|
||||||
"uneven_fields": {
|
"uneven_fields": {
|
||||||
s: "-A",
|
s: "-A",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
|
||||||
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||||
},
|
},
|
||||||
"unknown_key": {
|
"unknown_key": {
|
||||||
s: "-x something",
|
s: "-x something",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
|
||||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||||
},
|
},
|
||||||
"one_pair": {
|
"one_pair": {
|
||||||
@@ -74,9 +70,10 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
|||||||
rule, err := parseIptablesInstruction(testCase.s)
|
rule, err := parseIptablesInstruction(testCase.s)
|
||||||
|
|
||||||
assert.Equal(t, testCase.instruction, rule)
|
assert.Equal(t, testCase.instruction, rule)
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errWrapped != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,12 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var ErrNotSupported = errors.New("no iptables supported found")
|
||||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
|
||||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
|
||||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
|
||||||
ErrNotSupported = errors.New("no iptables supported found")
|
|
||||||
)
|
|
||||||
|
|
||||||
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||||
iptablesPathsToTry ...string,
|
iptablesPathsToTry ...string,
|
||||||
@@ -53,7 +48,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
|||||||
if allArePermissionDenied {
|
if allArePermissionDenied {
|
||||||
// If the error is related to a denied permission for all iptables path,
|
// If the error is related to a denied permission for all iptables path,
|
||||||
// return an error describing what to do from an end-user perspective.
|
// return an error describing what to do from an end-user perspective.
|
||||||
return "", fmt.Errorf("%w: %s", ErrNetAdminMissing, strings.Join(allUnsupportedMessages, "; "))
|
return "", fmt.Errorf("NET_ADMIN capability is missing: %s", strings.Join(allUnsupportedMessages, "; "))
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("%w: errors encountered are: %s",
|
return "", fmt.Errorf("%w: errors encountered are: %s",
|
||||||
@@ -85,7 +80,7 @@ func testIptablesPath(ctx context.Context, path string,
|
|||||||
output, err = runner.Run(cmd)
|
output, err = runner.Run(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// this is a critical error, we want to make sure our test rule gets removed.
|
// this is a critical error, we want to make sure our test rule gets removed.
|
||||||
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err)
|
criticalErr = fmt.Errorf("failed cleaning up test rule: %s (%s)", output, err)
|
||||||
return false, "", criticalErr
|
return false, "", criticalErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +103,7 @@ func testIptablesPath(ctx context.Context, path string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if inputPolicy == "" {
|
if inputPolicy == "" {
|
||||||
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output)
|
criticalErr = fmt.Errorf("input policy not found: in INPUT rules: %s", output)
|
||||||
return false, "", criticalErr
|
return false, "", criticalErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
||||||
@@ -43,7 +42,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
|||||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||||
iptablesPathsToTry []string
|
iptablesPathsToTry []string
|
||||||
iptablesPath string
|
iptablesPath string
|
||||||
errSentinel error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"critical error when checking": {
|
"critical error when checking": {
|
||||||
@@ -56,7 +54,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
|||||||
return runner
|
return runner
|
||||||
},
|
},
|
||||||
iptablesPathsToTry: []string{"path1", "path2"},
|
iptablesPathsToTry: []string{"path1", "path2"},
|
||||||
errSentinel: ErrTestRuleCleanup,
|
|
||||||
errMessage: "for path1: failed cleaning up test rule: " +
|
errMessage: "for path1: failed cleaning up test rule: " +
|
||||||
"output (exit code 4)",
|
"output (exit code 4)",
|
||||||
},
|
},
|
||||||
@@ -86,7 +83,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
|||||||
return runner
|
return runner
|
||||||
},
|
},
|
||||||
iptablesPathsToTry: []string{"path1", "path2"},
|
iptablesPathsToTry: []string{"path1", "path2"},
|
||||||
errSentinel: ErrNetAdminMissing,
|
|
||||||
errMessage: "NET_ADMIN capability is missing: " +
|
errMessage: "NET_ADMIN capability is missing: " +
|
||||||
"path1: Permission denied (you must be root) more context (exit code 4); " +
|
"path1: Permission denied (you must be root) more context (exit code 4); " +
|
||||||
"path2: context: Permission denied (you must be root) (exit code 4)",
|
"path2: context: Permission denied (you must be root) (exit code 4)",
|
||||||
@@ -101,7 +97,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
|||||||
return runner
|
return runner
|
||||||
},
|
},
|
||||||
iptablesPathsToTry: []string{"path1", "path2"},
|
iptablesPathsToTry: []string{"path1", "path2"},
|
||||||
errSentinel: ErrNotSupported,
|
|
||||||
errMessage: "no iptables supported found: " +
|
errMessage: "no iptables supported found: " +
|
||||||
"errors encountered are: " +
|
"errors encountered are: " +
|
||||||
"path1: output 1 (exit code 4); " +
|
"path1: output 1 (exit code 4); " +
|
||||||
@@ -118,9 +113,10 @@ func Test_checkIptablesSupport(t *testing.T) {
|
|||||||
|
|
||||||
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
|
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
|
||||||
|
|
||||||
require.ErrorIs(t, err, testCase.errSentinel)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errSentinel != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
assert.Equal(t, testCase.iptablesPath, iptablesPath)
|
assert.Equal(t, testCase.iptablesPath, iptablesPath)
|
||||||
})
|
})
|
||||||
@@ -139,7 +135,6 @@ func Test_testIptablesPath(t *testing.T) {
|
|||||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||||
ok bool
|
ok bool
|
||||||
unsupportedMessage string
|
unsupportedMessage string
|
||||||
criticalErrWrapped error
|
|
||||||
criticalErrMessage string
|
criticalErrMessage string
|
||||||
}{
|
}{
|
||||||
"append test rule permission denied": {
|
"append test rule permission denied": {
|
||||||
@@ -168,7 +163,6 @@ func Test_testIptablesPath(t *testing.T) {
|
|||||||
Return("some output", errDummy)
|
Return("some output", errDummy)
|
||||||
return runner
|
return runner
|
||||||
},
|
},
|
||||||
criticalErrWrapped: ErrTestRuleCleanup,
|
|
||||||
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
|
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
|
||||||
},
|
},
|
||||||
"list input rules permission denied": {
|
"list input rules permission denied": {
|
||||||
@@ -202,7 +196,6 @@ func Test_testIptablesPath(t *testing.T) {
|
|||||||
Return("some\noutput", nil)
|
Return("some\noutput", nil)
|
||||||
return runner
|
return runner
|
||||||
},
|
},
|
||||||
criticalErrWrapped: ErrInputPolicyNotFound,
|
|
||||||
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
|
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
|
||||||
},
|
},
|
||||||
"set policy permission denied": {
|
"set policy permission denied": {
|
||||||
@@ -257,9 +250,10 @@ func Test_testIptablesPath(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, testCase.ok, ok)
|
assert.Equal(t, testCase.ok, ok)
|
||||||
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
|
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
|
||||||
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped)
|
if testCase.criticalErrMessage != "" {
|
||||||
if testCase.criticalErrWrapped != nil {
|
|
||||||
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
|
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, criticalErr)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,12 +45,10 @@ func (f tcpFlag) String() string {
|
|||||||
case tcpFlagCWR:
|
case tcpFlagCWR:
|
||||||
return "CWR"
|
return "CWR"
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
|
panic(fmt.Sprintf("unknown TCP flag: %d", f))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errTCPFlagUnknown = errors.New("unknown TCP flag")
|
|
||||||
|
|
||||||
func parseTCPFlag(s string) (tcpFlag, error) {
|
func parseTCPFlag(s string) (tcpFlag, error) {
|
||||||
allFlags := []tcpFlag{
|
allFlags := []tcpFlag{
|
||||||
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
||||||
@@ -61,7 +59,7 @@ func parseTCPFlag(s string) (tcpFlag, error) {
|
|||||||
return flag, nil
|
return flag, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
return 0, fmt.Errorf("unknown TCP flag: %s", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
|
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
|
||||||
|
|||||||
@@ -266,8 +266,6 @@ func makeAddressToDial(address string) (addressToDial string, err error) {
|
|||||||
return address, nil
|
return address, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrAllCheckTriesFailed = errors.New("all check tries failed")
|
|
||||||
|
|
||||||
func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
||||||
logger Logger, checkName string, check func(ctx context.Context, try int) error,
|
logger Logger, checkName string, check func(ctx context.Context, try int) error,
|
||||||
) error {
|
) error {
|
||||||
@@ -297,7 +295,7 @@ func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
|||||||
for i, err := range errs {
|
for i, err := range errs {
|
||||||
errStrings[i] = fmt.Sprintf("attempt %d (%dms): %s", i+1, err.durationMS, err.err)
|
errStrings[i] = fmt.Sprintf("attempt %d (%dms): %s", i+1, err.durationMS, err.err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%w:\n\t%s", ErrAllCheckTriesFailed, strings.Join(errStrings, "\n\t"))
|
return fmt.Errorf("all check tries failed:\n\t%s", strings.Join(errStrings, "\n\t"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Checker) startupCheck(ctx context.Context) error {
|
func (c *Checker) startupCheck(ctx context.Context) error {
|
||||||
@@ -342,7 +340,7 @@ func (c *Checker) startupCheck(ctx context.Context) error {
|
|||||||
for i, err := range errs {
|
for i, err := range errs {
|
||||||
errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err)
|
errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
|
return fmt.Errorf("all check tries failed: %s", strings.Join(errStrings, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package healthcheck
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -68,7 +67,7 @@ func Test_makeAddressToDial(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
address string
|
address string
|
||||||
addressToDial string
|
addressToDial string
|
||||||
err error
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"host without port": {
|
"host without port": {
|
||||||
address: "test.com",
|
address: "test.com",
|
||||||
@@ -79,8 +78,8 @@ func Test_makeAddressToDial(t *testing.T) {
|
|||||||
addressToDial: "test.com:80",
|
addressToDial: "test.com:80",
|
||||||
},
|
},
|
||||||
"bad address": {
|
"bad address": {
|
||||||
address: "test.com::",
|
address: "test.com::",
|
||||||
err: fmt.Errorf("splitting host and port from address: address test.com::: too many colons in address"), //nolint:lll
|
errMessage: "splitting host and port from address: address test.com::: too many colons in address",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,8 +90,8 @@ func Test_makeAddressToDial(t *testing.T) {
|
|||||||
addressToDial, err := makeAddressToDial(testCase.address)
|
addressToDial, err := makeAddressToDial(testCase.address)
|
||||||
|
|
||||||
assert.Equal(t, testCase.addressToDial, addressToDial)
|
assert.Equal(t, testCase.addressToDial, addressToDial)
|
||||||
if testCase.err != nil {
|
if testCase.errMessage != "" {
|
||||||
assert.EqualError(t, err, testCase.err.Error())
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,12 @@ package healthcheck
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
|
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
@@ -41,6 +38,6 @@ func (c *Client) Check(ctx context.Context, url string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%w: %d %s: %s", ErrHTTPStatusNotOK,
|
return fmt.Errorf("HTTP response status is not OK: %d %s: %s",
|
||||||
response.StatusCode, response.Status, string(b))
|
response.StatusCode, response.Status, string(b))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -41,8 +40,6 @@ func concatAddrPorts(addrs [][]netip.AddrPort) []netip.AddrPort {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrLookupNoIPs = errors.New("no IPs found from DNS lookup")
|
|
||||||
|
|
||||||
func (c *Client) Check(ctx context.Context) error {
|
func (c *Client) Check(ctx context.Context) error {
|
||||||
dnsAddr := c.serverAddrs[c.dnsIPIndex].String()
|
dnsAddr := c.serverAddrs[c.dnsIPIndex].String()
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
@@ -59,7 +56,7 @@ func (c *Client) Check(ctx context.Context) error {
|
|||||||
return fmt.Errorf("with DNS server %s: %w", dnsAddr, err)
|
return fmt.Errorf("with DNS server %s: %w", dnsAddr, err)
|
||||||
case len(ips) == 0:
|
case len(ips) == 0:
|
||||||
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
|
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
|
||||||
return fmt.Errorf("with DNS server %s: %w", dnsAddr, ErrLookupNoIPs)
|
return fmt.Errorf("with DNS server %s: no IPs found from DNS lookup", dnsAddr)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,11 +12,9 @@ type handler struct {
|
|||||||
logger Logger
|
logger Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var errHealthcheckNotRunYet = errors.New("healthcheck did not run yet")
|
|
||||||
|
|
||||||
func newHandler(logger Logger) *handler {
|
func newHandler(logger Logger) *handler {
|
||||||
return &handler{
|
return &handler{
|
||||||
healthErr: errHealthcheckNotRunYet,
|
healthErr: errors.New("healthcheck did not run yet"),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,11 +19,6 @@ import (
|
|||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
|
||||||
ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Echoer struct {
|
type Echoer struct {
|
||||||
buffer []byte
|
buffer []byte
|
||||||
randomSource io.Reader
|
randomSource io.Reader
|
||||||
@@ -60,10 +55,7 @@ func (e *Echoer) Reset() {
|
|||||||
e.seqStart = time.Now()
|
e.seqStart = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var ErrNotPermitted = errors.New("not permitted")
|
||||||
ErrTimedOut = errors.New("timed out waiting for ICMP echo reply")
|
|
||||||
ErrNotPermitted = errors.New("not permitted")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
|
func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
|
||||||
var ipVersion string
|
var ipVersion string
|
||||||
@@ -114,14 +106,14 @@ func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
|
|||||||
receivedData, err := receiveEchoReply(conn, e.id, e.seq, e.buffer, ipVersion, e.logger)
|
receivedData, err := receiveEchoReply(conn, e.id, e.seq, e.buffer, ipVersion, e.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
|
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
|
||||||
return fmt.Errorf("%w from %s", ErrTimedOut, ip)
|
return fmt.Errorf("timed out waiting for ICMP echo reply from %s", ip)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("receiving ICMP echo reply from %s: %w", ip, err)
|
return fmt.Errorf("receiving ICMP echo reply from %s: %w", ip, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
|
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
|
||||||
if !bytes.Equal(receivedData, sentData) {
|
if !bytes.Equal(receivedData, sentData) {
|
||||||
return fmt.Errorf("%w: sent %x to %s and received %x", ErrICMPEchoDataMismatch, sentData, ip, receivedData)
|
return fmt.Errorf("ICMP data mismatch: sent %x to %s and received %x", sentData, ip, receivedData)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -216,8 +208,9 @@ func receiveEchoReply(conn net.PacketConn, id, seq int, buffer []byte, ipVersion
|
|||||||
message.Code, returnAddr, id, seq)
|
message.Code, returnAddr, id, seq)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("%w: %T (type %d, code %d, return address %s, expected id %d and seq %d)",
|
return nil, fmt.Errorf("ICMP body type is not supported: "+
|
||||||
ErrICMPBodyUnsupported, body, message.Type, message.Code, returnAddr, id, seq)
|
"%T (type %d, code %d, return address %s, expected id %d and seq %d)",
|
||||||
|
body, message.Type, message.Code, returnAddr, id, seq)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:generate mockgen -destination=logger_mock_test.go -package $GOPACKAGE . Logger
|
//go:generate mockgen -destination=logger_mock_test.go -package $GOPACKAGE . Logger
|
||||||
@@ -20,11 +19,9 @@ func Test_New(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
settings Settings
|
settings Settings
|
||||||
expected *Server
|
expected *Server
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"empty settings": {
|
"empty settings": {
|
||||||
errWrapped: ErrHandlerIsNotSet,
|
|
||||||
errMessage: "http server settings validation failed: HTTP handler cannot be left unset",
|
errMessage: "http server settings validation failed: HTTP handler cannot be left unset",
|
||||||
},
|
},
|
||||||
"filled settings": {
|
"filled settings": {
|
||||||
@@ -52,9 +49,10 @@ func Test_New(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
server, err := New(testCase.settings)
|
server, err := New(testCase.settings)
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errWrapped != nil {
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
require.EqualError(t, err, testCase.errMessage)
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if server != nil {
|
if server != nil {
|
||||||
|
|||||||
@@ -64,14 +64,6 @@ func (s *Settings) OverrideWith(other Settings) {
|
|||||||
s.ShutdownTimeout = gosettings.OverrideWithComparable(s.ShutdownTimeout, other.ShutdownTimeout)
|
s.ShutdownTimeout = gosettings.OverrideWithComparable(s.ShutdownTimeout, other.ShutdownTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrHandlerIsNotSet = errors.New("HTTP handler cannot be left unset")
|
|
||||||
ErrLoggerIsNotSet = errors.New("logger cannot be left unset")
|
|
||||||
ErrReadHeaderTimeoutTooSmall = errors.New("read header timeout is too small")
|
|
||||||
ErrReadTimeoutTooSmall = errors.New("read timeout is too small")
|
|
||||||
ErrShutdownTimeoutTooSmall = errors.New("shutdown timeout is too small")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s Settings) Validate() (err error) {
|
func (s Settings) Validate() (err error) {
|
||||||
err = validate.ListeningAddress(s.Address, os.Getuid())
|
err = validate.ListeningAddress(s.Address, os.Getuid())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -79,31 +71,25 @@ func (s Settings) Validate() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.Handler == nil {
|
if s.Handler == nil {
|
||||||
return fmt.Errorf("%w", ErrHandlerIsNotSet)
|
return errors.New("HTTP handler cannot be left unset")
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Logger == nil {
|
if s.Logger == nil {
|
||||||
return fmt.Errorf("%w", ErrLoggerIsNotSet)
|
return errors.New("logger cannot be left unset")
|
||||||
}
|
}
|
||||||
|
|
||||||
const minReadTimeout = time.Millisecond
|
const minReadTimeout = time.Millisecond
|
||||||
if s.ReadHeaderTimeout < minReadTimeout {
|
if s.ReadHeaderTimeout < minReadTimeout {
|
||||||
return fmt.Errorf("%w: %s must be at least %s",
|
return fmt.Errorf("read header timeout is too small: %s must be at least %s", s.ReadHeaderTimeout, minReadTimeout)
|
||||||
ErrReadHeaderTimeoutTooSmall,
|
|
||||||
s.ReadHeaderTimeout, minReadTimeout)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.ReadTimeout < minReadTimeout {
|
if s.ReadTimeout < minReadTimeout {
|
||||||
return fmt.Errorf("%w: %s must be at least %s",
|
return fmt.Errorf("read timeout is too small: %s must be at least %s", s.ReadTimeout, minReadTimeout)
|
||||||
ErrReadTimeoutTooSmall,
|
|
||||||
s.ReadTimeout, minReadTimeout)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const minShutdownTimeout = 5 * time.Millisecond
|
const minShutdownTimeout = 5 * time.Millisecond
|
||||||
if s.ShutdownTimeout < minShutdownTimeout {
|
if s.ShutdownTimeout < minShutdownTimeout {
|
||||||
return fmt.Errorf("%w: %s must be at least %s",
|
return fmt.Errorf("shutdown timeout is too small: %s must be at least %s", s.ShutdownTimeout, minShutdownTimeout)
|
||||||
ErrShutdownTimeoutTooSmall,
|
|
||||||
s.ShutdownTimeout, minShutdownTimeout)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gosettings/validate"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -189,30 +188,26 @@ func Test_Settings_Validate(t *testing.T) {
|
|||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
settings Settings
|
settings Settings
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"bad_address": {
|
"bad_address": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
Address: "address:notanint",
|
Address: "address:notanint",
|
||||||
},
|
},
|
||||||
errWrapped: validate.ErrPortNotAnInteger,
|
|
||||||
errMessage: "port value is not an integer: notanint",
|
errMessage: "port value is not an integer: notanint",
|
||||||
},
|
},
|
||||||
"nil handler": {
|
"nil handler": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
Address: ":8000",
|
Address: ":8000",
|
||||||
},
|
},
|
||||||
errWrapped: ErrHandlerIsNotSet,
|
errMessage: "HTTP handler cannot be left unset",
|
||||||
errMessage: ErrHandlerIsNotSet.Error(),
|
|
||||||
},
|
},
|
||||||
"nil logger": {
|
"nil logger": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
Address: ":8000",
|
Address: ":8000",
|
||||||
Handler: someHandler,
|
Handler: someHandler,
|
||||||
},
|
},
|
||||||
errWrapped: ErrLoggerIsNotSet,
|
errMessage: "logger cannot be left unset",
|
||||||
errMessage: ErrLoggerIsNotSet.Error(),
|
|
||||||
},
|
},
|
||||||
"read header timeout too small": {
|
"read header timeout too small": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -221,7 +216,6 @@ func Test_Settings_Validate(t *testing.T) {
|
|||||||
Logger: someLogger,
|
Logger: someLogger,
|
||||||
ReadHeaderTimeout: time.Nanosecond,
|
ReadHeaderTimeout: time.Nanosecond,
|
||||||
},
|
},
|
||||||
errWrapped: ErrReadHeaderTimeoutTooSmall,
|
|
||||||
errMessage: "read header timeout is too small: 1ns must be at least 1ms",
|
errMessage: "read header timeout is too small: 1ns must be at least 1ms",
|
||||||
},
|
},
|
||||||
"read timeout too small": {
|
"read timeout too small": {
|
||||||
@@ -232,7 +226,6 @@ func Test_Settings_Validate(t *testing.T) {
|
|||||||
ReadHeaderTimeout: time.Millisecond,
|
ReadHeaderTimeout: time.Millisecond,
|
||||||
ReadTimeout: time.Nanosecond,
|
ReadTimeout: time.Nanosecond,
|
||||||
},
|
},
|
||||||
errWrapped: ErrReadTimeoutTooSmall,
|
|
||||||
errMessage: "read timeout is too small: 1ns must be at least 1ms",
|
errMessage: "read timeout is too small: 1ns must be at least 1ms",
|
||||||
},
|
},
|
||||||
"shutdown timeout too small": {
|
"shutdown timeout too small": {
|
||||||
@@ -244,7 +237,6 @@ func Test_Settings_Validate(t *testing.T) {
|
|||||||
ReadTimeout: time.Millisecond,
|
ReadTimeout: time.Millisecond,
|
||||||
ShutdownTimeout: time.Millisecond,
|
ShutdownTimeout: time.Millisecond,
|
||||||
},
|
},
|
||||||
errWrapped: ErrShutdownTimeoutTooSmall,
|
|
||||||
errMessage: "shutdown timeout is too small: 1ms must be at least 5ms",
|
errMessage: "shutdown timeout is too small: 1ms must be at least 5ms",
|
||||||
},
|
},
|
||||||
"valid settings": {
|
"valid settings": {
|
||||||
@@ -265,9 +257,10 @@ func Test_Settings_Validate(t *testing.T) {
|
|||||||
|
|
||||||
err := testCase.settings.Validate()
|
err := testCase.settings.Validate()
|
||||||
|
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
|
||||||
if testCase.errMessage != "" {
|
if testCase.errMessage != "" {
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,12 @@ package loopstate
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrInvalidStatus = errors.New("invalid status")
|
|
||||||
|
|
||||||
// ApplyStatus sends signals to the running loop depending on the
|
// ApplyStatus sends signals to the running loop depending on the
|
||||||
// current status and status requested, such that its next status
|
// current status and status requested, such that its next status
|
||||||
// matches the requested one. It is thread safe and a synchronous call
|
// matches the requested one. It is thread safe and a synchronous call
|
||||||
@@ -73,7 +70,7 @@ func (s *State) ApplyStatus(ctx context.Context, status models.LoopStatus) (
|
|||||||
return newStatus.String(), nil
|
return newStatus.String(), nil
|
||||||
default:
|
default:
|
||||||
s.statusMu.Unlock()
|
s.statusMu.Unlock()
|
||||||
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
return "", fmt.Errorf("invalid status: %s: it can only be one of: %s, %s",
|
||||||
ErrInvalidStatus, status, constants.Running, constants.Stopped)
|
status, constants.Running, constants.Stopped)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,19 +3,11 @@ package mod
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errModuleNameUnknown = errors.New("unknown module name")
|
|
||||||
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
|
|
||||||
errKernelFeatureNotSet = errors.New("kernel feature not set")
|
|
||||||
errKernelFeatureNotFound = errors.New("kernel feature not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
|
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
|
||||||
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
|
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
|
||||||
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
|
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
|
||||||
@@ -39,7 +31,7 @@ func checkProcConfig(moduleName string) error {
|
|||||||
// If any group of kernel features is satisfied, then the module is considered supported.
|
// If any group of kernel features is satisfied, then the module is considered supported.
|
||||||
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
|
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName)
|
return fmt.Errorf("unknown module name: %s", moduleName)
|
||||||
}
|
}
|
||||||
groups := make([]map[string]bool, len(kernelFeatureGroups))
|
groups := make([]map[string]bool, len(kernelFeatureGroups))
|
||||||
for i, group := range kernelFeatureGroups {
|
for i, group := range kernelFeatureGroups {
|
||||||
@@ -58,20 +50,20 @@ func checkProcConfig(moduleName string) error {
|
|||||||
switch {
|
switch {
|
||||||
case ok:
|
case ok:
|
||||||
case strings.HasPrefix(line, name+"=m"):
|
case strings.HasPrefix(line, name+"=m"):
|
||||||
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name)
|
return fmt.Errorf("kernel feature is a module, not built-in: %s", name)
|
||||||
case strings.HasPrefix(line, name+"=y"):
|
case strings.HasPrefix(line, name+"=y"):
|
||||||
featureToOK[name] = true
|
featureToOK[name] = true
|
||||||
if allFeaturesOK(featureToOK) {
|
if allFeaturesOK(featureToOK) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
case strings.HasPrefix(line, "# "+name+" is not set"):
|
case strings.HasPrefix(line, "# "+name+" is not set"):
|
||||||
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name)
|
return fmt.Errorf("kernel feature not set: %s", name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName)
|
return fmt.Errorf("kernel feature not found: for module name %s", moduleName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
|
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
|
||||||
|
|||||||
@@ -181,8 +181,6 @@ func getLoadedModules(modulesInfo map[string]moduleInfo) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrModulePathNotFound = errors.New("module path not found")
|
|
||||||
|
|
||||||
func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modulePath string, err error) {
|
func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modulePath string, err error) {
|
||||||
// Kernel module names can have underscores or hyphens in their names,
|
// Kernel module names can have underscores or hyphens in their names,
|
||||||
// but only one or the other in one particular name.
|
// but only one or the other in one particular name.
|
||||||
@@ -205,5 +203,5 @@ func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modul
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("%w: for %q", ErrModulePathNotFound, moduleName)
|
return "", fmt.Errorf("module path not found: for %q", moduleName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package mod
|
package mod
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
@@ -14,15 +13,10 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrModuleInfoNotFound = errors.New("module info not found")
|
|
||||||
ErrCircularDependency = errors.New("circular dependency")
|
|
||||||
)
|
|
||||||
|
|
||||||
func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error) {
|
func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error) {
|
||||||
info, ok := modulesInfo[path]
|
info, ok := modulesInfo[path]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("%w: %s", ErrModuleInfoNotFound, path)
|
return fmt.Errorf("module info not found: %s", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch info.state {
|
switch info.state {
|
||||||
@@ -30,8 +24,7 @@ func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error
|
|||||||
case loaded, builtin:
|
case loaded, builtin:
|
||||||
return nil
|
return nil
|
||||||
case loading:
|
case loading:
|
||||||
return fmt.Errorf("%w: %s is already in the loading state",
|
return fmt.Errorf("circular dependency: %s is already in the loading state", path)
|
||||||
ErrCircularDependency, path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info.state = loading
|
info.state = loading
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -109,8 +108,6 @@ func (s *Servers) toMarkdown(vpnProvider string) (formatted string, err error) {
|
|||||||
return formatted, nil
|
return formatted, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrMarkdownHeadersNotDefined = errors.New("markdown headers not defined")
|
|
||||||
|
|
||||||
func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
|
func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
|
||||||
switch vpnProvider {
|
switch vpnProvider {
|
||||||
case providers.Airvpn:
|
case providers.Airvpn:
|
||||||
@@ -169,6 +166,6 @@ func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
|
|||||||
case providers.Windscribe:
|
case providers.Windscribe:
|
||||||
return []string{regionHeader, cityHeader, hostnameHeader, vpnHeader}, nil
|
return []string{regionHeader, cityHeader, hostnameHeader, vpnHeader}, nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("%w: for %s", ErrMarkdownHeadersNotDefined, vpnProvider)
|
return nil, fmt.Errorf("markdown headers not defined: for %s", vpnProvider)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,12 +15,10 @@ func Test_Servers_ToMarkdown(t *testing.T) {
|
|||||||
provider string
|
provider string
|
||||||
servers Servers
|
servers Servers
|
||||||
formatted string
|
formatted string
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"unsupported_provider": {
|
"unsupported_provider": {
|
||||||
provider: "unsupported",
|
provider: "unsupported",
|
||||||
errWrapped: ErrMarkdownHeadersNotDefined,
|
|
||||||
errMessage: "getting markdown headers: markdown headers not defined: for unsupported",
|
errMessage: "getting markdown headers: markdown headers not defined: for unsupported",
|
||||||
},
|
},
|
||||||
providers.Cyberghost: {
|
providers.Cyberghost: {
|
||||||
@@ -58,9 +56,10 @@ func Test_Servers_ToMarkdown(t *testing.T) {
|
|||||||
markdown, err := testCase.servers.toMarkdown(testCase.provider)
|
markdown, err := testCase.servers.toMarkdown(testCase.provider)
|
||||||
|
|
||||||
assert.Equal(t, testCase.formatted, markdown)
|
assert.Equal(t, testCase.formatted, markdown)
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errWrapped != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,27 +38,18 @@ type Server struct {
|
|||||||
IPs []netip.Addr `json:"ips,omitempty"`
|
IPs []netip.Addr `json:"ips,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrVPNFieldEmpty = errors.New("vpn field is empty")
|
|
||||||
ErrHostnameFieldEmpty = errors.New("hostname field is empty")
|
|
||||||
ErrIPsFieldEmpty = errors.New("ips field is empty")
|
|
||||||
ErrNoNetworkProtocol = errors.New("both TCP and UDP fields are false for OpenVPN")
|
|
||||||
ErrNetworkProtocolSet = errors.New("no network protocol should be set")
|
|
||||||
ErrWireguardPublicKeyEmpty = errors.New("wireguard public key field is empty")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s *Server) HasMinimumInformation() (err error) {
|
func (s *Server) HasMinimumInformation() (err error) {
|
||||||
switch {
|
switch {
|
||||||
case s.VPN == "":
|
case s.VPN == "":
|
||||||
return fmt.Errorf("%w", ErrVPNFieldEmpty)
|
return errors.New("vpn field is empty")
|
||||||
case len(s.IPs) == 0:
|
case len(s.IPs) == 0:
|
||||||
return fmt.Errorf("%w", ErrIPsFieldEmpty)
|
return errors.New("ips field is empty")
|
||||||
case s.VPN == vpn.Wireguard && (s.TCP || s.UDP):
|
case s.VPN == vpn.Wireguard && (s.TCP || s.UDP):
|
||||||
return fmt.Errorf("%w", ErrNetworkProtocolSet)
|
return errors.New("no network protocol should be set")
|
||||||
case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP:
|
case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP:
|
||||||
return fmt.Errorf("%w", ErrNoNetworkProtocol)
|
return errors.New("both TCP and UDP fields are false for OpenVPN")
|
||||||
case s.VPN == vpn.Wireguard && s.WgPubKey == "":
|
case s.VPN == vpn.Wireguard && s.WgPubKey == "":
|
||||||
return fmt.Errorf("%w", ErrWireguardPublicKeyEmpty)
|
return errors.New("wireguard public key field is empty")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package models
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -158,8 +157,6 @@ type Servers struct {
|
|||||||
Servers []Server `json:"servers,omitempty"`
|
Servers []Server `json:"servers,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrServersFormatNotSupported = errors.New("servers format not supported")
|
|
||||||
|
|
||||||
func (s *Servers) Format(vpnProvider, format string) (formatted string, err error) {
|
func (s *Servers) Format(vpnProvider, format string) (formatted string, err error) {
|
||||||
switch format {
|
switch format {
|
||||||
case "markdown":
|
case "markdown":
|
||||||
@@ -167,7 +164,7 @@ func (s *Servers) Format(vpnProvider, format string) (formatted string, err erro
|
|||||||
case "json":
|
case "json":
|
||||||
return s.toJSON()
|
return s.toJSON()
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("%w: %s", ErrServersFormatNotSupported, format)
|
return "", fmt.Errorf("servers format not supported: %s", format)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
allServers *AllServers
|
allServers *AllServers
|
||||||
dataString string
|
dataString string
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"no provider": {
|
"no provider": {
|
||||||
@@ -58,16 +57,18 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
data, err := testCase.allServers.MarshalJSON()
|
data, err := testCase.allServers.MarshalJSON()
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if err != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
require.Equal(t, testCase.dataString, string(data))
|
require.Equal(t, testCase.dataString, string(data))
|
||||||
|
|
||||||
data, err = json.Marshal(testCase.allServers)
|
data, err = json.Marshal(testCase.allServers)
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if err != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
require.Equal(t, testCase.dataString, string(data))
|
require.Equal(t, testCase.dataString, string(data))
|
||||||
|
|
||||||
@@ -87,7 +88,6 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
dataString string
|
dataString string
|
||||||
allServers AllServers
|
allServers AllServers
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"empty": {
|
"empty": {
|
||||||
@@ -131,9 +131,10 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
|
|||||||
|
|
||||||
err := json.Unmarshal(data, &allServers)
|
err := json.Unmarshal(data, &allServers)
|
||||||
|
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if err != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
assert.Equal(t, testCase.allServers, allServers)
|
assert.Equal(t, testCase.allServers, allServers)
|
||||||
})
|
})
|
||||||
|
|||||||
+16
-33
@@ -6,48 +6,40 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrRequestSizeTooSmall = errors.New("message size is too small")
|
|
||||||
|
|
||||||
func checkRequest(request []byte) (err error) {
|
func checkRequest(request []byte) (err error) {
|
||||||
const minMessageSize = 2 // version number + operation code
|
const minMessageSize = 2 // version number + operation code
|
||||||
if len(request) < minMessageSize {
|
if len(request) < minMessageSize {
|
||||||
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
|
return fmt.Errorf("message size is too small: need at least %d bytes and got %d byte(s)",
|
||||||
ErrRequestSizeTooSmall, minMessageSize, len(request))
|
minMessageSize, len(request))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrResponseSizeTooSmall = errors.New("response size is too small")
|
|
||||||
ErrResponseSizeUnexpected = errors.New("response size is unexpected")
|
|
||||||
ErrProtocolVersionUnknown = errors.New("protocol version is unknown")
|
|
||||||
ErrOperationCodeUnexpected = errors.New("operation code is unexpected")
|
|
||||||
)
|
|
||||||
|
|
||||||
func checkResponse(response []byte, expectedOperationCode byte,
|
func checkResponse(response []byte, expectedOperationCode byte,
|
||||||
expectedResponseSize uint,
|
expectedResponseSize uint,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
const minResponseSize = 4
|
const minResponseSize = 4
|
||||||
if len(response) < minResponseSize {
|
if len(response) < minResponseSize {
|
||||||
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
|
return fmt.Errorf("response size is too small: "+
|
||||||
ErrResponseSizeTooSmall, minResponseSize, len(response))
|
"need at least %d bytes and got %d byte(s)",
|
||||||
|
minResponseSize, len(response))
|
||||||
}
|
}
|
||||||
|
|
||||||
if uint(len(response)) != expectedResponseSize {
|
if uint(len(response)) != expectedResponseSize {
|
||||||
return fmt.Errorf("%w: expected %d bytes and got %d byte(s)",
|
return fmt.Errorf("response size is unexpected: "+
|
||||||
ErrResponseSizeUnexpected, expectedResponseSize, len(response))
|
"expected %d bytes and got %d byte(s)",
|
||||||
|
expectedResponseSize, len(response))
|
||||||
}
|
}
|
||||||
|
|
||||||
protocolVersion := response[0]
|
protocolVersion := response[0]
|
||||||
if protocolVersion != 0 {
|
if protocolVersion != 0 {
|
||||||
return fmt.Errorf("%w: %d", ErrProtocolVersionUnknown, protocolVersion)
|
return fmt.Errorf("protocol version is unknown: %d", protocolVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
operationCode := response[1]
|
operationCode := response[1]
|
||||||
if operationCode != expectedOperationCode {
|
if operationCode != expectedOperationCode {
|
||||||
return fmt.Errorf("%w: expected 0x%x and got 0x%x",
|
return fmt.Errorf("operation code is unexpected: expected 0x%x and got 0x%x", expectedOperationCode, operationCode)
|
||||||
ErrOperationCodeUnexpected, expectedOperationCode, operationCode)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resultCode := binary.BigEndian.Uint16(response[2:4])
|
resultCode := binary.BigEndian.Uint16(response[2:4])
|
||||||
@@ -59,15 +51,6 @@ func checkResponse(response []byte, expectedOperationCode byte,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrVersionNotSupported = errors.New("version is not supported")
|
|
||||||
ErrNotAuthorized = errors.New("not authorized")
|
|
||||||
ErrNetworkFailure = errors.New("network failure")
|
|
||||||
ErrOutOfResources = errors.New("out of resources")
|
|
||||||
ErrOperationCodeNotSupported = errors.New("operation code is not supported")
|
|
||||||
ErrResultCodeUnknown = errors.New("result code is unknown")
|
|
||||||
)
|
|
||||||
|
|
||||||
// checkResultCode checks the result code and returns an error
|
// checkResultCode checks the result code and returns an error
|
||||||
// if the result code is not a success (0).
|
// if the result code is not a success (0).
|
||||||
// See https://www.ietf.org/rfc/rfc6886.html#section-3.5
|
// See https://www.ietf.org/rfc/rfc6886.html#section-3.5
|
||||||
@@ -78,16 +61,16 @@ func checkResultCode(resultCode uint16) (err error) {
|
|||||||
case 0:
|
case 0:
|
||||||
return nil
|
return nil
|
||||||
case 1:
|
case 1:
|
||||||
return fmt.Errorf("%w", ErrVersionNotSupported)
|
return errors.New("version is not supported")
|
||||||
case 2:
|
case 2:
|
||||||
return fmt.Errorf("%w", ErrNotAuthorized)
|
return errors.New("not authorized")
|
||||||
case 3:
|
case 3:
|
||||||
return fmt.Errorf("%w", ErrNetworkFailure)
|
return errors.New("network failure")
|
||||||
case 4:
|
case 4:
|
||||||
return fmt.Errorf("%w", ErrOutOfResources)
|
return errors.New("out of resources")
|
||||||
case 5:
|
case 5:
|
||||||
return fmt.Errorf("%w", ErrOperationCodeNotSupported)
|
return errors.New("operation code is not supported")
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %d", ErrResultCodeUnknown, resultCode)
|
return fmt.Errorf("result code is unknown: %d", resultCode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package natpmp
|
package natpmp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -11,12 +12,10 @@ func Test_checkRequest(t *testing.T) {
|
|||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
request []byte
|
request []byte
|
||||||
err error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"too_short": {
|
"too_short": {
|
||||||
request: []byte{1},
|
request: []byte{1},
|
||||||
err: ErrRequestSizeTooSmall,
|
|
||||||
errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)",
|
errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)",
|
||||||
},
|
},
|
||||||
"success": {
|
"success": {
|
||||||
@@ -30,9 +29,10 @@ func Test_checkRequest(t *testing.T) {
|
|||||||
|
|
||||||
err := checkRequest(testCase.request)
|
err := checkRequest(testCase.request)
|
||||||
|
|
||||||
assert.ErrorIs(t, err, testCase.err)
|
if testCase.errMessage != "" {
|
||||||
if testCase.err != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -50,33 +50,33 @@ func Test_checkResponse(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
"too_short": {
|
"too_short": {
|
||||||
response: []byte{1},
|
response: []byte{1},
|
||||||
err: ErrResponseSizeTooSmall,
|
err: errors.New("response size is too small"),
|
||||||
errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)",
|
errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)",
|
||||||
},
|
},
|
||||||
"size_mismatch": {
|
"size_mismatch": {
|
||||||
response: []byte{0, 0, 0, 0},
|
response: []byte{0, 0, 0, 0},
|
||||||
expectedResponseSize: 5,
|
expectedResponseSize: 5,
|
||||||
err: ErrResponseSizeUnexpected,
|
err: errors.New("response size is unexpected"),
|
||||||
errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)",
|
errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)",
|
||||||
},
|
},
|
||||||
"protocol_unknown": {
|
"protocol_unknown": {
|
||||||
response: []byte{1, 0, 0, 0},
|
response: []byte{1, 0, 0, 0},
|
||||||
expectedResponseSize: 4,
|
expectedResponseSize: 4,
|
||||||
err: ErrProtocolVersionUnknown,
|
err: errors.New("protocol version is unknown"),
|
||||||
errMessage: "protocol version is unknown: 1",
|
errMessage: "protocol version is unknown: 1",
|
||||||
},
|
},
|
||||||
"operation_code_unexpected": {
|
"operation_code_unexpected": {
|
||||||
response: []byte{0, 2, 0, 0},
|
response: []byte{0, 2, 0, 0},
|
||||||
expectedOperationCode: 1,
|
expectedOperationCode: 1,
|
||||||
expectedResponseSize: 4,
|
expectedResponseSize: 4,
|
||||||
err: ErrOperationCodeUnexpected,
|
err: errors.New("operation code is unexpected"),
|
||||||
errMessage: "operation code is unexpected: expected 0x1 and got 0x2",
|
errMessage: "operation code is unexpected: expected 0x1 and got 0x2",
|
||||||
},
|
},
|
||||||
"result_code_failure": {
|
"result_code_failure": {
|
||||||
response: []byte{0, 1, 0, 1},
|
response: []byte{0, 1, 0, 1},
|
||||||
expectedOperationCode: 1,
|
expectedOperationCode: 1,
|
||||||
expectedResponseSize: 4,
|
expectedResponseSize: 4,
|
||||||
err: ErrVersionNotSupported,
|
err: errors.New("version is not supported"),
|
||||||
errMessage: "result code: version is not supported",
|
errMessage: "result code: version is not supported",
|
||||||
},
|
},
|
||||||
"success": {
|
"success": {
|
||||||
@@ -94,9 +94,11 @@ func Test_checkResponse(t *testing.T) {
|
|||||||
testCase.expectedOperationCode,
|
testCase.expectedOperationCode,
|
||||||
testCase.expectedResponseSize)
|
testCase.expectedResponseSize)
|
||||||
|
|
||||||
assert.ErrorIs(t, err, testCase.err)
|
|
||||||
if testCase.err != nil {
|
if testCase.err != nil {
|
||||||
|
assert.ErrorContains(t, err, testCase.err.Error())
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -113,32 +115,32 @@ func Test_checkResultCode(t *testing.T) {
|
|||||||
"success": {},
|
"success": {},
|
||||||
"version_unsupported": {
|
"version_unsupported": {
|
||||||
resultCode: 1,
|
resultCode: 1,
|
||||||
err: ErrVersionNotSupported,
|
err: errors.New("version is not supported"),
|
||||||
errMessage: "version is not supported",
|
errMessage: "version is not supported",
|
||||||
},
|
},
|
||||||
"not_authorized": {
|
"not_authorized": {
|
||||||
resultCode: 2,
|
resultCode: 2,
|
||||||
err: ErrNotAuthorized,
|
err: errors.New("not authorized"),
|
||||||
errMessage: "not authorized",
|
errMessage: "not authorized",
|
||||||
},
|
},
|
||||||
"network_failure": {
|
"network_failure": {
|
||||||
resultCode: 3,
|
resultCode: 3,
|
||||||
err: ErrNetworkFailure,
|
err: errors.New("network failure"),
|
||||||
errMessage: "network failure",
|
errMessage: "network failure",
|
||||||
},
|
},
|
||||||
"out_of_resources": {
|
"out_of_resources": {
|
||||||
resultCode: 4,
|
resultCode: 4,
|
||||||
err: ErrOutOfResources,
|
err: errors.New("out of resources"),
|
||||||
errMessage: "out of resources",
|
errMessage: "out of resources",
|
||||||
},
|
},
|
||||||
"unsupported_operation_code": {
|
"unsupported_operation_code": {
|
||||||
resultCode: 5,
|
resultCode: 5,
|
||||||
err: ErrOperationCodeNotSupported,
|
err: errors.New("operation code is not supported"),
|
||||||
errMessage: "operation code is not supported",
|
errMessage: "operation code is not supported",
|
||||||
},
|
},
|
||||||
"unknown": {
|
"unknown": {
|
||||||
resultCode: 6,
|
resultCode: 6,
|
||||||
err: ErrResultCodeUnknown,
|
err: errors.New("result code is unknown"),
|
||||||
errMessage: "result code is unknown: 6",
|
errMessage: "result code is unknown: 6",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -149,9 +151,11 @@ func Test_checkResultCode(t *testing.T) {
|
|||||||
|
|
||||||
err := checkResultCode(testCase.resultCode)
|
err := checkResultCode(testCase.resultCode)
|
||||||
|
|
||||||
assert.ErrorIs(t, err, testCase.err)
|
|
||||||
if testCase.err != nil {
|
if testCase.err != nil {
|
||||||
|
assert.ErrorContains(t, err, testCase.err.Error())
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,17 +3,11 @@ package natpmp
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrNetworkProtocolUnknown = errors.New("network protocol is unknown")
|
|
||||||
ErrLifetimeTooLong = errors.New("lifetime is too long")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Add or delete a port mapping. To delete a mapping, set both the
|
// Add or delete a port mapping. To delete a mapping, set both the
|
||||||
// requestedExternalPort and lifetime to 0.
|
// requestedExternalPort and lifetime to 0.
|
||||||
// See https://www.ietf.org/rfc/rfc6886.html#section-3.3
|
// See https://www.ietf.org/rfc/rfc6886.html#section-3.3
|
||||||
@@ -26,8 +20,9 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
|
|||||||
lifetimeSecondsFloat := lifetime.Seconds()
|
lifetimeSecondsFloat := lifetime.Seconds()
|
||||||
const maxLifetimeSeconds = uint64(^uint32(0))
|
const maxLifetimeSeconds = uint64(^uint32(0))
|
||||||
if uint64(lifetimeSecondsFloat) > maxLifetimeSeconds {
|
if uint64(lifetimeSecondsFloat) > maxLifetimeSeconds {
|
||||||
return 0, 0, 0, 0, fmt.Errorf("%w: %d seconds must at most %d seconds",
|
return 0, 0, 0, 0, fmt.Errorf("lifetime is too long: "+
|
||||||
ErrLifetimeTooLong, uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
|
"%d seconds must at most %d seconds",
|
||||||
|
uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
|
||||||
}
|
}
|
||||||
const messageSize = 12
|
const messageSize = 12
|
||||||
message := make([]byte, messageSize)
|
message := make([]byte, messageSize)
|
||||||
@@ -38,7 +33,7 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
|
|||||||
case "tcp":
|
case "tcp":
|
||||||
message[1] = 2 // operationCode 2
|
message[1] = 2 // operationCode 2
|
||||||
default:
|
default:
|
||||||
return 0, 0, 0, 0, fmt.Errorf("%w: %s", ErrNetworkProtocolUnknown, protocol)
|
return 0, 0, 0, 0, fmt.Errorf("network protocol is unknown: %s", protocol)
|
||||||
}
|
}
|
||||||
// [2:3] are reserved.
|
// [2:3] are reserved.
|
||||||
binary.BigEndian.PutUint16(message[4:6], internalPort)
|
binary.BigEndian.PutUint16(message[4:6], internalPort)
|
||||||
|
|||||||
@@ -25,18 +25,15 @@ func Test_Client_AddPortMapping(t *testing.T) {
|
|||||||
assignedInternalPort uint16
|
assignedInternalPort uint16
|
||||||
assignedExternalPort uint16
|
assignedExternalPort uint16
|
||||||
assignedLifetime time.Duration
|
assignedLifetime time.Duration
|
||||||
err error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"lifetime_too_long": {
|
"lifetime_too_long": {
|
||||||
lifetime: time.Duration(uint64(^uint32(0))+1) * time.Second,
|
lifetime: time.Duration(uint64(^uint32(0))+1) * time.Second,
|
||||||
err: ErrLifetimeTooLong,
|
|
||||||
errMessage: "lifetime is too long: 4294967296 seconds must at most 4294967295 seconds",
|
errMessage: "lifetime is too long: 4294967296 seconds must at most 4294967295 seconds",
|
||||||
},
|
},
|
||||||
"protocol_unknown": {
|
"protocol_unknown": {
|
||||||
lifetime: time.Second,
|
lifetime: time.Second,
|
||||||
protocol: "xyz",
|
protocol: "xyz",
|
||||||
err: ErrNetworkProtocolUnknown,
|
|
||||||
errMessage: "network protocol is unknown: xyz",
|
errMessage: "network protocol is unknown: xyz",
|
||||||
},
|
},
|
||||||
"rpc_error": {
|
"rpc_error": {
|
||||||
@@ -48,7 +45,6 @@ func Test_Client_AddPortMapping(t *testing.T) {
|
|||||||
lifetime: 1200 * time.Second,
|
lifetime: 1200 * time.Second,
|
||||||
initialConnectionDuration: time.Millisecond,
|
initialConnectionDuration: time.Millisecond,
|
||||||
exchanges: []udpExchange{{close: true}},
|
exchanges: []udpExchange{{close: true}},
|
||||||
err: ErrConnectionTimeout,
|
|
||||||
errMessage: "executing remote procedure call: connection timeout: failed attempts: " +
|
errMessage: "executing remote procedure call: connection timeout: failed attempts: " +
|
||||||
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
|
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
|
||||||
},
|
},
|
||||||
@@ -136,9 +132,6 @@ func Test_Client_AddPortMapping(t *testing.T) {
|
|||||||
assert.Equal(t, testCase.assignedExternalPort, assignedExternalPort)
|
assert.Equal(t, testCase.assignedExternalPort, assignedExternalPort)
|
||||||
assert.Equal(t, testCase.assignedLifetime, assignedLifetime)
|
assert.Equal(t, testCase.assignedLifetime, assignedLifetime)
|
||||||
if testCase.errMessage != "" {
|
if testCase.errMessage != "" {
|
||||||
if testCase.err != nil {
|
|
||||||
assert.ErrorIs(t, err, testCase.err)
|
|
||||||
}
|
|
||||||
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
|
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@@ -11,17 +11,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrGatewayIPUnspecified = errors.New("gateway IP is unspecified")
|
|
||||||
ErrConnectionTimeout = errors.New("connection timeout")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
|
func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
|
||||||
request []byte, responseSize uint) (
|
request []byte, responseSize uint) (
|
||||||
response []byte, err error,
|
response []byte, err error,
|
||||||
) {
|
) {
|
||||||
if gateway.IsUnspecified() || !gateway.IsValid() {
|
if gateway.IsUnspecified() || !gateway.IsValid() {
|
||||||
return nil, fmt.Errorf("%w", ErrGatewayIPUnspecified)
|
return nil, errors.New("gateway IP is unspecified")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = checkRequest(request)
|
err = checkRequest(request)
|
||||||
@@ -114,8 +109,7 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if retryCount == c.maxRetries {
|
if retryCount == c.maxRetries {
|
||||||
return nil, fmt.Errorf("%w: failed attempts: %s",
|
return nil, fmt.Errorf("connection timeout: failed attempts: %s", dedupFailedAttempts(failedAttempts))
|
||||||
ErrConnectionTimeout, dedupFailedAttempts(failedAttempts))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Opcodes between 0 and 127 are client requests. Opcodes from 128 to
|
// Opcodes between 0 and 127 are client requests. Opcodes from 128 to
|
||||||
|
|||||||
@@ -20,20 +20,17 @@ func Test_Client_rpc(t *testing.T) {
|
|||||||
initialConnectionDuration time.Duration
|
initialConnectionDuration time.Duration
|
||||||
exchanges []udpExchange
|
exchanges []udpExchange
|
||||||
expectedResponse []byte
|
expectedResponse []byte
|
||||||
err error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"gateway_ip_unspecified": {
|
"gateway_ip_unspecified": {
|
||||||
gateway: netip.IPv6Unspecified(),
|
gateway: netip.IPv6Unspecified(),
|
||||||
request: []byte{0, 0},
|
request: []byte{0, 0},
|
||||||
err: ErrGatewayIPUnspecified,
|
|
||||||
errMessage: "gateway IP is unspecified",
|
errMessage: "gateway IP is unspecified",
|
||||||
},
|
},
|
||||||
"request_too_small": {
|
"request_too_small": {
|
||||||
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
request: []byte{0},
|
request: []byte{0},
|
||||||
initialConnectionDuration: time.Nanosecond, // doesn't matter
|
initialConnectionDuration: time.Nanosecond, // doesn't matter
|
||||||
err: ErrRequestSizeTooSmall,
|
|
||||||
errMessage: `checking request: message size is too small: ` +
|
errMessage: `checking request: message size is too small: ` +
|
||||||
`need at least 2 bytes and got 1 byte\(s\)`,
|
`need at least 2 bytes and got 1 byte\(s\)`,
|
||||||
},
|
},
|
||||||
@@ -53,7 +50,6 @@ func Test_Client_rpc(t *testing.T) {
|
|||||||
exchanges: []udpExchange{
|
exchanges: []udpExchange{
|
||||||
{request: []byte{0, 1}, close: true},
|
{request: []byte{0, 1}, close: true},
|
||||||
},
|
},
|
||||||
err: ErrConnectionTimeout,
|
|
||||||
errMessage: "connection timeout: failed attempts: " +
|
errMessage: "connection timeout: failed attempts: " +
|
||||||
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
|
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
|
||||||
},
|
},
|
||||||
@@ -66,7 +62,6 @@ func Test_Client_rpc(t *testing.T) {
|
|||||||
request: []byte{0, 0},
|
request: []byte{0, 0},
|
||||||
response: []byte{1},
|
response: []byte{1},
|
||||||
}},
|
}},
|
||||||
err: ErrResponseSizeTooSmall,
|
|
||||||
errMessage: `checking response: response size is too small: ` +
|
errMessage: `checking response: response size is too small: ` +
|
||||||
`need at least 4 bytes and got 1 byte\(s\)`,
|
`need at least 4 bytes and got 1 byte\(s\)`,
|
||||||
},
|
},
|
||||||
@@ -80,7 +75,6 @@ func Test_Client_rpc(t *testing.T) {
|
|||||||
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
response: []byte{0, 1, 2, 3}, // size 4
|
response: []byte{0, 1, 2, 3}, // size 4
|
||||||
}},
|
}},
|
||||||
err: ErrResponseSizeUnexpected,
|
|
||||||
errMessage: `checking response: response size is unexpected: ` +
|
errMessage: `checking response: response size is unexpected: ` +
|
||||||
`expected 5 bytes and got 4 byte\(s\)`,
|
`expected 5 bytes and got 4 byte\(s\)`,
|
||||||
},
|
},
|
||||||
@@ -94,7 +88,6 @@ func Test_Client_rpc(t *testing.T) {
|
|||||||
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
response: []byte{0x1, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
response: []byte{0x1, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
}},
|
}},
|
||||||
err: ErrProtocolVersionUnknown,
|
|
||||||
errMessage: "checking response: protocol version is unknown: 1",
|
errMessage: "checking response: protocol version is unknown: 1",
|
||||||
},
|
},
|
||||||
"unexpected_operation_code": {
|
"unexpected_operation_code": {
|
||||||
@@ -107,7 +100,6 @@ func Test_Client_rpc(t *testing.T) {
|
|||||||
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
response: []byte{0x0, 0x88, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
response: []byte{0x0, 0x88, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
}},
|
}},
|
||||||
err: ErrOperationCodeUnexpected,
|
|
||||||
errMessage: "checking response: operation code is unexpected: expected 0x82 and got 0x88",
|
errMessage: "checking response: operation code is unexpected: expected 0x82 and got 0x88",
|
||||||
},
|
},
|
||||||
"failure_result_code": {
|
"failure_result_code": {
|
||||||
@@ -120,7 +112,6 @@ func Test_Client_rpc(t *testing.T) {
|
|||||||
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
response: []byte{0x0, 0x82, 0x0, 0x11, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
response: []byte{0x0, 0x82, 0x0, 0x11, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
}},
|
}},
|
||||||
err: ErrResultCodeUnknown,
|
|
||||||
errMessage: "checking response: result code: result code is unknown: 17",
|
errMessage: "checking response: result code: result code is unknown: 17",
|
||||||
},
|
},
|
||||||
"success": {
|
"success": {
|
||||||
@@ -153,9 +144,6 @@ func Test_Client_rpc(t *testing.T) {
|
|||||||
testCase.request, testCase.responseSize)
|
testCase.request, testCase.responseSize)
|
||||||
|
|
||||||
if testCase.errMessage != "" {
|
if testCase.errMessage != "" {
|
||||||
if testCase.err != nil {
|
|
||||||
assert.ErrorIs(t, err, testCase.err)
|
|
||||||
}
|
|
||||||
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
|
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@@ -41,8 +41,6 @@ func findAvailableTCPPort(t *testing.T) (port uint16) {
|
|||||||
func Test_dialAddrThroughFirewall(t *testing.T) {
|
func Test_dialAddrThroughFirewall(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
errTest := errors.New("test error")
|
|
||||||
|
|
||||||
const ipv6InternetWorks = false
|
const ipv6InternetWorks = false
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
@@ -102,7 +100,7 @@ func Test_dialAddrThroughFirewall(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"firewall_add_error": {
|
"firewall_add_error": {
|
||||||
firewallAddErr: errTest,
|
firewallAddErr: errors.New("test error"),
|
||||||
errMessageRegex: func() string {
|
errMessageRegex: func() string {
|
||||||
return "accepting output traffic: test error"
|
return "accepting output traffic: test error"
|
||||||
},
|
},
|
||||||
@@ -122,7 +120,7 @@ func Test_dialAddrThroughFirewall(t *testing.T) {
|
|||||||
addrPort := netip.MustParseAddrPort(listener.Addr().String())
|
addrPort := netip.MustParseAddrPort(listener.Addr().String())
|
||||||
return netip.AddrPortFrom(loopback, addrPort.Port())
|
return netip.AddrPortFrom(loopback, addrPort.Port())
|
||||||
},
|
},
|
||||||
firewallRemoveErr: errTest,
|
firewallRemoveErr: errors.New("test error"),
|
||||||
errMessageRegex: func() string {
|
errMessageRegex: func() string {
|
||||||
return "removing output traffic rule: test error"
|
return "removing output traffic rule: test error"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/jsimonetti/rtnetlink"
|
"github.com/jsimonetti/rtnetlink"
|
||||||
@@ -47,8 +46,6 @@ func (n *NetLink) LinkList() (links []Link, err error) {
|
|||||||
return links, nil
|
return links, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrLinkNotFound = errors.New("link not found")
|
|
||||||
|
|
||||||
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
||||||
links, err := n.LinkList()
|
links, err := n.LinkList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -61,7 +58,7 @@ func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
|
return Link{}, fmt.Errorf("link not found: for name %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
|
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
|
||||||
@@ -76,7 +73,7 @@ func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
|
return Link{}, fmt.Errorf("link not found: for index %d", index)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
|
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
|
||||||
@@ -114,7 +111,7 @@ func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
|
return 0, fmt.Errorf("link not found: matching name %s", link.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
|
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
|
||||||
|
|||||||
@@ -1,17 +1,11 @@
|
|||||||
package extract
|
package extract
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrRead = errors.New("cannot read file")
|
|
||||||
ErrExtractConnection = errors.New("cannot extract connection from file")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Data extracts the lines and connection from the OpenVPN configuration file.
|
// Data extracts the lines and connection from the OpenVPN configuration file.
|
||||||
func (e *Extractor) Data(filepath string) (lines []string,
|
func (e *Extractor) Data(filepath string) (lines []string,
|
||||||
connection models.Connection, err error,
|
connection models.Connection, err error,
|
||||||
|
|||||||
@@ -11,8 +11,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errRemoteLineNotFound = errors.New("remote line not found")
|
|
||||||
|
|
||||||
func extractDataFromLines(lines []string) (
|
func extractDataFromLines(lines []string) (
|
||||||
connection models.Connection, err error,
|
connection models.Connection, err error,
|
||||||
) {
|
) {
|
||||||
@@ -35,7 +33,7 @@ func extractDataFromLines(lines []string) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !connection.IP.IsValid() {
|
if !connection.IP.IsValid() {
|
||||||
return connection, errRemoteLineNotFound
|
return connection, errors.New("remote line not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if connection.Protocol == "" {
|
if connection.Protocol == "" {
|
||||||
@@ -81,19 +79,15 @@ func extractDataFromLine(line string) (
|
|||||||
return ip, 0, "", nil
|
return ip, 0, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var errProtoLineFieldsCount = errors.New("proto line has not 2 fields as expected")
|
|
||||||
|
|
||||||
func extractProto(line string) (protocol string, err error) {
|
func extractProto(line string) (protocol string, err error) {
|
||||||
fields := strings.Fields(line)
|
fields := strings.Fields(line)
|
||||||
if len(fields) != 2 { //nolint:mnd
|
if len(fields) != 2 { //nolint:mnd
|
||||||
return "", fmt.Errorf("%w: %s", errProtoLineFieldsCount, line)
|
return "", fmt.Errorf("proto line has not 2 fields as expected: %s", line)
|
||||||
}
|
}
|
||||||
|
|
||||||
return parseProto(fields[1])
|
return parseProto(fields[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
var errProtocolNotSupported = errors.New("network protocol not supported")
|
|
||||||
|
|
||||||
func parseProto(field string) (protocol string, err error) {
|
func parseProto(field string) (protocol string, err error) {
|
||||||
switch field {
|
switch field {
|
||||||
case "tcp", "tcp4", "tcp6", "tcp-client":
|
case "tcp", "tcp4", "tcp6", "tcp-client":
|
||||||
@@ -106,16 +100,10 @@ func parseProto(field string) (protocol string, err error) {
|
|||||||
// determined by the remote IP address version.
|
// determined by the remote IP address version.
|
||||||
return constants.UDP, nil
|
return constants.UDP, nil
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("%w: %s", errProtocolNotSupported, field)
|
return "", fmt.Errorf("network protocol not supported: %s", field)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
errRemoteLineFieldsCount = errors.New("remote line has not 2 fields as expected")
|
|
||||||
errHostNotIP = errors.New("host is not an IP address")
|
|
||||||
errPortNotValid = errors.New("port is not valid")
|
|
||||||
)
|
|
||||||
|
|
||||||
func extractRemote(line string) (ip netip.Addr, port uint16,
|
func extractRemote(line string) (ip netip.Addr, port uint16,
|
||||||
protocol string, err error,
|
protocol string, err error,
|
||||||
) {
|
) {
|
||||||
@@ -123,13 +111,13 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
|
|||||||
n := len(fields)
|
n := len(fields)
|
||||||
|
|
||||||
if n < 2 || n > 4 {
|
if n < 2 || n > 4 {
|
||||||
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errRemoteLineFieldsCount, line)
|
return netip.Addr{}, 0, "", fmt.Errorf("remote line has not 2 fields as expected: %s", line)
|
||||||
}
|
}
|
||||||
|
|
||||||
host := fields[1]
|
host := fields[1]
|
||||||
ip, err = netip.ParseAddr(host)
|
ip, err = netip.ParseAddr(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errHostNotIP, host)
|
return netip.Addr{}, 0, "", fmt.Errorf("host is not an IP address: %s", host)
|
||||||
// TODO resolve hostname once there is an option to allow it through
|
// TODO resolve hostname once there is an option to allow it through
|
||||||
// the firewall before the VPN is up.
|
// the firewall before the VPN is up.
|
||||||
}
|
}
|
||||||
@@ -137,9 +125,9 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
|
|||||||
if n > 2 { //nolint:mnd
|
if n > 2 { //nolint:mnd
|
||||||
portInt, err := strconv.Atoi(fields[2])
|
portInt, err := strconv.Atoi(fields[2])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line)
|
return netip.Addr{}, 0, "", fmt.Errorf("port is not valid: %s", line)
|
||||||
} else if portInt < 1 || portInt > 65535 {
|
} else if portInt < 1 || portInt > 65535 {
|
||||||
return netip.Addr{}, 0, "", fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt)
|
return netip.Addr{}, 0, "", fmt.Errorf("port is not valid: %d must be between 1 and 65535", portInt)
|
||||||
}
|
}
|
||||||
port = uint16(portInt)
|
port = uint16(portInt)
|
||||||
}
|
}
|
||||||
@@ -154,20 +142,18 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
|
|||||||
return ip, port, protocol, nil
|
return ip, port, protocol, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var errPostLineFieldsCount = errors.New("post line has not 2 fields as expected")
|
|
||||||
|
|
||||||
func extractPort(line string) (port uint16, err error) {
|
func extractPort(line string) (port uint16, err error) {
|
||||||
fields := strings.Fields(line)
|
fields := strings.Fields(line)
|
||||||
const expectedFieldsCount = 2
|
const expectedFieldsCount = 2
|
||||||
if len(fields) != expectedFieldsCount {
|
if len(fields) != expectedFieldsCount {
|
||||||
return 0, fmt.Errorf("%w: %s", errPostLineFieldsCount, line)
|
return 0, fmt.Errorf("post line has not 2 fields as expected: %s", line)
|
||||||
}
|
}
|
||||||
|
|
||||||
portInt, err := strconv.Atoi(fields[1])
|
portInt, err := strconv.Atoi(fields[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("%w: %s", errPortNotValid, line)
|
return 0, fmt.Errorf("port is not valid: %s", line)
|
||||||
} else if portInt < 1 || portInt > 65535 {
|
} else if portInt < 1 || portInt > 65535 {
|
||||||
return 0, fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt)
|
return 0, fmt.Errorf("port is not valid: %d must be between 1 and 65535", portInt)
|
||||||
}
|
}
|
||||||
port = uint16(portInt)
|
port = uint16(portInt)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func Test_extractDataFromLines(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
lines []string
|
lines []string
|
||||||
connection models.Connection
|
connection models.Connection
|
||||||
err error
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"success": {
|
"success": {
|
||||||
lines: []string{"bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"},
|
lines: []string{"bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"},
|
||||||
@@ -28,8 +28,8 @@ func Test_extractDataFromLines(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"extraction error": {
|
"extraction error": {
|
||||||
lines: []string{"bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
|
lines: []string{"bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
|
||||||
err: errors.New("on line 2: extracting protocol from proto line: network protocol not supported: bad"),
|
errMessage: "on line 2: extracting protocol from proto line: network protocol not supported: bad",
|
||||||
},
|
},
|
||||||
"only use first values found": {
|
"only use first values found": {
|
||||||
lines: []string{"proto udp", "proto tcp", "remote 1.2.3.4 443 tcp", "remote 5.2.3.4 1194 udp"},
|
lines: []string{"proto udp", "proto tcp", "remote 1.2.3.4 443 tcp", "remote 5.2.3.4 1194 udp"},
|
||||||
@@ -44,7 +44,7 @@ func Test_extractDataFromLines(t *testing.T) {
|
|||||||
connection: models.Connection{
|
connection: models.Connection{
|
||||||
Protocol: constants.TCP,
|
Protocol: constants.TCP,
|
||||||
},
|
},
|
||||||
err: errRemoteLineNotFound,
|
errMessage: "remote line not found",
|
||||||
},
|
},
|
||||||
"default TCP port": {
|
"default TCP port": {
|
||||||
lines: []string{"remote 1.2.3.4", "proto tcp"},
|
lines: []string{"remote 1.2.3.4", "proto tcp"},
|
||||||
@@ -70,9 +70,8 @@ func Test_extractDataFromLines(t *testing.T) {
|
|||||||
|
|
||||||
connection, err := extractDataFromLines(testCase.lines)
|
connection, err := extractDataFromLines(testCase.lines)
|
||||||
|
|
||||||
if testCase.err != nil {
|
if testCase.errMessage != "" {
|
||||||
require.Error(t, err)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
@@ -86,18 +85,18 @@ func Test_extractDataFromLine(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
line string
|
line string
|
||||||
ip netip.Addr
|
ip netip.Addr
|
||||||
port uint16
|
port uint16
|
||||||
protocol string
|
protocol string
|
||||||
isErr error
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"irrelevant line": {
|
"irrelevant line": {
|
||||||
line: "bla",
|
line: "bla",
|
||||||
},
|
},
|
||||||
"extract proto error": {
|
"extract proto error": {
|
||||||
line: "proto bad",
|
line: "proto bad",
|
||||||
isErr: errProtocolNotSupported,
|
errMessage: "network protocol not supported",
|
||||||
},
|
},
|
||||||
"extract proto success": {
|
"extract proto success": {
|
||||||
line: "proto tcp",
|
line: "proto tcp",
|
||||||
@@ -108,8 +107,8 @@ func Test_extractDataFromLine(t *testing.T) {
|
|||||||
protocol: constants.TCP,
|
protocol: constants.TCP,
|
||||||
},
|
},
|
||||||
"extract remote error": {
|
"extract remote error": {
|
||||||
line: "remote bad",
|
line: "remote bad",
|
||||||
isErr: errHostNotIP,
|
errMessage: "host is not an IP address",
|
||||||
},
|
},
|
||||||
"extract remote success": {
|
"extract remote success": {
|
||||||
line: "remote 1.2.3.4 1194 udp",
|
line: "remote 1.2.3.4 1194 udp",
|
||||||
@@ -118,8 +117,8 @@ func Test_extractDataFromLine(t *testing.T) {
|
|||||||
protocol: constants.UDP,
|
protocol: constants.UDP,
|
||||||
},
|
},
|
||||||
"extract_port_fail": {
|
"extract_port_fail": {
|
||||||
line: "port a",
|
line: "port a",
|
||||||
isErr: errPortNotValid,
|
errMessage: "port is not valid",
|
||||||
},
|
},
|
||||||
"extract_port_success": {
|
"extract_port_success": {
|
||||||
line: "port 1194",
|
line: "port 1194",
|
||||||
@@ -133,8 +132,8 @@ func Test_extractDataFromLine(t *testing.T) {
|
|||||||
|
|
||||||
ip, port, protocol, err := extractDataFromLine(testCase.line)
|
ip, port, protocol, err := extractDataFromLine(testCase.line)
|
||||||
|
|
||||||
if testCase.isErr != nil {
|
if testCase.errMessage != "" {
|
||||||
assert.ErrorIs(t, err, testCase.isErr)
|
assert.ErrorContains(t, err, testCase.errMessage)
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,15 +4,12 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var errPEMDecode = errors.New("cannot decode PEM encoded block")
|
|
||||||
|
|
||||||
func PEM(b []byte) (encodedData string, err error) {
|
func PEM(b []byte) (encodedData string, err error) {
|
||||||
pemBlock, _ := pem.Decode(b)
|
pemBlock, _ := pem.Decode(b)
|
||||||
if pemBlock == nil {
|
if pemBlock == nil {
|
||||||
return "", fmt.Errorf("%w", errPEMDecode)
|
return "", errors.New("cannot decode PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|
||||||
der := pemBlock.Bytes
|
der := pemBlock.Bytes
|
||||||
|
|||||||
@@ -13,16 +13,13 @@ func Test_PEM(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
b []byte
|
b []byte
|
||||||
encodedData string
|
encodedData string
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"no input": {
|
"no input": {
|
||||||
errWrapped: errPEMDecode,
|
|
||||||
errMessage: "cannot decode PEM encoded block",
|
errMessage: "cannot decode PEM encoded block",
|
||||||
},
|
},
|
||||||
"bad input": {
|
"bad input": {
|
||||||
b: []byte{1, 2, 3},
|
b: []byte{1, 2, 3},
|
||||||
errWrapped: errPEMDecode,
|
|
||||||
errMessage: "cannot decode PEM encoded block",
|
errMessage: "cannot decode PEM encoded block",
|
||||||
},
|
},
|
||||||
"valid data with extras": {
|
"valid data with extras": {
|
||||||
@@ -46,9 +43,10 @@ func Test_PEM(t *testing.T) {
|
|||||||
encodedData, err := PEM(testCase.b)
|
encodedData, err := PEM(testCase.b)
|
||||||
|
|
||||||
assert.Equal(t, testCase.encodedData, encodedData)
|
assert.Equal(t, testCase.encodedData, encodedData)
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
if testCase.errMessage != "" {
|
||||||
if testCase.errWrapped != nil {
|
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package pkcs8
|
|||||||
import (
|
import (
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -11,8 +10,6 @@ import (
|
|||||||
// https://www.ibm.com/docs/en/zos/2.3.0?topic=programming-object-identifiers
|
// https://www.ibm.com/docs/en/zos/2.3.0?topic=programming-object-identifiers
|
||||||
var oidDESCBC = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 7} //nolint:gochecknoglobals
|
var oidDESCBC = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 7} //nolint:gochecknoglobals
|
||||||
|
|
||||||
var ErrEncryptionAlgorithmNotPBES2 = errors.New("encryption algorithm is not PBES2")
|
|
||||||
|
|
||||||
type encryptedPrivateKey struct {
|
type encryptedPrivateKey struct {
|
||||||
EncryptionAlgorithm pkix.AlgorithmIdentifier
|
EncryptionAlgorithm pkix.AlgorithmIdentifier
|
||||||
EncryptedData []byte
|
EncryptedData []byte
|
||||||
@@ -35,8 +32,8 @@ func getEncryptionAlgorithmOid(der []byte) (
|
|||||||
oidPBES2 := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13}
|
oidPBES2 := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13}
|
||||||
oidAlgorithm := encryptedPrivateKeyData.EncryptionAlgorithm.Algorithm
|
oidAlgorithm := encryptedPrivateKeyData.EncryptionAlgorithm.Algorithm
|
||||||
if !oidAlgorithm.Equal(oidPBES2) {
|
if !oidAlgorithm.Equal(oidPBES2) {
|
||||||
return nil, fmt.Errorf("%w: %s instead of PBES2 %s",
|
return nil, fmt.Errorf("encryption algorithm is not PBES2: %s instead of PBES2 %s",
|
||||||
ErrEncryptionAlgorithmNotPBES2, oidAlgorithm, oidPBES2)
|
oidAlgorithm, oidPBES2)
|
||||||
}
|
}
|
||||||
|
|
||||||
var encryptionAlgorithmParams encryptedAlgorithmParams
|
var encryptionAlgorithmParams encryptedAlgorithmParams
|
||||||
|
|||||||
@@ -2,14 +2,11 @@ package pkcs8
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
pkcs8lib "github.com/youmark/pkcs8"
|
pkcs8lib "github.com/youmark/pkcs8"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrUnsupportedKeyType = errors.New("unsupported key type")
|
|
||||||
|
|
||||||
// UpgradeEncryptedKey eventually upgrades an encrypted key to a newer encryption
|
// UpgradeEncryptedKey eventually upgrades an encrypted key to a newer encryption
|
||||||
// if its encryption is too weak for Openvpn/Openssl.
|
// if its encryption is too weak for Openvpn/Openssl.
|
||||||
// If the key is encrypted using DES-CBC, it is decrypted and re-encrypted using AES-256-CBC.
|
// If the key is encrypted using DES-CBC, it is decrypted and re-encrypted using AES-256-CBC.
|
||||||
|
|||||||
@@ -2,15 +2,12 @@ package openvpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants/openvpn"
|
"github.com/qdm12/gluetun/internal/constants/openvpn"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrVersionUnknown = errors.New("OpenVPN version is unknown")
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
binOpenvpn25 = "openvpn2.5"
|
binOpenvpn25 = "openvpn2.5"
|
||||||
binOpenvpn26 = "openvpn2.6"
|
binOpenvpn26 = "openvpn2.6"
|
||||||
@@ -26,7 +23,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
|
|||||||
case openvpn.Openvpn26:
|
case openvpn.Openvpn26:
|
||||||
bin = binOpenvpn26
|
bin = binOpenvpn26
|
||||||
default:
|
default:
|
||||||
return nil, nil, nil, fmt.Errorf("%w: %s", ErrVersionUnknown, version)
|
return nil, nil, nil, fmt.Errorf("OpenVPN version is unknown: %s", version)
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []string{"--config", configPath}
|
args := []string{"--config", configPath}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package openvpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -16,8 +15,6 @@ func (c *Configurator) Version26(ctx context.Context) (version string, err error
|
|||||||
return c.version(ctx, binOpenvpn26)
|
return c.version(ctx, binOpenvpn26)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrVersionTooShort = errors.New("version output is too short")
|
|
||||||
|
|
||||||
func (c *Configurator) version(ctx context.Context, binName string) (version string, err error) {
|
func (c *Configurator) version(ctx context.Context, binName string) (version string, err error) {
|
||||||
cmd := exec.CommandContext(ctx, binName, "--version")
|
cmd := exec.CommandContext(ctx, binName, "--version")
|
||||||
output, err := c.cmder.Run(cmd)
|
output, err := c.cmder.Run(cmd)
|
||||||
@@ -28,7 +25,7 @@ func (c *Configurator) version(ctx context.Context, binName string) (version str
|
|||||||
words := strings.Fields(firstLine)
|
words := strings.Fields(firstLine)
|
||||||
const minWords = 2
|
const minWords = 2
|
||||||
if len(words) < minWords {
|
if len(words) < minWords {
|
||||||
return "", fmt.Errorf("%w: %s", ErrVersionTooShort, firstLine)
|
return "", fmt.Errorf("version output is too short: %s", firstLine)
|
||||||
}
|
}
|
||||||
return words[1], nil
|
return words[1], nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,24 +2,17 @@ package icmp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
|
||||||
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
|
||||||
)
|
|
||||||
|
|
||||||
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
||||||
switch {
|
switch {
|
||||||
case mtu < minMTU:
|
case mtu < minMTU:
|
||||||
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
|
return fmt.Errorf("ICMP Next Hop MTU is too low: %d", mtu)
|
||||||
case mtu > physicalLinkMTU:
|
case mtu > physicalLinkMTU:
|
||||||
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
return fmt.Errorf("ICMP Next Hop MTU is too high: %d is larger than physical link MTU %d", mtu, physicalLinkMTU)
|
||||||
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -34,14 +27,12 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
|
|||||||
}
|
}
|
||||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
return false, fmt.Errorf("ICMP body type is not supported: %T", inboundMessage.Body)
|
||||||
}
|
}
|
||||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||||
return inboundBody.ID == outboundBody.ID, nil
|
return inboundBody.ID == outboundBody.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrIDMismatch = errors.New("ICMP id mismatch")
|
|
||||||
|
|
||||||
func checkEchoReply(icmpProtocol int, received []byte,
|
func checkEchoReply(icmpProtocol int, received []byte,
|
||||||
outboundMessage *icmp.Message, truncatedBody bool,
|
outboundMessage *icmp.Message, truncatedBody bool,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
@@ -51,12 +42,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
|||||||
}
|
}
|
||||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
return fmt.Errorf("ICMP body type is not supported: %T", inboundMessage.Body)
|
||||||
}
|
}
|
||||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||||
if inboundBody.ID != outboundBody.ID {
|
if inboundBody.ID != outboundBody.ID {
|
||||||
return fmt.Errorf("%w: sent id %d and received id %d",
|
return fmt.Errorf("ICMP id mismatch: sent id %d and received id %d",
|
||||||
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
|
outboundBody.ID, inboundBody.ID)
|
||||||
}
|
}
|
||||||
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -65,19 +56,17 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
|
|
||||||
|
|
||||||
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
||||||
if len(received) > len(sent) {
|
if len(received) > len(sent) {
|
||||||
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
return fmt.Errorf("ICMP data mismatch: sent %d bytes and received %d bytes",
|
||||||
ErrEchoDataMismatch, len(sent), len(received))
|
len(sent), len(received))
|
||||||
}
|
}
|
||||||
if receivedTruncated {
|
if receivedTruncated {
|
||||||
sent = sent[:len(received)]
|
sent = sent[:len(received)]
|
||||||
}
|
}
|
||||||
if !bytes.Equal(received, sent) {
|
if !bytes.Equal(received, sent) {
|
||||||
return fmt.Errorf("%w: sent %x and received %x",
|
return fmt.Errorf("ICMP data mismatch: sent %x and received %x",
|
||||||
ErrEchoDataMismatch, sent, received)
|
sent, received)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrNotPermitted = errors.New("ICMP not permitted")
|
ErrNotPermitted = errors.New("ICMP not permitted")
|
||||||
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
|
errCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||||
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
|
||||||
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
|
|
||||||
ErrMTUNotFound = errors.New("MTU not found")
|
ErrMTUNotFound = errors.New("MTU not found")
|
||||||
errTimeout = errors.New("operation timed out")
|
errTimeout = errors.New("operation timed out")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
|||||||
switch {
|
switch {
|
||||||
case err == nil:
|
case err == nil:
|
||||||
return mtu, nil
|
return mtu, nil
|
||||||
case errors.Is(err, errTimeout) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
|
case errors.Is(err, errTimeout) || errors.Is(err, errCommunicationAdministrativelyProhibited): // blackhole
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err)
|
return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -117,13 +117,10 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
|||||||
case portUnreachable: // triggered by TCP or UDP from applications
|
case portUnreachable: // triggered by TCP or UDP from applications
|
||||||
continue // ignore and wait for the next message
|
continue // ignore and wait for the next message
|
||||||
case communicationAdministrativelyProhibitedCode:
|
case communicationAdministrativelyProhibitedCode:
|
||||||
return 0, fmt.Errorf("%w: %w (code %d)",
|
return 0, fmt.Errorf("ICMP destination unreachable: %w (code %d)", errCommunicationAdministrativelyProhibited,
|
||||||
ErrDestinationUnreachable,
|
|
||||||
ErrCommunicationAdministrativelyProhibited,
|
|
||||||
inboundMessage.Code)
|
inboundMessage.Code)
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: code %d",
|
return 0, fmt.Errorf("ICMP destination unreachable: code %d", inboundMessage.Code)
|
||||||
ErrDestinationUnreachable, inboundMessage.Code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
||||||
@@ -158,7 +155,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
|||||||
inboundID, outboundID)
|
inboundID, outboundID)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
return 0, fmt.Errorf("ICMP body type is not supported: %T", typedBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package icmp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -115,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
||||||
} else if idMatch {
|
} else if idMatch {
|
||||||
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
|
return 0, errors.New("ICMP destination unreachable")
|
||||||
}
|
}
|
||||||
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
||||||
continue
|
continue
|
||||||
@@ -128,7 +129,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
|||||||
inboundID, outboundID)
|
inboundID, outboundID)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
return 0, fmt.Errorf("ICMP body type %T is not supported", typedBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
|||||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||||
err = fmt.Errorf("%w", ErrNotPermitted)
|
err = ErrNotPermitted
|
||||||
}
|
}
|
||||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||||
}
|
}
|
||||||
@@ -157,7 +157,7 @@ func collectReplies(conn net.PacketConn, ipVersion string,
|
|||||||
logger.Debugf("ignoring ICMP message (type: %d, code: %d)", message.Type, message.Code)
|
logger.Debugf("ignoring ICMP message (type: %d, code: %d)", message.Type, message.Code)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
|
return fmt.Errorf("ICMP body type is not supported: %T", message.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
echoBody, _ := message.Body.(*icmp.Echo)
|
echoBody, _ := message.Body.(*icmp.Echo)
|
||||||
@@ -183,8 +183,8 @@ func collectReplies(conn net.PacketConn, ipVersion string,
|
|||||||
ipPacketLength == conservativeReplyLength
|
ipPacketLength == conservativeReplyLength
|
||||||
// Check the packet size is the same if the reply is not truncated
|
// Check the packet size is the same if the reply is not truncated
|
||||||
if !truncated && sentBytes != ipPacketLength {
|
if !truncated && sentBytes != ipPacketLength {
|
||||||
return fmt.Errorf("%w: sent %dB and received %dB",
|
return fmt.Errorf("ICMP data mismatch: sent %dB and received %dB",
|
||||||
ErrEchoDataMismatch, sentBytes, ipPacketLength)
|
sentBytes, ipPacketLength)
|
||||||
}
|
}
|
||||||
// Truncated reply or matching reply size
|
// Truncated reply or matching reply size
|
||||||
tests[testIndex].ok = true
|
tests[testIndex].ok = true
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(),
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errNoRoute = fmt.Errorf("no route to destination")
|
errNoRoute = errors.New("no route to destination")
|
||||||
ErrNetworkUnreachable = errors.New("network unreachable")
|
ErrNetworkUnreachable = errors.New("network unreachable")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrICMPOkTCPFail = errors.New("PMTUD succeeded with ICMP but failed with TCP")
|
|
||||||
ErrICMPFailTCPFail = errors.New("PMTUD failed with both ICMP and TCP")
|
|
||||||
)
|
|
||||||
|
|
||||||
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
|
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
|
||||||
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
|
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
|
||||||
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
|
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
|
||||||
@@ -81,10 +76,10 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if icmpSuccess {
|
if icmpSuccess {
|
||||||
return 0, fmt.Errorf("%w - discarding ICMP obtained MTU %d",
|
return 0, fmt.Errorf("PMTUD succeeded with ICMP but failed with TCP "+
|
||||||
ErrICMPOkTCPFail, maxPossibleMTU)
|
"- discarding ICMP obtained MTU %d", maxPossibleMTU)
|
||||||
}
|
}
|
||||||
return 0, fmt.Errorf("%w", ErrICMPFailTCPFail)
|
return 0, errors.New("PMTUD failed with both ICMP and TCP")
|
||||||
}
|
}
|
||||||
logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu)
|
logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu)
|
||||||
return mtu, nil
|
return mtu, nil
|
||||||
|
|||||||
@@ -57,8 +57,6 @@ func (l *noopLogger) Warn(_ string) {}
|
|||||||
func (l *noopLogger) Warnf(_ string, _ ...any) {}
|
func (l *noopLogger) Warnf(_ string, _ ...any) {}
|
||||||
func (l *noopLogger) Error(_ string) {}
|
func (l *noopLogger) Error(_ string) {}
|
||||||
|
|
||||||
var errRouteNotFound = errors.New("route not found")
|
|
||||||
|
|
||||||
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||||
routes, err := netlinker.RouteList(netlink.FamilyV4)
|
routes, err := netlinker.RouteList(netlink.FamilyV4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -76,7 +74,7 @@ func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
|||||||
return min(link.MTU, maxMTU), nil
|
return min(link.MTU, maxMTU), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
return 0, errors.New("route not found: no loopback route found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||||
@@ -100,7 +98,7 @@ func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if mtu == 0 {
|
if mtu == 0 {
|
||||||
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
return 0, errors.New("route not found: no default route found")
|
||||||
}
|
}
|
||||||
return mtu, nil
|
return mtu, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errTCPServersUnreachable = errors.New("all TCP servers are unreachable")
|
|
||||||
|
|
||||||
// findHighestMSSDestination finds the destination with the highest
|
// findHighestMSSDestination finds the destination with the highest
|
||||||
// MSS amongst the provided destinations.
|
// MSS amongst the provided destinations.
|
||||||
func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor,
|
func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor,
|
||||||
@@ -68,7 +66,7 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
|
|||||||
}
|
}
|
||||||
|
|
||||||
if mss == 0 { // no MSS found for any destination
|
if mss == 0 { // no MSS found for any destination
|
||||||
return netip.AddrPort{}, 0, fmt.Errorf("%w (%d servers)", errTCPServersUnreachable, len(dsts))
|
return netip.AddrPort{}, 0, fmt.Errorf("all %d TCP servers are unreachable", len(dsts))
|
||||||
}
|
}
|
||||||
|
|
||||||
maxPossibleMTU = ip.HeaderLength(dst.Addr().Is4()) + constants.BaseTCPHeaderLength + mss
|
maxPossibleMTU = ip.HeaderLength(dst.Addr().Is4()) + constants.BaseTCPHeaderLength + mss
|
||||||
@@ -77,8 +75,6 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
|
|||||||
return dst, mss, nil
|
return dst, mss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var errMSSNotFound = errors.New("MSS option not found in reply")
|
|
||||||
|
|
||||||
func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
|
func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
|
||||||
excludeMark int, tracker *tracker, firewall Firewall, logger Logger) (
|
excludeMark int, tracker *tracker, firewall Firewall, logger Logger) (
|
||||||
mss uint32, err error,
|
mss uint32, err error,
|
||||||
@@ -132,11 +128,12 @@ func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
|
|||||||
case err != nil:
|
case err != nil:
|
||||||
return 0, fmt.Errorf("parsing reply TCP header: %w", err)
|
return 0, fmt.Errorf("parsing reply TCP header: %w", err)
|
||||||
case replyHeader.typ != packetTypeSYNACK:
|
case replyHeader.typ != packetTypeSYNACK:
|
||||||
return 0, fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, replyHeader.typ)
|
return 0, fmt.Errorf("TCP packet is not a SYN-ACK: unexpected packet type %s", replyHeader.typ)
|
||||||
case replyHeader.ack != synSeq+1:
|
case replyHeader.ack != synSeq+1:
|
||||||
return 0, fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, replyHeader.ack)
|
return 0, fmt.Errorf("TCP SYN-ACK ACK number %d does not match expected value %d",
|
||||||
|
replyHeader.ack, synSeq+1)
|
||||||
case replyHeader.options.mss == 0:
|
case replyHeader.options.mss == 0:
|
||||||
return 0, fmt.Errorf("%w: MSS option not found in reply", errMSSNotFound)
|
return 0, errors.New("MSS option not found in reply")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = sendRST(fd, src, dst, replyHeader.ack)
|
err = sendRST(fd, src, dst, replyHeader.ack)
|
||||||
|
|||||||
@@ -12,11 +12,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/pmtud/test"
|
"github.com/qdm12/gluetun/internal/pmtud/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrMTUNotFound = errors.New("MTU not found")
|
|
||||||
ErrMSSTooSmall = errors.New("TCP MSS is too small to find the MTU")
|
|
||||||
)
|
|
||||||
|
|
||||||
type testUnit struct {
|
type testUnit struct {
|
||||||
mtu uint32
|
mtu uint32
|
||||||
ok bool
|
ok bool
|
||||||
@@ -178,5 +173,5 @@ func pathMTUDiscover(ctx context.Context, fd fileDescriptor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
|
return 0, errors.New("MTU not found: your connection might not be working at all")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,13 +75,6 @@ func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), er
|
|||||||
return fileDescriptor(fdPlatform), stop, nil
|
return fileDescriptor(fdPlatform), stop, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
|
|
||||||
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
|
|
||||||
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
|
|
||||||
errTCPPacketLost = errors.New("TCP packet was lost")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Craft and send a raw TCP packet to test the MTU.
|
// Craft and send a raw TCP packet to test the MTU.
|
||||||
// It expects either an RST reply (if no server is listening)
|
// It expects either an RST reply (if no server is listening)
|
||||||
// or a SYN-ACK/ACK reply (if a server is listening).
|
// or a SYN-ACK/ACK reply (if a server is listening).
|
||||||
@@ -142,9 +135,10 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
|
|||||||
// server actively closed the connection, try sending a SYN with data
|
// server actively closed the connection, try sending a SYN with data
|
||||||
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
|
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
|
||||||
case firstReplyHeader.typ != packetTypeSYNACK:
|
case firstReplyHeader.typ != packetTypeSYNACK:
|
||||||
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, firstReplyHeader.typ)
|
return fmt.Errorf("TCP packet is not a SYN-ACK: unexpected packet type %s", firstReplyHeader.typ)
|
||||||
case firstReplyHeader.ack != synSeq+1:
|
case firstReplyHeader.ack != synSeq+1:
|
||||||
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, firstReplyHeader.ack)
|
return fmt.Errorf("TCP SYN-ACK ACK number does not match expected value: "+
|
||||||
|
"expected %d, got %d", synSeq+1, firstReplyHeader.ack)
|
||||||
}
|
}
|
||||||
|
|
||||||
if firstReplyHeader.options.mss != 0 {
|
if firstReplyHeader.options.mss != 0 {
|
||||||
@@ -191,15 +185,13 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case packetTypeSYNACK: // server never received our MTU-test ACK packet
|
case packetTypeSYNACK: // server never received our MTU-test ACK packet
|
||||||
return fmt.Errorf("%w: server responded with second SYN-ACK packet", errTCPPacketLost)
|
return errors.New("TCP packet was lost: server responded with second SYN-ACK packet")
|
||||||
default:
|
default:
|
||||||
_ = sendRST(fd, src, dst, finalPacketHeader.ack)
|
_ = sendRST(fd, src, dst, finalPacketHeader.ack)
|
||||||
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, finalPacketHeader.typ)
|
return fmt.Errorf("final TCP packet type is unexpected: %s", finalPacketHeader.typ)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
|
|
||||||
|
|
||||||
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
|
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
|
||||||
src, dst netip.AddrPort, mtu uint32,
|
src, dst netip.AddrPort, mtu uint32,
|
||||||
) error {
|
) error {
|
||||||
@@ -223,7 +215,7 @@ func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
|
|||||||
return fmt.Errorf("parsing reply TCP header: %w", err)
|
return fmt.Errorf("parsing reply TCP header: %w", err)
|
||||||
} else if replyPacketHeader.typ != packetTypeRST &&
|
} else if replyPacketHeader.typ != packetTypeRST &&
|
||||||
replyPacketHeader.typ != packetTypeRSTACK {
|
replyPacketHeader.typ != packetTypeRSTACK {
|
||||||
return fmt.Errorf("%w: %s", errTCPPacketNotRST, replyPacketHeader.typ)
|
return fmt.Errorf("TCP packet is not an RST: %s", replyPacketHeader.typ)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package tcp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
@@ -120,17 +119,11 @@ type tcpHeader struct {
|
|||||||
options options
|
options options
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
errTCPHeaderTooShort = errors.New("TCP header is too short")
|
|
||||||
errTCPHeaderDataOffset = errors.New("TCP header data offset is invalid")
|
|
||||||
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
|
|
||||||
)
|
|
||||||
|
|
||||||
// parseTCPHeader parses the TCP header from b.
|
// parseTCPHeader parses the TCP header from b.
|
||||||
// b should be the entire TCP packet bytes.
|
// b should be the entire TCP packet bytes.
|
||||||
func parseTCPHeader(b []byte) (header tcpHeader, err error) {
|
func parseTCPHeader(b []byte) (header tcpHeader, err error) {
|
||||||
if len(b) < int(constants.BaseTCPHeaderLength) {
|
if len(b) < int(constants.BaseTCPHeaderLength) {
|
||||||
return tcpHeader{}, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(b))
|
return tcpHeader{}, fmt.Errorf("TCP header is too short: %d bytes", len(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
header.srcPort = binary.BigEndian.Uint16(b[0:2])
|
header.srcPort = binary.BigEndian.Uint16(b[0:2])
|
||||||
@@ -146,11 +139,11 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case uint32(header.dataOffset) < constants.BaseTCPHeaderLength:
|
case uint32(header.dataOffset) < constants.BaseTCPHeaderLength:
|
||||||
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, expected at least %d bytes",
|
return tcpHeader{}, fmt.Errorf("TCP header data offset is invalid: "+
|
||||||
errTCPHeaderDataOffset, header.dataOffset, constants.BaseTCPHeaderLength)
|
"data offset is %d bytes, expected at least %d bytes", header.dataOffset, constants.BaseTCPHeaderLength)
|
||||||
case int(header.dataOffset) > len(b):
|
case int(header.dataOffset) > len(b):
|
||||||
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, but packet is only %d bytes",
|
return tcpHeader{}, fmt.Errorf("TCP header data offset is invalid: "+
|
||||||
errTCPHeaderDataOffset, header.dataOffset, len(b))
|
"data offset is %d bytes, but packet is only %d bytes", header.dataOffset, len(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
if uint32(header.dataOffset) > constants.BaseTCPHeaderLength {
|
if uint32(header.dataOffset) > constants.BaseTCPHeaderLength {
|
||||||
@@ -186,7 +179,7 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
|
|||||||
case flags&ackFlag != 0:
|
case flags&ackFlag != 0:
|
||||||
header.typ = packetTypeACK
|
header.typ = packetTypeACK
|
||||||
default:
|
default:
|
||||||
return tcpHeader{}, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
|
return tcpHeader{}, fmt.Errorf("TCP packet type is unknown: flags are 0x%02x", flags)
|
||||||
}
|
}
|
||||||
|
|
||||||
header.seq = binary.BigEndian.Uint32(b[4:8])
|
header.seq = binary.BigEndian.Uint32(b[4:8])
|
||||||
@@ -206,15 +199,6 @@ type optionTimestamps struct {
|
|||||||
echo uint32
|
echo uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
errTCPOptionLengthTruncated = errors.New("TCP option length is truncated")
|
|
||||||
ErrTCPOptionLengthInvalid = errors.New("TCP option length is invalid")
|
|
||||||
ErrTCPOptionMSSInvalid = errors.New("TCP option MSS value is invalid")
|
|
||||||
ErrTCPOptionWindowScaleInvalid = errors.New("TCP option Window Scale value is invalid")
|
|
||||||
ErrTCPOptionTimestampsInvalid = errors.New("TCP option Timestamps value is invalid")
|
|
||||||
errTCPOptionTypeUnknown = errors.New("TCP option type is unknown")
|
|
||||||
)
|
|
||||||
|
|
||||||
func parseTCPOptions(b []byte) (parsed options, err error) {
|
func parseTCPOptions(b []byte) (parsed options, err error) {
|
||||||
i := 0
|
i := 0
|
||||||
for i < len(b) {
|
for i < len(b) {
|
||||||
@@ -232,7 +216,7 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
|
|||||||
// Handle TLV (Type-Length-Value) options
|
// Handle TLV (Type-Length-Value) options
|
||||||
if i+1 >= len(b) {
|
if i+1 >= len(b) {
|
||||||
// This should not happen for DF packets.
|
// This should not happen for DF packets.
|
||||||
return options{}, fmt.Errorf("%w: at offset %d", errTCPOptionLengthTruncated, i)
|
return options{}, fmt.Errorf("TCP option length is truncated: at offset %d", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
length := int(b[i+1])
|
length := int(b[i+1])
|
||||||
@@ -240,11 +224,11 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
|
|||||||
maxLength := len(b) - i
|
maxLength := len(b) - i
|
||||||
switch {
|
switch {
|
||||||
case length < minLength:
|
case length < minLength:
|
||||||
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d < %d",
|
return options{}, fmt.Errorf("TCP option length is invalid: "+
|
||||||
ErrTCPOptionLengthInvalid, optionType, i, length, minLength)
|
"type %d at offset %d has length %d < %d", optionType, i, length, minLength)
|
||||||
case length > maxLength:
|
case length > maxLength:
|
||||||
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d > %d",
|
return options{}, fmt.Errorf("TCP option length is invalid: "+
|
||||||
ErrTCPOptionLengthInvalid, optionType, i, length, maxLength)
|
"type %d at offset %d has length %d > %d", optionType, i, length, maxLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
data := b[i+2 : i+length]
|
data := b[i+2 : i+length]
|
||||||
@@ -259,15 +243,15 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
|
|||||||
case optionTypeMSS:
|
case optionTypeMSS:
|
||||||
const expectedLength = 4
|
const expectedLength = 4
|
||||||
if length != expectedLength {
|
if length != expectedLength {
|
||||||
return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d",
|
return options{}, fmt.Errorf("TCP option MSS value is invalid: "+
|
||||||
ErrTCPOptionMSSInvalid, i, length, expectedLength)
|
"MSS option at offset %d has length %d, expected %d", i, length, expectedLength)
|
||||||
}
|
}
|
||||||
parsed.mss = uint32(binary.BigEndian.Uint16(data))
|
parsed.mss = uint32(binary.BigEndian.Uint16(data))
|
||||||
case optionTypeWindowScale:
|
case optionTypeWindowScale:
|
||||||
const expectedLength = 3
|
const expectedLength = 3
|
||||||
if length != expectedLength {
|
if length != expectedLength {
|
||||||
return options{}, fmt.Errorf("%w: window scale option at offset %d has length %d, expected %d",
|
return options{}, fmt.Errorf("TCP option Window Scale value is invalid: "+
|
||||||
ErrTCPOptionWindowScaleInvalid, i, length, expectedLength)
|
"window scale option at offset %d has length %d, expected %d", i, length, expectedLength)
|
||||||
}
|
}
|
||||||
windowScale := data[0]
|
windowScale := data[0]
|
||||||
parsed.windowScale = &windowScale
|
parsed.windowScale = &windowScale
|
||||||
@@ -276,15 +260,15 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
|
|||||||
case optionTypeTimestamps:
|
case optionTypeTimestamps:
|
||||||
const expectedLength = 10
|
const expectedLength = 10
|
||||||
if length != expectedLength {
|
if length != expectedLength {
|
||||||
return options{}, fmt.Errorf("%w: timestamps option at offset %d has length %d, expected %d",
|
return options{}, fmt.Errorf("TCP option Timestamps value is invalid: "+
|
||||||
ErrTCPOptionTimestampsInvalid, i, length, expectedLength)
|
"timestamps option at offset %d has length %d, expected %d", i, length, expectedLength)
|
||||||
}
|
}
|
||||||
parsed.timestamps = &optionTimestamps{
|
parsed.timestamps = &optionTimestamps{
|
||||||
value: binary.BigEndian.Uint32(data[:4]),
|
value: binary.BigEndian.Uint32(data[:4]),
|
||||||
echo: binary.BigEndian.Uint32(data[4:]),
|
echo: binary.BigEndian.Uint32(data[4:]),
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return options{}, fmt.Errorf("%w: type %d", errTCPOptionTypeUnknown, optionType)
|
return options{}, fmt.Errorf("TCP option type is unknown: type %d", optionType)
|
||||||
}
|
}
|
||||||
|
|
||||||
i += length
|
i += length
|
||||||
|
|||||||
@@ -177,11 +177,9 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) {
|
|||||||
return l.service.GetPortsForwarded()
|
return l.service.GetPortsForwarded()
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrServiceNotStarted = errors.New("port forwarding service not started")
|
|
||||||
|
|
||||||
func (l *Loop) SetPortsForwarded(ports []uint16) (err error) {
|
func (l *Loop) SetPortsForwarded(ports []uint16) (err error) {
|
||||||
if l.service == nil {
|
if l.service == nil {
|
||||||
return fmt.Errorf("%w", ErrServiceNotStarted)
|
return errors.New("port forwarding service not started")
|
||||||
}
|
}
|
||||||
|
|
||||||
return l.service.SetPortsForwarded(l.runCtx, ports)
|
return l.service.SetPortsForwarded(l.runCtx, ports)
|
||||||
|
|||||||
@@ -55,23 +55,10 @@ func (s *Settings) OverrideWith(update Settings) {
|
|||||||
s.Password = gosettings.OverrideWithComparable(s.Password, update.Password)
|
s.Password = gosettings.OverrideWithComparable(s.Password, update.Password)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrPortForwarderNotSet = errors.New("port forwarder not set")
|
|
||||||
ErrServerNameNotSet = errors.New("server name not set")
|
|
||||||
ErrUsernameNotSet = errors.New("username not set")
|
|
||||||
ErrPasswordNotSet = errors.New("password not set")
|
|
||||||
ErrFilepathNotSet = errors.New("file path not set")
|
|
||||||
ErrInterfaceNotSet = errors.New("interface not set")
|
|
||||||
ErrPortsCountZero = errors.New("ports count cannot be zero")
|
|
||||||
ErrPortsCountTooHigh = errors.New("ports count too high")
|
|
||||||
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
|
|
||||||
ErrListeningPortZero = errors.New("listening port cannot be 0")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s *Settings) Validate(forStartup bool) (err error) {
|
func (s *Settings) Validate(forStartup bool) (err error) {
|
||||||
// Minimal validation
|
// Minimal validation
|
||||||
if s.Filepath == "" {
|
if s.Filepath == "" {
|
||||||
return fmt.Errorf("%w", ErrFilepathNotSet)
|
return errors.New("file path not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !forStartup {
|
if !forStartup {
|
||||||
@@ -83,41 +70,42 @@ func (s *Settings) Validate(forStartup bool) (err error) {
|
|||||||
// Startup validation requires additional fields set.
|
// Startup validation requires additional fields set.
|
||||||
switch {
|
switch {
|
||||||
case s.PortForwarder == nil:
|
case s.PortForwarder == nil:
|
||||||
return fmt.Errorf("%w", ErrPortForwarderNotSet)
|
return errors.New("port forwarder not set")
|
||||||
case s.Interface == "":
|
case s.Interface == "":
|
||||||
return fmt.Errorf("%w", ErrInterfaceNotSet)
|
return errors.New("interface not set")
|
||||||
case s.PortsCount == 0:
|
case s.PortsCount == 0:
|
||||||
return fmt.Errorf("%w", ErrPortsCountZero)
|
return errors.New("ports count cannot be zero")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch s.PortForwarder.Name() {
|
switch s.PortForwarder.Name() {
|
||||||
case providers.PrivateInternetAccess:
|
case providers.PrivateInternetAccess:
|
||||||
switch {
|
switch {
|
||||||
case s.ServerName == "":
|
case s.ServerName == "":
|
||||||
return fmt.Errorf("%w", ErrServerNameNotSet)
|
return errors.New("server name not set")
|
||||||
case s.Username == "":
|
case s.Username == "":
|
||||||
return fmt.Errorf("%w", ErrUsernameNotSet)
|
return errors.New("username not set")
|
||||||
case s.Password == "":
|
case s.Password == "":
|
||||||
return fmt.Errorf("%w", ErrPasswordNotSet)
|
return errors.New("password not set")
|
||||||
}
|
}
|
||||||
case providers.Protonvpn:
|
case providers.Protonvpn:
|
||||||
const maxPortsCount = 4
|
const maxPortsCount = 4
|
||||||
if s.PortsCount > maxPortsCount {
|
if s.PortsCount > maxPortsCount {
|
||||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount)
|
return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
const maxPortsCount = 1
|
const maxPortsCount = 1
|
||||||
if s.PortsCount > maxPortsCount {
|
if s.PortsCount > maxPortsCount {
|
||||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount)
|
return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !slices.Equal(s.ListeningPorts, []uint16{0}) {
|
if !slices.Equal(s.ListeningPorts, []uint16{0}) {
|
||||||
switch {
|
switch {
|
||||||
case len(s.ListeningPorts) != int(s.PortsCount):
|
case len(s.ListeningPorts) != int(s.PortsCount):
|
||||||
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(s.ListeningPorts), s.PortsCount)
|
return fmt.Errorf("listening ports length must be equal to ports count: %d != %d",
|
||||||
|
len(s.ListeningPorts), s.PortsCount)
|
||||||
case slices.Contains(s.ListeningPorts, 0):
|
case slices.Contains(s.ListeningPorts, 0):
|
||||||
return fmt.Errorf("%w: in %v", ErrListeningPortZero, s.ListeningPorts)
|
return fmt.Errorf("listening port cannot be 0: in %v", s.ListeningPorts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ func Test_Server_BadSettings(t *testing.T) {
|
|||||||
|
|
||||||
server, err := New(settings)
|
server, err := New(settings)
|
||||||
assert.Nil(t, server)
|
assert.Nil(t, server)
|
||||||
assert.ErrorIs(t, err, ErrBlockProfileRateNegative)
|
assert.ErrorContains(t, err, "block profile rate cannot be negative")
|
||||||
const expectedErrMessage = "pprof settings failed validation: block profile rate cannot be negative"
|
const expectedErrMessage = "pprof settings failed validation: block profile rate cannot be negative"
|
||||||
assert.EqualError(t, err, expectedErrMessage)
|
assert.EqualError(t, err, expectedErrMessage)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package pprof
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/httpserver"
|
"github.com/qdm12/gluetun/internal/httpserver"
|
||||||
@@ -51,18 +50,13 @@ func (s *Settings) OverrideWith(other Settings) {
|
|||||||
s.HTTPServer.OverrideWith(other.HTTPServer)
|
s.HTTPServer.OverrideWith(other.HTTPServer)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrBlockProfileRateNegative = errors.New("block profile rate cannot be negative")
|
|
||||||
ErrMutexProfileRateNegative = errors.New("mutex profile rate cannot be negative")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s Settings) Validate() (err error) {
|
func (s Settings) Validate() (err error) {
|
||||||
if *s.BlockProfileRate < 0 {
|
if *s.BlockProfileRate < 0 {
|
||||||
return fmt.Errorf("%w", ErrBlockProfileRateNegative)
|
return errors.New("block profile rate cannot be negative")
|
||||||
}
|
}
|
||||||
|
|
||||||
if *s.MutexProfileRate < 0 {
|
if *s.MutexProfileRate < 0 {
|
||||||
return fmt.Errorf("%w", ErrMutexProfileRateNegative)
|
return errors.New("mutex profile rate cannot be negative")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.HTTPServer.Validate()
|
return s.HTTPServer.Validate()
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/httpserver"
|
"github.com/qdm12/gluetun/internal/httpserver"
|
||||||
"github.com/qdm12/gosettings/validate"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -195,7 +194,6 @@ func Test_Settings_Validate(t *testing.T) {
|
|||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
settings Settings
|
settings Settings
|
||||||
errWrapped error
|
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"negative block profile rate": {
|
"negative block profile rate": {
|
||||||
@@ -203,16 +201,14 @@ func Test_Settings_Validate(t *testing.T) {
|
|||||||
BlockProfileRate: intPtr(-1),
|
BlockProfileRate: intPtr(-1),
|
||||||
MutexProfileRate: intPtr(0),
|
MutexProfileRate: intPtr(0),
|
||||||
},
|
},
|
||||||
errWrapped: ErrBlockProfileRateNegative,
|
errMessage: "block profile rate cannot be negative",
|
||||||
errMessage: ErrBlockProfileRateNegative.Error(),
|
|
||||||
},
|
},
|
||||||
"negative mutex profile rate": {
|
"negative mutex profile rate": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
BlockProfileRate: intPtr(0),
|
BlockProfileRate: intPtr(0),
|
||||||
MutexProfileRate: intPtr(-1),
|
MutexProfileRate: intPtr(-1),
|
||||||
},
|
},
|
||||||
errWrapped: ErrMutexProfileRateNegative,
|
errMessage: "mutex profile rate cannot be negative",
|
||||||
errMessage: ErrMutexProfileRateNegative.Error(),
|
|
||||||
},
|
},
|
||||||
"http server validation error": {
|
"http server validation error": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -222,7 +218,6 @@ func Test_Settings_Validate(t *testing.T) {
|
|||||||
Address: ":x",
|
Address: ":x",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
errWrapped: validate.ErrPortNotAnInteger,
|
|
||||||
errMessage: "port value is not an integer: x",
|
errMessage: "port value is not an integer: x",
|
||||||
},
|
},
|
||||||
"valid settings": {
|
"valid settings": {
|
||||||
@@ -247,9 +242,10 @@ func Test_Settings_Validate(t *testing.T) {
|
|||||||
|
|
||||||
err := testCase.settings.Validate()
|
err := testCase.settings.Validate()
|
||||||
|
|
||||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
|
||||||
if testCase.errMessage != "" {
|
if testCase.errMessage != "" {
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/provider/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type apiData struct {
|
type apiData struct {
|
||||||
@@ -48,8 +46,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
|||||||
|
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
_ = response.Body.Close()
|
_ = response.Body.Close()
|
||||||
return data, fmt.Errorf("%w: %d %s",
|
return data, fmt.Errorf("HTTP status code not OK: %d %s",
|
||||||
common.ErrHTTPStatusCodeNotOK, response.StatusCode, response.Status)
|
response.StatusCode, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
var ErrPortForwardNotSupported = errors.New("port forwarding not supported")
|
|
||||||
@@ -10,10 +10,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrNotEnoughServers = errors.New("not enough servers found")
|
ErrNotEnoughServers = errors.New("not enough servers found")
|
||||||
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
ErrCredentialsMissing = errors.New("credentials are missing")
|
||||||
ErrIPFetcherUnsupported = errors.New("IP fetcher not supported")
|
|
||||||
ErrCredentialsMissing = errors.New("credentials missing")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Fetcher interface {
|
type Fetcher interface {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package custom
|
package custom
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
@@ -10,8 +9,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrVPNTypeNotSupported = errors.New("VPN type not supported for custom provider")
|
|
||||||
|
|
||||||
// GetConnection gets the connection from the OpenVPN configuration file.
|
// GetConnection gets the connection from the OpenVPN configuration file.
|
||||||
func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) (
|
func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) (
|
||||||
connection models.Connection, err error,
|
connection models.Connection, err error,
|
||||||
@@ -22,7 +19,7 @@ func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) (
|
|||||||
case vpn.Wireguard, vpn.AmneziaWg:
|
case vpn.Wireguard, vpn.AmneziaWg:
|
||||||
return getWireguardConnection(selection), nil
|
return getWireguardConnection(selection), nil
|
||||||
default:
|
default:
|
||||||
return connection, fmt.Errorf("%w: %s", ErrVPNTypeNotSupported, selection.VPN)
|
return connection, fmt.Errorf("VPN type not supported for custom provider: %s", selection.VPN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package custom
|
package custom
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -12,8 +11,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrExtractData = errors.New("failed extracting information from custom configuration file")
|
|
||||||
|
|
||||||
func (p *Provider) OpenVPNConfig(connection models.Connection,
|
func (p *Provider) OpenVPNConfig(connection models.Connection,
|
||||||
settings settings.OpenVPN, ipv6Supported bool,
|
settings settings.OpenVPN, ipv6Supported bool,
|
||||||
) (lines []string) {
|
) (lines []string) {
|
||||||
|
|||||||
@@ -3,13 +3,10 @@ package updater
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
|
||||||
|
|
||||||
type apiData struct {
|
type apiData struct {
|
||||||
Servers []apiServer `json:"servers"`
|
Servers []apiServer `json:"servers"`
|
||||||
}
|
}
|
||||||
@@ -42,8 +39,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
|||||||
|
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
_ = response.Body.Close()
|
_ = response.Body.Close()
|
||||||
return data, fmt.Errorf("%w: %d %s",
|
return data, fmt.Errorf("HTTP status code not OK: %d %s",
|
||||||
errHTTPStatusCodeNotOK, response.StatusCode, response.Status)
|
response.StatusCode, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user