
for convenience. Signed-off-by: Izuru Yakumo <yakumo.izuru@chaotic.ninja> git-svn-id: file:///srv/svn/repo/suika/trunk@822 f0ae65fe-ee39-954e-97ec-027ff2717ef4
286 lines
7.3 KiB
Go
286 lines
7.3 KiB
Go
package proxyproto
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
)
|
|
|
|
var (
|
|
lengthUnspec = uint16(0)
|
|
lengthV4 = uint16(12)
|
|
lengthV6 = uint16(36)
|
|
lengthUnix = uint16(216)
|
|
lengthUnspecBytes = func() []byte {
|
|
a := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(a, lengthUnspec)
|
|
return a
|
|
}()
|
|
lengthV4Bytes = func() []byte {
|
|
a := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(a, lengthV4)
|
|
return a
|
|
}()
|
|
lengthV6Bytes = func() []byte {
|
|
a := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(a, lengthV6)
|
|
return a
|
|
}()
|
|
lengthUnixBytes = func() []byte {
|
|
a := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(a, lengthUnix)
|
|
return a
|
|
}()
|
|
errUint16Overflow = errors.New("proxyproto: uint16 overflow")
|
|
)
|
|
|
|
type _ports struct {
|
|
SrcPort uint16
|
|
DstPort uint16
|
|
}
|
|
|
|
type _addr4 struct {
|
|
Src [4]byte
|
|
Dst [4]byte
|
|
SrcPort uint16
|
|
DstPort uint16
|
|
}
|
|
|
|
type _addr6 struct {
|
|
Src [16]byte
|
|
Dst [16]byte
|
|
_ports
|
|
}
|
|
|
|
type _addrUnix struct {
|
|
Src [108]byte
|
|
Dst [108]byte
|
|
}
|
|
|
|
func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
|
|
// Skip first 12 bytes (signature)
|
|
for i := 0; i < 12; i++ {
|
|
if _, err = reader.ReadByte(); err != nil {
|
|
return nil, ErrCantReadProtocolVersionAndCommand
|
|
}
|
|
}
|
|
|
|
header = new(Header)
|
|
header.Version = 2
|
|
|
|
// Read the 13th byte, protocol version and command
|
|
b13, err := reader.ReadByte()
|
|
if err != nil {
|
|
return nil, ErrCantReadProtocolVersionAndCommand
|
|
}
|
|
header.Command = ProtocolVersionAndCommand(b13)
|
|
if _, ok := supportedCommand[header.Command]; !ok {
|
|
return nil, ErrUnsupportedProtocolVersionAndCommand
|
|
}
|
|
|
|
// Read the 14th byte, address family and protocol
|
|
b14, err := reader.ReadByte()
|
|
if err != nil {
|
|
return nil, ErrCantReadAddressFamilyAndProtocol
|
|
}
|
|
header.TransportProtocol = AddressFamilyAndProtocol(b14)
|
|
// UNSPEC is only supported when LOCAL is set.
|
|
if header.TransportProtocol == UNSPEC && header.Command != LOCAL {
|
|
return nil, ErrUnsupportedAddressFamilyAndProtocol
|
|
}
|
|
|
|
// Make sure there are bytes available as specified in length
|
|
var length uint16
|
|
if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil {
|
|
return nil, ErrCantReadLength
|
|
}
|
|
if !header.validateLength(length) {
|
|
return nil, ErrInvalidLength
|
|
}
|
|
|
|
// Return early if the length is zero, which means that
|
|
// there's no address information and TLVs present for UNSPEC.
|
|
if length == 0 {
|
|
return header, nil
|
|
}
|
|
|
|
if _, err := reader.Peek(int(length)); err != nil {
|
|
return nil, ErrInvalidLength
|
|
}
|
|
|
|
// Length-limited reader for payload section
|
|
payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader)
|
|
|
|
// Read addresses and ports for protocols other than UNSPEC.
|
|
// Ignore address information for UNSPEC, and skip straight to read TLVs,
|
|
// since the length is greater than zero.
|
|
if header.TransportProtocol != UNSPEC {
|
|
if header.TransportProtocol.IsIPv4() {
|
|
var addr _addr4
|
|
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
|
|
return nil, ErrInvalidAddress
|
|
}
|
|
header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
|
|
header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
|
|
} else if header.TransportProtocol.IsIPv6() {
|
|
var addr _addr6
|
|
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
|
|
return nil, ErrInvalidAddress
|
|
}
|
|
header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
|
|
header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
|
|
} else if header.TransportProtocol.IsUnix() {
|
|
var addr _addrUnix
|
|
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
|
|
return nil, ErrInvalidAddress
|
|
}
|
|
|
|
network := "unix"
|
|
if header.TransportProtocol.IsDatagram() {
|
|
network = "unixgram"
|
|
}
|
|
|
|
header.SourceAddr = &net.UnixAddr{
|
|
Net: network,
|
|
Name: parseUnixName(addr.Src[:]),
|
|
}
|
|
header.DestinationAddr = &net.UnixAddr{
|
|
Net: network,
|
|
Name: parseUnixName(addr.Dst[:]),
|
|
}
|
|
}
|
|
}
|
|
|
|
// Copy bytes for optional Type-Length-Value vector
|
|
header.rawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice
|
|
if _, err = io.ReadFull(payloadReader, header.rawTLVs); err != nil && err != io.EOF {
|
|
return nil, err
|
|
}
|
|
|
|
return header, nil
|
|
}
|
|
|
|
func (header *Header) formatVersion2() ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
buf.Write(SIGV2)
|
|
buf.WriteByte(header.Command.toByte())
|
|
buf.WriteByte(header.TransportProtocol.toByte())
|
|
if header.TransportProtocol.IsUnspec() {
|
|
// For UNSPEC, write no addresses and ports but only TLVs if they are present
|
|
hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
buf.Write(hdrLen)
|
|
} else {
|
|
var addrSrc, addrDst []byte
|
|
if header.TransportProtocol.IsIPv4() {
|
|
hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
buf.Write(hdrLen)
|
|
sourceIP, destIP, _ := header.IPs()
|
|
addrSrc = sourceIP.To4()
|
|
addrDst = destIP.To4()
|
|
} else if header.TransportProtocol.IsIPv6() {
|
|
hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
buf.Write(hdrLen)
|
|
sourceIP, destIP, _ := header.IPs()
|
|
addrSrc = sourceIP.To16()
|
|
addrDst = destIP.To16()
|
|
} else if header.TransportProtocol.IsUnix() {
|
|
buf.Write(lengthUnixBytes)
|
|
sourceAddr, destAddr, ok := header.UnixAddrs()
|
|
if !ok {
|
|
return nil, ErrInvalidAddress
|
|
}
|
|
addrSrc = formatUnixName(sourceAddr.Name)
|
|
addrDst = formatUnixName(destAddr.Name)
|
|
}
|
|
|
|
if addrSrc == nil || addrDst == nil {
|
|
return nil, ErrInvalidAddress
|
|
}
|
|
buf.Write(addrSrc)
|
|
buf.Write(addrDst)
|
|
|
|
if sourcePort, destPort, ok := header.Ports(); ok {
|
|
portBytes := make([]byte, 2)
|
|
|
|
binary.BigEndian.PutUint16(portBytes, uint16(sourcePort))
|
|
buf.Write(portBytes)
|
|
|
|
binary.BigEndian.PutUint16(portBytes, uint16(destPort))
|
|
buf.Write(portBytes)
|
|
}
|
|
}
|
|
|
|
if len(header.rawTLVs) > 0 {
|
|
buf.Write(header.rawTLVs)
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func (header *Header) validateLength(length uint16) bool {
|
|
if header.TransportProtocol.IsIPv4() {
|
|
return length >= lengthV4
|
|
} else if header.TransportProtocol.IsIPv6() {
|
|
return length >= lengthV6
|
|
} else if header.TransportProtocol.IsUnix() {
|
|
return length >= lengthUnix
|
|
} else if header.TransportProtocol.IsUnspec() {
|
|
return length >= lengthUnspec
|
|
}
|
|
return false
|
|
}
|
|
|
|
// addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow.
|
|
func addTLVLen(cur []byte, tlvLen int) ([]byte, error) {
|
|
if tlvLen == 0 {
|
|
return cur, nil
|
|
}
|
|
curLen := binary.BigEndian.Uint16(cur)
|
|
newLen := int(curLen) + tlvLen
|
|
if newLen >= 1<<16 {
|
|
return nil, errUint16Overflow
|
|
}
|
|
a := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(a, uint16(newLen))
|
|
return a, nil
|
|
}
|
|
|
|
func newIPAddr(transport AddressFamilyAndProtocol, ip net.IP, port uint16) net.Addr {
|
|
if transport.IsStream() {
|
|
return &net.TCPAddr{IP: ip, Port: int(port)}
|
|
} else if transport.IsDatagram() {
|
|
return &net.UDPAddr{IP: ip, Port: int(port)}
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func parseUnixName(b []byte) string {
|
|
i := bytes.IndexByte(b, 0)
|
|
if i < 0 {
|
|
return string(b)
|
|
}
|
|
return string(b[:i])
|
|
}
|
|
|
|
func formatUnixName(name string) []byte {
|
|
n := int(lengthUnix) / 2
|
|
if len(name) >= n {
|
|
return []byte(name[:n])
|
|
}
|
|
pad := make([]byte, n-len(name))
|
|
return append([]byte(name), pad...)
|
|
}
|