
for convenience. Signed-off-by: Izuru Yakumo <yakumo.izuru@chaotic.ninja> git-svn-id: https://svn.yakumo.dev/yakumo.izuru/suika/trunk@822 f0ae65fe-ee39-954e-97ec-027ff2717ef4
360 lines
7.0 KiB
Go
360 lines
7.0 KiB
Go
package bare
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
"sync"
|
|
)
|
|
|
|
// A type which implements this interface will be responsible for unmarshaling
|
|
// itself when encountered.
|
|
type Unmarshalable interface {
|
|
Unmarshal(r *Reader) error
|
|
}
|
|
|
|
// Unmarshals a BARE message into val, which must be a pointer to a value of
|
|
// the message type.
|
|
func Unmarshal(data []byte, val interface{}) error {
|
|
b := bytes.NewReader(data)
|
|
r := NewReader(b)
|
|
return UnmarshalBareReader(r, val)
|
|
}
|
|
|
|
// Unmarshals a BARE message into value (val, which must be a pointer), from a
|
|
// reader. See Unmarshal for details.
|
|
func UnmarshalReader(r io.Reader, val interface{}) error {
|
|
r = newLimitedReader(r)
|
|
return UnmarshalBareReader(NewReader(r), val)
|
|
}
|
|
|
|
type decodeFunc func(r *Reader, v reflect.Value) error
|
|
|
|
var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc
|
|
|
|
func UnmarshalBareReader(r *Reader, val interface{}) error {
|
|
t := reflect.TypeOf(val)
|
|
v := reflect.ValueOf(val)
|
|
if t.Kind() != reflect.Ptr {
|
|
return errors.New("Expected val to be pointer type")
|
|
}
|
|
|
|
return getDecoder(t.Elem())(r, v.Elem())
|
|
}
|
|
|
|
// get decoder from cache
|
|
func getDecoder(t reflect.Type) decodeFunc {
|
|
if f, ok := decodeFuncCache.Load(t); ok {
|
|
return f.(decodeFunc)
|
|
}
|
|
|
|
f := decoderFunc(t)
|
|
decodeFuncCache.Store(t, f)
|
|
return f
|
|
}
|
|
|
|
var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()
|
|
|
|
func decoderFunc(t reflect.Type) decodeFunc {
|
|
if reflect.PtrTo(t).Implements(unmarshalableInterface) {
|
|
return func(r *Reader, v reflect.Value) error {
|
|
uv := v.Addr().Interface().(Unmarshalable)
|
|
return uv.Unmarshal(r)
|
|
}
|
|
}
|
|
|
|
if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
|
|
return decodeUnion(t)
|
|
}
|
|
|
|
switch t.Kind() {
|
|
case reflect.Ptr:
|
|
return decodeOptional(t.Elem())
|
|
case reflect.Struct:
|
|
return decodeStruct(t)
|
|
case reflect.Array:
|
|
return decodeArray(t)
|
|
case reflect.Slice:
|
|
return decodeSlice(t)
|
|
case reflect.Map:
|
|
return decodeMap(t)
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
return decodeUint
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return decodeInt
|
|
case reflect.Float32, reflect.Float64:
|
|
return decodeFloat
|
|
case reflect.Bool:
|
|
return decodeBool
|
|
case reflect.String:
|
|
return decodeString
|
|
}
|
|
|
|
return func(r *Reader, v reflect.Value) error {
|
|
return &UnsupportedTypeError{v.Type()}
|
|
}
|
|
}
|
|
|
|
func decodeOptional(t reflect.Type) decodeFunc {
|
|
return func(r *Reader, v reflect.Value) error {
|
|
s, err := r.ReadU8()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if s > 1 {
|
|
return fmt.Errorf("Invalid optional value: %#x", s)
|
|
}
|
|
|
|
if s == 0 {
|
|
return nil
|
|
}
|
|
|
|
v.Set(reflect.New(t))
|
|
return getDecoder(t)(r, v.Elem())
|
|
}
|
|
}
|
|
|
|
func decodeStruct(t reflect.Type) decodeFunc {
|
|
n := t.NumField()
|
|
decoders := make([]decodeFunc, n)
|
|
for i := 0; i < n; i++ {
|
|
field := t.Field(i)
|
|
if field.Tag.Get("bare") == "-" {
|
|
continue
|
|
}
|
|
decoders[i] = getDecoder(field.Type)
|
|
}
|
|
|
|
return func(r *Reader, v reflect.Value) error {
|
|
for i := 0; i < n; i++ {
|
|
if decoders[i] == nil {
|
|
continue
|
|
}
|
|
err := decoders[i](r, v.Field(i))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func decodeArray(t reflect.Type) decodeFunc {
|
|
f := getDecoder(t.Elem())
|
|
len := t.Len()
|
|
|
|
return func(r *Reader, v reflect.Value) error {
|
|
for i := 0; i < len; i++ {
|
|
err := f(r, v.Index(i))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func decodeSlice(t reflect.Type) decodeFunc {
|
|
elem := t.Elem()
|
|
f := getDecoder(elem)
|
|
|
|
return func(r *Reader, v reflect.Value) error {
|
|
len, err := r.ReadUint()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len > maxArrayLength {
|
|
return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength)
|
|
}
|
|
|
|
v.Set(reflect.MakeSlice(t, int(len), int(len)))
|
|
|
|
for i := 0; i < int(len); i++ {
|
|
if err := f(r, v.Index(i)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func decodeMap(t reflect.Type) decodeFunc {
|
|
keyType := t.Key()
|
|
keyf := getDecoder(keyType)
|
|
|
|
valueType := t.Elem()
|
|
valf := getDecoder(valueType)
|
|
|
|
return func(r *Reader, v reflect.Value) error {
|
|
size, err := r.ReadUint()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if size > maxMapSize {
|
|
return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize)
|
|
}
|
|
|
|
v.Set(reflect.MakeMapWithSize(t, int(size)))
|
|
|
|
key := reflect.New(keyType).Elem()
|
|
value := reflect.New(valueType).Elem()
|
|
|
|
for i := uint64(0); i < size; i++ {
|
|
if err := keyf(r, key); err != nil {
|
|
return err
|
|
}
|
|
|
|
if v.MapIndex(key).Kind() > reflect.Invalid {
|
|
return fmt.Errorf("Encountered duplicate map key: %v", key.Interface())
|
|
}
|
|
|
|
if err := valf(r, value); err != nil {
|
|
return err
|
|
}
|
|
|
|
v.SetMapIndex(key, value)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func decodeUnion(t reflect.Type) decodeFunc {
|
|
ut, ok := unionRegistry[t]
|
|
if !ok {
|
|
return func(r *Reader, v reflect.Value) error {
|
|
return fmt.Errorf("Union type %s is not registered", t.Name())
|
|
}
|
|
}
|
|
|
|
decoders := make(map[uint64]decodeFunc)
|
|
for tag, t := range ut.types {
|
|
t := t
|
|
f := getDecoder(t)
|
|
|
|
decoders[tag] = func(r *Reader, v reflect.Value) error {
|
|
nv := reflect.New(t)
|
|
if err := f(r, nv.Elem()); err != nil {
|
|
return err
|
|
}
|
|
|
|
v.Set(nv)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return func(r *Reader, v reflect.Value) error {
|
|
tag, err := r.ReadUint()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if f, ok := decoders[tag]; ok {
|
|
return f(r, v)
|
|
}
|
|
|
|
return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name())
|
|
}
|
|
}
|
|
|
|
func decodeUint(r *Reader, v reflect.Value) error {
|
|
var err error
|
|
switch getIntKind(v.Type()) {
|
|
case reflect.Uint:
|
|
var u uint64
|
|
u, err = r.ReadUint()
|
|
v.SetUint(u)
|
|
|
|
case reflect.Uint8:
|
|
var u uint8
|
|
u, err = r.ReadU8()
|
|
v.SetUint(uint64(u))
|
|
|
|
case reflect.Uint16:
|
|
var u uint16
|
|
u, err = r.ReadU16()
|
|
v.SetUint(uint64(u))
|
|
case reflect.Uint32:
|
|
var u uint32
|
|
u, err = r.ReadU32()
|
|
v.SetUint(uint64(u))
|
|
|
|
case reflect.Uint64:
|
|
var u uint64
|
|
u, err = r.ReadU64()
|
|
v.SetUint(uint64(u))
|
|
|
|
default:
|
|
panic("not an uint")
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func decodeInt(r *Reader, v reflect.Value) error {
|
|
var err error
|
|
switch getIntKind(v.Type()) {
|
|
case reflect.Int:
|
|
var i int64
|
|
i, err = r.ReadInt()
|
|
v.SetInt(i)
|
|
|
|
case reflect.Int8:
|
|
var i int8
|
|
i, err = r.ReadI8()
|
|
v.SetInt(int64(i))
|
|
|
|
case reflect.Int16:
|
|
var i int16
|
|
i, err = r.ReadI16()
|
|
v.SetInt(int64(i))
|
|
case reflect.Int32:
|
|
var i int32
|
|
i, err = r.ReadI32()
|
|
v.SetInt(int64(i))
|
|
|
|
case reflect.Int64:
|
|
var i int64
|
|
i, err = r.ReadI64()
|
|
v.SetInt(int64(i))
|
|
|
|
default:
|
|
panic("not an int")
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func decodeFloat(r *Reader, v reflect.Value) error {
|
|
var err error
|
|
switch v.Type().Kind() {
|
|
case reflect.Float32:
|
|
var f float32
|
|
f, err = r.ReadF32()
|
|
v.SetFloat(float64(f))
|
|
case reflect.Float64:
|
|
var f float64
|
|
f, err = r.ReadF64()
|
|
v.SetFloat(f)
|
|
default:
|
|
panic("not a float")
|
|
}
|
|
return err
|
|
}
|
|
|
|
func decodeBool(r *Reader, v reflect.Value) error {
|
|
b, err := r.ReadBool()
|
|
v.SetBool(b)
|
|
return err
|
|
}
|
|
|
|
func decodeString(r *Reader, v reflect.Value) error {
|
|
s, err := r.ReadString()
|
|
v.SetString(s)
|
|
return err
|
|
}
|