]> Sergey Matveev's repositories - tofuproxy.git/blobdiff - main.go
More reliable Habr hack
[tofuproxy.git] / main.go
diff --git a/main.go b/main.go
index 16fae3d6c0742f5ab38723ce1dfa4ae6f704cf35..a48ec246ffc51f8eafe46741d4cc65e10fdd013f 100644 (file)
--- a/main.go
+++ b/main.go
@@ -19,18 +19,19 @@ package main
 import (
        "context"
        "crypto"
-       "crypto/sha256"
        "crypto/tls"
        "crypto/x509"
-       "encoding/hex"
        "flag"
        "fmt"
        "io"
+       "io/ioutil"
        "log"
        "net"
        "net/http"
+       "os"
+       "os/exec"
+       "path/filepath"
        "strings"
-       "sync"
        "time"
 
        "github.com/dustin/go-humanize"
@@ -39,47 +40,34 @@ import (
 
 var (
        tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
-       tlsNextProtoC = make(map[string]func(string, *tls.Conn) http.RoundTripper)
        caCert        *x509.Certificate
        caPrv         crypto.PrivateKey
-       certs         *string
-       dnsSrv        *string
        transport     = http.Transport{
-               ForceAttemptHTTP2:   false,
-               DisableKeepAlives:   true,
-               MaxIdleConnsPerHost: 2,
-               TLSNextProto:        tlsNextProtoC,
-               DialTLSContext:      dialTLS,
+               DialTLSContext:    dialTLS,
+               ForceAttemptHTTP2: true,
        }
+       sessionCache = tls.NewLRUClientSessionCache(1024)
 
-       accepted  = make(map[string]string)
-       acceptedM sync.RWMutex
-       rejected  = make(map[string]string)
-       rejectedM sync.RWMutex
-)
-
-func spkiHash(cert *x509.Certificate) string {
-       hsh := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
-       return hex.EncodeToString(hsh[:])
-}
-
-func acceptedAdd(addr, h string) {
-       acceptedM.Lock()
-       accepted[addr] = h
-       acceptedM.Unlock()
-}
-
-func rejectedAdd(addr, h string) {
-       rejectedM.Lock()
-       rejected[addr] = h
-       rejectedM.Unlock()
-}
-
-type ErrRejected struct {
-       addr string
-}
+       CmdDWebP = "dwebp"
+       CmdDJXL  = "djxl"
 
-func (err ErrRejected) Error() string { return err.addr + " was rejected" }
+       imageExts = map[string]struct{}{
+               ".apng": {},
+               ".avif": {},
+               ".gif":  {},
+               ".heic": {},
+               ".jp2":  {},
+               ".jpeg": {},
+               ".jpg":  {},
+               ".jxl":  {},
+               ".mng":  {},
+               ".png":  {},
+               ".svg":  {},
+               ".tif":  {},
+               ".tiff": {},
+               ".webp": {},
+       }
+)
 
 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
        host := strings.TrimSuffix(addr, ":443")
@@ -90,6 +78,8 @@ func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
                ) error {
                        return verifyCert(host, nil, rawCerts, verifiedChains)
                },
+               ClientSessionCache: sessionCache,
+               NextProtos:         []string{"h2", "http/1.1"},
        }
        conn, dialErr := tls.Dial(network, addr, &cfg)
        if dialErr != nil {
@@ -111,13 +101,16 @@ func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
                }
        }
        connState := conn.ConnectionState()
-       sinkTLS <- fmt.Sprintf(
-               "%s\t%s %s\t%s",
-               strings.TrimSuffix(addr, ":443"),
-               ucspi.TLSVersion(connState.Version),
-               tls.CipherSuiteName(connState.CipherSuite),
-               spkiHash(connState.PeerCertificates[0]),
-       )
+       if connState.DidResume {
+               sinkTLS <- fmt.Sprintf(
+                       "%s\t%s %s\t%s\t%s",
+                       strings.TrimSuffix(addr, ":443"),
+                       ucspi.TLSVersion(connState.Version),
+                       tls.CipherSuiteName(connState.CipherSuite),
+                       spkiHash(connState.PeerCertificates[0]),
+                       connState.NegotiatedProtocol,
+               )
+       }
        return conn, nil
 }
 
@@ -128,6 +121,7 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
        }
        sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
        host := strings.TrimSuffix(req.URL.Host, ":443")
+
        for _, spy := range SpyDomains {
                if strings.HasSuffix(host, spy) {
                        http.NotFound(w, req)
@@ -140,11 +134,19 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                        return
                }
        }
