4 changed files with 164 additions and 5 deletions
@ -0,0 +1,72 @@
|
||||
package db |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"errors" |
||||
"log" |
||||
"time" |
||||
|
||||
"go.etcd.io/bbolt" |
||||
) |
||||
|
||||
var ( |
||||
communityBucket = []byte("community") |
||||
) |
||||
|
||||
type community struct { |
||||
Name string |
||||
CreationDate time.Time |
||||
} |
||||
|
||||
// AddCommunity stores in the db the community that was created by host
|
||||
func (db *DB) AddCommunity(name string, host string) error { |
||||
var communities []community |
||||
err := db.get(communityBucket, host, &communities) |
||||
if err != nil && !errors.Is(err, notFoundError{}) { |
||||
return err |
||||
} |
||||
|
||||
communities = append(communities, community{name, time.Now()}) |
||||
return db.put(communityBucket, host, communities) |
||||
} |
||||
|
||||
// CountCommunities returns the nubmer of communities created by host
|
||||
func (db *DB) CountCommunities(host string) (int, error) { |
||||
var communities []community |
||||
err := db.get(communityBucket, host, &communities) |
||||
if errors.Is(err, notFoundError{}) { |
||||
return 0, nil |
||||
} |
||||
return len(communities), err |
||||
} |
||||
|
||||
// ExpireCommunities older than duration
|
||||
func (db *DB) ExpireCommunities(duration time.Duration) error { |
||||
return db.bolt.Update(func(tx *bbolt.Tx) error { |
||||
b := tx.Bucket(communityBucket) |
||||
return b.ForEach(func(k, v []byte) error { |
||||
var communities []community |
||||
err := json.Unmarshal(v, &communities) |
||||
if err != nil { |
||||
log.Printf("Error unmarshalling %s: %v", string(k), err) |
||||
return nil |
||||
} |
||||
|
||||
newCommunities := []community{} |
||||
for _, c := range communities { |
||||
if c.CreationDate.Add(duration).After(time.Now()) { |
||||
newCommunities = append(newCommunities, c) |
||||
} |
||||
} |
||||
if len(newCommunities) == len(communities) { |
||||
return nil |
||||
} |
||||
|
||||
encodedValue, err := json.Marshal(newCommunities) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return b.Put(k, encodedValue) |
||||
}) |
||||
}) |
||||
} |
@ -0,0 +1,85 @@
|
||||
package db |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
const ( |
||||
communityName = "community" |
||||
communityHost = "host" |
||||
) |
||||
|
||||
func TestAddCountCommunity(t *testing.T) { |
||||
db := initTestDB(t) |
||||
defer delTestDB(db) |
||||
|
||||
count, err := db.CountCommunities(communityHost) |
||||
if err != nil { |
||||
t.Fatalf("Got an error counting communities: %v", err) |
||||
} |
||||
if count != 0 { |
||||
t.Errorf("Got an unexpected number of communities: %d", count) |
||||
} |
||||
|
||||
err = db.AddCommunity(communityName, communityHost) |
||||
if err != nil { |
||||
t.Fatalf("Got an error adding a community: %v", err) |
||||
} |
||||
|
||||
count, err = db.CountCommunities(communityHost) |
||||
if err != nil { |
||||
t.Fatalf("Got an error counting communities: %v", err) |
||||
} |
||||
if count != 1 { |
||||
t.Errorf("Got an unexpected number of communities: %d", count) |
||||
} |
||||
|
||||
err = db.AddCommunity(communityName+"1", communityHost) |
||||
if err != nil { |
||||
t.Fatalf("Got an error adding a community: %v", err) |
||||
} |
||||
err = db.AddCommunity(communityName+"2", communityHost) |
||||
if err != nil { |
||||
t.Fatalf("Got an error adding a community: %v", err) |
||||
} |
||||
|
||||
count, err = db.CountCommunities(communityHost) |
||||
if err != nil { |
||||
t.Fatalf("Got an error counting communities: %v", err) |
||||
} |
||||
if count != 3 { |
||||
t.Errorf("Got an unexpected number of communities: %d", count) |
||||
} |
||||
} |
||||
|
||||
func TestExpireCommunities(t *testing.T) { |
||||
db := initTestDB(t) |
||||
defer delTestDB(db) |
||||
|
||||
err := db.AddCommunity(communityName, communityHost) |
||||
if err != nil { |
||||
t.Fatalf("Got an error adding a community: %v", err) |
||||
} |
||||
|
||||
count, err := db.CountCommunities(communityHost) |
||||
if err != nil { |
||||
t.Fatalf("Got an error counting communities: %v", err) |
||||
} |
||||
if count != 1 { |
||||
t.Errorf("Got an unexpected number of communities: %d", count) |
||||
} |
||||
|
||||
err = db.ExpireCommunities(time.Microsecond) |
||||
if err != nil { |
||||
t.Fatalf("Got an error expiring invites: %v", err) |
||||
} |
||||
|
||||
count, err = db.CountCommunities(communityHost) |
||||
if err != nil { |
||||
t.Fatalf("Got an error counting communities: %v", err) |
||||
} |
||||
if count != 0 { |
||||
t.Errorf("Got an unexpected number of communities: %d", count) |
||||
} |
||||
} |
Loading…
Reference in new issue