diff --git a/conf.sample/destination b/conf.sample/destination index faa3a15..0af60f3 100644 --- a/conf.sample/destination +++ b/conf.sample/destination @@ -1 +1 @@ -stdout +127.0.0.1:4443 diff --git a/d4-goclient.go b/d4-goclient.go index eb11556..86f8feb 100644 --- a/d4-goclient.go +++ b/d4-goclient.go @@ -4,13 +4,18 @@ import ( "bytes" "crypto/hmac" "crypto/sha256" + "crypto/tls" + "crypto/x509" "encoding/binary" "flag" "fmt" "io" + "io/ioutil" "log" + "net" "os" "strconv" + "strings" "time" //BSD 3 @@ -49,6 +54,12 @@ type ( src io.Reader dst d4Writer confdir string + cka time.Duration + ct time.Duration + ce bool + retry time.Duration + cc bool + ca x509.CertPool d4error uint8 errnoCopy uint8 debug bool @@ -74,8 +85,17 @@ var ( logger.Output(2, info) } + tmpct, _ = time.ParseDuration("5mn") + tmpcka, _ = time.ParseDuration("2h") + tmpretry, _ = time.ParseDuration("30s") + confdir = flag.String("c", "", "configuration directory") debug = flag.Bool("v", false, "Set to True, true, TRUE, 1, or t to enable verbose output on stdout") + ce = flag.Bool("ce", true, "Set to True, true, TRUE, 1, or t to enable TLS on network destination") + ct = flag.Duration("ct", tmpct, "Set timeout in human format") + cka = flag.Duration("cka", tmpcka, "Keep Alive time human format, 0 to disable") + retry = flag.Duration("rt", tmpretry, "Time in human format before retry after connection failure, set to 0 to exit on failure") + cc = flag.Bool("cc", false, "Check TLS certificate againt rootCA.crt") ) func main() { @@ -104,6 +124,11 @@ func main() { fmt.Printf("destination - the destination where the data is written to\n") fmt.Printf("\n") fmt.Printf("-v [TRUE] for verbose output on stdout") + fmt.Printf("-ce [TRUE] if destination is set to ip:port, use of tls") + fmt.Printf("-cc [FALSE] if destination is set to ip:port, verification of server's tls certificate againt rootCA.crt") + fmt.Printf("-ct [300] if destination is set to ip:port, timeout") + fmt.Printf("-cka [3600] if destination is set to ip:port, keepalive") + fmt.Printf("-retry [5] if destination is set to ip:port, retry period ") flag.PrintDefaults() } @@ -113,6 +138,11 @@ func main() { os.Exit(1) } d4.confdir = *confdir + d4.ce = *ce + d4.ct = *ct + d4.cc = *cc + d4.cka = *cka + d4.retry = *retry // Output logging before closing if debug is enabled if *debug == true { @@ -120,15 +150,44 @@ func main() { defer fmt.Print(&buf) } - if d4loadConfig(d4p) == true { - if d4.dst.initHeader(d4p) == true { - io.CopyBuffer(&d4.dst, d4.src, d4.dst.pb) + c := make(chan string) + for { + if set(d4p) { + go d4Copy(d4p, c) + } else if d4.retry > 0 { + go func() { + time.Sleep(d4.retry) + infof(fmt.Sprintf("Sleeping for %f seconds before retry.\n", d4.retry.Seconds())) + c <- "done waiting" + }() + } else { + panic("Unrecoverable error without retry.") } + <-c + } +} + +func set(d4 *d4S) bool { + if d4loadConfig(d4) { + if setReaderWriters(d4) { + if d4.dst.initHeader(d4) { + return true + } + } + } + return false +} + +func d4Copy(d4 *d4S, c chan string) { + _, err := io.CopyBuffer(&d4.dst, d4.src, d4.dst.pb) + if err != nil { + c <- fmt.Sprintf("%s", err) } } func readConfFile(d4 *d4S, fileName string) []byte { f, err := os.Open((*d4).confdir + "/" + fileName) + defer f.Close() if err != nil { log.Fatal(err) } @@ -143,7 +202,7 @@ func readConfFile(d4 *d4S, fileName string) []byte { if err := f.Close(); err != nil { log.Fatal(err) } - // removes 1 for \n + // trim \n if present return bytes.TrimSuffix(data[:count], []byte("\n")) } @@ -158,12 +217,12 @@ func d4loadConfig(d4 *d4S) bool { (*d4).conf.uuid = generateUUIDv4() // And push it into the conf file f, err := os.OpenFile((*d4).confdir+"/uuid", os.O_WRONLY|os.O_CREATE, 0666) + defer f.Close() if err != nil { log.Fatal(err) } // store as canonical representation f.WriteString(fmt.Sprintf("%s", uuid.FromBytesOrNil((*d4).conf.uuid)) + "\n") - f.Close() } else { (*d4).conf.uuid = tmpu.Bytes() } @@ -177,7 +236,16 @@ func d4loadConfig(d4 *d4S) bool { // parse type to uint8 tmp, _ = strconv.ParseUint(string(readConfFile(d4, "type")), 10, 8) (*d4).conf.ttype = uint8(tmp) - return d4checkConfig(d4) + // Add the custom CA cert in D4 certpool + if (*d4).cc { + certb, _ := ioutil.ReadFile((*d4).confdir + "rootCA.crt") + (*d4).ca = *x509.NewCertPool() + ok := (*d4).ca.AppendCertsFromPEM(certb) + if !ok { + panic("Failed to parse provided root certificate.") + } + } + return true } func newD4Writer(writer io.Writer, key []byte) d4Writer { @@ -185,7 +253,7 @@ func newD4Writer(writer io.Writer, key []byte) d4Writer { } // TODO QUICK IMPLEM, REVISE -func d4checkConfig(d4 *d4S) bool { +func setReaderWriters(d4 *d4S) bool { //TODO implement other destination file, fifo unix_socket ... switch (*d4).conf.source { @@ -195,13 +263,47 @@ func d4checkConfig(d4 *d4S) bool { f, _ := os.Open("capture.pcap") (*d4).src = f } - - switch (*d4).conf.destination { - case "stdout": - (*d4).dst = newD4Writer(os.Stdout, (*d4).conf.key) - case "file": - f, _ := os.Create("test.txt") - (*d4).dst = newD4Writer(f, (*d4).conf.key) + isn, dstnet := isNet((*d4).conf.destination) + if isn { + dial := net.Dialer{ + DualStack: true, + Timeout: (*d4).ct, + KeepAlive: (*d4).cka, + FallbackDelay: 0, + } + tlsc := tls.Config{ + InsecureSkipVerify: true, + } + if (*d4).cc { + tlsc = tls.Config{ + InsecureSkipVerify: false, + RootCAs: &(*d4).ca, + } + } + if (*d4).ce == true { + conn, errc := tls.DialWithDialer(&dial, "tcp", dstnet[0]+":"+dstnet[1], &tlsc) + if errc != nil { + fmt.Println(errc) + return false + } + (*d4).dst = newD4Writer(conn, (*d4).conf.key) + } else { + conn, errc := dial.Dial("tcp", dstnet[0]+":"+dstnet[1]) + if errc != nil { + return false + } + (*d4).dst = newD4Writer(conn, (*d4).conf.key) + } + } else { + switch (*d4).conf.destination { + case "stdout": + (*d4).dst = newD4Writer(os.Stdout, (*d4).conf.key) + case "file": + f, _ := os.Create("test.txt") + (*d4).dst = newD4Writer(f, (*d4).conf.key) + default: + panic(fmt.Sprintf("No suitable destination found, given :%q", (*d4).conf.destination)) + } } // Create the copy buffer @@ -211,6 +313,17 @@ func d4checkConfig(d4 *d4S) bool { return true } +func isNet(d string) (bool, []string) { + ss := strings.Split(string(d), ":") + if len(ss) != 1 { + if net.ParseIP(ss[0]) != nil { + infof(fmt.Sprintf("Server IP: %s, Server Port: %s\n", ss[0], ss[1])) + return true, ss + } + } + return false, make([]string, 0) +} + func generateUUIDv4() []byte { uuid, err := uuid.NewV4() if err != nil { @@ -233,10 +346,7 @@ func (d4w *d4Writer) Write(bs []byte) (int, error) { d4w.updateHMAC(len(bs)) // Eventually write binary in the sink err := binary.Write(d4w.w, binary.LittleEndian, d4w.fb[:62+len(bs)]) - if err != nil { - log.Fatal(err) - } - return len(bs), nil + return len(bs), err } // TODO write go idiomatic err return values