diff --git a/morty.go b/morty.go index 5a61d28..6579c20 100644 --- a/morty.go +++ b/morty.go @@ -17,7 +17,8 @@ import ( "github.com/valyala/fasthttp" "golang.org/x/net/html" - "golang.org/x/text/encoding/charmap" + "golang.org/x/net/html/charset" + "golang.org/x/text/encoding" ) const ( @@ -122,6 +123,8 @@ input[type=checkbox]#mortytoggle:checked ~ div { display: none; } ` +var HTML_META_CONTENT_TYPE string = "" + func (p *Proxy) RequestHandler(ctx *fasthttp.RequestCtx) { if appRequestHandler(ctx) { @@ -236,13 +239,17 @@ func (p *Proxy) RequestHandler(ctx *fasthttp.RequestCtx) { var responseBody []byte - if len(contentInfo) == 2 && bytes.Contains(contentInfo[1], []byte("ISO-8859-2")) && bytes.Contains(contentInfo[0], []byte("text")) { - var err error - responseBody, err = charmap.ISO8859_2.NewDecoder().Bytes(resp.Body()) - if err != nil { - // HTTP status code 503 : Service Unavailable - p.serveMainPage(ctx, 503, err) - return + if len(contentInfo) == 2 && bytes.Contains(contentInfo[0], []byte("text")) { + e, ename, _ := charset.DetermineEncoding(resp.Body(), string(contentType)) + if (e != encoding.Nop) && (!strings.EqualFold("utf-8", ename)) { + responseBody, err = e.NewDecoder().Bytes(resp.Body()) + if err != nil { + // HTTP status code 503 : Service Unavailable + p.serveMainPage(ctx, 503, err) + return + } + } else { + responseBody = resp.Body() } } else { responseBody = resp.Body() @@ -325,7 +332,6 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) { unsafeElements := make([][]byte, 0, 8) state := STATE_DEFAULT - for { token := decoder.Next() if token == html.ErrorToken { @@ -353,11 +359,12 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) { if bytes.Equal(tag, []byte("base")) { for { attrName, attrValue, moreAttr := decoder.TagAttr() - if bytes.Equal(attrName, []byte("href")) { - parsedURI, err := url.Parse(string(attrValue)) - if err == nil { - rc.BaseURL = parsedURI - } + if !bytes.Equal(attrName, []byte("href")) { + continue + } + parsedURI, err := url.Parse(string(attrValue)) + if err == nil { + rc.BaseURL = parsedURI } if !moreAttr { break @@ -388,14 +395,15 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) { break } + if bytes.Equal(tag, []byte("meta")) { + sanitizeMetaTag(rc, out, attrs) + break + } + fmt.Fprintf(out, "<%s", tag) if hasAttrs { - if bytes.Equal(tag, []byte("meta")) { - sanitizeMetaAttrs(rc, out, attrs) - } else { - sanitizeAttrs(rc, out, attrs) - } + sanitizeAttrs(rc, out, attrs) } if token == html.SelfClosingTagToken { @@ -407,6 +415,10 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) { } } + if bytes.Equal(tag, []byte("head")) { + fmt.Fprintf(out, HTML_META_CONTENT_TYPE) + } + if bytes.Equal(tag, []byte("form")) { var formURL *url.URL for _, attr := range attrs { @@ -504,7 +516,7 @@ func sanitizeLinkTag(rc *RequestConfig, out io.Writer, attrs [][][]byte) { } } -func sanitizeMetaAttrs(rc *RequestConfig, out io.Writer, attrs [][][]byte) { +func sanitizeMetaTag(rc *RequestConfig, out io.Writer, attrs [][][]byte) { var http_equiv []byte var content []byte @@ -517,8 +529,17 @@ func sanitizeMetaAttrs(rc *RequestConfig, out io.Writer, attrs [][][]byte) { if bytes.Equal(attrName, []byte("content")) { content = attrValue } + if bytes.Equal(attrName, []byte("charset")) { + // exclude + return + } } + if bytes.Equal(http_equiv, []byte("content-type")) { + return + } + + out.Write([]byte("")) } func sanitizeAttrs(rc *RequestConfig, out io.Writer, attrs [][][]byte) {