This commit is contained in:
2019-12-22 01:17:18 +05:00
parent da4bc379d8
commit a52b18ffe4
31 changed files with 202 additions and 82 deletions

View File

@@ -27,7 +27,7 @@ package postgresql
import (
// local
"go.dev.pztrn.name/fastpastebin/internal/context"
"go.dev.pztrn.name/fastpastebin/internal/database/dialects/interface"
dialectinterface "go.dev.pztrn.name/fastpastebin/internal/database/dialects/interface"
)
var (
@@ -38,5 +38,6 @@ var (
func New(cc *context.Context) {
c = cc
d = &Database{}
c.Database.RegisterDialect(dialectinterface.Interface(Handler{}))
}

View File

@@ -36,6 +36,7 @@ import (
// other
"github.com/jmoiron/sqlx"
// postgresql adapter
_ "github.com/lib/pq"
)
@@ -70,7 +71,9 @@ func (db *Database) GetDatabaseConnection() *sql.DB {
// GetPaste returns a single paste by ID.
func (db *Database) GetPaste(pasteID int) (*structs.Paste, error) {
db.check()
p := &structs.Paste{}
err := db.db.Get(p, db.db.Rebind("SELECT * FROM pastes WHERE id=$1"), pasteID)
if err != nil {
return nil, err
@@ -88,8 +91,11 @@ func (db *Database) GetPaste(pasteID int) (*structs.Paste, error) {
func (db *Database) GetPagedPastes(page int) ([]structs.Paste, error) {
db.check()
var pastesRaw []structs.Paste
var pastes []structs.Paste
var (
pastesRaw []structs.Paste
pastes []structs.Paste
)
// Pagination.
var startPagination = 0
@@ -119,8 +125,12 @@ func (db *Database) GetPagedPastes(page int) ([]structs.Paste, error) {
func (db *Database) GetPastesPages() int {
db.check()
var pastesRaw []structs.Paste
var pastes []structs.Paste
var (
pastesRaw []structs.Paste
pastes []structs.Paste
)
err := db.db.Get(&pastesRaw, "SELECT * FROM pastes WHERE private != true")
if err != nil {
return 1
@@ -164,6 +174,7 @@ func (db *Database) Initialize() {
}
c.Logger.Info().Msg("Database connection established")
db.db = dbConn
// Perform migrations.
@@ -173,12 +184,14 @@ func (db *Database) Initialize() {
func (db *Database) SavePaste(p *structs.Paste) (int64, error) {
db.check()
stmt, err := db.db.PrepareNamed("INSERT INTO pastes (title, data, created_at, keep_for, keep_for_unit_type, language, private, password, password_salt) VALUES (:title, :data, :created_at, :keep_for, :keep_for_unit_type, :language, :private, :password, :password_salt) RETURNING id")
if err != nil {
return 0, err
}
var id int64
err = stmt.Get(&id, p)
if err != nil {
return 0, err