]> Sergey Matveev's repositories - tofuproxy.git/blobdiff - tls.go
Refactoring
[tofuproxy.git] / tls.go
diff --git a/tls.go b/tls.go
new file mode 100644 (file)
index 0000000..4c29a6f
--- /dev/null
+++ b/tls.go
@@ -0,0 +1,148 @@
+/*
+tofuproxy -- HTTP proxy with TLS certificates management
+Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, version 3 of the License.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program.  If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package tofuproxy
+
+import (
+       "context"
+       "crypto"
+       "crypto/tls"
+       "crypto/x509"
+       "fmt"
+       "log"
+       "net"
+       "net/http"
+       "strings"
+       "time"
+
+       "go.cypherpunks.ru/ucspi"
+       "go.stargrave.org/tofuproxy/fifos"
+)
+
+var (
+       TLSNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
+       CACert        *x509.Certificate
+       CAPrv         crypto.PrivateKey
+       sessionCache  = tls.NewLRUClientSessionCache(1024)
+)
+
+type Handler struct{}
+
+func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+       if req.Method != http.MethodConnect {
+               roundTrip(w, req)
+               return
+       }
+       hj, ok := w.(http.Hijacker)
+       if !ok {
+               log.Fatalln("no hijacking")
+       }
+       conn, _, err := hj.Hijack()
+       if err != nil {
+               log.Fatalln(err)
+       }
+       defer conn.Close()
+       conn.Write([]byte(fmt.Sprintf(
+               "%s %d %s\r\n\r\n",
+               req.Proto,
+               http.StatusOK, http.StatusText(http.StatusOK),
+       )))
+       host := strings.Split(req.Host, ":")[0]
+       hostCertsM.Lock()
+       keypair, ok := hostCerts[host]
+       if !ok || !keypair.cert.NotAfter.After(time.Now().Add(time.Hour)) {
+               keypair = newKeypair(host, CACert, CAPrv)
+               hostCerts[host] = keypair
+       }
+       hostCertsM.Unlock()
+       tlsConn := tls.Server(conn, &tls.Config{
+               Certificates: []tls.Certificate{{
+                       Certificate: [][]byte{keypair.cert.Raw},
+                       PrivateKey:  keypair.prv,
+               }},
+       })
+       if err = tlsConn.Handshake(); err != nil {
+               log.Printf("TLS error %s: %+v\n", host, err)
+               return
+       }
+       srv := http.Server{
+               Handler:      &HTTPSHandler{host: req.Host},
+               TLSNextProto: TLSNextProtoS,
+       }
+       err = srv.Serve(&SingleListener{conn: tlsConn})
+       if err != nil {
+               if _, ok := err.(AlreadyAccepted); !ok {
+                       log.Printf("TLS serve error %s: %+v\n", host, err)
+                       return
+               }
+       }
+}
+
+type HTTPSHandler struct {
+       host string
+}
+
+func (h *HTTPSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+       req.URL.Scheme = "https"
+       req.URL.Host = h.host
+       roundTrip(w, req)
+}
+
+func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
+       host := strings.TrimSuffix(addr, ":443")
+       cfg := tls.Config{
+               VerifyPeerCertificate: func(
+                       rawCerts [][]byte,
+                       verifiedChains [][]*x509.Certificate,
+               ) error {
+                       return verifyCert(host, nil, rawCerts, verifiedChains)
+               },
+               ClientSessionCache: sessionCache,
+               NextProtos:         []string{"h2", "http/1.1"},
+       }
+       conn, dialErr := tls.Dial(network, addr, &cfg)
+       if dialErr != nil {
+               if _, ok := dialErr.(ErrRejected); ok {
+                       return nil, dialErr
+               }
+               cfg.InsecureSkipVerify = true
+               cfg.VerifyPeerCertificate = func(
+                       rawCerts [][]byte,
+                       verifiedChains [][]*x509.Certificate,
+               ) error {
+                       return verifyCert(host, dialErr, rawCerts, verifiedChains)
+               }
+               var err error
+               conn, err = tls.Dial(network, addr, &cfg)
+               if err != nil {
+                       fifos.SinkErr <- fmt.Sprintf("%s\t%s", addr, dialErr.Error())
+                       return nil, err
+               }
+       }
+       connState := conn.ConnectionState()
+       if connState.DidResume {
+               fifos.SinkTLS <- fmt.Sprintf(
+                       "%s\t%s %s\t%s\t%s",
+                       strings.TrimSuffix(addr, ":443"),
+                       ucspi.TLSVersion(connState.Version),
+                       tls.CipherSuiteName(connState.CipherSuite),
+                       spkiHash(connState.PeerCertificates[0]),
+                       connState.NegotiatedProtocol,
+               )
+       }
+       return conn, nil
+}