giredore/internal/httpserver/checkallowedips.go

70 lines
1.7 KiB
Go

package httpserver
import (
// stdlib
"net"
"net/http"
"strings"
// local
"go.dev.pztrn.name/giredore/internal/configuration"
"go.dev.pztrn.name/giredore/internal/structs"
// other
"github.com/labstack/echo"
)
func checkAllowedIPs() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(ec echo.Context) error {
// Do nothing if request came not in "/_api" namespace.
if !strings.HasPrefix(ec.Request().RequestURI, "/_api") {
_ = next(ec)
return nil
}
// Get IPs and subnets from configuration and parse them
// into comparable things.
// If IP address was specified without network mask - assume /32.
var subnets []*net.IPNet
allowedIPs := configuration.Cfg.GetAllowedIPs()
for _, ip := range allowedIPs {
ipToParse := ip
if !strings.Contains(ip, "/") {
ipToParse = ip + "/32"
}
_, net, err := net.ParseCIDR(ipToParse)
if err != nil {
log.Error().Err(err).Str("subnet", ipToParse).Msg("Failed to parse CIDR. /_api/ endpoint won't be accessible, this should be fixed manually in configuration file!")
return ec.JSON(http.StatusInternalServerError, &structs.Reply{Status: structs.StatusFailure, Errors: []structs.Error{structs.ErrInvalidAllowedIPDefined}})
}
subnets = append(subnets, net)
}
// Check if requester's IP address are within allowed IP
// subnets.
ipToCheck := net.ParseIP(ec.RealIP())
var allowed bool
for _, subnet := range subnets {
if subnet.Contains(ipToCheck) {
allowed = true
break
}
}
if allowed {
_ = next(ec)
return nil
}
return ec.JSON(http.StatusBadRequest, &structs.Reply{Status: structs.StatusFailure, Errors: []structs.Error{structs.ErrIPAddressNotAllowed}})
}
}
}