"io"
"log"
"mime"
+ "mime/multipart"
"net/mail"
"net/textproto"
"os"
"strings"
)
+const (
+ ContentType = "Content-Type"
+ MultipartEncrypted = "multipart/encrypted"
+ PGPEncrypted = "application/pgp-encrypted"
+ OctetStream = "application/octet-stream"
+ StatusPrefix = "[GNUPG:] "
+ DecryptionOK = "DECRYPTION_OKAY"
+)
+
+func readHdrLines(data []byte) (hdrs []string, values [][]string) {
+ var line string
+ s := bufio.NewScanner(bytes.NewReader(data))
+ for s.Scan() {
+ line = s.Text()
+ if line == "" {
+ break
+ }
+ if line[0] == ' ' || line[0] == '\t' {
+ values[len(values)-1] = append(values[len(values)-1], line)
+ continue
+ }
+ hdrs = append(hdrs, textproto.CanonicalMIMEHeaderKey(
+ line[:strings.Index(line, ":")]))
+ values = append(values, []string{line})
+ }
+ return
+}
+
func main() {
log.SetFlags(log.Lshortfile)
- rawEnc, err := io.ReadAll(bufio.NewReader(os.Stdin))
+ encRaw, err := io.ReadAll(bufio.NewReader(os.Stdin))
if err != nil {
log.Fatal(err)
}
- msgEnc, err := mail.ReadMessage(bytes.NewReader(rawEnc))
+ msgEnc, err := mail.ReadMessage(bytes.NewReader(encRaw))
if err != nil {
log.Fatal(err)
}
+ var part *multipart.Part
{
- ct := msgEnc.Header.Get("Content-Type")
- if ct == "" {
- log.Fatal("no Content-Type")
+ var boundary string
+ {
+ ct := msgEnc.Header.Get(ContentType)
+ if ct == "" {
+ log.Fatal("no", ContentType)
+ }
+ mt, params, errParse := mime.ParseMediaType(ct)
+ if errParse != nil {
+ log.Fatal(errParse)
+ }
+ if mt != MultipartEncrypted {
+ log.Fatal("not", MultipartEncrypted)
+ }
+ if proto := params["protocol"]; proto != PGPEncrypted {
+ log.Fatal("not", PGPEncrypted)
+ }
+ boundary = params["boundary"]
}
- mt, params, errParse := mime.ParseMediaType(ct)
- if errParse != nil {
- log.Fatal(errParse)
+ mpr := multipart.NewReader(msgEnc.Body, boundary)
+ part, err = mpr.NextPart()
+ if err != nil {
+ log.Fatal(err)
}
- if mt != "multipart/encrypted" {
- log.Fatal("not multipart/encrypted")
+ if part.Header.Get(ContentType) != PGPEncrypted {
+ log.Fatal("wrong first part's", ContentType)
}
- if proto := params["protocol"]; proto != "application/pgp-encrypted" {
- log.Fatal("not application/pgp-encrypted")
+ part, err = mpr.NextPart()
+ if err != nil {
+ log.Fatal(err)
+ }
+ if part.Header.Get(ContentType) != OctetStream {
+ log.Fatal("wrong second part's", ContentType)
}
}
- var rawDec []byte
+ var line string
+ var decRaw []byte
{
args := []string{"--batch", "--decrypt", "--status-fd", "3"}
- args = append(args, os.Args[1:]...)
- cmd := exec.Command("gpg", args...)
+ cmd := exec.Command("gpg", append(args, os.Args[1:]...)...)
cmd.Stderr = os.Stderr
var stdin io.WriteCloser
stdin, err = cmd.StdinPipe()
log.Fatal(err)
}
cmd.ExtraFiles = append(cmd.ExtraFiles, statusW)
-
err = cmd.Start()
if err != nil {
log.Fatal(err)
}
go func() {
- _, errStdin := io.Copy(stdin, msgEnc.Body)
+ _, errStdin := io.Copy(stdin, part)
if err != nil {
log.Print(errStdin)
}
var goodDec bool
go func() {
s := bufio.NewScanner(statusR)
- var t string
for s.Scan() {
- t = s.Text()
- if t == "[GNUPG:] DECRYPTION_OKAY" {
+ line = s.Text()
+ if line == StatusPrefix+DecryptionOK {
goodDec = true
}
- log.Print(t)
+ log.Print(line)
}
if s.Err() != nil {
log.Print(s.Err())
}
}()
- rawDec, err = io.ReadAll(stdout)
+ decRaw, err = io.ReadAll(stdout)
if err != nil {
log.Fatal(err)
}
log.Fatal(err)
}
if !goodDec {
- log.Fatal("no DECRYPTION_OKAY received")
+ log.Fatal("no", DecryptionOK, "received")
}
}
- msgDec, err := mail.ReadMessage(bytes.NewReader(rawDec))
- if err != nil {
- log.Fatal(err)
- }
stdout := bufio.NewWriter(os.Stdout)
- r := textproto.NewReader(bufio.NewReader(bytes.NewReader(rawEnc)))
- var t string
- var hdr string
- for {
- t, err = r.ReadContinuedLine()
- if err != nil {
- log.Fatal(err)
- }
- if t == "" {
- break
- }
- hdr = t[:strings.Index(t, ":")]
- if msgDec.Header.Get(hdr) != "" {
- continue
- }
- if _, err = stdout.WriteString(t); err != nil {
- log.Fatal(err)
- }
- if _, err = stdout.WriteString("\n"); err != nil {
- log.Fatal(err)
+ {
+ encHdrs, encValues := readHdrLines(encRaw)
+ decHdrs, _ := readHdrLines(decRaw)
+ EncHdr:
+ for idxEnc, encHdr := range encHdrs {
+ for _, decHdr := range decHdrs {
+ if decHdr == encHdr {
+ continue EncHdr
+ }
+ }
+ for _, line := range encValues[idxEnc] {
+ if _, err = stdout.WriteString(line + "\n"); err != nil {
+ log.Fatal(err)
+ }
+ }
}
}
- if _, err = io.Copy(stdout, bytes.NewReader(rawDec)); err != nil {
+ if _, err = io.Copy(stdout, bytes.NewReader(decRaw)); err != nil {
log.Fatal(err)
}
if err = stdout.Flush(); err != nil {