From 4ea2337668f5af22516b2356323a27cbb825de03 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 5 May 2026 21:15:28 +0000 Subject: [PATCH] feat(dns): re-introduce `DNS_SERVER` option - force to set `DNS_UPSTREAM_RESOLVER_TYPE=plain` to avoid any confusion/security hole - force to set `DNS_UPSTREAM_PLAIN_ADDRESSES` to addresses only with port 53 --- Dockerfile | 1 + internal/configuration/settings/deprecated.go | 6 +-- internal/configuration/settings/dns.go | 51 ++++++++++++++++++- internal/dns/run.go | 34 ++++++++----- internal/dns/setup.go | 23 +++++++++ 5 files changed, 97 insertions(+), 18 deletions(-) diff --git a/Dockerfile b/Dockerfile index 281e46ee..6effc70f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -209,6 +209,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ HEALTH_SMALL_CHECK_TYPE=icmp \ HEALTH_RESTART_VPN=on \ # DNS + DNS_SERVER=on \ DNS_UPSTREAM_RESOLVER_TYPE=DoT \ # Note: DNS_UPSTREAM_RESOLVERS defaults to cloudflare in code if DNS_UPSTREAM_PLAIN_ADDRESSES is empty DNS_UPSTREAM_RESOLVERS= \ diff --git a/internal/configuration/settings/deprecated.go b/internal/configuration/settings/deprecated.go index 5da27053..70a416e5 100644 --- a/internal/configuration/settings/deprecated.go +++ b/internal/configuration/settings/deprecated.go @@ -14,10 +14,8 @@ func readObsolete(r *reader.Reader) (warnings []string) { "DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.", "HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete", "HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete", - "DNS_SERVER": "DNS_SERVER is obsolete because the forwarding server is always enabled.", - "DOT": "DOT is obsolete because the forwarding server is always enabled.", - "DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because the forwarding server is always used and " + - "forwards local names to private DNS resolvers found in /etc/resolv.conf", + "DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because you should use the built-in server which now " + + "forwards local names to private DNS resolvers found in /etc/resolv.conf at container start", } sortedKeys := maps.Keys(keyToMessage) slices.Sort(sortedKeys) diff --git a/internal/configuration/settings/dns.go b/internal/configuration/settings/dns.go index 789e5324..9dcb60a8 100644 --- a/internal/configuration/settings/dns.go +++ b/internal/configuration/settings/dns.go @@ -20,6 +20,9 @@ const ( // DNS contains settings to configure DNS. type DNS struct { + // ServerEnabled indicates if the DNS server should be enabled. + // It defaults to true and cannot be nil in the internal state. + ServerEnabled *bool `json:"enabled"` // UpstreamType can be [DNSUpstreamTypeDot], [DNSUpstreamTypeDoh] // or [DNSUpstreamTypePlain]. It defaults to [DNSUpstreamTypeDot]. UpstreamType string `json:"upstream_type"` @@ -52,6 +55,13 @@ func (d DNS) validate() (err error) { return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType) } + if !*d.ServerEnabled { + err = d.validateForServerOff() + if err != nil { + return err + } + } + const minUpdatePeriod = 30 * time.Second if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod { return fmt.Errorf("update period is too short: %s must be bigger than %s", @@ -90,8 +100,26 @@ func (d DNS) validate() (err error) { return nil } +func (d DNS) validateForServerOff() (err error) { + switch { + case d.UpstreamType != DNSUpstreamTypePlain: + return fmt.Errorf("upstream type %s must be %s if the built-in DNS server is disabled", + d.UpstreamType, DNSUpstreamTypePlain) + case len(d.UpstreamPlainAddresses) == 0: + return fmt.Errorf("if DNS is disabled, at least one upstream plain address must be set") + } + for _, addrPort := range d.UpstreamPlainAddresses { + const defaultDNSPort = 53 + if addrPort.Port() != defaultDNSPort { + return fmt.Errorf("invalid DNS port in %s: must be %d", addrPort, defaultDNSPort) + } + } + return nil +} + func (d *DNS) Copy() (copied DNS) { return DNS{ + ServerEnabled: gosettings.CopyPointer(d.ServerEnabled), UpstreamType: d.UpstreamType, UpdatePeriod: gosettings.CopyPointer(d.UpdatePeriod), Providers: gosettings.CopySlice(d.Providers), @@ -106,6 +134,7 @@ func (d *DNS) Copy() (copied DNS) { // settings object with any field set in the other // settings. func (d *DNS) overrideWith(other DNS) { + d.ServerEnabled = gosettings.OverrideWithPointer(d.ServerEnabled, other.ServerEnabled) d.UpstreamType = gosettings.OverrideWithComparable(d.UpstreamType, other.UpstreamType) d.UpdatePeriod = gosettings.OverrideWithPointer(d.UpdatePeriod, other.UpdatePeriod) d.Providers = gosettings.OverrideWithSlice(d.Providers, other.Providers) @@ -116,7 +145,12 @@ func (d *DNS) overrideWith(other DNS) { } func (d *DNS) setDefaults() { - d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, DNSUpstreamTypeDot) + d.ServerEnabled = gosettings.DefaultPointer(d.ServerEnabled, true) + defaultUpstreamType := DNSUpstreamTypeDot + if !*d.ServerEnabled { + defaultUpstreamType = DNSUpstreamTypePlain + } + d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, defaultUpstreamType) const defaultUpdatePeriod = 24 * time.Hour d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod) d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{}) @@ -139,6 +173,14 @@ func (d DNS) String() string { func (d DNS) toLinesNode() (node *gotree.Node) { node = gotree.New("DNS settings:") + if !*d.ServerEnabled { + plainServers := node.Append("Plain DNS servers to use directly:") + for _, addr := range d.UpstreamPlainAddresses { + plainServers.Append(addr.String()) + } + return node + } + node.Appendf("Upstream resolver type: %s", d.UpstreamType) upstreamResolvers := node.Append("Upstream resolvers:") @@ -174,6 +216,11 @@ func (d DNS) toLinesNode() (node *gotree.Node) { } func (d *DNS) read(r *reader.Reader) (err error) { + d.ServerEnabled, err = r.BoolPtr("DNS_SERVER", reader.RetroKeys("DOT")) + if err != nil { + return err + } + d.UpstreamType = r.String("DNS_UPSTREAM_RESOLVER_TYPE") d.UpdatePeriod, err = r.DurationPtr("DNS_UPDATE_PERIOD") @@ -207,7 +254,7 @@ func (d *DNS) read(r *reader.Reader) (err error) { } func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) { - // If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_TYPE=plain + // If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_RESOLVER_TYPE=plain // for these to be used. This is an added safety measure to reduce misunderstandings, and // reduce odd settings overrides. d.UpstreamPlainAddresses, err = r.CSVNetipAddrPorts("DNS_UPSTREAM_PLAIN_ADDRESSES") diff --git a/internal/dns/run.go b/internal/dns/run.go index 7f509168..658d71f9 100644 --- a/internal/dns/run.go +++ b/internal/dns/run.go @@ -33,9 +33,22 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { for { settings = l.GetSettings() var err error - runError, err = l.setupServer(ctx, settings) - if err == nil { - break + if *settings.ServerEnabled { //nolint:nestif + runError, err = l.setupServer(ctx, settings) + if err == nil { + l.logger.Infof("ready and using DNS server with %s upstream resolvers", settings.UpstreamType) + err = l.updateFiles(ctx, settings) + if err != nil { + l.logger.Warn("downloading block lists failed, skipping: " + err.Error()) + } + break + } + } else { + err = l.usePlainServers(settings.UpstreamPlainAddresses) + if err == nil { + l.logger.Infof("ready and using plain DNS resolvers: %v", settings.UpstreamPlainAddresses) + break + } } l.signalOrSetStatus(constants.Crashed) @@ -46,12 +59,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { } l.backoffTime = defaultBackoffTime - l.logger.Infof("ready and using DNS server with %s upstream resolvers", settings.UpstreamType) - - err = l.updateFiles(ctx, settings) - if err != nil { - l.logger.Warn("downloading block lists failed, skipping: " + err.Error()) - } l.signalOrSetStatus(constants.Running) l.userTrigger = false @@ -74,13 +81,13 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo for { select { case <-ctx.Done(): - l.stopServer() + l.stopServerIfAny() // TODO revert OS and Go nameserver when exiting return true case <-l.stop: l.userTrigger = true l.logger.Info("stopping") - l.stopServer() + l.stopServerIfAny() l.stopped <- struct{}{} case <-l.start: l.userTrigger = true @@ -94,7 +101,10 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo } } -func (l *Loop) stopServer() { +func (l *Loop) stopServerIfAny() { + if l.server == nil { + return + } stopErr := l.server.Stop() if stopErr != nil { l.logger.Error("stopping server: " + stopErr.Error()) diff --git a/internal/dns/setup.go b/internal/dns/setup.go index 4776f0fd..a7a85197 100644 --- a/internal/dns/setup.go +++ b/internal/dns/setup.go @@ -3,6 +3,7 @@ package dns import ( "context" "fmt" + "net/netip" "github.com/qdm12/dns/v2/pkg/middlewares/filter/update" "github.com/qdm12/dns/v2/pkg/nameserver" @@ -45,3 +46,25 @@ func (l *Loop) setupServer(ctx context.Context, settings settings.DNS) (runError return runError, nil } + +func (l *Loop) usePlainServers(addrPorts []netip.AddrPort) (err error) { + nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{ + AddrPort: addrPorts[0], + }) + addresses := make([]netip.Addr, len(addrPorts)) + for i, addrPort := range addrPorts { + const defaultDNSPort = 53 + if addrPort.Port() != defaultDNSPort { + return fmt.Errorf("invalid DNS port: %d, must be %d", addrPort.Port(), defaultDNSPort) + } + addresses[i] = addrPort.Addr() + } + err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{ + IPs: addresses, + ResolvPath: l.resolvConf, + }) + if err != nil { + return fmt.Errorf("using DNS system wide: %w", err) + } + return nil +}