diff --git a/morty.go b/morty.go
index 5a61d28..6579c20 100644
--- a/morty.go
+++ b/morty.go
@@ -17,7 +17,8 @@ import (
"github.com/valyala/fasthttp"
"golang.org/x/net/html"
- "golang.org/x/text/encoding/charmap"
+ "golang.org/x/net/html/charset"
+ "golang.org/x/text/encoding"
)
const (
@@ -122,6 +123,8 @@ input[type=checkbox]#mortytoggle:checked ~ div { display: none; }
`
+var HTML_META_CONTENT_TYPE string = ""
+
func (p *Proxy) RequestHandler(ctx *fasthttp.RequestCtx) {
if appRequestHandler(ctx) {
@@ -236,13 +239,17 @@ func (p *Proxy) RequestHandler(ctx *fasthttp.RequestCtx) {
var responseBody []byte
- if len(contentInfo) == 2 && bytes.Contains(contentInfo[1], []byte("ISO-8859-2")) && bytes.Contains(contentInfo[0], []byte("text")) {
- var err error
- responseBody, err = charmap.ISO8859_2.NewDecoder().Bytes(resp.Body())
- if err != nil {
- // HTTP status code 503 : Service Unavailable
- p.serveMainPage(ctx, 503, err)
- return
+ if len(contentInfo) == 2 && bytes.Contains(contentInfo[0], []byte("text")) {
+ e, ename, _ := charset.DetermineEncoding(resp.Body(), string(contentType))
+ if (e != encoding.Nop) && (!strings.EqualFold("utf-8", ename)) {
+ responseBody, err = e.NewDecoder().Bytes(resp.Body())
+ if err != nil {
+ // HTTP status code 503 : Service Unavailable
+ p.serveMainPage(ctx, 503, err)
+ return
+ }
+ } else {
+ responseBody = resp.Body()
}
} else {
responseBody = resp.Body()
@@ -325,7 +332,6 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) {
unsafeElements := make([][]byte, 0, 8)
state := STATE_DEFAULT
-
for {
token := decoder.Next()
if token == html.ErrorToken {
@@ -353,11 +359,12 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) {
if bytes.Equal(tag, []byte("base")) {
for {
attrName, attrValue, moreAttr := decoder.TagAttr()
- if bytes.Equal(attrName, []byte("href")) {
- parsedURI, err := url.Parse(string(attrValue))
- if err == nil {
- rc.BaseURL = parsedURI
- }
+ if !bytes.Equal(attrName, []byte("href")) {
+ continue
+ }
+ parsedURI, err := url.Parse(string(attrValue))
+ if err == nil {
+ rc.BaseURL = parsedURI
}
if !moreAttr {
break
@@ -388,14 +395,15 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) {
break
}
+ if bytes.Equal(tag, []byte("meta")) {
+ sanitizeMetaTag(rc, out, attrs)
+ break
+ }
+
fmt.Fprintf(out, "<%s", tag)
if hasAttrs {
- if bytes.Equal(tag, []byte("meta")) {
- sanitizeMetaAttrs(rc, out, attrs)
- } else {
- sanitizeAttrs(rc, out, attrs)
- }
+ sanitizeAttrs(rc, out, attrs)
}
if token == html.SelfClosingTagToken {
@@ -407,6 +415,10 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) {
}
}
+ if bytes.Equal(tag, []byte("head")) {
+ fmt.Fprintf(out, HTML_META_CONTENT_TYPE)
+ }
+
if bytes.Equal(tag, []byte("form")) {
var formURL *url.URL
for _, attr := range attrs {
@@ -504,7 +516,7 @@ func sanitizeLinkTag(rc *RequestConfig, out io.Writer, attrs [][][]byte) {
}
}
-func sanitizeMetaAttrs(rc *RequestConfig, out io.Writer, attrs [][][]byte) {
+func sanitizeMetaTag(rc *RequestConfig, out io.Writer, attrs [][][]byte) {
var http_equiv []byte
var content []byte
@@ -517,8 +529,17 @@ func sanitizeMetaAttrs(rc *RequestConfig, out io.Writer, attrs [][][]byte) {
if bytes.Equal(attrName, []byte("content")) {
content = attrValue
}
+ if bytes.Equal(attrName, []byte("charset")) {
+ // exclude
+ return
+ }
}
+ if bytes.Equal(http_equiv, []byte("content-type")) {
+ return
+ }
+
+ out.Write([]byte(""))
}
func sanitizeAttrs(rc *RequestConfig, out io.Writer, attrs [][][]byte) {