diff --git a/morty.go b/morty.go
index 3c886d3..795e051 100644
--- a/morty.go
+++ b/morty.go
@@ -103,7 +103,7 @@ type Proxy struct {
type RequestConfig struct {
Key []byte
- baseURL *url.URL
+ BaseURL *url.URL
}
var HTML_FORM_EXTENSION string = ``
@@ -196,7 +196,8 @@ func (p *Proxy) RequestHandler(ctx *fasthttp.RequestCtx) {
case 301, 302, 303, 307, 308:
loc := resp.Header.Peek("Location")
if loc != nil {
- url, err := proxifyURI(&RequestConfig{p.Key, parsedURI}, string(loc))
+ rc := &RequestConfig{Key: p.Key, BaseURL: parsedURI}
+ url, err := rc.ProxifyURI(string(loc))
if err == nil {
ctx.SetStatusCode(resp.StatusCode())
ctx.Response.Header.Add("Location", url)
@@ -240,9 +241,9 @@ func (p *Proxy) RequestHandler(ctx *fasthttp.RequestCtx) {
switch {
case bytes.Contains(contentType, []byte("css")):
- sanitizeCSS(&RequestConfig{p.Key, parsedURI}, ctx, responseBody)
+ sanitizeCSS(&RequestConfig{Key: p.Key, BaseURL: parsedURI}, ctx, responseBody)
case bytes.Contains(contentType, []byte("html")):
- sanitizeHTML(&RequestConfig{p.Key, parsedURI}, ctx, responseBody)
+ sanitizeHTML(&RequestConfig{Key: p.Key, BaseURL: parsedURI}, ctx, responseBody)
default:
ctx.Write(responseBody)
}
@@ -290,7 +291,7 @@ func sanitizeCSS(rc *RequestConfig, out io.Writer, css []byte) {
urlStart := s[4]
urlEnd := s[5]
- if uri, err := proxifyURI(rc, string(css[urlStart:urlEnd])); err == nil {
+ if uri, err := rc.ProxifyURI(string(css[urlStart:urlEnd])); err == nil {
out.Write(css[startIndex:urlStart])
out.Write([]byte(uri))
startIndex = urlEnd
@@ -382,12 +383,12 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) {
for _, attr := range attrs {
if bytes.Equal(attr[0], []byte("action")) {
formURL, _ = url.Parse(string(attr[1]))
- mergeURIs(rc.baseURL, formURL)
+ mergeURIs(rc.BaseURL, formURL)
break
}
}
if formURL == nil {
- formURL = rc.baseURL
+ formURL = rc.BaseURL
}
urlStr := formURL.String()
var key string
@@ -403,7 +404,7 @@ func sanitizeHTML(rc *RequestConfig, out io.Writer, htmlDoc []byte) {
writeEndTag := true
switch string(tag) {
case "body":
- fmt.Fprintf(out, HTML_BODY_EXTENSION, rc.baseURL.String())
+ fmt.Fprintf(out, HTML_BODY_EXTENSION, rc.BaseURL.String())
case "style":
state = STATE_DEFAULT
case "noscript":
@@ -492,7 +493,7 @@ func sanitizeMetaAttrs(rc *RequestConfig, out io.Writer, attrs [][][]byte) {
urlIndex := bytes.Index(bytes.ToLower(content), []byte("url="))
if bytes.Equal(http_equiv, []byte("refresh")) && urlIndex != -1 {
contentUrl := content[urlIndex+4:]
- if uri, err := proxifyURI(rc, string(contentUrl)); err == nil {
+ if uri, err := rc.ProxifyURI(string(contentUrl)); err == nil {
fmt.Fprintf(out, ` http-equiv="refresh" content="%surl=%s"`, content[:urlIndex], uri)
}
} else {
@@ -514,7 +515,7 @@ func sanitizeAttr(rc *RequestConfig, out io.Writer, attrName, attrValue, escaped
}
switch string(attrName) {
case "src", "href", "action":
- if uri, err := proxifyURI(rc, string(attrValue)); err == nil {
+ if uri, err := rc.ProxifyURI(string(attrValue)); err == nil {
fmt.Fprintf(out, " %s=\"%s\"", attrName, uri)
} else {
log.Println("cannot proxify uri:", attrValue)
@@ -538,7 +539,7 @@ func mergeURIs(u1, u2 *url.URL) {
}
}
-func proxifyURI(rc *RequestConfig, uri string) (string, error) {
+func (rc *RequestConfig) ProxifyURI(uri string) (string, error) {
// TODO check malicious data: - e.g. data:script
if strings.HasPrefix(uri, "data:") {
return uri, nil
@@ -552,7 +553,7 @@ func proxifyURI(rc *RequestConfig, uri string) (string, error) {
if err != nil {
return "", err
}
- mergeURIs(rc.baseURL, u)
+ mergeURIs(rc.BaseURL, u)
uri = u.String()