diff --git a/morty.go b/morty.go index 3c886d3..795e051 100644 --- a/morty.go +++ b/morty.go @@ -103,7 +103,7 @@ type Proxy struct { type RequestConfig struct { Key []byte - baseURL *url.URL + BaseURL *url.URL } var HTML_FORM_EXTENSION string = `` @@ -196,7 +196,8 @@ func (p *Proxy) RequestHandler(ctx *fasthttp.RequestCtx) { case 301, 302, 303, 307, 308: loc := resp.Header.Peek("Location") if loc != nil { - url, err := proxifyURI(&RequestConfig{p.Key, parsedURI}, string(loc)) + rc := &RequestConfig{Key: p.Key, BaseURL: parsedURI} + url, err := rc.ProxifyURI(string(loc)) if err == nil { ctx.SetStatusCode(resp.StatusCode()) ctx.Response.Header.Add("Location", url) @@ -240,9 +241,9 @@ func (p *Proxy) RequestHandler(ctx *fasthttp.RequestCtx) { switch { case bytes.Contains(contentType, []byte("css")): - sanitizeCSS(&RequestConfig{p.Key, parsedURI}, ctx, responseBody) + sanitizeCSS(&RequestConfig{Key: p.Key, BaseURL: parsedURI}, ctx, responseBody) case bytes.Contains(contentType, []byte("html")): - sanitizeHTML(&RequestConfig{p.Key, parsedURI}, ctx, responseBody) + sanitizeHTML(&RequestConfig{Key: p.Key, BaseURL: parsedURI}, ctx, responseBody) default: ctx.Write(responseBody) } @@ -290,7 +291,7 @@ func sanitizeCSS(rc *RequestConfig, out io.Writer, css []byte) { urlStart := s[4] urlEnd := s[5] - if uri, err := proxifyURI(rc, string(css[urlStart:urlEnd])); err == nil { + if uri, err := rc.ProxifyURI(string(css[urlStart:urlEnd])); err == nil { out.Write(css[startIndex:urlStart]) out.Write([]byte(uri)) startIndex = urlEnd @@ -382,12 +383,12 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) { for _, attr := range attrs { if bytes.Equal(attr[0], []byte("action")) { formURL, _ = url.Parse(string(attr[1])) - mergeURIs(rc.baseURL, formURL) + mergeURIs(rc.BaseURL, formURL) break } } if formURL == nil { - formURL = rc.baseURL + formURL = rc.BaseURL } urlStr := formURL.String() var key string @@ -403,7 +404,7 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) { writeEndTag := true switch string(tag) { case "body": - fmt.Fprintf(out, HTML_BODY_EXTENSION, rc.baseURL.String()) + fmt.Fprintf(out, HTML_BODY_EXTENSION, rc.BaseURL.String()) case "style": state = STATE_DEFAULT case "noscript": @@ -492,7 +493,7 @@ func sanitizeMetaAttrs(rc *RequestConfig, out io.Writer, attrs [][][]byte) { urlIndex := bytes.Index(bytes.ToLower(content), []byte("url=")) if bytes.Equal(http_equiv, []byte("refresh")) && urlIndex != -1 { contentUrl := content[urlIndex+4:] - if uri, err := proxifyURI(rc, string(contentUrl)); err == nil { + if uri, err := rc.ProxifyURI(string(contentUrl)); err == nil { fmt.Fprintf(out, ` http-equiv="refresh" content="%surl=%s"`, content[:urlIndex], uri) } } else { @@ -514,7 +515,7 @@ func sanitizeAttr(rc *RequestConfig, out io.Writer, attrName, attrValue, escaped } switch string(attrName) { case "src", "href", "action": - if uri, err := proxifyURI(rc, string(attrValue)); err == nil { + if uri, err := rc.ProxifyURI(string(attrValue)); err == nil { fmt.Fprintf(out, " %s=\"%s\"", attrName, uri) } else { log.Println("cannot proxify uri:", attrValue) @@ -538,7 +539,7 @@ func mergeURIs(u1, u2 *url.URL) { } } -func proxifyURI(rc *RequestConfig, uri string) (string, error) { +func (rc *RequestConfig) ProxifyURI(uri string) (string, error) { // TODO check malicious data: - e.g. data:script if strings.HasPrefix(uri, "data:") { return uri, nil @@ -552,7 +553,7 @@ func proxifyURI(rc *RequestConfig, uri string) (string, error) { if err != nil { return "", err } - mergeURIs(rc.baseURL, u) + mergeURIs(rc.BaseURL, u) uri = u.String()