]> Sergey Matveev's repositories - tofuproxy.git/blobdiff - tls/verify.go
bytes.Equal instead of bytes.Compare
[tofuproxy.git] / tls / verify.go
index 5857b84c0bed94602ffb84b591c3247e23913b6b..2df0c24e5bcf7bc3984244df2676fae6bab71f72 100644 (file)
@@ -1,7 +1,7 @@
 /*
 tofuproxy -- flexible HTTP/HTTPS proxy, TLS terminator, X.509 TOFU
              manager, WARC/geminispace browser
-Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
+Copyright (C) 2021-2023 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
@@ -24,7 +24,9 @@ import (
        "crypto/x509"
        "encoding/hex"
        "encoding/pem"
+       "errors"
        "fmt"
+       "io/fs"
        "log"
        "os"
        "os/exec"
@@ -198,7 +200,7 @@ func verifyCert(
        if certTheirHash == certOurHash {
                return ErrRejected{host}
        }
-       daneExists, daneMatched := dane(host, certTheir)
+       daneExists, daneMatched := DANE(host, certTheir)
        if daneExists {
                if daneMatched {
                        fifos.LogDANE <- fmt.Sprintf("%s\tACK", host)
@@ -206,6 +208,30 @@ func verifyCert(
                        fifos.LogDANE <- fmt.Sprintf("%s\tNAK", host)
                }
        }
+       if len(verifiedChains) > 0 {
+               caHashes := make(map[string]struct{})
+               for _, certs := range verifiedChains {
+                       for _, cert := range certs {
+                               caHashes[spkiHash(cert)] = struct{}{}
+                       }
+               }
+               var restrictedHosts []string
+               caches.RestrictedM.RLock()
+               for h := range caHashes {
+                       restrictedHosts = append(restrictedHosts, caches.Restricted[h]...)
+               }
+               caches.RestrictedM.RUnlock()
+               if len(restrictedHosts) > 0 {
+                       for _, h := range restrictedHosts {
+                               if host == h || strings.HasSuffix(host, "."+h) {
+                                       goto HostIsNotRestricted
+                               }
+                       }
+                       fifos.LogCert <- fmt.Sprintf("Restricted\t%s", host)
+                       return ErrRejected{host}
+               }
+       }
+HostIsNotRestricted:
        fn := filepath.Join(Certs, host)
        certsOur, _, err := ucspi.CertPoolFromFile(fn)
        if err == nil || dialErr != nil || (daneExists && !daneMatched) {
@@ -213,18 +239,18 @@ func verifyCert(
                        caches.AcceptedM.Lock()
                        caches.Accepted[host] = certTheirHash
                        caches.AcceptedM.Unlock()
-                       if bytes.Compare(certsOur[0].Raw, rawCerts[0]) != 0 {
+                       if !bytes.Equal(certsOur[0].Raw, rawCerts[0]) {
                                fifos.LogCert <- fmt.Sprintf("Refresh\t%s\t%s", host, certTheirHash)
                                goto CertUpdate
                        }
                        return nil
                }
                var b bytes.Buffer
-               b.WriteString(fmt.Sprintf("set host \"%s\"\n", host))
+               fmt.Fprintf(&b, "set host \"%s\"\n", host)
                if dialErr == nil {
-                       b.WriteString(fmt.Sprintf("set err \"\"\n"))
+                       fmt.Fprintf(&b, "set err \"\"\n")
                } else {
-                       b.WriteString(fmt.Sprintf("set err \"%s\"\n", dialErr.Error()))
+                       fmt.Fprintf(&b, "set err \"%s\"\n", dialErr.Error())
                }
                var daneStatus string
                if daneExists {
@@ -234,24 +260,20 @@ func verifyCert(
                                daneStatus = "bad"
                        }
                }
-               b.WriteString(fmt.Sprintf("set daneStatus \"%s\"\n", daneStatus))
+               fmt.Fprintf(&b, "set daneStatus \"%s\"\n", daneStatus)
                hexCerts := make([]string, 0, len(rawCerts))
                for _, rawCert := range rawCerts {
                        hexCerts = append(hexCerts, hex.EncodeToString([]byte(certInfo(rawCert))))
                }
-               b.WriteString(fmt.Sprintf(
-                       "set certsTheir \"%s\"\n", strings.Join(hexCerts, " "),
-               ))
+               fmt.Fprintf(&b, "set certsTheir \"%s\"\n", strings.Join(hexCerts, " "))
                hexCerts = make([]string, 0, len(certsOur))
                for _, cert := range certsOur {
                        hexCerts = append(hexCerts, hex.EncodeToString([]byte(certInfo(cert.Raw))))
                }
-               b.WriteString(fmt.Sprintf(
-                       "set certsOur \"%s\"\n", strings.Join(hexCerts, " "),
-               ))
+               fmt.Fprintf(&b, "set certsOur \"%s\"\n", strings.Join(hexCerts, " "))
                b.WriteString(VerifyDialog)
                cmd := exec.Command(CmdWish)
-               // ioutil.WriteFile("/tmp/verify-dialog.tcl", b.Bytes(), 0666)
+               // os.WriteFile("/tmp/verify-dialog.tcl", b.Bytes(), 0666)
                cmd.Stdin = &b
                err = cmd.Run()
                exitError, ok := err.(*exec.ExitError)
@@ -279,7 +301,7 @@ func verifyCert(
                        return ErrRejected{host}
                }
        } else {
-               if !os.IsNotExist(err) {
+               if !errors.Is(err, fs.ErrNotExist) {
                        return err
                }
                fifos.LogCert <- fmt.Sprintf("TOFU\t%s\t%s", host, certTheirHash)