Add context args to Database interface

This is a mecanical change, which just lifts up the context.TODO()
calls from inside the DB implementations to the callers.

Future work involves properly wiring up the contexts when it makes
sense.

git-svn-id: file:///srv/svn/repo/suika/trunk@652 f0ae65fe-ee39-954e-97ec-027ff2717ef4
This commit is contained in:
contact 2021-10-18 17:15:15 +00:00
parent 9455e04afb
commit 2748715183
11 changed files with 110 additions and 101 deletions

View File

@ -2,6 +2,7 @@ package main
import (
"bufio"
"context"
"flag"
"fmt"
"io"
@ -75,7 +76,7 @@ func main() {
Password: string(hashed),
Admin: *admin,
}
if err := db.StoreUser(&user); err != nil {
if err := db.StoreUser(context.TODO(), &user); err != nil {
log.Fatalf("failed to create user: %v", err)
}
case "change-password":
@ -85,7 +86,7 @@ func main() {
os.Exit(1)
}
user, err := db.GetUser(username)
user, err := db.GetUser(context.TODO(), username)
if err != nil {
log.Fatalf("failed to get user: %v", err)
}
@ -101,7 +102,7 @@ func main() {
}
user.Password = string(hashed)
if err := db.StoreUser(user); err != nil {
if err := db.StoreUser(context.TODO(), user); err != nil {
log.Fatalf("failed to update password: %v", err)
}
default:

View File

@ -2,6 +2,7 @@ package main
import (
"bufio"
"context"
"flag"
"fmt"
"io"
@ -79,7 +80,7 @@ func main() {
log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
}
l, err := db.ListUsers()
l, err := db.ListUsers(context.TODO())
if err != nil {
log.Fatalf("failed to list users in DB: %v", err)
}
@ -111,12 +112,12 @@ func main() {
u.Admin = section.Values.Get("Admin") == "true"
if err := db.StoreUser(u); err != nil {
if err := db.StoreUser(context.TODO(), u); err != nil {
log.Fatalf("failed to store user %q: %v", username, err)
}
userID := u.ID
l, err := db.ListNetworks(userID)
l, err := db.ListNetworks(context.TODO(), userID)
if err != nil {
log.Fatalf("failed to list networks for user %q: %v", username, err)
}
@ -183,11 +184,11 @@ func main() {
n.Pass = pass
n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
if err := db.StoreNetwork(userID, n); err != nil {
if err := db.StoreNetwork(context.TODO(), userID, n); err != nil {
logger.Fatalf("failed to store network: %v", err)
}
l, err := db.ListChannels(n.ID)
l, err := db.ListChannels(context.TODO(), n.ID)
if err != nil {
logger.Fatalf("failed to list channels: %v", err)
}
@ -217,7 +218,7 @@ func main() {
ch.Key = section.Values.Get("Key")
ch.Detached = section.Values.Get("Detached") == "true"
if err := db.StoreChannel(n.ID, ch); err != nil {
if err := db.StoreChannel(context.TODO(), n.ID, ch); err != nil {
logger.Printf("channel %q: failed to store channel: %v", chName, err)
}
})

27
db.go
View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"fmt"
"net/url"
"strings"
@ -9,22 +10,22 @@ import (
type Database interface {
Close() error
Stats() (*DatabaseStats, error)
Stats(ctx context.Context) (*DatabaseStats, error)
ListUsers() ([]User, error)
GetUser(username string) (*User, error)
StoreUser(user *User) error
DeleteUser(id int64) error
ListUsers(ctx context.Context) ([]User, error)
GetUser(ctx context.Context, username string) (*User, error)
StoreUser(ctx context.Context, user *User) error
DeleteUser(ctx context.Context, id int64) error
ListNetworks(userID int64) ([]Network, error)
StoreNetwork(userID int64, network *Network) error
DeleteNetwork(id int64) error
ListChannels(networkID int64) ([]Channel, error)
StoreChannel(networKID int64, ch *Channel) error
DeleteChannel(id int64) error
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
StoreNetwork(ctx context.Context, userID int64, network *Network) error
DeleteNetwork(ctx context.Context, id int64) error
ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
DeleteChannel(ctx context.Context, id int64) error
ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error)
StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error
ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
}
func OpenDB(driver, source string) (Database, error) {

View File

@ -147,8 +147,8 @@ func (db *PostgresDB) Close() error {
return db.db.Close()
}
func (db *PostgresDB) Stats() (*DatabaseStats, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
var stats DatabaseStats
@ -163,8 +163,8 @@ func (db *PostgresDB) Stats() (*DatabaseStats, error) {
return &stats, nil
}
func (db *PostgresDB) ListUsers() ([]User, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
@ -192,8 +192,8 @@ func (db *PostgresDB) ListUsers() ([]User, error) {
return users, nil
}
func (db *PostgresDB) GetUser(username string) (*User, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
user := &User{Username: username}
@ -210,8 +210,8 @@ func (db *PostgresDB) GetUser(username string) (*User, error) {
return user, nil
}
func (db *PostgresDB) StoreUser(user *User) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
password := toNullString(user.Password)
@ -234,16 +234,16 @@ func (db *PostgresDB) StoreUser(user *User) error {
return err
}
func (db *PostgresDB) DeleteUser(id int64) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -286,8 +286,8 @@ func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
return networks, nil
}
func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
netName := toNullString(network.Name)
@ -338,16 +338,16 @@ func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
return err
}
func (db *PostgresDB) DeleteNetwork(id int64) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -380,8 +380,8 @@ func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
return channels, nil
}
func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
key := toNullString(ch.Key)
@ -408,16 +408,16 @@ func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
return err
}
func (db *PostgresDB) DeleteChannel(id int64) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
return err
}
func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -444,8 +444,8 @@ func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt,
return receipts, nil
}
func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout)
func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
tx, err := db.db.Begin()

View File

@ -208,11 +208,11 @@ func (db *SqliteDB) upgrade() error {
return tx.Commit()
}
func (db *SqliteDB) Stats() (*DatabaseStats, error) {
func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
var stats DatabaseStats
@ -234,11 +234,11 @@ func toNullString(s string) sql.NullString {
}
}
func (db *SqliteDB) ListUsers() ([]User, error) {
func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
@ -266,11 +266,11 @@ func (db *SqliteDB) ListUsers() ([]User, error) {
return users, nil
}
func (db *SqliteDB) GetUser(username string) (*User, error) {
func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
user := &User{Username: username}
@ -287,11 +287,11 @@ func (db *SqliteDB) GetUser(username string) (*User, error) {
return user, nil
}
func (db *SqliteDB) StoreUser(user *User) error {
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
args := []interface{}{
@ -323,11 +323,11 @@ func (db *SqliteDB) StoreUser(user *User) error {
return err
}
func (db *SqliteDB) DeleteUser(id int64) error {
func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
tx, err := db.db.Begin()
@ -371,11 +371,11 @@ func (db *SqliteDB) DeleteUser(id int64) error {
return tx.Commit()
}
func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -420,11 +420,11 @@ func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
return networks, nil
}
func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
@ -490,11 +490,11 @@ func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
return err
}
func (db *SqliteDB) DeleteNetwork(id int64) error {
func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
tx, err := db.db.Begin()
@ -521,11 +521,11 @@ func (db *SqliteDB) DeleteNetwork(id int64) error {
return tx.Commit()
}
func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `SELECT
@ -558,11 +558,11 @@ func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
return channels, nil
}
func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
args := []interface{}{
@ -598,22 +598,22 @@ func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
return err
}
func (db *SqliteDB) DeleteChannel(id int64) error {
func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
_, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id)
return err
}
func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx, `
@ -642,11 +642,11 @@ func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, er
return receipts, nil
}
func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
db.lock.Lock()
defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout)
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
tx, err := db.db.Begin()

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
@ -976,7 +977,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
func (dc *downstreamConn) authenticate(username, password string) error {
username, clientName, networkName := unmarshalUsername(username)
u, err := dc.srv.db.GetUser(username)
u, err := dc.srv.db.GetUser(context.TODO(), username)
if err != nil {
dc.logger.Printf("failed authentication for %q: user not found: %v", username, err)
return errAuthFailed
@ -1377,7 +1378,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
return
}
n.Nick = nick
err = dc.srv.db.StoreNetwork(dc.user.ID, &n.Network)
err = dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network)
})
if err != nil {
return err
@ -1427,7 +1428,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
})
n.Realname = storeRealname
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
dc.logger.Printf("failed to store network realname: %v", err)
storeErr = err
}
@ -1516,7 +1517,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}
uc.network.channels.SetValue(upstreamName, ch)
}
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
}
}
@ -1548,7 +1549,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}
uc.network.channels.SetValue(upstreamName, ch)
}
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
}
} else {
@ -2445,7 +2446,7 @@ func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
n.SASL.Mechanism = "PLAIN"
n.SASL.Plain.Username = username
n.SASL.Plain.Password = password
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
dc.logger.Printf("failed to save NickServ credentials: %v", err)
}
}

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"fmt"
"log"
"mime"
@ -85,7 +86,7 @@ func (s *Server) prefix() *irc.Prefix {
}
func (s *Server) Start() error {
users, err := s.db.ListUsers()
users, err := s.db.ListUsers(context.TODO())
if err != nil {
return err
}
@ -126,7 +127,7 @@ func (s *Server) createUser(user *User) (*user, error) {
return nil, fmt.Errorf("user %q already exists", user.Username)
}
err := s.db.StoreUser(user)
err := s.db.StoreUser(context.TODO(), user)
if err != nil {
return nil, fmt.Errorf("could not create user in db: %v", err)
}

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"net"
"testing"
@ -43,7 +44,7 @@ func createTestUser(t *testing.T, db Database) *User {
}
record := &User{Username: testUsername, Password: string(hashed)}
if err := db.StoreUser(record); err != nil {
if err := db.StoreUser(context.TODO(), record); err != nil {
t.Fatalf("failed to store test user: %v", err)
}
@ -68,7 +69,7 @@ func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Li
Nick: user.Username,
Enabled: true,
}
if err := db.StoreNetwork(user.ID, network); err != nil {
if err := db.StoreNetwork(context.TODO(), user.ID, network); err != nil {
t.Fatalf("failed to store test network: %v", err)
}

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
@ -657,7 +658,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = privKey
net.SASL.Mechanism = "EXTERNAL"
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
return err
}
@ -698,7 +699,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error {
net.SASL.Plain.Password = params[2]
net.SASL.Mechanism = "PLAIN"
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
return err
}
@ -722,7 +723,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = nil
net.SASL.Mechanism = ""
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
return err
}
@ -860,7 +861,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error {
u.stop()
if err := dc.srv.db.DeleteUser(u.ID); err != nil {
if err := dc.srv.db.DeleteUser(context.TODO(), u.ID); err != nil {
return fmt.Errorf("failed to delete user: %v", err)
}
@ -1015,7 +1016,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
uc.updateChannelAutoDetach(upstreamName)
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
return fmt.Errorf("failed to update channel: %v", err)
}
@ -1024,7 +1025,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
}
func handleServiceServerStatus(dc *downstreamConn, params []string) error {
dbStats, err := dc.user.srv.db.Stats()
dbStats, err := dc.user.srv.db.Stats(context.TODO())
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"crypto"
"crypto/sha256"
"crypto/tls"
@ -1516,7 +1517,7 @@ func (uc *upstreamConn) handleDetachedMessage(ch *Channel, msg *irc.Message) {
}
if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
uc.network.attach(ch)
if err := uc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
}
}

