Archived
1
0
This repository has been archived on 2023-08-12. You can view files and clone it, but cannot push or open issues or pull requests.
Jeff Becker 3a6cbf9de6 move srndv2 to nntpchan repo with vendored deps so that nothing breaks every again
this deprecates the github.com/majestrate/srndv2 repo
2017-04-03 10:00:38 -04:00

1912 lines
66 KiB
Go

//
// postgres db backend
//
package srnd
/**
* TODO:
* ~ caching of board settings
* ~ caching of encrypted address info
* ~ multithreading check
* ~ checking for duplicate articles
*/
import (
"database/sql"
"encoding/hex"
"errors"
"fmt"
"github.com/lib/pq"
"log"
"math"
"net"
"os"
"strconv"
"strings"
"time"
)
// postgres database driver implementation
type PostgresDatabase struct {
conn *sql.DB
db_str string
stmt map[string]string
}
// create postgres database driver
func NewPostgresDatabase(host, port, user, password string) Database {
db := new(PostgresDatabase)
var err error
if len(user) > 0 {
if len(password) > 0 {
db.db_str = fmt.Sprintf("user=%s password='%s' host=%s port=%s client_encoding='UTF8'", user, password, host, port)
} else {
db.db_str = fmt.Sprintf("user=%s host=%s port=%s client_encoding='UTF8'", user, host, port)
}
} else {
if len(port) > 0 {
db.db_str = fmt.Sprintf("host=%s port=%s client_encoding='UTF8'", host, port)
} else {
db.db_str = fmt.Sprintf("host=%s client_encoding='UTF8'", host)
}
}
log.Println("Connecting to postgres...")
db.conn, err = sql.Open("postgres", db.db_str)
if err != nil {
log.Fatalf("can`not open connection to db: %s", err)
}
db.SetConnectionLifetime(30)
db.SetMaxOpenConns(30)
db.SetMaxIdleConns(10)
return db
}
func (db *PostgresDatabase) SetConnectionLifetime(seconds int) {
db.conn.SetConnMaxLifetime(time.Second * time.Duration(seconds))
}
func (db *PostgresDatabase) SetMaxOpenConns(n int) {
db.conn.SetMaxOpenConns(n)
}
func (db *PostgresDatabase) SetMaxIdleConns(n int) {
db.conn.SetMaxIdleConns(n)
}
// finalize all transactions
// close database connections
func (self *PostgresDatabase) Close() {
if self.conn != nil {
self.conn.Close()
self.conn = nil
}
}
const NewsgroupBanned = "NewsgroupBanned"
const ArticleBanned = "ArticleBanned"
const GetAllNewsgroups = "GetAllNewsgroups"
const GetPostsInGroup = "GetPostsInGroup"
const GetPostModel = "GetPostModel"
const GetArticlePubkey = "GetArticlePubkey"
const GetThreadModel = "GetThreadModel"
const GetThreadModelPubkeys = "GetThreadModelPubkeys"
const GetThreadModelAttachments = "GetThreadModelAttachments"
const DeleteArticle_1 = "DeleteArticle_1"
const DeleteArticle_2 = "DeleteArticle_2"
const DeleteArticle_3 = "DeleteArticle_3"
const DeleteArticle_4 = "DeleteArticle_4"
const DeleteArticle_5 = "DeleteArticle_5"
const DeleteThread = "DeleteThread"
const GetThreadReplyPostModels_1 = "GetThreadReplyPostModels_1"
const GetThreadReplyPostModels_2 = "GetThreadReplyPostModels_2"
const GetThreadReplies_1 = "GetThreadReplies_1"
const GetThreadReplies_2 = "GetThreadReplies_2"
const GetGroupThreads = "GetGroupThreads"
const GetLastBumpedThreadsPaginated_1 = "GetLastBumpedThreadsPaginated_1"
const GetLastBumpedThreadsPaginated_2 = "GetLastBumpedThreadsPaginated_2"
const HasNewsgroup = "HasNewsgroup"
const HasArticle = "HasArticle"
const HasArticleLocal = "HasArticleLocal"
const GetPostAttachments = "GetPostAttachments"
const GetPostAttachmentModels = "GetPostAttachmentModels"
const RegisterArticle_1 = "RegisterArticle_1"
const RegisterArticle_2 = "RegisterArticle_2"
const RegisterArticle_3 = "RegisterArticle_3"
const RegisterArticle_4 = "RegisterArticle_4"
const RegisterArticle_5 = "RegisterArticle_5"
const RegisterArticle_6 = "RegisterArticle_6"
const RegisterArticle_7 = "RegisterArticle_7"
const RegisterArticle_8 = "RegisterArticle_8"
const GetMessageIDByHeader = "GetMessageIDByHeader"
const RegisterSigned = "RegisterSigned"
const GetAllArticlesInGroup = "GetAllArticlesInGroup"
const GetAllArticles = "GetAllArticles"
const GetMessageIDByHash = "GetMessageIDByHash"
const CheckEncIPBanned = "CheckEncIPBanned"
const GetFirstAndLastForGroup = "GetFirstAndLastForGroup"
const GetMessageIDForNNTPID = "GetMessageIDForNNTPID"
const GetNNTPIDForMessageID = "GetNNTPIDForMessageID"
const IsExpired = "IsExpired"
const GetLastDaysPostsForGroup = "GetLastDaysPostsForGroup"
const GetLastDaysPosts = "GetLastDaysPosts"
const GetLastPostedPostModels = "GetLastPostedPostModels"
const GetMonthlyPostHistory = "GetMonthlyPostHistory"
const CheckNNTPLogin = "CheckNNTPLogin"
const CheckNNTPUserExists = "CheckNNTPUserExists"
const GetHeadersForMessage = "GetHeadersForMessage"
const CountAllArticlesInGroup = "CountAllArticlesInGroup"
const GetMessageIDByCIDR = "GetMessageIDByCIDR"
const GetMessageIDByEncryptedIP = "GetMessageIDByEncryptedIP"
const GetPostsBefore = "GetPostsBefore"
const SearchQuery_1 = "SearchQuery_1"
const SearchQuery_2 = "SearchQuery_2"
const SearchByHash_1 = "SearchByHash_1"
const SearchByHash_2 = "SearchByHash_2"
const GetNNTPPostsInGroup = "GetNNTPPostsInGroup"
const GetCitesByPostHashLike = "GetCitesByPostHashLike"
func (self *PostgresDatabase) prepareStatements() {
self.stmt = map[string]string{
NewsgroupBanned: "SELECT COUNT(newsgroup) FROM BannedGroups WHERE newsgroup = $1",
ArticleBanned: "SELECT COUNT(message_id) FROM BannedArticles WHERE message_id = $1",
GetAllNewsgroups: "SELECT name FROM Newsgroups",
GetPostsInGroup: "SELECT newsgroup, message_id, ref_id, name, subject, path, time_posted, message, addr FROM ArticlePosts WHERE newsgroup = $1 ORDER BY time_posted",
GetPostModel: "SELECT newsgroup, message_id, ref_id, name, subject, path, time_posted, message, addr FROM ArticlePosts WHERE message_id = $1 LIMIT 1",
GetArticlePubkey: "SELECT pubkey FROM ArticleKeys WHERE message_id = $1",
GetThreadModel: "SELECT ArticlePosts.newsgroup, ArticlePosts.message_id, ArticlePosts.name, ArticlePosts.subject, ArticlePosts.time_posted, ArticlePosts.message, ArticlePosts.addr FROM ArticlePosts WHERE ArticlePosts.message_id = $1 OR ArticlePosts.ref_id = $1 ORDER BY ArticlePosts.time_posted",
GetThreadModelPubkeys: "SELECT pubkey, message_id from ArticleKeys WHERE message_id IN ( SELECT message_id FROM ArticlePosts WHERE ref_id = $1 OR message_id = $1 )",
GetThreadModelAttachments: "SELECT filename, filepath, message_id from ArticleAttachments WHERE message_id IN ( SELECT message_id FROM ArticlePosts WHERE ref_id = $1 OR message_id = $1 )",
DeleteArticle_1: "DELETE FROM NNTPHeaders WHERE header_article_message_id = $1",
DeleteArticle_2: "DELETE FROM ArticleNumbers WHERE message_id = $1",
DeleteArticle_3: "DELETE FROM ArticlePosts WHERE message_id = $1",
DeleteArticle_4: "DELETE FROM ArticleKeys WHERE message_id = $1",
DeleteArticle_5: "DELETE FROM ArticleAttachments WHERE message_id = $1",
DeleteThread: "DELETE FROM ArticleThreads WHERE root_message_id = $1",
GetThreadReplyPostModels_1: "SELECT newsgroup, message_id, ref_id, name, subject, path, time_posted, message, addr FROM ArticlePosts WHERE message_id IN ( SELECT message_id FROM ArticlePosts WHERE ref_id = $1 ORDER BY time_posted DESC LIMIT $2 ) ORDER BY time_posted ASC",
GetThreadReplyPostModels_2: "SELECT newsgroup, message_id, ref_id, name, subject, path, time_posted, message, addr FROM ArticlePosts WHERE message_id IN ( SELECT message_id FROM ArticlePosts WHERE ref_id = $1 ) ORDER BY time_posted ASC",
GetThreadReplies_1: "SELECT message_id FROM ArticlePosts WHERE message_id IN ( SELECT message_id FROM ArticlePosts WHERE ref_id = $1 ORDER BY time_posted DESC LIMIT $2 ) ORDER BY time_posted ASC",
GetThreadReplies_2: "SELECT message_id FROM ArticlePosts WHERE message_id IN ( SELECT message_id FROM ArticlePosts WHERE ref_id = $1 ) ORDER BY time_posted ASC",
GetGroupThreads: "SELECT message_id FROM ArticlePosts WHERE newsgroup = $1 AND ref_id = '' ",
GetLastBumpedThreadsPaginated_1: "SELECT root_message_id, newsgroup FROM ArticleThreads WHERE newsgroup = $1 ORDER BY last_bump DESC LIMIT $2",
GetLastBumpedThreadsPaginated_2: "SELECT root_message_id, newsgroup FROM ArticleThreads WHERE newsgroup != 'ctl' ORDER BY last_bump DESC LIMIT $1",
HasNewsgroup: "SELECT COUNT(name) FROM Newsgroups WHERE name = $1",
HasArticle: "SELECT COUNT(message_id) FROM Articles WHERE message_id = $1",
HasArticleLocal: "SELECT COUNT(message_id) FROM ArticlePosts WHERE message_id = $1",
GetPostAttachments: "SELECT filepath FROM ArticleAttachments WHERE message_id = $1",
GetPostAttachmentModels: "SELECT filepath, filename FROM ArticleAttachments WHERE message_id = $1",
RegisterArticle_1: "INSERT INTO Articles (message_id, message_id_hash, message_newsgroup, time_obtained, message_ref_id) VALUES($1, $2, $3, $4, $5)",
RegisterArticle_2: "UPDATE Newsgroups SET last_post = $1 WHERE name = $2",
RegisterArticle_3: "INSERT INTO ArticlePosts(newsgroup, message_id, ref_id, name, subject, path, time_posted, message, addr) VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9)",
RegisterArticle_4: "INSERT INTO ArticleThreads(root_message_id, last_bump, last_post, newsgroup) VALUES($1, $2, $2, $3)",
RegisterArticle_5: "SELECT COUNT(*) FROM ArticlePosts WHERE ref_id = $1",
RegisterArticle_6: "UPDATE ArticleThreads SET last_bump = $2 WHERE root_message_id = $1",
RegisterArticle_7: "UPDATE ArticleThreads SET last_post = $2 WHERE root_message_id = $1",
RegisterArticle_8: "INSERT INTO ArticleAttachments(message_id, sha_hash, filename, filepath) VALUES($1, $2, $3, $4)",
GetMessageIDByHeader: "SELECT header_article_message_id FROM NNTPHeaders WHERE header_name = $1 AND header_value = $2",
RegisterSigned: "INSERT INTO ArticleKeys(message_id, pubkey) VALUES ($1, $2)",
GetAllArticlesInGroup: "SELECT message_id FROM ArticlePosts WHERE newsgroup = $1",
GetAllArticles: "SELECT message_id, newsgroup FROM ArticlePosts",
GetMessageIDByHash: "SELECT message_id, message_newsgroup FROM Articles WHERE message_id_hash = $1 LIMIT 1",
CheckEncIPBanned: "SELECT COUNT(*) FROM EncIPBans WHERE encaddr = $1",
GetFirstAndLastForGroup: "WITH x(min_no, max_no) AS ( SELECT MIN(message_no) AS min_no, MAX(message_no) AS max_no FROM ArticleNumbers WHERE newsgroup = $1) SELECT CASE WHEN min_no IS NULL THEN 0 ELSE min_no END AS min_no FROM x UNION SELECT CASE WHEN max_no IS NULL THEN 1 ELSE max_no END AS max_no FROM x",
GetMessageIDForNNTPID: "SELECT message_id FROM ArticleNumbers WHERE newsgroup = $1 AND message_no = $2 LIMIT 1",
GetNNTPIDForMessageID: "SELECT message_no FROM ArticleNumbers WHERE newsgroup = $1 AND message_id = $2 LIMIT 1",
IsExpired: "WITH x(msgid) AS ( SELECT message_id FROM Articles WHERE message_id = $1 INTERSECT ( SELECT message_id FROM ArticlePosts WHERE message_id = $1 ) ) SELECT COUNT(*) FROM x",
GetLastDaysPostsForGroup: "SELECT COUNT(*) FROM ArticlePosts WHERE time_posted < $1 AND time_posted > $2 AND newsgroup = $3",
GetLastDaysPosts: "SELECT COUNT(*) FROM ArticlePosts WHERE time_posted < $1 AND time_posted > $2",
GetLastPostedPostModels: "SELECT newsgroup, message_id, ref_id, name, subject, path, time_posted, message, addr FROM ArticlePosts WHERE newsgroup != 'ctl' ORDER BY time_posted DESC LIMIT $1",
GetMonthlyPostHistory: "SELECT time_posted FROM ArticlePosts WHERE time_posted > 0 ORDER BY time_posted ASC LIMIT 1",
CheckNNTPLogin: "SELECT login_hash, login_salt FROM NNTPUsers WHERE username = $1",
CheckNNTPUserExists: "SELECT COUNT(username) FROM NNTPUsers WHERE username = $1",
GetHeadersForMessage: "SELECT header_name, header_value FROM NNTPHeaders WHERE header_article_message_id = $1",
CountAllArticlesInGroup: "SELECT COUNT(message_id) FROM ArticlePosts WHERE newsgroup = $1",
GetMessageIDByCIDR: "SELECT message_id FROM ArticlePosts WHERE addr IN ( SELECT encaddr FROM EncryptedAddrs WHERE addr_cidr <<= cidr($1) )",
GetMessageIDByEncryptedIP: "SELECT message_id FROM ArticlePosts WHERE addr = $1",
GetPostsBefore: "SELECT message_id FROM ArticlePosts WHERE time_posted < $1",
SearchQuery_1: "SELECT newsgroup, message_id, ref_id FROM ArticlePosts WHERE message LIKE $1 ORDER BY time_posted DESC",
SearchQuery_2: "SELECT newsgroup, message_id, ref_id FROM ArticlePosts WHERE newsgroup = $1 AND message LIKE $2 ORDER BY time_posted DESC",
SearchByHash_1: "SELECT message_newsgroup, message_id, message_ref_id FROM Articles WHERE message_id_hash LIKE $1 ORDER BY time_obtained DESC",
SearchByHash_2: "SELECT message_newsgroup, message_id, message_ref_id FROM Articles WHERE message_newsgroup = $2 AND message_id_hash LIKE $1 ORDER BY time_obtained DESC",
GetNNTPPostsInGroup: "SELECT message_no, ArticlePosts.message_id, subject, time_posted, ref_id, name, path FROM ArticleNumbers INNER JOIN ArticlePosts ON ArticleNumbers.message_id = ArticlePosts.message_id WHERE ArticlePosts.newsgroup = $1 ORDER BY message_no",
GetCitesByPostHashLike: "SELECT message_id, message_ref_id FROM Articles WHERE message_id_hash LIKE $1",
}
}
func (self *PostgresDatabase) CreateTables() {
for {
version := self.getDBVersion()
if version == -1 {
// no tables
self.createTablesV0()
self.upgrade0to1()
} else if version == 0 {
// upgrade to version 1
self.upgrade0to1()
} else if version == 1 {
// upgrade to version 2
self.upgrade1to2()
} else if version == 2 {
// upgrade to version 3
self.upgrade2to3()
} else if version == 3 {
// update to version 4
self.upgrade3to4()
} else if version == 4 {
// update to version 5
self.upgrade4to5()
} else if version == 5 {
// upgrade to version 6
self.upgrade5to6()
} else if version == 6 {
// upgrade to version 7
self.upgrade6to7()
} else if version == 7 {
// we are up to date
log.Println("we are up to date at version", version)
break
}
}
self.prepareStatements()
}
func (self *PostgresDatabase) upgrade1to2() {
log.Println("migrating... 1 -> 2")
var err error
tables := make(map[string]string)
tables["NNTPUsers"] = `(
username VARCHAR(255) PRIMARY KEY,
login_hash VARCHAR(255) NOT NULL,
login_salt VARCHAR(255) NOT NULL
)`
table_order := []string{"NNTPUsers"}
for _, table := range table_order {
q := tables[table]
// create table
_, err = self.conn.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s", table, q))
if err != nil {
log.Fatalf("cannot create table %s, %s, login was '%s'", table, err, self.db_str)
}
}
self.setDBVersion(2)
}
func (self *PostgresDatabase) upgrade2to3() {
log.Println("migrating... 2 -> 3")
var err error
tables := make(map[string]string)
tables["NNTPHeaders"] = `(
header_name VARCHAR(255) NOT NULL,
header_value TEXT NOT NULL,
header_article_message_id VARCHAR(255) NOT NULL,
FOREIGN KEY(header_article_message_id) REFERENCES ArticlePosts(message_id)
)`
table_order := []string{"NNTPHeaders"}
for _, table := range table_order {
q := tables[table]
// create table
_, err = self.conn.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s", table, q))
if err != nil {
log.Fatalf("cannot create table %s, %s, login was '%s'", table, err, self.db_str)
}
}
cmds := []string{
"CREATE INDEX ON NNTPHeaders(header_name)",
}
for _, cmd := range cmds {
_, err = self.conn.Exec(cmd)
checkError(err)
}
self.setDBVersion(3)
}
func (self *PostgresDatabase) upgrade5to6() {
log.Println("migrating... 5 -> 6")
tables := make(map[string]string)
// public key properties, key value pair: pubkey -> status
tables["PubkeyProperties"] = `(
pubkey VARCHAR(255) PRIMARY KEY,
status VARCHAR(255) NOT NULL
)`
// ledger of public key property modification events
tables["PubkeyModifyEvents"] = `(
source_pubkey VARCHAR(255) NOT NULL,
target_pubkey VARCHAR(255) NOT NULL,
event_time BIGINT NOT NULL,
status VARCHAR(255) NOT NULL,
id BIGSERIAL PRIMARY KEY
)`
table_order := []string{"PubkeyProperties", "PubkeyModifyEvents"}
for _, t := range table_order {
q := tables[t]
// create table
_, err := self.conn.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s", t, q))
if err != nil {
log.Fatalf("cannot create table %s, %s", t, err)
}
}
self.setDBVersion(6)
}
func (self *PostgresDatabase) upgrade4to5() {
log.Println("migrating... 4 -> 5")
cmds := []string{
"ALTER TABLE EncryptedAddrs DROP COLUMN IF EXISTS addr_cidr",
"ALTER TABLE EncryptedAddrs ADD COLUMN addr_cidr cidr",
"UPDATE EncryptedAddrs AS a SET addr_cidr = e.cidr FROM ( SELECT cidr(addr), addr FROM EncryptedAddrs) AS e WHERE e.addr = a.addr",
}
for _, cmd := range cmds {
_, err := self.conn.Exec(cmd)
if err != nil {
log.Fatalf("failed to execute query `%s`, %s", cmd, err.Error())
}
}
self.setDBVersion(5)
}
func (self *PostgresDatabase) upgrade3to4() {
log.Println("migrating... 3 -> 4")
tables := make(map[string]string)
tables["ArticleNumbers"] = `(
newsgroup VARCHAR(255) NOT NULL,
message_id VARCHAR(255) NOT NULL,
message_no BIGINT NOT NULL,
FOREIGN KEY (newsgroup) REFERENCES Newsgroups(name),
FOREIGN KEY (message_id) REFERENCES ArticlePosts(message_id)
)`
table_order := []string{"ArticleNumbers"}
cmds := []string{"CREATE INDEX ON ArticleNumbers(message_no)"}
for _, table := range table_order {
_, err := self.conn.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s", table, tables[table]))
if err != nil {
log.Fatalf("cannot create table %s: %s", table, err.Error())
}
}
for _, cmd := range cmds {
_, err := self.conn.Exec(cmd)
if err != nil {
log.Fatalf("failed to execute query: %s, %s", cmd, err.Error())
}
}
log.Println("migrating post numbers, this can take a bit DO NOT INTERRUPT")
rows, err := self.conn.Query("SELECT message_id, newsgroup FROM ArticlePosts ORDER BY time_posted DESC")
if err != nil {
log.Fatalf("could not query ArticlePosts table: %s", err.Error())
}
counter := int64(0)
for rows.Next() {
counter++
var msgid, group string
err = rows.Scan(&msgid, &group)
if err != nil {
log.Fatalf("could not scan row: %s", err.Error())
}
err = self.registerNNTPNumber(group, msgid)
if err != nil {
log.Fatalf("could not migrate article %s in %s, %s", msgid, group, err.Error())
}
if counter%100 == 0 {
log.Println("migrated ", counter)
}
}
log.Println("total migrated posts: ", counter)
self.setDBVersion(4)
}
func (self *PostgresDatabase) upgrade0to1() {
// begin >:D
log.Println("migrating... 0 -> 1")
var err error
cmds := []string{
// newsgroups table
"CREATE INDEX ON Newsgroups(name)",
// article posts table
"ALTER TABLE ArticlePosts DROP COLUMN IF EXISTS addr",
"ALTER TABLE ArticlePosts ADD COLUMN addr VARCHAR(255)",
"ALTER TABLE ArticlePosts DROP CONSTRAINT IF EXISTS group_depend",
"ALTER TABLE ArticlePosts ADD CONSTRAINT group_depend FOREIGN KEY(newsgroup) REFERENCES Newsgroups(name) ON DELETE CASCADE",
"ALTER TABLE ArticlePosts DROP CONSTRAINT IF EXISTS msgid_pk",
"ALTER TABLE ArticlePosts ADD CONSTRAINT msgid_pk PRIMARY KEY(message_id)",
"CREATE INDEX ON ArticlePosts(ref_id)",
// article keys table
"DELETE FROM ArticleKeys WHERE message_id NOT IN ( SELECT message_id FROM ArticlePosts )",
"ALTER TABLE ArticleKeys DROP CONSTRAINT IF EXISTS msgid_depend",
"ALTER TABLE ArticleKeys ADD CONSTRAINT msgid_depend FOREIGN KEY(message_id) REFERENCES ArticlePosts(message_id) ON DELETE CASCADE",
// article threads table
"ALTER TABLE ArticleThreads DROP CONSTRAINT IF EXISTS msgid_depend",
"ALTER TABLE ArticleThreads DROP CONSTRAINT IF EXISTS group_depend",
"DELETE FROM ArticleThreads WHERE root_message_id NOT IN ( SELECT message_id FROM ArticlePosts )",
"ALTER TABLE ArticleThreads ADD CONSTRAINT msgid_depend FOREIGN KEY(root_message_id) REFERENCES ArticlePosts(message_id) ON DELETE CASCADE",
"ALTER TABLE ArticleThreads ADD CONSTRAINT group_depend FOREIGN KEY(newsgroup) REFERENCES Newsgroups(name) ON DELETE CASCADE",
// article attachments table
"ALTER TABLE ArticleAttachments DROP CONSTRAINT IF EXISTS msgid_depend",
"DELETE FROM ArticleAttachments WHERE message_id NOT IN ( SELECT message_id FROM ArticlePosts )",
"ALTER TABLE ArticleAttachments ADD CONSTRAINT msgid_depend FOREIGN KEY(message_id) REFERENCES ArticlePosts(message_id) ON DELETE CASCADE",
}
for _, cmd := range cmds {
_, err = self.conn.Exec(cmd)
checkError(err)
}
self.setDBVersion(1)
}
func (self *PostgresDatabase) upgrade6to7() {
tables := make(map[string]string)
log.Println("migrating... 6 -> 7")
// table for thumbnail info
tables["Thumbnails"] = `(
sha_hash VARCHAR(128) PRIMARY KEY,
width INTEGER NOT NULL,
height INTEGER NOT NULL
)`
tables["Cites"] = `(
post_msgid VARCHAR(255) NOT NULL,
cite_msgid VARCHAR(255) NOT NULL
)`
var err error
table_order := []string{"Thumbnails", "Cites"}
for _, table := range table_order {
q := tables[table]
// create table
_, err = self.conn.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s", table, q))
if err != nil {
log.Fatalf("cannot create table %s, %s, login was '%s'", table, err, self.db_str)
}
}
// make indexes
cmds := []string{
"CREATE INDEX ON Thumbnails(sha_hash)",
"CREATE INDEX ON Cites(cite_msgid)",
}
for _, cmd := range cmds {
_, err = self.conn.Exec(cmd)
checkError(err)
}
/*
// rebuild ALL cites
log.Println("!!! Building Cites table, this will take a long time. Do NOT interrupt !!!")
post_counter := 0
cite_counter := 0
var rows *sql.Rows
rows, err = self.conn.Query("SELECT message, message_id FROM ArticlePosts")
if err != nil {
log.Fatalf("error migrating: %s", err)
}
cites := make(map[string][]string)
for rows.Next() {
var msg, msgid string
rows.Scan(&msg, &msgid)
c := findBacklinks(msg)
cite_counter += len(c)
cites[msgid] = c
post_counter++
if post_counter%100 == 0 {
log.Printf("selected %d messages %d cites", post_counter, cite_counter)
}
}
rows.Close()
log.Printf("calculating %d cites ...", cite_counter)
cites_insert := make(map[string][2]string)
citemap_counter := 0
for msgid, citelist := range cites {
for _, cite := range citelist {
cite = cite[2:]
citeLike := cite + "%"
var cite_msgid string
err = self.conn.QueryRow("SELECT message_id FROM Articles WHERE message_id_hash LIKE $1 LIMIT 1", citeLike).Scan(&cite_msgid)
if err != nil {
continue
//log.Fatalf("failed to select cite like %s: %s", citeLike, err)
}
cites_insert[msgid+cite_msgid] = [2]string{msgid, cite_msgid}
citemap_counter++
if cite_counter%100 == 0 {
log.Printf("calculated %d cites", cite_counter)
}
}
}
log.Printf("inserting %d cites ...", cite_counter)
txn, err := self.conn.Begin()
if err != nil {
log.Fatalf("failed to begin insert: %s", err)
}
st, err := txn.Prepare(pq.CopyIn("Cites", "post_msgid", "cite_msgid"))
if err != nil {
log.Fatalf("failed to prepare statement: %s", err)
}
for _, ct := range cites_insert {
_, err = st.Exec(ct[0], ct[1])
if err != nil {
log.Fatalf("failed to insert with prepared statement: %s", err)
}
}
_, err = st.Exec()
if err != nil {
log.Fatalf("failed to excute statement: %s", err)
}
err = st.Close()
if err != nil {
log.Fatalf("failed to close statement: %s", err)
}
log.Println("committing...")
err = txn.Commit()
if err != nil {
log.Fatalf("failed to commit transaction: %s", err)
}
log.Println("insertion done")
*/
self.setDBVersion(7)
}
// create all tables for database version 0
func (self *PostgresDatabase) createTablesV0() {
tables := make(map[string]string)
// table of active newsgroups
tables["Newsgroups"] = `(
name VARCHAR(255) PRIMARY KEY,
last_post INTEGER NOT NULL,
restricted BOOLEAN
)`
// table for ip and their encryption key
tables["EncryptedAddrs"] = `(
enckey VARCHAR(255) NOT NULL,
addr VARCHAR(255) NOT NULL,
encaddr VARCHAR(255) NOT NULL
)`
// table for articles that have been banned
tables["BannedArticles"] = `(
message_id VARCHAR(255) PRIMARY KEY,
time_banned INTEGER NOT NULL,
ban_reason TEXT NOT NULL
)`
// table for banned newsgroups
tables["BannedGroups"] = `(
newsgroup VARCHAR(255) PRIMARY KEY,
time_banned INTEGER NOT NULL
)`
// table for storing nntp article meta data
tables["Articles"] = `(
message_id VARCHAR(255) PRIMARY KEY,
message_id_hash VARCHAR(40) UNIQUE NOT NULL,
message_newsgroup VARCHAR(255),
message_ref_id VARCHAR(255),
time_obtained INTEGER NOT NULL,
FOREIGN KEY(message_newsgroup) REFERENCES Newsgroups(name)
)`
// table for storing nntp article post content
tables["ArticlePosts"] = `(
newsgroup VARCHAR(255),
message_id VARCHAR(255),
ref_id VARCHAR(255),
name TEXT NOT NULL,
subject TEXT NOT NULL,
path TEXT NOT NULL,
time_posted INTEGER NOT NULL,
message TEXT NOT NULL
)`
// table for storing nntp article posts to pubkey mapping
tables["ArticleKeys"] = `(
message_id VARCHAR(255) NOT NULL,
pubkey VARCHAR(255) NOT NULL
)`
// table for thread state
tables["ArticleThreads"] = `(
newsgroup VARCHAR(255) NOT NULL,
root_message_id VARCHAR(255) NOT NULL,
last_bump INTEGER NOT NULL,
last_post INTEGER NOT NULL
)`
// table for storing nntp article attachment info
tables["ArticleAttachments"] = `(
message_id VARCHAR(255),
sha_hash VARCHAR(128) NOT NULL,
filename TEXT NOT NULL,
filepath TEXT NOT NULL
)`
// table for storing current permissions of mod pubkeys
tables["ModPrivs"] = `(
pubkey VARCHAR(255),
newsgroup VARCHAR(255),
permission VARCHAR(255)
)`
// table for storing moderation events
tables["ModLogs"] = `(
pubkey VARCHAR(255),
action VARCHAR(255),
target VARCHAR(255),
time INTEGER
)`
// ip range bans
tables["IPBans"] = `(
addr cidr NOT NULL,
made INTEGER NOT NULL,
expires INTEGER NOT NULL
)`
// bans for encrypted addresses that we don't have the ip for
tables["EncIPBans"] = `(
encaddr VARCHAR(255) NOT NULL,
made INTEGER NOT NULL,
expires INTEGER NOT NULL
)`
tables["Settings"] = `(
name VARCHAR(255) NOT NULL,
value VARCHAR(255) NOT NULL
)`
var err error
table_order := []string{"Newsgroups", "BannedGroups", "BannedArticles", "IPBans", "EncIPBans", "Settings", "Articles", "ArticlePosts", "ArticleKeys", "ArticleThreads", "ArticleAttachments", "ModPrivs", "ModLogs", "EncryptedAddrs"}
for _, table := range table_order {
q := tables[table]
// create table
_, err = self.conn.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s", table, q))
if err != nil {
log.Fatalf("cannot create table %s, %s, login was '%s'", table, err, self.db_str)
}
}
// create indexes
_, err = self.conn.Exec("CREATE INDEX IF NOT EXISTS ON ArticleThreads(root_message_id)")
_, err = self.conn.Exec("CREATE INDEX IF NOT EXISTS ON ArticleAttachments(message_id)")
_, err = self.conn.Exec("CREATE INDEX IF NOT EXISTS ON ArticlePosts(message_id)")
_, err = self.conn.Exec("CREATE INDEX IF NOT EXISTS ON Articles(message_id)")
_, err = self.conn.Exec("CREATE INDEX IF NOT EXISTS ON Newsgroups(name)")
self.setDBVersion(0)
}
// set what the current database version is
func (self *PostgresDatabase) setDBVersion(version int) (err error) {
log.Println("set db version to", version)
_, err = self.conn.Exec("DELETE FROM Settings WHERE name = $1", "version")
_, err = self.conn.Exec("INSERT INTO Settings(name, value) VALUES($1, $2)", "version", fmt.Sprintf("%d", version))
return
}
// get the current database version
func (self *PostgresDatabase) getDBVersion() (version int) {
var val string
var vers int64
err := self.conn.QueryRow("SELECT value FROM Settings WHERE name = $1", "version").Scan(&val)
if err == nil {
vers, err = strconv.ParseInt(val, 10, 32)
if err == nil {
version = int(vers)
} else {
log.Fatal("cannot figure out db version", err)
}
} else {
version = -1
}
return
}
func (self *PostgresDatabase) BanNewsgroup(group string) (err error) {
_, err = self.conn.Exec("INSERT INTO BannedGroups(newsgroup, time_banned) VALUES($1, $2)", group, timeNow())
return
}
func (self *PostgresDatabase) UnbanNewsgroup(group string) (err error) {
_, err = self.conn.Exec("DELETE FROM BannedGroups WHERE newsgroup = $1", group)
return
}
func (self *PostgresDatabase) NewsgroupBanned(group string) (banned bool, err error) {
var count int64
err = self.conn.QueryRow(self.stmt[NewsgroupBanned], group).Scan(&count)
banned = count > 0
return
}
func (self *PostgresDatabase) NukeNewsgroup(group string, store ArticleStore) {
// first delete all thread presences
_, _ = self.conn.Exec("DELETE FROM ArticleThreads WHERE newsgroup = $1", group)
// get all articles in that newsgroup
chnl := make(chan ArticleEntry, 24)
go func() {
self.GetAllArticlesInGroup(group, chnl)
close(chnl)
}()
// for each article delete it fully
for {
article, ok := <-chnl
if ok {
msgid := article.MessageID()
log.Println("delete", msgid)
// remove article from store
fname := store.GetFilename(msgid)
os.Remove(fname)
// get all attachments
for _, att := range self.GetPostAttachments(msgid) {
// remove attachment
log.Println("delete attachment", att)
os.Remove(store.ThumbnailFilepath(att))
os.Remove(store.AttachmentFilepath(att))
}
// delete from database
self.DeleteArticle(msgid)
} else {
log.Println("nuke of", group, "done")
return
}
}
}
func (self *PostgresDatabase) AddModPubkey(pubkey string) error {
if self.CheckModPubkey(pubkey) {
log.Println("did not add pubkey", pubkey, "already exists")
return nil
}
_, err := self.conn.Exec("INSERT INTO ModPrivs(pubkey, newsgroup, permission) VALUES ( $1, $2, $3 )", pubkey, "ctl", "login")
return err
}
func (self *PostgresDatabase) GetGroupForMessage(message_id string) (group string, err error) {
err = self.conn.QueryRow("SELECT newsgroup FROM ArticlePosts WHERE message_id = $1", message_id).Scan(&group)
return
}
func (self *PostgresDatabase) GetPageForRootMessage(root_message_id string) (group string, page int64, err error) {
err = self.conn.QueryRow("SELECT newsgroup FROM ArticleThreads WHERE root_message_id = $1", root_message_id).Scan(&group)
if err == nil {
perpage, _ := self.GetPagesPerBoard(group)
err = self.conn.QueryRow("WITH thread(bump) AS (SELECT last_bump FROM ArticleThreads WHERE root_message_id = $1 ) SELECT COUNT(*) FROM ( SELECT last_bump FROM ArticleThreads INNER JOIN thread ON (thread.bump <= ArticleThreads.last_bump AND newsgroup = $2 ) ) AS amount", root_message_id, group).Scan(&page)
return group, page / int64(perpage), err
}
return
}
func (self *PostgresDatabase) GetInfoForMessage(msgid string) (root string, newsgroup string, page int64, err error) {
err = self.conn.QueryRow("SELECT newsgroup, ref_id FROM ArticlePosts WHERE message_id = $1", msgid).Scan(&newsgroup, &root)
if err == nil {
if root == "" {
root = msgid
}
perpage, _ := self.GetPagesPerBoard(newsgroup)
err = self.conn.QueryRow("WITH thread(bump) AS (SELECT last_bump FROM ArticleThreads WHERE root_message_id = $1 ) SELECT COUNT(*) FROM ( SELECT last_bump FROM ArticleThreads INNER JOIN thread ON (thread.bump <= ArticleThreads.last_bump AND newsgroup = $2 ) ) AS amount", root, newsgroup).Scan(&page)
page = page / int64(perpage)
}
return
}
func (self *PostgresDatabase) CheckModPubkeyGlobal(pubkey string) bool {
var result int64
_ = self.conn.QueryRow("SELECT COUNT(*) FROM ModPrivs WHERE pubkey = $1 AND newsgroup = $2 AND permission = $3", pubkey, "overchan", "all").Scan(&result)
return result > 0
}
func (self *PostgresDatabase) CheckModPubkeyCanModGroup(pubkey, newsgroup string) bool {
var result int64
_ = self.conn.QueryRow("SELECT COUNT(*) FROM ModPrivs WHERE pubkey = $1 AND newsgroup = $2", pubkey, newsgroup).Scan(&result)
return result > 0
}
func (self *PostgresDatabase) CountPostsInGroup(newsgroup string, time_frame int64) (result int64) {
if time_frame > 0 {
time_frame = timeNow() - time_frame
} else if time_frame < 0 {
time_frame = 0
}
self.conn.QueryRow("SELECT COUNT(*) FROM ArticlePosts WHERE time_posted > $2 AND newsgroup = $1", newsgroup, time_frame).Scan(&result)
return
}
func (self *PostgresDatabase) CheckModPubkey(pubkey string) bool {
var result int64
self.conn.QueryRow("SELECT COUNT(*) FROM ModPrivs WHERE pubkey = $1", pubkey).Scan(&result)
return result > 0
}
func (self *PostgresDatabase) BanArticle(messageID, reason string) error {
if self.ArticleBanned(messageID) {
log.Println(messageID, "already banned")
return nil
}
_, err := self.conn.Exec("INSERT INTO BannedArticles(message_id, time_banned, ban_reason) VALUES($1, $2, $3)", messageID, timeNow(), reason)
return err
}
func (self *PostgresDatabase) ArticleBanned(messageID string) (result bool) {
var count int64
err := self.conn.QueryRow(self.stmt[ArticleBanned], messageID).Scan(&count)
if err == nil {
result = count > 0
} else {
log.Println("error checking if article is banned", err)
}
return
}
func (self *PostgresDatabase) GetEncAddress(addr string) (encaddr string, err error) {
var count int64
err = self.conn.QueryRow("SELECT COUNT(addr) FROM EncryptedAddrs WHERE addr = $1", addr).Scan(&count)
if err == nil {
if count == 0 {
// needs to be inserted
var key string
key, encaddr = newAddrEnc(addr)
if len(encaddr) == 0 {
err = errors.New("failed to generate new encryption key")
} else {
_, err = self.conn.Exec("INSERT INTO EncryptedAddrs(enckey, encaddr, addr, addr_cidr) VALUES($1, $2, $3, cidr($4))", key, encaddr, addr, addr+"/32")
}
} else {
err = self.conn.QueryRow("SELECT encAddr FROM EncryptedAddrs WHERE addr = $1 LIMIT 1", addr).Scan(&encaddr)
}
}
return
}
func (self *PostgresDatabase) GetEncKey(encAddr string) (enckey string, err error) {
err = self.conn.QueryRow("SELECT enckey FROM EncryptedAddrs WHERE encaddr = $1 LIMIT 1", encAddr).Scan(&enckey)
return
}
func (self *PostgresDatabase) CheckIPBanned(addr string) (banned bool, err error) {
var amount int64
err = self.conn.QueryRow("SELECT COUNT(*) FROM IPBans WHERE addr >>= $1 ", addr).Scan(&amount)
banned = amount > 0
return
}
func (self *PostgresDatabase) GetIPAddress(encaddr string) (addr string, err error) {
var count int64
err = self.conn.QueryRow("SELECT COUNT(encAddr) FROM EncryptedAddrs WHERE encAddr = $1", encaddr).Scan(&count)
if err == nil && count > 0 {
err = self.conn.QueryRow("SELECT addr FROM EncryptedAddrs WHERE encAddr = $1 LIMIT 1", encaddr).Scan(&addr)
}
return
}
func (self *PostgresDatabase) MarkModPubkeyGlobal(pubkey string) (err error) {
if len(pubkey) != 64 {
err = errors.New("invalid pubkey length")
return
}
if self.CheckModPubkeyGlobal(pubkey) {
// already marked
log.Println("pubkey already marked as global", pubkey)
} else {
_, err = self.conn.Exec("INSERT INTO ModPrivs(pubkey, newsgroup, permission) VALUES ( $1, $2, $3 )", pubkey, "overchan", "all")
}
return
}
func (self *PostgresDatabase) MarkPubkeyAdmin(pubkey string) (err error) {
var admin bool
admin, err = self.CheckAdminPubkey(pubkey)
if err == nil && !admin {
// add as admin since it's not already there
_, err = self.conn.Exec("INSERT INTO ModPrivs(pubkey, newsgroup, permission) VALUES ( $1, $2, $3 )", pubkey, "overchan", "admin")
}
return
}
func (self *PostgresDatabase) UnmarkPubkeyAdmin(pubkey string) (err error) {
_, err = self.conn.Exec("DELETE FROM ModPrivs WHERE pubkey = $1 AND permission = $2", pubkey, "admin")
return
}
func (self *PostgresDatabase) CheckAdminPubkey(pubkey string) (admin bool, err error) {
var count int64
err = self.conn.QueryRow("SELECT COUNT(pubkey) FROM ModPrivs WHERE pubkey = $1 AND permission = $2", pubkey, "admin").Scan(&count)
if err == nil {
admin = count > 0
}
return
}
func (self *PostgresDatabase) UnMarkModPubkeyGlobal(pubkey string) (err error) {
if self.CheckModPubkeyGlobal(pubkey) {
// already marked
_, err = self.conn.Exec("DELETE FROM ModPrivs WHERE pubkey = $1 AND newsgroup = $2 AND permission = $3", pubkey, "overchan", "all")
} else {
err = errors.New("public key not marked as global")
}
return
}
func (self *PostgresDatabase) CountThreadReplies(root_message_id string) (repls int64) {
_ = self.conn.QueryRow("SELECT COUNT(message_id) FROM ArticlePosts WHERE ref_id = $1", root_message_id).Scan(&repls)
return
}
func (self *PostgresDatabase) GetRootPostsForExpiration(newsgroup string, threadcount int) (roots []string) {
rows, err := self.conn.Query("SELECT root_message_id FROM ArticleThreads WHERE newsgroup = $1 AND root_message_id NOT IN ( SELECT root_message_id FROM ArticleThreads WHERE newsgroup = $1 ORDER BY last_bump DESC LIMIT $2)", newsgroup, threadcount)
if err == nil {
// get results
for rows.Next() {
var root string
rows.Scan(&root)
roots = append(roots, root)
log.Println(root)
}
rows.Close()
} else {
log.Println("failed to get root posts for expiration", err)
}
// return the list of expired roots
return
}
// register an article in a newsgroup with the ArticleNumbers table
func (self *PostgresDatabase) registerNNTPNumber(group, msgid string) (err error) {
_, err = self.conn.Exec("WITH x(msg_no) AS ( SELECT MAX(message_no) AS msg_no FROM ArticleNumbers WHERE newsgroup = $1 ) INSERT INTO ArticleNumbers(newsgroup, message_id, message_no) VALUES($1, $2, (SELECT CASE WHEN msg_no IS NULL THEN 0 ELSE msg_no END FROM x) + 1 )", group, msgid)
return
}
func (self *PostgresDatabase) GetAllNewsgroups() (groups []string) {
rows, err := self.conn.Query(self.stmt[GetAllNewsgroups])
if err == nil {
for rows.Next() {
var group string
rows.Scan(&group)
groups = append(groups, group)
}
rows.Close()
}
return
}
func (self *PostgresDatabase) GetGroupPageCount(newsgroup string) int64 {
var count int64
err := self.conn.QueryRow("SELECT COUNT(*) FROM ArticleThreads WHERE newsgroup = $1", newsgroup).Scan(&count)
if err != nil {
log.Println("failed to count pages in group", newsgroup, err)
}
// divide by threads per page
return int64(math.Ceil(float64(count/10)) + 1)
}
// only fetches root posts
// does not update the thread contents
func (self *PostgresDatabase) GetGroupForPage(prefix, frontend, newsgroup string, pageno, perpage int) BoardModel {
var threads []ThreadModel
pages := self.GetGroupPageCount(newsgroup)
rows, err := self.conn.Query("WITH roots(root_message_id, last_bump) AS ( SELECT root_message_id, last_bump FROM ArticleThreads WHERE newsgroup = $1 ORDER BY last_bump DESC OFFSET $2 LIMIT $3 ) SELECT p.newsgroup, p.message_id, p.name, p.subject, p.path, p.time_posted, p.message, p.addr FROM ArticlePosts p INNER JOIN roots ON ( roots.root_message_id = p.message_id ) ORDER BY roots.last_bump DESC", newsgroup, pageno*perpage, perpage)
if err == nil {
for rows.Next() {
p := &post{
prefix: prefix,
}
rows.Scan(&p.board, &p.Message_id, &p.PostName, &p.PostSubject, &p.MessagePath, &p.Posted, &p.PostMessage, &p.addr)
p.Parent = p.Message_id
p.op = true
_ = self.conn.QueryRow("SELECT pubkey FROM ArticleKeys WHERE message_id = $1", p.Message_id).Scan(&p.Key)
p.sage = isSage(p.PostSubject)
atts := self.GetPostAttachmentModels(prefix, p.Message_id)
if atts != nil {
p.Files = append(p.Files, atts...)
}
threads = append(threads, createThreadModel(p))
}
rows.Close()
} else {
log.Println("failed to fetch board model for", newsgroup, "page", pageno, err)
}
return &boardModel{
prefix: prefix,
frontend: frontend,
board: newsgroup,
page: pageno,
pages: int(pages),
threads: threads,
}
}
func (self *PostgresDatabase) GetNNTPPostsInGroup(newsgroup string) (models []PostModel, err error) {
rows, err := self.conn.Query(self.stmt[GetNNTPPostsInGroup], newsgroup)
if err == nil {
for rows.Next() {
model := new(post)
model.Newsgroup = newsgroup
rows.Scan(&model.nntp_id, &model.Message_id, &model.PostSubject, &model.Posted, &model.Parent, &model.PostName, &model.MessagePath)
models = append(models, model)
}
rows.Close()
}
return
}
func (self *PostgresDatabase) GetPostsInGroup(newsgroup string) (models []PostModel, err error) {
rows, err := self.conn.Query(self.stmt[GetPostsInGroup], newsgroup)
if err == nil {
for rows.Next() {
model := new(post)
rows.Scan(&model.board, &model.Message_id, &model.Parent, &model.PostName, &model.PostSubject, &model.MessagePath, &model.Posted, &model.PostMessage, &model.addr)
models = append(models, model)
}
rows.Close()
}
return
}
func (self *PostgresDatabase) GetPostModel(prefix, messageID string) PostModel {
model := new(post)
err := self.conn.QueryRow(self.stmt[GetPostModel], messageID).Scan(&model.board, &model.Message_id, &model.Parent, &model.PostName, &model.PostSubject, &model.MessagePath, &model.Posted, &model.PostMessage, &model.addr)
if err == nil {
model.op = len(model.Parent) == 0
if len(model.Parent) == 0 {
model.Parent = model.Message_id
}
model.sage = isSage(model.PostSubject)
atts := self.GetPostAttachmentModels(prefix, messageID)
if atts != nil {
model.Files = append(model.Files, atts...)
}
// quiet fail
self.conn.QueryRow(self.stmt[GetArticlePubkey], messageID).Scan(&model.Key)
return model
} else {
log.Println("failed to prepare query for geting post model for", messageID, err)
return nil
}
}
func (self *PostgresDatabase) GetCitesByPostHashLike(like string) (cites []MessageIDTuple, err error) {
var rows *sql.Rows
rows, err = self.conn.Query(self.stmt[GetCitesByPostHashLike], like+"%")
if err == nil {
for rows.Next() {
var tup MessageIDTuple
rows.Scan(&tup[0], &tup[1])
cites = append(cites, tup)
}
rows.Close()
} else if err != sql.ErrNoRows {
log.Println("error getting post models like", like, err)
}
return
}
func (self *PostgresDatabase) GetThreadModel(prefix, msgid string) (th ThreadModel, err error) {
var posts []PostModel
var rows *sql.Rows
pmap := make(map[string]*post)
rows, err = self.conn.Query(self.stmt[GetThreadModel], msgid)
for err == nil && rows.Next() {
p := new(post)
p.Parent = msgid
err = rows.Scan(&p.board, &p.Message_id, &p.PostName, &p.PostSubject, &p.Posted, &p.PostMessage, &p.addr)
pmap[p.Message_id] = p
posts = append(posts, p)
}
rows.Close()
rows, err = self.conn.Query(self.stmt[GetThreadModelAttachments], msgid)
for err == nil && rows.Next() {
att := &attachment{
prefix: prefix,
}
var att_msgid string
rows.Scan(&att.Name, &att.Path, &att_msgid)
p, ok := pmap[att_msgid]
if ok {
p.Files = append(p.Files, att)
}
}
rows.Close()
rows, err = self.conn.Query(self.stmt[GetThreadModelPubkeys], msgid)
if err != nil {
log.Println(err)
}
for err == nil && rows.Next() {
var key_msgid, key string
rows.Scan(&key, &key_msgid)
p, ok := pmap[key_msgid]
if ok {
p.Key = key
}
}
rows.Close()
th = createThreadModel(posts...)
return
}
func (self *PostgresDatabase) DeleteThread(msgid string) (err error) {
_, err = self.conn.Exec(self.stmt[DeleteThread], msgid)
return
}
func (self *PostgresDatabase) DeleteArticle(msgid string) (err error) {
for _, q := range []string{DeleteArticle_1, DeleteArticle_2, DeleteArticle_3, DeleteArticle_4, DeleteArticle_5} {
_, err = self.conn.Exec(self.stmt[q], msgid)
if err != nil {
break
}
}
return
}
func (self *PostgresDatabase) GetThreadReplyPostModels(prefix, rootpost string, start, limit int) (repls []PostModel) {
var rows *sql.Rows
var err error
if limit > 0 {
rows, err = self.conn.Query(self.stmt[GetThreadReplyPostModels_1], rootpost, limit)
} else {
rows, err = self.conn.Query(self.stmt[GetThreadReplyPostModels_2], rootpost)
}
offset := start
if err == nil {
for rows.Next() {
// TODO: this is a hack, optimize queries plz
if offset > 0 {
offset--
continue
}
model := new(post)
model.prefix = prefix
rows.Scan(&model.board, &model.Message_id, &model.Parent, &model.PostName, &model.PostSubject, &model.MessagePath, &model.Posted, &model.PostMessage, &model.addr)
model.op = len(model.Parent) == 0
if len(model.Parent) == 0 {
model.Parent = model.Message_id
}
model.sage = isSage(model.PostSubject)
atts := self.GetPostAttachmentModels(prefix, model.Message_id)
if atts != nil {
model.Files = append(model.Files, atts...)
}
// get pubkey if it exists
// quiet fail
self.conn.QueryRow(self.stmt[GetArticlePubkey], model.Message_id).Scan(model.Key)
repls = append(repls, model)
}
rows.Close()
} else {
log.Println("failed to get thread replies", rootpost, err)
}
return
}
func (self *PostgresDatabase) GetThreadReplies(rootpost string, start, limit int) (repls []string) {
var rows *sql.Rows
var err error
if limit > 0 {
rows, err = self.conn.Query(self.stmt[GetThreadReplies_1], rootpost, limit)
} else {
rows, err = self.conn.Query(self.stmt[GetThreadReplies_2], rootpost)
}
offset := start
if err == nil {
for rows.Next() {
// TODO: this is a hack, optimize queries plz
if offset > 0 {
offset--
continue
}
var msgid string
rows.Scan(&msgid)
repls = append(repls, msgid)
}
rows.Close()
} else {
log.Println("failed to get thread replies", rootpost, err)
}
return
}
func (self *PostgresDatabase) ThreadHasReplies(rootpost string) bool {
var count int64
err := self.conn.QueryRow("SELECT COUNT(message_id) FROM ArticlePosts WHERE ref_id = $1", rootpost).Scan(&count)
if err != nil {
log.Println("failed to count thread replies", err)
}
return count > 0
}
func (self *PostgresDatabase) GetGroupThreads(group string, recv chan ArticleEntry) {
rows, err := self.conn.Query(self.stmt[GetGroupThreads], group)
if err == nil {
for rows.Next() {
var msgid string
rows.Scan(&msgid)
recv <- ArticleEntry{msgid, group}
}
rows.Close()
} else if err != sql.ErrNoRows {
log.Println("failed to get group threads", err)
}
}
func (self *PostgresDatabase) GetLastBumpedThreads(newsgroups string, threads int) []ArticleEntry {
return self.GetLastBumpedThreadsPaginated(newsgroups, threads, 0)
}
func (self *PostgresDatabase) GetLastBumpedThreadsPaginated(newsgroup string, threads, offset int) (roots []ArticleEntry) {
var err error
var rows *sql.Rows
if len(newsgroup) > 0 {
rows, err = self.conn.Query(self.stmt[GetLastBumpedThreadsPaginated_1], newsgroup, threads+offset)
} else {
rows, err = self.conn.Query(self.stmt[GetLastBumpedThreadsPaginated_2], threads+offset)
}
if err == nil {
for rows.Next() {
var ent ArticleEntry
rows.Scan(&ent[0], &ent[1])
if offset > 0 {
offset--
} else {
roots = append(roots, ent)
}
}
rows.Close()
} else {
log.Println("failed to get last bumped", err)
}
return
}
func (self *PostgresDatabase) GroupHasPosts(group string) bool {
var count int64
err := self.conn.QueryRow("SELECT COUNT(message_id) FROM ArticlePosts WHERE newsgroup = $1", group).Scan(&count)
if err != nil {
log.Println("error counting posts in group", group, err)
}
return count > 0
}
// check if a newsgroup exists
func (self *PostgresDatabase) HasNewsgroup(group string) bool {
var count int64
err := self.conn.QueryRow(self.stmt[HasNewsgroup], group).Scan(&count)
if err != nil {
log.Println("failed to check for newsgroup", group, err)
}
return count > 0
}
// check if an article exists
func (self *PostgresDatabase) HasArticle(message_id string) bool {
var count int64
err := self.conn.QueryRow(self.stmt[HasArticle], message_id).Scan(&count)
if err != nil {
log.Println("failed to check for article", message_id, err)
}
return count > 0
}
// check if an article exists locally
func (self *PostgresDatabase) HasArticleLocal(message_id string) bool {
var count int64
err := self.conn.QueryRow(self.stmt[HasArticleLocal], message_id).Scan(&count)
if err != nil {
log.Println("failed to check for local article", message_id, err)
}
return count > 0
}
// count articles we have
func (self *PostgresDatabase) ArticleCount() (count int64) {
err := self.conn.QueryRow("SELECT COUNT(message_id) FROM ArticlePosts").Scan(&count)
if err != nil {
log.Println("failed to count articles", err)
}
return
}
// register a new newsgroup
func (self *PostgresDatabase) RegisterNewsgroup(group string) {
_, err := self.conn.Exec("INSERT INTO Newsgroups (name, last_post) VALUES($1, $2)", group, timeNow())
if err != nil {
log.Println("failed to register newsgroup", group, err)
}
}
func (self *PostgresDatabase) GetPostAttachments(messageID string) (atts []string) {
rows, err := self.conn.Query(self.stmt[GetPostAttachments], messageID)
if err == nil {
for rows.Next() {
var val string
rows.Scan(&val)
atts = append(atts, val)
}
rows.Close()
} else {
log.Println("cannot find attachments for", messageID, err)
}
return
}
func (self *PostgresDatabase) GetPostAttachmentModels(prefix, messageID string) (atts []AttachmentModel) {
rows, err := self.conn.Query(self.stmt[GetPostAttachmentModels], messageID)
if err == nil {
for rows.Next() {
var fpath, fname string
rows.Scan(&fpath, &fname)
atts = append(atts, &attachment{
prefix: prefix,
Path: fpath,
Name: fname,
})
}
rows.Close()
} else {
log.Println("failed to get attachment models for", messageID, err)
}
return
}
// register a message with the database
func (self *PostgresDatabase) RegisterArticle(message NNTPMessage) (err error) {
msgid := message.MessageID()
group := message.Newsgroup()
if !self.HasNewsgroup(group) {
self.RegisterNewsgroup(group)
}
if self.HasArticle(msgid) {
return
}
now := timeNow()
// insert article metadata
_, err = self.conn.Exec(self.stmt[RegisterArticle_1], msgid, HashMessageID(msgid), group, now, message.Reference())
if err != nil {
log.Println("failed to insert article metadata", err)
return
}
// update newsgroup
_, err = self.conn.Exec(self.stmt[RegisterArticle_2], now, group)
if err != nil {
log.Println("failed to update newsgroup last post", err)
return
}
// insert article post
_, err = self.conn.Exec(self.stmt[RegisterArticle_3], group, msgid, message.Reference(), message.Name(), message.Subject(), message.Path(), message.Posted(), message.Message(), message.Addr())
if err != nil {
log.Println("cannot insert article post", err)
return
}
// set / update thread state
if message.OP() {
// insert new thread for op
_, err = self.conn.Exec(self.stmt[RegisterArticle_4], message.MessageID(), message.Posted(), group)
if err != nil {
log.Println("cannot register thread", msgid, err)
return
}
} else {
ref := message.Reference()
if !message.Sage() {
// TODO: this could be 1 query possibly?
var posts int64
err = self.conn.QueryRow(self.stmt[RegisterArticle_5], ref).Scan(&posts)
if err == nil && posts <= BumpLimit {
// bump it nigguh
_, err = self.conn.Exec(self.stmt[RegisterArticle_6], ref, message.Posted())
}
if err != nil {
log.Println("failed to bump thread", ref, err)
return
}
}
// update last posted
_, err = self.conn.Exec(self.stmt[RegisterArticle_7], ref, message.Posted())
if err != nil {
log.Println("failed to update post time for", ref, err)
return
}
}
var tx *sql.Tx
tx, err = self.conn.Begin()
if err == nil {
var st *sql.Stmt
st, err = tx.Prepare(pq.CopyIn("nntpheaders", "header_name", "header_value", "header_article_message_id"))
if err != nil {
log.Printf("error with copyin: %s", err)
}
// register article header key value pairs
for k, val := range message.Headers() {
k = strings.ToLower(k)
for _, v := range val {
_, err = st.Exec(k, v, msgid)
if err != nil {
log.Println("failed to register nntp article header in transaction", err)
break
}
}
}
_, err = st.Exec()
if err == nil {
st.Close()
err = tx.Commit()
if err != nil {
log.Println("failed to commit nntp article header values:", err)
return
}
} else {
log.Println("failed to execute prepared statement for nntp article header values:", err)
}
}
err = self.registerNNTPNumber(group, msgid)
if err != nil {
log.Println("failed to register nntp number for", msgid, err)
return
}
// register all attachments
atts := message.Attachments()
if atts == nil {
// no attachments
return
}
for _, att := range atts {
h := hex.EncodeToString(att.Hash())
_, err = self.conn.Exec(self.stmt[RegisterArticle_8], msgid, h, att.Filename(), att.Filepath())
if err != nil {
log.Println("failed to register attachment", err)
continue
}
}
return
}
//
// get message ids of articles with this header name and value
//
func (self *PostgresDatabase) GetMessageIDByHeader(name, val string) (msgids []string, err error) {
var rows *sql.Rows
name = strings.ToLower(name)
rows, err = self.conn.Query(self.stmt[GetMessageIDByHeader], name, val)
if err == nil {
for rows.Next() {
var msgid string
rows.Scan(&msgid)
msgids = append(msgids, msgid)
}
rows.Close()
}
return
}
func (self *PostgresDatabase) RegisterSigned(message_id, pubkey string) (err error) {
_, err = self.conn.Exec(self.stmt[RegisterSigned], message_id, pubkey)
return
}
// get all articles in a newsgroup
// send result down a channel
func (self *PostgresDatabase) GetAllArticlesInGroup(group string, recv chan ArticleEntry) {
rows, err := self.conn.Query(self.stmt[GetAllArticlesInGroup], group)
if err != nil {
log.Printf("failed to get all articles in %s: %s", group, err)
return
}
for rows.Next() {
var msgid string
rows.Scan(&msgid)
recv <- ArticleEntry{msgid, group}
}
rows.Close()
}
// get all articles
// send result down a channel
func (self *PostgresDatabase) GetAllArticles() (articles []ArticleEntry) {
rows, err := self.conn.Query(self.stmt[GetAllArticles])
if err == nil {
for rows.Next() {
var entry ArticleEntry
rows.Scan(&entry[0], &entry[1])
articles = append(articles, entry)
}
rows.Close()
} else {
log.Println("failed to get all articles", err)
}
return articles
}
func (self *PostgresDatabase) GetPagesPerBoard(group string) (int, error) {
//XXX: hardcoded
return 10, nil
}
func (self *PostgresDatabase) GetThreadsPerPage(group string) (int, error) {
//XXX: hardcoded
return 10, nil
}
func (self *PostgresDatabase) GetMessageIDByHash(hash string) (article ArticleEntry, err error) {
err = self.conn.QueryRow(self.stmt[GetMessageIDByHash], hash).Scan(&article[0], &article[1])
return
}
func (self *PostgresDatabase) BanAddr(addr string) (err error) {
_, err = self.conn.Exec("INSERT INTO IPBans(addr, made, expires) VALUES($1, $2, $3)", addr, timeNow(), -1)
return
}
// assumes it is there
func (self *PostgresDatabase) UnbanAddr(addr string) (err error) {
_, err = self.conn.Exec("DELETE FROM IPBans WHERE addr >>= $1", addr)
return
}
func (self *PostgresDatabase) CheckEncIPBanned(encaddr string) (banned bool, err error) {
var result int64
err = self.conn.QueryRow(self.stmt[CheckEncIPBanned], encaddr).Scan(&result)
banned = result > 0
return
}
func (self *PostgresDatabase) BanEncAddr(encaddr string) (err error) {
_, err = self.conn.Exec("INSERT INTO EncIPBans(encaddr, made, expires) VALUES($1, $2, $3)", encaddr, timeNow(), -1)
return
}
func (self *PostgresDatabase) GetLastAndFirstForGroup(group string) (last, first int64, err error) {
var rows *sql.Rows
rows, err = self.conn.Query(self.stmt[GetFirstAndLastForGroup], group)
if err == nil {
if rows.Next() {
err = rows.Scan(&first)
if err == nil {
if rows.Next() {
err = rows.Scan(&last)
}
}
}
rows.Close()
}
return
}
func (self *PostgresDatabase) GetMessageIDForNNTPID(group string, id int64) (msgid string, err error) {
err = self.conn.QueryRow(self.stmt[GetMessageIDForNNTPID], group, id).Scan(&msgid)
return
}
func (self *PostgresDatabase) GetNNTPIDForMessageID(group, msgid string) (id int64, err error) {
err = self.conn.QueryRow(self.stmt[GetNNTPIDForMessageID], group, msgid).Scan(&id)
return
}
func (self *PostgresDatabase) MarkModPubkeyCanModGroup(pubkey, group string) (err error) {
_, err = self.conn.Exec("INSERT INTO ModPrivs(pubkey, newsgroup, permission) VALUES($1, $2, $3)", pubkey, group, "all")
return
}
func (self *PostgresDatabase) UnMarkModPubkeyCanModGroup(pubkey, group string) (err error) {
_, err = self.conn.Exec("DELETE FROM ModPrivs WHERE pubkey = $1 AND newsgroup = $2", pubkey, group)
return
}
func (self *PostgresDatabase) IsExpired(root_message_id string) bool {
var count int
err := self.conn.QueryRow(self.stmt[IsExpired], root_message_id).Scan(&count)
if err != nil {
log.Println("error checking for expired article:", err)
}
return count == 0
}
func (self *PostgresDatabase) GetLastDaysPostsForGroup(newsgroup string, n int64) (posts []PostEntry) {
day := time.Hour * 24
now := time.Now().UTC()
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
for n > 0 {
var num int64
err := self.conn.QueryRow(self.stmt[GetLastDaysPostsForGroup], now.Add(day).Unix(), now.Unix(), newsgroup).Scan(&num)
if err == nil {
posts = append(posts, PostEntry{now.Unix(), num})
now = now.Add(-day)
} else {
log.Println("error counting last n days posts", err)
return nil
}
n--
}
return
}
func (self *PostgresDatabase) GetLastDaysPosts(n int64) (posts []PostEntry) {
day := time.Hour * 24
now := time.Now().UTC()
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
for n > 0 {
var num int64
err := self.conn.QueryRow(self.stmt[GetLastDaysPosts], now.Add(day).Unix(), now.Unix()).Scan(&num)
if err == nil {
posts = append(posts, PostEntry{now.Unix(), num})
now = now.Add(-day)
} else {
log.Println("error counting last n days posts", err)
return nil
}
n--
}
return
}
func (self *PostgresDatabase) GetLastPostedPostModels(prefix string, n int64) (posts []PostModel) {
rows, err := self.conn.Query(self.stmt[GetLastPostedPostModels], n)
if err == nil {
for rows.Next() {
model := new(post)
rows.Scan(&model.board, &model.Message_id, &model.Parent, &model.PostName, &model.PostSubject, &model.MessagePath, &model.Posted, &model.PostMessage, &model.addr)
model.op = len(model.Parent) == 0
if len(model.Parent) == 0 {
model.Parent = model.Message_id
}
model.sage = isSage(model.PostSubject)
atts := self.GetPostAttachmentModels(prefix, model.Message_id)
if atts != nil {
model.Files = append(model.Files, atts...)
}
// quiet fail
self.conn.QueryRow(self.stmt[GetArticlePubkey], model.Message_id).Scan(&model.Key)
posts = append(posts, model)
}
rows.Close()
return
} else {
log.Println("failed to prepare query for geting last post models", err)
return nil
}
}
func (self *PostgresDatabase) GetMonthlyPostHistory() (posts []PostEntry) {
var oldest int64
now := time.Now()
now = time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
err := self.conn.QueryRow(self.stmt[GetMonthlyPostHistory]).Scan(&oldest)
if err == nil {
// we got the oldest
// convert it to the oldest year/date
old := time.Unix(oldest, 0)
old = time.Date(old.Year(), old.Month(), 1, 0, 0, 0, 0, time.UTC)
// count up from oldest to newest
for now.Unix() >= old.Unix() {
var count int64
var next_month time.Time
if now.Month() < 12 {
next_month = time.Date(old.Year(), old.Month()+1, 1, 0, 0, 0, 0, time.UTC)
} else {
next_month = time.Date(old.Year()+1, 1, 1, 0, 0, 0, 0, time.UTC)
}
// get the post count in that montth
err = self.conn.QueryRow(self.stmt[GetLastDaysPosts], old.Unix(), next_month.Unix()).Scan(&count)
if err == nil {
posts = append(posts, PostEntry{old.Unix(), count})
old = next_month
} else {
posts = nil
break
}
}
}
if err != nil {
log.Println("failed getting monthly post history", err)
}
return
}
func (self *PostgresDatabase) CheckNNTPLogin(username, passwd string) (valid bool, err error) {
var login_hash, login_salt string
err = self.conn.QueryRow(self.stmt[CheckNNTPLogin], username).Scan(&login_hash, &login_salt)
if err == nil {
// no errors
if len(login_hash) > 0 && len(login_salt) > 0 {
valid = nntpLoginCredHash(passwd, login_salt) == login_hash
}
}
return
}
func (self *PostgresDatabase) AddNNTPLogin(username, passwd string) (err error) {
login_salt := genLoginCredSalt()
login_hash := nntpLoginCredHash(passwd, login_salt)
_, err = self.conn.Exec("INSERT INTO NNTPUsers(username, login_hash, login_salt) VALUES($1, $2, $3)", username, login_hash, login_salt)
return
}
func (self *PostgresDatabase) RemoveNNTPLogin(username string) (err error) {
_, err = self.conn.Exec("DELETE FROM NNTPUsers WHERE username = $1", username)
return
}
func (self *PostgresDatabase) CheckNNTPUserExists(username string) (exists bool, err error) {
var count int64
err = self.conn.QueryRow(self.stmt[CheckNNTPUserExists], username).Scan(&count)
exists = count > 0
return
}
func (self *PostgresDatabase) GetHeadersForMessage(msgid string) (hdr ArticleHeaders, err error) {
var rows *sql.Rows
rows, err = self.conn.Query(self.stmt[GetHeadersForMessage], msgid)
if err == nil {
hdr = make(ArticleHeaders)
for rows.Next() {
var k, v string
rows.Scan(&k, &v)
hdr.Add(k, v)
}
rows.Close()
}
return
}
func (self *PostgresDatabase) CountAllArticlesInGroup(group string) (count int64, err error) {
err = self.conn.QueryRow(self.stmt[CountAllArticlesInGroup], group).Scan(&count)
return
}
func (self *PostgresDatabase) GetMessageIDByCIDR(cidr *net.IPNet) (msgids []string, err error) {
var rows *sql.Rows
rows, err = self.conn.Query(self.stmt[GetMessageIDByCIDR], cidr.String())
for err == nil && rows.Next() {
var msgid string
err = rows.Scan(&msgid)
if err == nil {
msgids = append(msgids, msgid)
}
rows.Close()
}
return
}
func (self *PostgresDatabase) GetMessageIDByEncryptedIP(encaddr string) (msgids []string, err error) {
var rows *sql.Rows
rows, err = self.conn.Query(self.stmt[GetMessageIDByEncryptedIP], encaddr)
for err == nil && rows.Next() {
var msgid string
err = rows.Scan(&msgid)
if err == nil {
msgids = append(msgids, msgid)
}
}
if rows != nil {
rows.Close()
}
return
}
func (self *PostgresDatabase) BanPubkey(pubkey string) (err error) {
// TODO: implement
err = errors.New("ban pubkey not implemented")
return
}
func (self *PostgresDatabase) PubkeyIsBanned(pubkey string) (bool, error) {
// TODO: implement
return false, nil
}
func (self *PostgresDatabase) GetPostsBefore(t time.Time) (msgids []string, err error) {
var rows *sql.Rows
rows, err = self.conn.Query(self.stmt[GetPostsBefore], t.Unix())
if err == nil {
for rows.Next() {
var msgid string
rows.Scan(&msgid)
msgids = append(msgids, msgid)
}
rows.Close()
}
return
}
func (self *PostgresDatabase) GetPostingStats(gran, begin, end int64) (st PostingStats, err error) {
return
}
func (self *PostgresDatabase) SearchQuery(prefix, group string, text string, chnl chan PostModel) (err error) {
if text != "" && strings.Count(text, "%") == 0 {
text = "%" + text + "%"
var rows *sql.Rows
if group == "" {
rows, err = self.conn.Query(self.stmt[SearchQuery_1], text)
} else {
rows, err = self.conn.Query(self.stmt[SearchQuery_2], group, text)
}
if err == nil {
for rows.Next() {
p := new(post)
rows.Scan(&p.board, &p.Message_id, &p.Parent)
chnl <- p
}
rows.Close()
}
}
close(chnl)
return
}
func (self *PostgresDatabase) SearchByHash(prefix, group, text string, chnl chan PostModel) (err error) {
if text != "" && strings.Count(text, "%") == 0 {
text = "%" + text + "%"
var rows *sql.Rows
if group == "" {
rows, err = self.conn.Query(self.stmt[SearchByHash_1], text)
} else {
rows, err = self.conn.Query(self.stmt[SearchByHash_2], text, group)
}
if err == nil {
for rows.Next() {
p := new(post)
rows.Scan(&p.board, &p.Message_id, &p.Parent)
chnl <- p
}
rows.Close()
}
}
close(chnl)
return
}