Add context args to Database interface
This is a mecanical change, which just lifts up the context.TODO() calls from inside the DB implementations to the callers. Future work involves properly wiring up the contexts when it makes sense. git-svn-id: file:///srv/svn/repo/suika/trunk@652 f0ae65fe-ee39-954e-97ec-027ff2717ef4
This commit is contained in:
parent
9455e04afb
commit
2748715183
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -75,7 +76,7 @@ func main() {
|
||||
Password: string(hashed),
|
||||
Admin: *admin,
|
||||
}
|
||||
if err := db.StoreUser(&user); err != nil {
|
||||
if err := db.StoreUser(context.TODO(), &user); err != nil {
|
||||
log.Fatalf("failed to create user: %v", err)
|
||||
}
|
||||
case "change-password":
|
||||
@ -85,7 +86,7 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
user, err := db.GetUser(username)
|
||||
user, err := db.GetUser(context.TODO(), username)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get user: %v", err)
|
||||
}
|
||||
@ -101,7 +102,7 @@ func main() {
|
||||
}
|
||||
|
||||
user.Password = string(hashed)
|
||||
if err := db.StoreUser(user); err != nil {
|
||||
if err := db.StoreUser(context.TODO(), user); err != nil {
|
||||
log.Fatalf("failed to update password: %v", err)
|
||||
}
|
||||
default:
|
||||
|
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -79,7 +80,7 @@ func main() {
|
||||
log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
|
||||
}
|
||||
|
||||
l, err := db.ListUsers()
|
||||
l, err := db.ListUsers(context.TODO())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to list users in DB: %v", err)
|
||||
}
|
||||
@ -111,12 +112,12 @@ func main() {
|
||||
|
||||
u.Admin = section.Values.Get("Admin") == "true"
|
||||
|
||||
if err := db.StoreUser(u); err != nil {
|
||||
if err := db.StoreUser(context.TODO(), u); err != nil {
|
||||
log.Fatalf("failed to store user %q: %v", username, err)
|
||||
}
|
||||
userID := u.ID
|
||||
|
||||
l, err := db.ListNetworks(userID)
|
||||
l, err := db.ListNetworks(context.TODO(), userID)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to list networks for user %q: %v", username, err)
|
||||
}
|
||||
@ -183,11 +184,11 @@ func main() {
|
||||
n.Pass = pass
|
||||
n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
|
||||
|
||||
if err := db.StoreNetwork(userID, n); err != nil {
|
||||
if err := db.StoreNetwork(context.TODO(), userID, n); err != nil {
|
||||
logger.Fatalf("failed to store network: %v", err)
|
||||
}
|
||||
|
||||
l, err := db.ListChannels(n.ID)
|
||||
l, err := db.ListChannels(context.TODO(), n.ID)
|
||||
if err != nil {
|
||||
logger.Fatalf("failed to list channels: %v", err)
|
||||
}
|
||||
@ -217,7 +218,7 @@ func main() {
|
||||
ch.Key = section.Values.Get("Key")
|
||||
ch.Detached = section.Values.Get("Detached") == "true"
|
||||
|
||||
if err := db.StoreChannel(n.ID, ch); err != nil {
|
||||
if err := db.StoreChannel(context.TODO(), n.ID, ch); err != nil {
|
||||
logger.Printf("channel %q: failed to store channel: %v", chName, err)
|
||||
}
|
||||
})
|
||||
|
27
db.go
27
db.go
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -9,22 +10,22 @@ import (
|
||||
|
||||
type Database interface {
|
||||
Close() error
|
||||
Stats() (*DatabaseStats, error)
|
||||
Stats(ctx context.Context) (*DatabaseStats, error)
|
||||
|
||||
ListUsers() ([]User, error)
|
||||
GetUser(username string) (*User, error)
|
||||
StoreUser(user *User) error
|
||||
DeleteUser(id int64) error
|
||||
ListUsers(ctx context.Context) ([]User, error)
|
||||
GetUser(ctx context.Context, username string) (*User, error)
|
||||
StoreUser(ctx context.Context, user *User) error
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
|
||||
ListNetworks(userID int64) ([]Network, error)
|
||||
StoreNetwork(userID int64, network *Network) error
|
||||
DeleteNetwork(id int64) error
|
||||
ListChannels(networkID int64) ([]Channel, error)
|
||||
StoreChannel(networKID int64, ch *Channel) error
|
||||
DeleteChannel(id int64) error
|
||||
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
|
||||
StoreNetwork(ctx context.Context, userID int64, network *Network) error
|
||||
DeleteNetwork(ctx context.Context, id int64) error
|
||||
ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
|
||||
StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
|
||||
DeleteChannel(ctx context.Context, id int64) error
|
||||
|
||||
ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error)
|
||||
StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error
|
||||
ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
|
||||
StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
|
||||
}
|
||||
|
||||
func OpenDB(driver, source string) (Database, error) {
|
||||
|
@ -147,8 +147,8 @@ func (db *PostgresDB) Close() error {
|
||||
return db.db.Close()
|
||||
}
|
||||
|
||||
func (db *PostgresDB) Stats() (*DatabaseStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var stats DatabaseStats
|
||||
@ -163,8 +163,8 @@ func (db *PostgresDB) Stats() (*DatabaseStats, error) {
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) ListUsers() ([]User, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx,
|
||||
@ -192,8 +192,8 @@ func (db *PostgresDB) ListUsers() ([]User, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) GetUser(username string) (*User, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
user := &User{Username: username}
|
||||
@ -210,8 +210,8 @@ func (db *PostgresDB) GetUser(username string) (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) StoreUser(user *User) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
password := toNullString(user.Password)
|
||||
@ -234,16 +234,16 @@ func (db *PostgresDB) StoreUser(user *User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) DeleteUser(id int64) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -286,8 +286,8 @@ func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
|
||||
return networks, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
netName := toNullString(network.Name)
|
||||
@ -338,16 +338,16 @@ func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) DeleteNetwork(id int64) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -380,8 +380,8 @@ func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
key := toNullString(ch.Key)
|
||||
@ -408,16 +408,16 @@ func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) DeleteChannel(id int64) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -444,8 +444,8 @@ func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt,
|
||||
return receipts, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
|
||||
func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
|
52
db_sqlite.go
52
db_sqlite.go
@ -208,11 +208,11 @@ func (db *SqliteDB) upgrade() error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (db *SqliteDB) Stats() (*DatabaseStats, error) {
|
||||
func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var stats DatabaseStats
|
||||
@ -234,11 +234,11 @@ func toNullString(s string) sql.NullString {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *SqliteDB) ListUsers() ([]User, error) {
|
||||
func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx,
|
||||
@ -266,11 +266,11 @@ func (db *SqliteDB) ListUsers() ([]User, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) GetUser(username string) (*User, error) {
|
||||
func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
user := &User{Username: username}
|
||||
@ -287,11 +287,11 @@ func (db *SqliteDB) GetUser(username string) (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) StoreUser(user *User) error {
|
||||
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
args := []interface{}{
|
||||
@ -323,11 +323,11 @@ func (db *SqliteDB) StoreUser(user *User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *SqliteDB) DeleteUser(id int64) error {
|
||||
func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
@ -371,11 +371,11 @@ func (db *SqliteDB) DeleteUser(id int64) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
|
||||
func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -420,11 +420,11 @@ func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
|
||||
return networks, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
|
||||
func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
|
||||
@ -490,11 +490,11 @@ func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *SqliteDB) DeleteNetwork(id int64) error {
|
||||
func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
@ -521,11 +521,11 @@ func (db *SqliteDB) DeleteNetwork(id int64) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
|
||||
func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `SELECT
|
||||
@ -558,11 +558,11 @@ func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
|
||||
func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
args := []interface{}{
|
||||
@ -598,22 +598,22 @@ func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *SqliteDB) DeleteChannel(id int64) error {
|
||||
func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
|
||||
func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx, `
|
||||
@ -642,11 +642,11 @@ func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, er
|
||||
return receipts, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
|
||||
func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
|
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
@ -976,7 +977,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
|
||||
func (dc *downstreamConn) authenticate(username, password string) error {
|
||||
username, clientName, networkName := unmarshalUsername(username)
|
||||
|
||||
u, err := dc.srv.db.GetUser(username)
|
||||
u, err := dc.srv.db.GetUser(context.TODO(), username)
|
||||
if err != nil {
|
||||
dc.logger.Printf("failed authentication for %q: user not found: %v", username, err)
|
||||
return errAuthFailed
|
||||
@ -1377,7 +1378,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
return
|
||||
}
|
||||
n.Nick = nick
|
||||
err = dc.srv.db.StoreNetwork(dc.user.ID, &n.Network)
|
||||
err = dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1427,7 +1428,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
})
|
||||
|
||||
n.Realname = storeRealname
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
|
||||
dc.logger.Printf("failed to store network realname: %v", err)
|
||||
storeErr = err
|
||||
}
|
||||
@ -1516,7 +1517,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
}
|
||||
uc.network.channels.SetValue(upstreamName, ch)
|
||||
}
|
||||
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
|
||||
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
|
||||
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
|
||||
}
|
||||
}
|
||||
@ -1548,7 +1549,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
}
|
||||
uc.network.channels.SetValue(upstreamName, ch)
|
||||
}
|
||||
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
|
||||
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
|
||||
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
|
||||
}
|
||||
} else {
|
||||
@ -2445,7 +2446,7 @@ func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
|
||||
n.SASL.Mechanism = "PLAIN"
|
||||
n.SASL.Plain.Username = username
|
||||
n.SASL.Plain.Password = password
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
|
||||
dc.logger.Printf("failed to save NickServ credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"mime"
|
||||
@ -85,7 +86,7 @@ func (s *Server) prefix() *irc.Prefix {
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
users, err := s.db.ListUsers()
|
||||
users, err := s.db.ListUsers(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -126,7 +127,7 @@ func (s *Server) createUser(user *User) (*user, error) {
|
||||
return nil, fmt.Errorf("user %q already exists", user.Username)
|
||||
}
|
||||
|
||||
err := s.db.StoreUser(user)
|
||||
err := s.db.StoreUser(context.TODO(), user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create user in db: %v", err)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@ -43,7 +44,7 @@ func createTestUser(t *testing.T, db Database) *User {
|
||||
}
|
||||
|
||||
record := &User{Username: testUsername, Password: string(hashed)}
|
||||
if err := db.StoreUser(record); err != nil {
|
||||
if err := db.StoreUser(context.TODO(), record); err != nil {
|
||||
t.Fatalf("failed to store test user: %v", err)
|
||||
}
|
||||
|
||||
@ -68,7 +69,7 @@ func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Li
|
||||
Nick: user.Username,
|
||||
Enabled: true,
|
||||
}
|
||||
if err := db.StoreNetwork(user.ID, network); err != nil {
|
||||
if err := db.StoreNetwork(context.TODO(), user.ID, network); err != nil {
|
||||
t.Fatalf("failed to store test network: %v", err)
|
||||
}
|
||||
|
||||
|
13
service.go
13
service.go
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
@ -657,7 +658,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error {
|
||||
net.SASL.External.PrivKeyBlob = privKey
|
||||
net.SASL.Mechanism = "EXTERNAL"
|
||||
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -698,7 +699,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error {
|
||||
net.SASL.Plain.Password = params[2]
|
||||
net.SASL.Mechanism = "PLAIN"
|
||||
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -722,7 +723,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error {
|
||||
net.SASL.External.PrivKeyBlob = nil
|
||||
net.SASL.Mechanism = ""
|
||||
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
|
||||
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -860,7 +861,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error {
|
||||
|
||||
u.stop()
|
||||
|
||||
if err := dc.srv.db.DeleteUser(u.ID); err != nil {
|
||||
if err := dc.srv.db.DeleteUser(context.TODO(), u.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete user: %v", err)
|
||||
}
|
||||
|
||||
@ -1015,7 +1016,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
|
||||
|
||||
uc.updateChannelAutoDetach(upstreamName)
|
||||
|
||||
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
|
||||
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
|
||||
return fmt.Errorf("failed to update channel: %v", err)
|
||||
}
|
||||
|
||||
@ -1024,7 +1025,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
|
||||
}
|
||||
|
||||
func handleServiceServerStatus(dc *downstreamConn, params []string) error {
|
||||
dbStats, err := dc.user.srv.db.Stats()
|
||||
dbStats, err := dc.user.srv.db.Stats(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
@ -1516,7 +1517,7 @@ func (uc *upstreamConn) handleDetachedMessage(ch *Channel, msg *irc.Message) {
|
||||
}
|
||||
if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
|
||||
uc.network.attach(ch)
|
||||
if err := uc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
|
||||
if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
|
||||
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
|
||||
}
|
||||
}
|
||||
|
21
user.go
21
user.go
@ -1,6 +1,7 @@
|
||||
package soju
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
@ -330,7 +331,7 @@ func (net *network) deleteChannel(name string) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil {
|
||||
if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
net.channels.Delete(name)
|
||||
@ -367,7 +368,7 @@ func (net *network) storeClientDeliveryReceipts(clientName string) {
|
||||
})
|
||||
})
|
||||
|
||||
if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil {
|
||||
if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil {
|
||||
net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
|
||||
}
|
||||
}
|
||||
@ -487,7 +488,7 @@ func (u *user) run() {
|
||||
close(u.done)
|
||||
}()
|
||||
|
||||
networks, err := u.srv.db.ListNetworks(u.ID)
|
||||
networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
|
||||
if err != nil {
|
||||
u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
|
||||
return
|
||||
@ -495,7 +496,7 @@ func (u *user) run() {
|
||||
|
||||
for _, record := range networks {
|
||||
record := record
|
||||
channels, err := u.srv.db.ListChannels(record.ID)
|
||||
channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
|
||||
if err != nil {
|
||||
u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
|
||||
continue
|
||||
@ -505,7 +506,7 @@ func (u *user) run() {
|
||||
u.networks = append(u.networks, network)
|
||||
|
||||
if u.hasPersistentMsgStore() {
|
||||
receipts, err := u.srv.db.ListDeliveryReceipts(record.ID)
|
||||
receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
|
||||
if err != nil {
|
||||
u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
|
||||
return
|
||||
@ -590,7 +591,7 @@ func (u *user) run() {
|
||||
continue
|
||||
}
|
||||
uc.network.detach(c)
|
||||
if err := uc.srv.db.StoreChannel(uc.network.ID, c); err != nil {
|
||||
if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
|
||||
u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
|
||||
}
|
||||
case eventDownstreamConnected:
|
||||
@ -779,7 +780,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
|
||||
}
|
||||
|
||||
network := newNetwork(u, record, nil)
|
||||
err := u.srv.db.StoreNetwork(u.ID, &network.Network)
|
||||
err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -821,7 +822,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
|
||||
panic("tried updating a non-existing network")
|
||||
}
|
||||
|
||||
if err := u.srv.db.StoreNetwork(u.ID, record); err != nil {
|
||||
if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -888,7 +889,7 @@ func (u *user) deleteNetwork(id int64) error {
|
||||
panic("tried deleting a non-existing network")
|
||||
}
|
||||
|
||||
if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
|
||||
if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -914,7 +915,7 @@ func (u *user) updateUser(record *User) error {
|
||||
}
|
||||
|
||||
realnameUpdated := u.Realname != record.Realname
|
||||
if err := u.srv.db.StoreUser(record); err != nil {
|
||||
if err := u.srv.db.StoreUser(context.TODO(), record); err != nil {
|
||||
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
|
||||
}
|
||||
u.User = *record
|
||||
|
Loading…
x
Reference in New Issue
Block a user