21
user.go
View File

@ -1,6 +1,7 @@
package soju
import (
"context"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
@ -330,7 +331,7 @@ func (net *network) deleteChannel(name string) error {
}
}
if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil {
if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil {
return err
}
net.channels.Delete(name)
@ -367,7 +368,7 @@ func (net *network) storeClientDeliveryReceipts(clientName string) {
})
})
if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil {
if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil {
net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
}
}
@ -487,7 +488,7 @@ func (u *user) run() {
close(u.done)
}()
networks, err := u.srv.db.ListNetworks(u.ID)
networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
if err != nil {
u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
return
@ -495,7 +496,7 @@ func (u *user) run() {
for _, record := range networks {
record := record
channels, err := u.srv.db.ListChannels(record.ID)
channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
if err != nil {
u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
continue
@ -505,7 +506,7 @@ func (u *user) run() {
u.networks = append(u.networks, network)
if u.hasPersistentMsgStore() {
receipts, err := u.srv.db.ListDeliveryReceipts(record.ID)
receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
if err != nil {
u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
return
@ -590,7 +591,7 @@ func (u *user) run() {
continue
}
uc.network.detach(c)
if err := uc.srv.db.StoreChannel(uc.network.ID, c); err != nil {
if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
}
case eventDownstreamConnected:
@ -779,7 +780,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
}
network := newNetwork(u, record, nil)
err := u.srv.db.StoreNetwork(u.ID, &network.Network)
err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network)
if err != nil {
return nil, err
}
@ -821,7 +822,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
panic("tried updating a non-existing network")
}
if err := u.srv.db.StoreNetwork(u.ID, record); err != nil {
if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil {
return nil, err
}
@ -888,7 +889,7 @@ func (u *user) deleteNetwork(id int64) error {
panic("tried deleting a non-existing network")
}
if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil {
return err
}
@ -914,7 +915,7 @@ func (u *user) updateUser(record *User) error {
}
realnameUpdated := u.Realname != record.Realname
if err := u.srv.db.StoreUser(record); err != nil {
if err := u.srv.db.StoreUser(context.TODO(), record); err != nil {
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
}
u.User = *record