diff --git a/contrib/backends/srndv2/src/srnd/line.go b/contrib/backends/srndv2/src/srnd/line.go index 297f00c..26a2dfa 100644 --- a/contrib/backends/srndv2/src/srnd/line.go +++ b/contrib/backends/srndv2/src/srnd/line.go @@ -7,10 +7,10 @@ import ( type LineWriter struct { w io.Writer - limit int + limit int64 } -func NewLineWriter(w io.Writer, limit int) *LineWriter { +func NewLineWriter(w io.Writer, limit int64) *LineWriter { return &LineWriter{ w: w, limit: limit, @@ -18,23 +18,12 @@ func NewLineWriter(w io.Writer, limit int) *LineWriter { } func (l *LineWriter) Write(data []byte) (n int, err error) { - dl := len(data) + n = len(data) data = bytes.Replace(data, []byte{13, 10}, []byte{10}, -1) - parts := bytes.Split(data, []byte{10}) - for _, part := range parts { - for len(part) > l.limit { - d := make([]byte, l.limit) - copy(d, part[:l.limit]) - d = append(d, 10) - _, err = l.w.Write(d) - part = part[l.limit:] - if err != nil { - return - } - } - part = append(part, 10) - _, err = l.w.Write(part) + _, err = l.w.Write(data) + l.limit -= int64(n) + if l.limit <= 0 { + err = ErrOversizedMessage } - n = dl return } diff --git a/contrib/backends/srndv2/src/srnd/message.go b/contrib/backends/srndv2/src/srnd/message.go index ae88720..944e8fb 100644 --- a/contrib/backends/srndv2/src/srnd/message.go +++ b/contrib/backends/srndv2/src/srnd/message.go @@ -209,18 +209,14 @@ func signArticle(nntp NNTPMessage, seed []byte) (signed *nntpArticle, err error) func (self *nntpArticle) BodyReader() io.Reader { if self.Pubkey() == "" { buff := new(bytes.Buffer) - self.WriteBody(buff, 80) + self.WriteBody(buff, MaxMessageSize) return buff } else { return self.signedPart.body } } -func (self *nntpArticle) WriteTo(wr io.Writer, limit int64) error { - return self.writeTo(wr, limit, false) -} - -func (self *nntpArticle) writeTo(wr io.Writer, limit int64, ignoreLimit bool) (err error) { +func (self *nntpArticle) WriteTo(wr io.Writer, limit int64) (err error) { // write headers var n int hdrs := self.headers @@ -248,8 +244,8 @@ func (self *nntpArticle) writeTo(wr io.Writer, limit int64, ignoreLimit bool) (e return } - if limit > 0 || ignoreLimit { - err = self.WriteBody(wr, 80) + if limit > 0 { + err = self.WriteBody(wr, limit) } else { err = ErrOversizedMessage } @@ -414,7 +410,7 @@ func (self *nntpArticle) WriteBody(wr io.Writer, limit int64) (err error) { boundary, ok := params["boundary"] if ok { - nlw := NewLineWriter(wr, 80) + nlw := NewLineWriter(wr, limit) w := multipart.NewWriter(nlw) err = w.SetBoundary(boundary) @@ -452,7 +448,7 @@ func (self *nntpArticle) WriteBody(wr io.Writer, limit int64) (err error) { err = w.Close() w = nil } else { - nlw := NewLineWriter(wr, 80) + nlw := NewLineWriter(wr, limit) // write out message _, err = io.WriteString(nlw, self.message) } diff --git a/contrib/backends/srndv2/src/srnd/spam.go b/contrib/backends/srndv2/src/srnd/spam.go index 6f67438..21773a7 100644 --- a/contrib/backends/srndv2/src/srnd/spam.go +++ b/contrib/backends/srndv2/src/srnd/spam.go @@ -1,10 +1,13 @@ package srnd import ( + "errors" "io" "net" ) +var ErrSpamFilterNotEnabled = errors.New("spam filter access attempted when disabled") + type SpamFilter struct { addr string enabled bool @@ -15,25 +18,27 @@ func (sp *SpamFilter) Configure(c SpamConfig) { sp.addr = c.addr } +func (sp *SpamFilter) Enabled() bool { + return sp.enabled +} + func (sp *SpamFilter) Rewrite(msg io.Reader, out io.WriteCloser) error { var buff [65636]byte - if sp.enabled { - addr, err := net.ResolveTCPAddr("tcp", sp.addr) - if err != nil { - return err - } - c, err := net.DialTCP("tcp", nil, addr) - if err != nil { - return err - } - io.CopyBuffer(c, msg, buff[:]) - c.CloseWrite() - _, err = io.CopyBuffer(out, c, buff[:]) - c.Close() - out.Close() + if !sp.Enabled() { + return ErrSpamFilterNotEnabled + } + addr, err := net.ResolveTCPAddr("tcp", sp.addr) + if err != nil { return err } - io.CopyBuffer(out, msg, buff[:]) + c, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return err + } + io.CopyBuffer(c, msg, buff[:]) + c.CloseWrite() + _, err = io.CopyBuffer(out, c, buff[:]) + c.Close() out.Close() - return nil + return err } diff --git a/contrib/backends/srndv2/src/srnd/store.go b/contrib/backends/srndv2/src/srnd/store.go index 33a7330..c761d8a 100644 --- a/contrib/backends/srndv2/src/srnd/store.go +++ b/contrib/backends/srndv2/src/srnd/store.go @@ -441,43 +441,19 @@ func (self *articleStore) getMIMEHeader(messageID string) (hdr textproto.MIMEHea return hdr } -func (self *articleStore) ProcessMessage(wr io.Writer, msg io.Reader, spamfilter func(string) bool) error { - pr_in, pw_in := io.Pipe() - pr_out, pw_out := io.Pipe() - go func() { - e := self.spamd.Rewrite(pr_in, pw_out) - if e != nil { - log.Println("failed to check spam", e) - } - pw_out.Close() - pr_in.Close() - }() - go func() { - var buff [65536]byte - _, e := io.CopyBuffer(pw_in, msg, buff[:]) - if e != nil { - log.Println("failed to read entire message", e) - } - pw_in.Close() - }() - r := bufio.NewReader(pr_out) - m, err := readMIMEHeader(r) - defer pr_out.Close() - if err != nil { - return err - } - writeMIMEHeader(wr, m.Header) - err = read_message_body(m.Body, m.Header, self, wr, false, func(nntp NNTPMessage) { +func (self *articleStore) ProcessMessage(wr io.Writer, msg io.Reader, spamfilter func(string) bool) (err error) { + process := func(nntp NNTPMessage) { if !spamfilter(nntp.Message()) { err = errors.New("spam message") return } + hdr := nntp.MIMEHeader() err = self.RegisterPost(nntp) if err == nil { - pk := m.Header.Get("X-PubKey-Ed25519") + pk := hdr.Get("X-PubKey-Ed25519") if len(pk) > 0 { // signed and valid - err = self.RegisterSigned(getMessageID(m.Header), pk) + err = self.RegisterSigned(getMessageID(hdr), pk) if err != nil { log.Println("register signed failed", err) } @@ -485,8 +461,47 @@ func (self *articleStore) ProcessMessage(wr io.Writer, msg io.Reader, spamfilter } else { log.Println("error procesing message body", err) } - }) - return err + } + if self.spamd.Enabled() { + pr_in, pw_in := io.Pipe() + pr_out, pw_out := io.Pipe() + go func() { + e := self.spamd.Rewrite(pr_in, pw_out) + if e != nil { + log.Println("failed to check spam", e) + } + }() + go func() { + var buff [65536]byte + _, e := io.CopyBuffer(pw_in, msg, buff[:]) + if e != nil { + log.Println("failed to read entire message", e) + } + pw_in.Close() + pr_in.Close() + }() + r := bufio.NewReader(pr_out) + m, e := readMIMEHeader(r) + err = e + defer func() { + pr_out.Close() + }() + if err != nil { + return + } + writeMIMEHeader(wr, m.Header) + read_message_body(m.Body, m.Header, self, wr, false, process) + } else { + r := bufio.NewReader(msg) + m, e := readMIMEHeader(r) + err = e + if err != nil { + return + } + writeMIMEHeader(wr, m.Header) + read_message_body(m.Body, m.Header, self, wr, false, process) + } + return } func (self *articleStore) GetMessage(msgid string) (nntp NNTPMessage) { @@ -541,7 +556,7 @@ func read_message_body(body io.Reader, hdr map[string][]string, store ArticleSto partReader := multipart.NewReader(body, boundary) for { part, err := partReader.NextPart() - if part == nil { + if part == nil && err == io.EOF { callback(nntp) return nil } else if err == nil {