]> Sergey Matveev's repositories - tofuproxy.git/blobdiff - main.go
More reliable Habr hack
[tofuproxy.git] / main.go
diff --git a/main.go b/main.go
index ac80e2a80b92ee73ca76ef779532783bb0c1239a..a48ec246ffc51f8eafe46741d4cc65e10fdd013f 100644 (file)
--- a/main.go
+++ b/main.go
@@ -30,6 +30,7 @@ import (
        "net/http"
        "os"
        "os/exec"
+       "path/filepath"
        "strings"
        "time"
 
@@ -42,13 +43,30 @@ var (
        caCert        *x509.Certificate
        caPrv         crypto.PrivateKey
        transport     = http.Transport{
-               ForceAttemptHTTP2: false,
-               TLSNextProto:      make(map[string]func(string, *tls.Conn) http.RoundTripper),
                DialTLSContext:    dialTLS,
+               ForceAttemptHTTP2: true,
        }
+       sessionCache = tls.NewLRUClientSessionCache(1024)
 
        CmdDWebP = "dwebp"
        CmdDJXL  = "djxl"
+
+       imageExts = map[string]struct{}{
+               ".apng": {},
+               ".avif": {},
+               ".gif":  {},
+               ".heic": {},
+               ".jp2":  {},
+               ".jpeg": {},
+               ".jpg":  {},
+               ".jxl":  {},
+               ".mng":  {},
+               ".png":  {},
+               ".svg":  {},
+               ".tif":  {},
+               ".tiff": {},
+               ".webp": {},
+       }
 )
 
 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -60,6 +78,8 @@ func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
                ) error {
                        return verifyCert(host, nil, rawCerts, verifiedChains)
                },
+               ClientSessionCache: sessionCache,
+               NextProtos:         []string{"h2", "http/1.1"},
        }
        conn, dialErr := tls.Dial(network, addr, &cfg)
        if dialErr != nil {
@@ -81,13 +101,16 @@ func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
                }
        }
        connState := conn.ConnectionState()
-       sinkTLS <- fmt.Sprintf(
-               "%s\t%s %s\t%s",
-               strings.TrimSuffix(addr, ":443"),
-               ucspi.TLSVersion(connState.Version),
-               tls.CipherSuiteName(connState.CipherSuite),
-               spkiHash(connState.PeerCertificates[0]),
-       )
+       if connState.DidResume {
+               sinkTLS <- fmt.Sprintf(
+                       "%s\t%s %s\t%s\t%s",
+                       strings.TrimSuffix(addr, ":443"),
+                       ucspi.TLSVersion(connState.Version),
+                       tls.CipherSuiteName(connState.CipherSuite),
+                       spkiHash(connState.PeerCertificates[0]),
+                       connState.NegotiatedProtocol,
+               )
+       }
        return conn, nil
 }
 
@@ -112,12 +135,18 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                }
        }
 
-       if strings.HasPrefix(req.URL.Host, "www.reddit.com") {
+       if host == "www.reddit.com" {
                req.URL.Host = "old.reddit.com"
                http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
                return
        }
 
+       if host == "habrastorage.org" && strings.Contains(req.URL.Path, "r/w780q1") {
+               req.URL.Path = strings.Replace(req.URL.Path, "r/w780q1/", "", 1)
+               http.Redirect(w, req, req.URL.String(), http.StatusFound)
+               return
+       }
+
        resp, err := transport.RoundTrip(req)
        if err != nil {
                sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
@@ -232,6 +261,9 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                        if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
                                goto NoRedir
                        }
+                       if _, ok := imageExts[filepath.Ext(req.URL.Path)]; ok {
+                               goto NoRedir
+                       }
                        redirType = "temporary"
                default:
                        goto NoRedir
@@ -243,10 +275,8 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                location := resp.Header.Get("Location")
                w.Write([]byte(
                        fmt.Sprintf(
-                               `<html>
-<head><title>%d %s: %s redirection</title></head>
-<body>Redirection to <a href="%s">%s</a>
-</body></html>`,
+                               `<html><head><title>%d %s: %s redirection</title></head>
+<body>Redirection to <a href="%s">%s</a></body></html>`,
                                resp.StatusCode, http.StatusText(resp.StatusCode),
                                redirType, location, location,
                        )))