-       if strings.HasPrefix(req.URL.Host, "www.reddit.com") {
+
+       if host == "www.reddit.com" {
                req.URL.Host = "old.reddit.com"
                http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
                return
        }
+
+       if host == "habrastorage.org" && strings.Contains(req.URL.Path, "r/w780q1") {
+               req.URL.Path = strings.Replace(req.URL.Path, "r/w780q1/", "", 1)
+               http.Redirect(w, req, req.URL.String(), http.StatusFound)
+               return
+       }
+
        resp, err := transport.RoundTrip(req)
        if err != nil {
                sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
@@ -152,8 +154,17 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                w.Write([]byte(err.Error()))
                return
        }
-       contentType := resp.Header.Get("Content-Type")
-       switch contentType {
+
+       for k, vs := range resp.Header {
+               if k == "Location" || k == "Content-Type" || k == "Content-Length" {
+                       continue
+               }
+               for _, v := range vs {
+                       w.Header().Add(k, v)
+               }
+       }
+
+       switch resp.Header.Get("Content-Type") {
        case "application/font-woff", "application/font-sfnt":
                // Those are deprecated types
                fallthrough
@@ -167,15 +178,79 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                )
                resp.Body.Close()
                return
-       }
-       for k, vs := range resp.Header {
-               if k == "Location" || k == "Content-Type" || k == "Content-Length" {
-                       continue
+       case "image/webp":
+               if strings.Contains(req.Header.Get("User-Agent"), "AppleWebKit/538.15") {
+                       // My Xombrero
+                       break
                }
-               for _, v := range vs {
-                       w.Header().Add(k, v)
+               tmpFd, err := ioutil.TempFile("", "tofuproxy.*.webp")
+               if err != nil {
+                       log.Fatalln(err)
+               }
+               defer tmpFd.Close()
+               defer os.Remove(tmpFd.Name())
+               defer resp.Body.Close()
+               if _, err = io.Copy(tmpFd, resp.Body); err != nil {
+                       log.Printf("Error during %s: %+v\n", req.URL, err)
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
+               }
+               tmpFd.Close()
+               cmd := exec.Command(CmdDWebP, tmpFd.Name(), "-o", "-")
+               data, err := cmd.Output()
+               if err != nil {
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
+               }
+               w.Header().Add("Content-Type", "image/png")
+               w.WriteHeader(http.StatusOK)
+               w.Write(data)
+               sinkOther <- fmt.Sprintf(
+                       "%s %s\t%d\tWebP transcoded to PNG",
+                       req.Method,
+                       req.URL.String(),
+                       http.StatusOK,
+               )
+               return
+       case "image/jxl":
+               tmpFd, err := ioutil.TempFile("", "tofuproxy.*.jxl")
+               if err != nil {
+                       log.Fatalln(err)
+               }
+               defer tmpFd.Close()
+               defer os.Remove(tmpFd.Name())
+               defer resp.Body.Close()
+               if _, err = io.Copy(tmpFd, resp.Body); err != nil {
+                       log.Printf("Error during %s: %+v\n", req.URL, err)
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
                }
+               tmpFd.Close()
+               dstFn := tmpFd.Name() + ".png"
+               cmd := exec.Command(CmdDJXL, tmpFd.Name(), dstFn)
+               err = cmd.Run()
+               defer os.Remove(dstFn)
+               if err != nil {
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
+               }
+               data, err := ioutil.ReadFile(dstFn)
+               if err != nil {
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
+               }
+               w.Header().Add("Content-Type", "image/png")
+               w.WriteHeader(http.StatusOK)
+               w.Write(data)
+               sinkOther <- fmt.Sprintf(
+                       "%s %s\t%d\tJPEG XL transcoded to PNG",
+                       req.Method,
+                       req.URL.String(),
+                       http.StatusOK,
+               )
+               return
        }
+
        if req.Method == http.MethodGet {
                var redirType string
                switch resp.StatusCode {
@@ -186,6 +261,9 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                        if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
                                goto NoRedir
                        }
+                       if _, ok := imageExts[filepath.Ext(req.URL.Path)]; ok {
+                               goto NoRedir
+                       }
                        redirType = "temporary"
                default:
                        goto NoRedir
@@ -197,10 +275,8 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                location := resp.Header.Get("Location")
                w.Write([]byte(
                        fmt.Sprintf(
-                               `<html>
-<head><title>%d %s: %s redirection</title></head>
-<body>Redirection to <a href="%s">%s</a>
-</body></html>`,
+                               `<html><head><title>%d %s: %s redirection</title></head>
+<body>Redirection to <a href="%s">%s</a></body></html>`,
                                resp.StatusCode, http.StatusText(resp.StatusCode),
                                redirType, location, location,
                        )))
@@ -209,6 +285,7 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                )
                return
        }
+
 NoRedir:
        for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
                if v := resp.Header.Get(h); v != "" {
@@ -328,7 +405,6 @@ func main() {
                Handler:      &Handler{},
                TLSNextProto: tlsNextProtoS,
        }
-       srv.SetKeepAlivesEnabled(false)
        log.Println("listening:", *bind)
        if err := srv.Serve(ln); err != nil {
                log.Fatalln(err)