yakumo.izuru 065e07da4b Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@chaotic.ninja>

git-svn-id: file:///srv/svn/repo/suika/trunk@822 f0ae65fe-ee39-954e-97ec-027ff2717ef4
2023-08-20 14:36:11 +00:00

281 lines
9.0 KiB
Go

// Package proxyproto implements Proxy Protocol (v1 and v2) parser and writer, as per specification:
// https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt
package proxyproto
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"time"
)
var (
// Protocol
SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'}
SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'}
ErrCantReadVersion1Header = errors.New("proxyproto: can't read version 1 header")
ErrVersion1HeaderTooLong = errors.New("proxyproto: version 1 header must be 107 bytes or less")
ErrLineMustEndWithCrlf = errors.New("proxyproto: version 1 header is invalid, must end with \\r\\n")
ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command")
ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol")
ErrCantReadLength = errors.New("proxyproto: can't read length")
ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address")
ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address")
ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present")
ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version")
ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command")
ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol")
ErrInvalidLength = errors.New("proxyproto: invalid length")
ErrInvalidAddress = errors.New("proxyproto: invalid address")
ErrInvalidPortNumber = errors.New("proxyproto: invalid port number")
ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one")
)
// Header is the placeholder for proxy protocol header.
type Header struct {
Version byte
Command ProtocolVersionAndCommand
TransportProtocol AddressFamilyAndProtocol
SourceAddr net.Addr
DestinationAddr net.Addr
rawTLVs []byte
}
// HeaderProxyFromAddrs creates a new PROXY header from a source and a
// destination address. If version is zero, the latest protocol version is
// used.
//
// The header is filled on a best-effort basis: if hints cannot be inferred
// from the provided addresses, the header will be left unspecified.
func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header {
if version < 1 || version > 2 {
version = 2
}
h := &Header{
Version: version,
Command: LOCAL,
TransportProtocol: UNSPEC,
}
switch sourceAddr := sourceAddr.(type) {
case *net.TCPAddr:
if _, ok := destAddr.(*net.TCPAddr); !ok {
break
}
if len(sourceAddr.IP.To4()) == net.IPv4len {
h.TransportProtocol = TCPv4
} else if len(sourceAddr.IP) == net.IPv6len {
h.TransportProtocol = TCPv6
}
case *net.UDPAddr:
if _, ok := destAddr.(*net.UDPAddr); !ok {
break
}
if len(sourceAddr.IP.To4()) == net.IPv4len {
h.TransportProtocol = UDPv4
} else if len(sourceAddr.IP) == net.IPv6len {
h.TransportProtocol = UDPv6
}
case *net.UnixAddr:
if _, ok := destAddr.(*net.UnixAddr); !ok {
break
}
switch sourceAddr.Net {
case "unix":
h.TransportProtocol = UnixStream
case "unixgram":
h.TransportProtocol = UnixDatagram
}
}
if h.TransportProtocol != UNSPEC {
h.Command = PROXY
h.SourceAddr = sourceAddr
h.DestinationAddr = destAddr
}
return h
}
func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) {
if !header.TransportProtocol.IsStream() {
return nil, nil, false
}
sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr)
destAddr, destOK := header.DestinationAddr.(*net.TCPAddr)
return sourceAddr, destAddr, sourceOK && destOK
}
func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) {
if !header.TransportProtocol.IsDatagram() {
return nil, nil, false
}
sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr)
destAddr, destOK := header.DestinationAddr.(*net.UDPAddr)
return sourceAddr, destAddr, sourceOK && destOK
}
func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) {
if !header.TransportProtocol.IsUnix() {
return nil, nil, false
}
sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr)
destAddr, destOK := header.DestinationAddr.(*net.UnixAddr)
return sourceAddr, destAddr, sourceOK && destOK
}
func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) {
if sourceAddr, destAddr, ok := header.TCPAddrs(); ok {
return sourceAddr.IP, destAddr.IP, true
} else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok {
return sourceAddr.IP, destAddr.IP, true
} else {
return nil, nil, false
}
}
func (header *Header) Ports() (sourcePort, destPort int, ok bool) {
if sourceAddr, destAddr, ok := header.TCPAddrs(); ok {
return sourceAddr.Port, destAddr.Port, true
} else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok {
return sourceAddr.Port, destAddr.Port, true
} else {
return 0, 0, false
}
}
// EqualTo returns true if headers are equivalent, false otherwise.
// Deprecated: use EqualsTo instead. This method will eventually be removed.
func (header *Header) EqualTo(otherHeader *Header) bool {
return header.EqualsTo(otherHeader)
}
// EqualsTo returns true if headers are equivalent, false otherwise.
func (header *Header) EqualsTo(otherHeader *Header) bool {
if otherHeader == nil {
return false
}
// TLVs only exist for version 2
if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) {
return false
}
if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol {
return false
}
// Return early for header with LOCAL command, which contains no address information
if header.Command == LOCAL {
return true
}
return header.SourceAddr.String() == otherHeader.SourceAddr.String() &&
header.DestinationAddr.String() == otherHeader.DestinationAddr.String()
}
// WriteTo renders a proxy protocol header in a format and writes it to an io.Writer.
func (header *Header) WriteTo(w io.Writer) (int64, error) {
buf, err := header.Format()
if err != nil {
return 0, err
}
return bytes.NewBuffer(buf).WriteTo(w)
}
// Format renders a proxy protocol header in a format to write over the wire.
func (header *Header) Format() ([]byte, error) {
switch header.Version {
case 1:
return header.formatVersion1()
case 2:
return header.formatVersion2()
default:
return nil, ErrUnknownProxyProtocolVersion
}
}
// TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol.
func (header *Header) TLVs() ([]TLV, error) {
return SplitTLVs(header.rawTLVs)
}
// SetTLVs sets the TLVs stored in this header. This method replaces any
// previous TLV.
func (header *Header) SetTLVs(tlvs []TLV) error {
raw, err := JoinTLVs(tlvs)
if err != nil {
return err
}
header.rawTLVs = raw
return nil
}
// Read identifies the proxy protocol version and reads the remaining of
// the header, accordingly.
//
// If proxy protocol header signature is not present, the reader buffer remains untouched
// and is safe for reading outside of this code.
//
// If proxy protocol header signature is present but an error is raised while processing
// the remaining header, assume the reader buffer to be in a corrupt state.
// Also, this operation will block until enough bytes are available for peeking.
func Read(reader *bufio.Reader) (*Header, error) {
// In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone.
b1, err := reader.Peek(1)
if err != nil {
if err == io.EOF {
return nil, ErrNoProxyProtocol
}
return nil, err
}
if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) {
signature, err := reader.Peek(5)
if err != nil {
if err == io.EOF {
return nil, ErrNoProxyProtocol
}
return nil, err
}
if bytes.Equal(signature[:5], SIGV1) {
return parseVersion1(reader)
}
signature, err = reader.Peek(12)
if err != nil {
if err == io.EOF {
return nil, ErrNoProxyProtocol
}
return nil, err
}
if bytes.Equal(signature[:12], SIGV2) {
return parseVersion2(reader)
}
}
return nil, ErrNoProxyProtocol
}
// ReadTimeout acts as Read but takes a timeout. If that timeout is reached, it's assumed
// there's no proxy protocol header.
func ReadTimeout(reader *bufio.Reader, timeout time.Duration) (*Header, error) {
type header struct {
h *Header
e error
}
read := make(chan *header, 1)
go func() {
h := &header{}
h.h, h.e = Read(reader)
read <- h
}()
timer := time.NewTimer(timeout)
select {
case result := <-read:
timer.Stop()
return result.h, result.e
case <-timer.C:
return nil, ErrNoProxyProtocol
}
}