mirror of
https://github.com/pelican-dev/wings.git
synced 2026-05-06 08:56:47 -04:00
Implement wings security fixes (#156)
* Implement pterodactyl security fixes * Implement pterodactyl 292 changes (#158) * Implement pterodactyl 292 changes Add the same change as https://github.com/pterodactyl/wings/pull/292 This adds configuration for the `machone-id` file that is required by hytale Creates and manages machine-id files on a per-server basis Adds code to remove machine-id files when a server is deleted as well. It also adds a group file for use along with the passwd file Updated config for passwd Updated mounts to not set default except for the the correct default. * Update machine-id generation Moved machine-id generation code outside of server create only called during initial server creation Create machine-id file for servers that already exists if the file is missing. Makes sure tempdir is created on wings start * remove append removes the append where not needed * Implement pterodactyl security fixes
This commit is contained in:
committed by
GitHub
parent
1e7c8cea49
commit
eb6db925a6
@@ -66,6 +66,7 @@ func Configure(m *wserver.Manager, client remote.Client) *gin.Engine {
|
||||
protected.GET("/api/servers", getAllServers)
|
||||
protected.POST("/api/servers", postCreateServer)
|
||||
protected.DELETE("/api/transfers/:server", deleteTransfer)
|
||||
protected.POST("/api/deauthorize-user", postDeauthorizeUser)
|
||||
|
||||
// These are server specific routes, and require that the request be authorized, and
|
||||
// that the server exist on the Daemon.
|
||||
|
||||
@@ -303,6 +303,8 @@ func deleteServer(c *gin.Context) {
|
||||
// Adds any of the JTIs passed through in the body to the deny list for the websocket
|
||||
// preventing any JWT generated before the current time from being used to connect to
|
||||
// the socket or send along commands.
|
||||
//
|
||||
// deprecated: prefer /api/deauthorize-user
|
||||
func postServerDenyWSTokens(c *gin.Context) {
|
||||
var data struct {
|
||||
JTIs []string `json:"jtis"`
|
||||
|
||||
+91
-19
@@ -2,14 +2,17 @@ package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"emperror.dev/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/goccy/go-json"
|
||||
ws "github.com/gorilla/websocket"
|
||||
|
||||
"github.com/pelican-dev/wings/router/middleware"
|
||||
"github.com/pelican-dev/wings/router/websocket"
|
||||
"github.com/pelican-dev/wings/server"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var expectedCloseCodes = []int{
|
||||
@@ -25,6 +28,27 @@ func getServerWebsocket(c *gin.Context) {
|
||||
manager := middleware.ExtractManager(c)
|
||||
s, _ := manager.Get(c.Param("server"))
|
||||
|
||||
// Limit the total number of websockets that can be opened at any one time for
|
||||
// a server instance. This applies across all users connected to the server, and
|
||||
// is not applied on a per-user basis.
|
||||
//
|
||||
// todo: it would be great to make this per-user instead, but we need to modify
|
||||
// how we even request this endpoint in order for that to be possible. Some type
|
||||
// of signed identifier in the URL that is verified on this end and set by the
|
||||
// panel using a shared secret is likely the easiest option. The benefit of that
|
||||
// is that we can both scope things to the user before authentication, and also
|
||||
// verify that the JWT provided by the panel is assigned to the same user.
|
||||
if s.Websockets().Len() >= 30 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Too many open websocket connections.",
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Security-Policy", "default-src 'self'")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
|
||||
// Create a context that can be canceled when the user disconnects from this
|
||||
// socket that will also cancel listeners running in separate threads. If the
|
||||
// connection itself is terminated listeners using this context will also be
|
||||
@@ -37,36 +61,61 @@ func getServerWebsocket(c *gin.Context) {
|
||||
middleware.CaptureAndAbort(c, err)
|
||||
return
|
||||
}
|
||||
defer handler.Connection.Close()
|
||||
|
||||
// Track this open connection on the server so that we can close them all programmatically
|
||||
// if the server is deleted.
|
||||
s.Websockets().Push(handler.Uuid(), &cancel)
|
||||
handler.Logger().Debug("opening connection to server websocket")
|
||||
defer s.Websockets().Remove(handler.Uuid())
|
||||
|
||||
defer func() {
|
||||
s.Websockets().Remove(handler.Uuid())
|
||||
handler.Logger().Debug("closing connection to server websocket")
|
||||
}()
|
||||
|
||||
// If the server is deleted we need to send a close message to the connected client
|
||||
// so that they disconnect since there will be no more events sent along. Listen for
|
||||
// the request context being closed to break this loop, otherwise this routine will
|
||||
// be left hanging in the background.
|
||||
go func() {
|
||||
select {
|
||||
// When the main context is canceled (through disconnect, server deletion, or server
|
||||
// suspension) close the connection itself.
|
||||
case <-ctx.Done():
|
||||
break
|
||||
case <-s.Context().Done():
|
||||
_ = handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5))
|
||||
handler.Logger().Debug("closing connection to server websocket")
|
||||
if err := handler.Connection.Close(); err != nil {
|
||||
handler.Logger().WithError(err).Error("failed to close websocket connection")
|
||||
}
|
||||
break
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
j := websocket.Message{}
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
// If the server is deleted we need to send a close message to the connected client
|
||||
// so that they disconnect since there will be no more events sent along. Listen for
|
||||
// the request context being closed to break this loop, otherwise this routine will
|
||||
//be left hanging in the background.
|
||||
case <-s.Context().Done():
|
||||
cancel()
|
||||
break
|
||||
}
|
||||
}()
|
||||
|
||||
_, p, err := handler.Connection.ReadMessage()
|
||||
// Due to how websockets are handled we need to connect to the socket
|
||||
// and _then_ abort it if the server is suspended. You cannot capture
|
||||
// the HTTP response in the websocket client, thus we connect and then
|
||||
// immediately close with failure.
|
||||
if s.IsSuspended() {
|
||||
_ = handler.Connection.WriteMessage(ws.CloseMessage, ws.FormatCloseMessage(4409, "server is suspended"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// There is a separate rate limiter that applies to individual message types
|
||||
// within the actual websocket logic handler. _This_ rate limiter just exists
|
||||
// to avoid enormous floods of data through the socket since we need to parse
|
||||
// JSON each time. This rate limit realistically should never be hit since this
|
||||
// would require sending 50+ messages a second over the websocket (no more than
|
||||
// 10 per 200ms).
|
||||
var throttled bool
|
||||
rl := rate.NewLimiter(rate.Every(time.Millisecond*200), 10)
|
||||
|
||||
for {
|
||||
t, p, err := handler.Connection.ReadMessage()
|
||||
if err != nil {
|
||||
if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
|
||||
handler.Logger().WithField("error", err).Warn("error handling websocket message for server")
|
||||
@@ -74,16 +123,39 @@ func getServerWebsocket(c *gin.Context) {
|
||||
break
|
||||
}
|
||||
|
||||
if !rl.Allow() {
|
||||
if !throttled {
|
||||
throttled = true
|
||||
_ = handler.Connection.WriteJSON(websocket.Message{Event: websocket.ThrottledEvent, Args: []string{"global"}})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
throttled = false
|
||||
|
||||
// If the message isn't a format we expect, or the length of the message is far larger
|
||||
// than we'd ever expect, drop it. The websocket upgrader logic does enforce a maximum
|
||||
// _compressed_ message size of 4Kb but that could decompress to a much larger amount
|
||||
// of data.
|
||||
if t != ws.TextMessage || len(p) > 32_768 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Discard and JSON parse errors into the void and don't continue processing this
|
||||
// specific socket request. If we did a break here the client would get disconnected
|
||||
// from the socket, which is NOT what we want to do.
|
||||
var j websocket.Message
|
||||
if err := json.Unmarshal(p, &j); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
go func(msg websocket.Message) {
|
||||
if err := handler.HandleInbound(ctx, msg); err != nil {
|
||||
_ = handler.SendErrorJson(msg, err)
|
||||
if errors.Is(err, server.ErrSuspended) {
|
||||
cancel()
|
||||
} else {
|
||||
_ = handler.SendErrorJson(msg, err)
|
||||
}
|
||||
}
|
||||
}(j)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/pelican-dev/wings/config"
|
||||
"github.com/pelican-dev/wings/internal/diagnostics"
|
||||
"github.com/pelican-dev/wings/router/middleware"
|
||||
"github.com/pelican-dev/wings/router/tokens"
|
||||
"github.com/pelican-dev/wings/server"
|
||||
"github.com/pelican-dev/wings/server/installer"
|
||||
"github.com/pelican-dev/wings/system"
|
||||
@@ -256,3 +257,33 @@ func postUpdateConfiguration(c *gin.Context) {
|
||||
Applied: true,
|
||||
})
|
||||
}
|
||||
|
||||
func postDeauthorizeUser(c *gin.Context) {
|
||||
var data struct {
|
||||
User string `json:"user"`
|
||||
Servers []string `json:"servers"`
|
||||
}
|
||||
|
||||
if err := c.BindJSON(&data); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// todo: disconnect websockets more gracefully
|
||||
m := middleware.ExtractManager(c)
|
||||
if len(data.Servers) > 0 {
|
||||
for _, uuid := range data.Servers {
|
||||
if s, ok := m.Get(uuid); ok {
|
||||
s.Websockets().CancelAll()
|
||||
s.Sftp().Cancel(data.User)
|
||||
tokens.DenyForServer(s.ID(), data.User)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, s := range m.All() {
|
||||
s.Websockets().CancelAll()
|
||||
s.Sftp().Cancel(data.User)
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@@ -24,16 +24,29 @@ var wingsBootTime = time.Now()
|
||||
// This is used to allow the Panel to revoke tokens en-masse for a given user & server
|
||||
// combination since the JTI for tokens is just MD5(user.id + server.uuid). When a server
|
||||
// is booted this listing is fetched from the panel and the Websocket is dynamically updated.
|
||||
//
|
||||
// deprecated: prefer use of userDenylist
|
||||
var denylist sync.Map
|
||||
var userDenylist sync.Map
|
||||
|
||||
// Adds a JTI to the denylist by marking any JWTs generated before the current time as
|
||||
// being invalid if they use the same JTI.
|
||||
//
|
||||
// deprecated: prefer the use of DenyForServer
|
||||
func DenyJTI(jti string) {
|
||||
log.WithField("jti", jti).Debugf("adding \"%s\" to JTI denylist", jti)
|
||||
|
||||
denylist.Store(jti, time.Now())
|
||||
}
|
||||
|
||||
// DenyForServer adds a user UUID to the denylist marking any existing JWTs issued
|
||||
// to the user as being invalid. This is associated with the user.
|
||||
func DenyForServer(s string, u string) {
|
||||
log.WithField("user_uuid", u).WithField("server_uuid", s).Debugf("denying all JWTs created at or before current time for user \"%s\"", u)
|
||||
|
||||
userDenylist.Store(strings.Join([]string{s, u}, ":"), time.Now())
|
||||
}
|
||||
|
||||
// WebsocketPayload defines the JWT payload for a websocket connection. This JWT is passed along to
|
||||
// the websocket after it has been connected to by sending an "auth" event.
|
||||
type WebsocketPayload struct {
|
||||
@@ -79,12 +92,21 @@ func (p *WebsocketPayload) Denylisted() bool {
|
||||
|
||||
// Finally, if the token was issued before a time that is currently denied for this
|
||||
// token instance, ignore the permissions response.
|
||||
//
|
||||
// This list is deprecated, but we maintain the check here so that custom instances
|
||||
// are able to continue working. We'll remove it in a future release.
|
||||
if t, ok := denylist.Load(p.JWTID); ok {
|
||||
if p.IssuedAt.Time.Before(t.(time.Time)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if t, ok := userDenylist.Load(strings.Join([]string{p.ServerUUID, p.UserUUID}, ":")); ok {
|
||||
if p.IssuedAt.Time.Before(t.(time.Time)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type LimiterBucket struct {
|
||||
mu sync.RWMutex
|
||||
limits map[Event]*rate.Limiter
|
||||
throttles map[Event]bool
|
||||
}
|
||||
|
||||
func (h *Handler) IsThrottled(e Event) bool {
|
||||
l := h.limiter.For(e)
|
||||
|
||||
h.limiter.mu.Lock()
|
||||
defer h.limiter.mu.Unlock()
|
||||
|
||||
if l.Allow() {
|
||||
h.limiter.throttles[e] = false
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// If not allowed, track the throttling and send an event over the wire
|
||||
// if one wasn't already sent in the same throttling period.
|
||||
if v, ok := h.limiter.throttles[e]; !v || !ok {
|
||||
h.limiter.throttles[e] = true
|
||||
h.Logger().WithField("event", e).Debug("throttling websocket due to event volume")
|
||||
|
||||
_ = h.unsafeSendJson(&Message{Event: ThrottledEvent, Args: []string{string(e)}})
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func NewLimiter() *LimiterBucket {
|
||||
return &LimiterBucket{
|
||||
limits: make(map[Event]*rate.Limiter, 4),
|
||||
throttles: make(map[Event]bool, 4),
|
||||
}
|
||||
}
|
||||
|
||||
// For returns the internal rate limiter for the given event type. In most
|
||||
// cases this is a shared rate limiter for events, but certain "heavy" or low-frequency
|
||||
// events implement their own limiters.
|
||||
func (l *LimiterBucket) For(e Event) *rate.Limiter {
|
||||
name := limiterName(e)
|
||||
|
||||
l.mu.RLock()
|
||||
if v, ok := l.limits[name]; ok {
|
||||
l.mu.RUnlock()
|
||||
return v
|
||||
}
|
||||
|
||||
l.mu.RUnlock()
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
limit, burst := limitValuesFor(e)
|
||||
l.limits[name] = rate.NewLimiter(limit, burst)
|
||||
|
||||
return l.limits[name]
|
||||
}
|
||||
|
||||
// limitValuesFor returns the underlying limit and burst value for the given event.
|
||||
func limitValuesFor(e Event) (rate.Limit, int) {
|
||||
// Twice every five seconds.
|
||||
if e == AuthenticationEvent || e == SendServerLogsEvent {
|
||||
return rate.Every(time.Second * 5), 2
|
||||
}
|
||||
|
||||
// 10 per second.
|
||||
if e == SendCommandEvent {
|
||||
return rate.Every(time.Second), 10
|
||||
}
|
||||
|
||||
// 4 per second.
|
||||
return rate.Every(time.Second), 4
|
||||
}
|
||||
|
||||
func limiterName(e Event) Event {
|
||||
if e == AuthenticationEvent || e == SendServerLogsEvent || e == SendCommandEvent {
|
||||
return e
|
||||
}
|
||||
|
||||
return "_default"
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func (h *Handler) listenForServerEvents(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
var sendErr error
|
||||
message := Message{Event: e.Topic}
|
||||
message := Message{Event: Event(e.Topic)}
|
||||
if str, ok := e.Data.(string); ok {
|
||||
message.Args = []string{str}
|
||||
} else if b, ok := e.Data.([]byte); ok {
|
||||
@@ -147,7 +147,7 @@ func (h *Handler) listenForServerEvents(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
}
|
||||
onError(message.Event, sendErr)
|
||||
onError(string(message.Event), sendErr)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package websocket
|
||||
|
||||
type Event string
|
||||
|
||||
const (
|
||||
AuthenticationSuccessEvent = "auth success"
|
||||
TokenExpiringEvent = "token expiring"
|
||||
@@ -11,11 +13,12 @@ const (
|
||||
SendStatsEvent = "send stats"
|
||||
ErrorEvent = "daemon error"
|
||||
JwtErrorEvent = "jwt error"
|
||||
ThrottledEvent = Event("throttled")
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
// The event to perform.
|
||||
Event string `json:"event"`
|
||||
Event Event `json:"event"`
|
||||
|
||||
// The data to pass along, only used by power/command currently. Other requests
|
||||
// should either omit the field or pass an empty value as it is ignored.
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pelican-dev/wings/internal/models"
|
||||
|
||||
"emperror.dev/errors"
|
||||
"github.com/apex/log"
|
||||
"github.com/gbrlsnchs/jwt/v3"
|
||||
@@ -23,6 +21,7 @@ import (
|
||||
"github.com/pelican-dev/wings/config"
|
||||
"github.com/pelican-dev/wings/environment"
|
||||
"github.com/pelican-dev/wings/environment/docker"
|
||||
"github.com/pelican-dev/wings/internal/models"
|
||||
"github.com/pelican-dev/wings/router/tokens"
|
||||
"github.com/pelican-dev/wings/server"
|
||||
)
|
||||
@@ -46,6 +45,7 @@ type Handler struct {
|
||||
server *server.Server
|
||||
ra server.RequestActivity
|
||||
uuid uuid.UUID
|
||||
limiter *LimiterBucket
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -84,6 +84,7 @@ func NewTokenPayload(token []byte) (*tokens.WebsocketPayload, error) {
|
||||
// GetHandler returns a new websocket handler using the context provided.
|
||||
func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request, c *gin.Context) (*Handler, error) {
|
||||
upgrader := websocket.Upgrader{
|
||||
EnableCompression: true,
|
||||
// Ensure that the websocket request is originating from the Panel itself,
|
||||
// and not some other location.
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
@@ -110,12 +111,16 @@ func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request, c *gin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.SetReadLimit(4096)
|
||||
_ = conn.SetCompressionLevel(5)
|
||||
|
||||
return &Handler{
|
||||
Connection: conn,
|
||||
jwt: nil,
|
||||
server: s,
|
||||
ra: s.NewRequestActivity("", c.ClientIP()),
|
||||
uuid: u,
|
||||
limiter: NewLimiter(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -150,7 +155,7 @@ func (h *Handler) SendJson(v Message) error {
|
||||
|
||||
// If the user does not have permission to see backup events, do not emit
|
||||
// them over the socket.
|
||||
if strings.HasPrefix(v.Event, server.BackupCompletedEvent) {
|
||||
if strings.HasPrefix(string(v.Event), server.BackupCompletedEvent) {
|
||||
if !j.HasPermission(PermissionReceiveBackups) {
|
||||
return nil
|
||||
}
|
||||
@@ -277,6 +282,14 @@ func (h *Handler) setJwt(token *tokens.WebsocketPayload) {
|
||||
|
||||
// HandleInbound handles an inbound socket request and route it to the proper action.
|
||||
func (h *Handler) HandleInbound(ctx context.Context, m Message) error {
|
||||
if h.server.IsSuspended() {
|
||||
return server.ErrSuspended
|
||||
}
|
||||
|
||||
if h.IsThrottled(m.Event) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.Event != AuthenticationEvent {
|
||||
if err := h.TokenValid(); err != nil {
|
||||
h.unsafeSendJson(Message{
|
||||
@@ -287,6 +300,10 @@ func (h *Handler) HandleInbound(ctx context.Context, m Message) error {
|
||||
}
|
||||
}
|
||||
|
||||
if h.server.IsSuspended() {
|
||||
return server.ErrSuspended
|
||||
}
|
||||
|
||||
switch m.Event {
|
||||
case AuthenticationEvent:
|
||||
{
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/pelican-dev/wings/system"
|
||||
)
|
||||
|
||||
// Sftp returns the SFTP connection bag for the server instance. This bag tracks
|
||||
// all open SFTP connections by individual user and allows for a single user or
|
||||
// all users to be disconnected by other processes.
|
||||
func (s *Server) Sftp() *system.ContextBag {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if s.sftpBag == nil {
|
||||
s.sftpBag = system.NewContextBag(s.Context())
|
||||
}
|
||||
|
||||
return s.sftpBag
|
||||
}
|
||||
+13
-2
@@ -70,6 +70,7 @@ type Server struct {
|
||||
// The console throttler instance used to control outputs.
|
||||
throttler *ConsoleThrottle
|
||||
throttleOnce sync.Once
|
||||
sftpBag *system.ContextBag
|
||||
|
||||
// Tracks open websocket connections for the server.
|
||||
wsBag *WebsocketBag
|
||||
@@ -191,7 +192,7 @@ func parseInvocation(invocation string, envvars map[string]interface{}, memory i
|
||||
invocation = strings.Replace(invocation, segment, tempSegments[i], 1)
|
||||
}
|
||||
|
||||
// Replace the placeholders outside of protected segments
|
||||
// Replace the placeholders outside protected segments
|
||||
invocation = strings.ReplaceAll(invocation, placeholder, fmt.Sprint(varval))
|
||||
|
||||
// Restore protected segments
|
||||
@@ -201,6 +202,10 @@ func parseInvocation(invocation string, envvars map[string]interface{}, memory i
|
||||
}
|
||||
|
||||
// Replace the defaults with their configured values.
|
||||
// and any connected SFTP clients. We don't need to worry about revoking any JWTs
|
||||
// here since they'll be blocked from re-connecting to the websocket anyways. This
|
||||
// just forces the client to disconnect and attempt to reconnect (rather than waiting
|
||||
// on them to send a message and hit that disconnect logic).
|
||||
invocation = strings.ReplaceAll(invocation, "${SERVER_PORT}", strconv.Itoa(port))
|
||||
invocation = strings.ReplaceAll(invocation, "${SERVER_MEMORY}", strconv.Itoa(int(memory)))
|
||||
invocation = strings.ReplaceAll(invocation, "${SERVER_IP}", ip)
|
||||
@@ -263,11 +268,17 @@ func (s *Server) Sync() error {
|
||||
|
||||
s.SyncWithEnvironment()
|
||||
|
||||
// If the server is suspended immediately disconnect all open websocket connections.
|
||||
if s.IsSuspended() {
|
||||
s.Websockets().CancelAll()
|
||||
s.Sftp().CancelAll()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncWithConfiguration accepts a configuration object for a server and will
|
||||
// sync all of the values with the existing server state. This only replaces the
|
||||
// sync all values with the existing server state. This only replaces the
|
||||
// existing configuration and process configuration for the server. The
|
||||
// underlying environment will not be affected. This is because this function
|
||||
// can be called from scoped where the server may not be fully initialized,
|
||||
|
||||
@@ -25,6 +25,13 @@ func (s *Server) Websockets() *WebsocketBag {
|
||||
return s.wsBag
|
||||
}
|
||||
|
||||
func (w *WebsocketBag) Len() int {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return len(w.conns)
|
||||
}
|
||||
|
||||
// Push adds a new websocket connection to the end of the stack.
|
||||
func (w *WebsocketBag) Push(u uuid.UUID, cancel *context.CancelFunc) {
|
||||
w.mu.Lock()
|
||||
|
||||
+6
-2
@@ -107,7 +107,7 @@ func (h *Handler) Filewrite(request *sftp.Request) (io.WriterAt, error) {
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
|
||||
if err := h.fs.IsIgnored(request.Filepath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func (h *Handler) Filecmd(request *sftp.Request) error {
|
||||
if err := h.fs.IsIgnored(request.Filepath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
switch request.Method {
|
||||
// Allows a user to make changes to the permissions of a given file or directory
|
||||
// on their server using their SFTP client.
|
||||
@@ -312,3 +312,7 @@ func (h *Handler) can(permission string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Handler) User() string {
|
||||
return h.events.user
|
||||
}
|
||||
|
||||
+28
-25
@@ -126,10 +126,10 @@ func (c *SFTPServer) AcceptInbound(conn net.Conn, config *ssh.ServerConfig) erro
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
for ch := range chans {
|
||||
// If its not a session channel we just move on because its not something we
|
||||
// If not a session channel we just move on because it's not something we
|
||||
// know how to handle at this point.
|
||||
if ch.ChannelType() != "session" {
|
||||
ch.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
_ = ch.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -143,37 +143,40 @@ func (c *SFTPServer) AcceptInbound(conn net.Conn, config *ssh.ServerConfig) erro
|
||||
// Channels have a type that is dependent on the protocol. For SFTP
|
||||
// this is "subsystem" with a payload that (should) be "sftp". Discard
|
||||
// anything else we receive ("pty", "shell", etc)
|
||||
req.Reply(req.Type == "subsystem" && string(req.Payload[4:]) == "sftp", nil)
|
||||
_ = req.Reply(req.Type == "subsystem" && string(req.Payload[4:]) == "sftp", nil)
|
||||
}
|
||||
}(requests)
|
||||
|
||||
// If no UUID has been set on this inbound request then we can assume we
|
||||
// have screwed up something in the authentication code. This is a sanity
|
||||
// check, but should never be encountered (ideally...).
|
||||
//
|
||||
// This will also attempt to match a specific server out of the global server
|
||||
// store and return nil if there is no match.
|
||||
uuid := sconn.Permissions.Extensions["uuid"]
|
||||
srv := c.manager.Find(func(s *server.Server) bool {
|
||||
if uuid == "" {
|
||||
return false
|
||||
if srv, ok := c.manager.Get(sconn.Permissions.Extensions["uuid"]); ok {
|
||||
if err := c.Handle(sconn, srv, channel); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.ID() == uuid
|
||||
})
|
||||
if srv == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Spin up a SFTP server instance for the authenticated user's server allowing
|
||||
// them access to the underlying filesystem.
|
||||
handler, err := NewHandler(sconn, srv)
|
||||
if err != nil {
|
||||
return errors.WithStackIf(err)
|
||||
}
|
||||
rs := sftp.NewRequestServer(channel, handler.Handlers())
|
||||
if err := rs.Serve(); err == io.EOF {
|
||||
// Handle spins up a SFTP server instance for the authenticated user's server allowing
|
||||
// them access to the underlying filesystem.
|
||||
func (c *SFTPServer) Handle(conn *ssh.ServerConn, srv *server.Server, channel ssh.Channel) error {
|
||||
handler, err := NewHandler(conn, srv)
|
||||
if err != nil {
|
||||
return errors.WithStackIf(err)
|
||||
}
|
||||
|
||||
ctx := srv.Sftp().Context(handler.User())
|
||||
rs := sftp.NewRequestServer(channel, handler.Handlers())
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
srv.Log().WithField("user", conn.User()).Warn("sftp: terminating active session")
|
||||
_ = rs.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := rs.Serve(); err == io.EOF {
|
||||
_ = rs.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ctxHolder struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type ContextBag struct {
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
items map[string]ctxHolder
|
||||
}
|
||||
|
||||
func NewContextBag(ctx context.Context) *ContextBag {
|
||||
return &ContextBag{ctx: ctx, items: make(map[string]ctxHolder)}
|
||||
}
|
||||
|
||||
// Context returns a context for the given key. If a value already exists in the
|
||||
// internal map it is returned, otherwise a new cancelable context is returned.
|
||||
// This context is shared between all callers until the cancel function is called
|
||||
// by calling Cancel or CancelAll.
|
||||
func (cb *ContextBag) Context(key string) context.Context {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
if _, ok := cb.items[key]; !ok {
|
||||
ctx, cancel := context.WithCancel(cb.ctx)
|
||||
cb.items[key] = ctxHolder{ctx, cancel}
|
||||
}
|
||||
|
||||
return cb.items[key].ctx
|
||||
}
|
||||
|
||||
func (cb *ContextBag) Cancel(key string) {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
if v, ok := cb.items[key]; ok {
|
||||
v.cancel()
|
||||
delete(cb.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *ContextBag) CancelAll() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
for _, v := range cb.items {
|
||||
v.cancel()
|
||||
}
|
||||
|
||||
cb.items = make(map[string]ctxHolder)
|
||||
}
|
||||
Reference in New Issue
Block a user