Common code for implementing Twitch bots.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

220 lines
5.2 KiB

// Package ctxrl implements context-based rate limiting. It is intended to be
// used on TMI for rate limiting messages to channels in slow mode. It
// implements a message queue in order to integrate with a blocking rate
// limiter that ensures a global maximum rate.
package ctxrl
import (
"container/list"
"context"
"sync"
"time"
"raccatta.cc/bot/internal/signal"
"raccatta.cc/bot/unstable/log"
)
// defaultInterval is the interval that is used if no interval is set
// explicitly.
const defaultInterval = time.Millisecond * 1100
// A ContextMessage describes a message to-be-enqueued.
type ContextMessage struct {
Context string // Usually a channel
Publisher Publisher
}
// A pendingMsg is a message which we've scheduled and is ready to be sent to
// the writer (to enforce a global rate limit).
type pendingMsg struct {
ContextMessage
scheduled time.Time
}
// A Publisher pushes the output to the desired connection behind the scenes.
// The call should block until it is reasonably certain the message will be
// sent as soon as possible, to avoid conflicting with the rate limiting.
type Publisher interface {
Publish()
}
// A Limiter rate-limits messages to a context, such as a Twitch channel.
type Limiter struct {
mu sync.Mutex
pending *list.List
cs map[string]*ctxstate
wakeup chan struct{}
}
// A ctxstate contains the state for a given context. The limiter may
// periodically clean up unused/expired channel states.
type ctxstate struct {
pending *list.List
last time.Time
interval time.Duration
outstanding int
}
// New creates a new Limiter.
func New() *Limiter {
return &Limiter{
pending: list.New(),
cs: make(map[string]*ctxstate),
wakeup: signal.New(),
}
}
// Pending returns the number of outstanding (to be sent) messages for the
// given context.
func (b *Limiter) Pending(ctx string) int {
b.mu.Lock()
defer b.mu.Unlock()
cstate := b.cs[ctx]
if cstate == nil {
return 0
}
return cstate.outstanding
}
// TODO: helper to update interval of an existing context (e.g. slowmode
// or user state changes)
// Send schedules a message to be sent in a context. It is up to the caller to
// limit the total number of messages that can be queued.
//
// A negative `ival` means it is not known, in which case the previous interval
// will be assumed to be correct.
func (b *Limiter) Send(msg ContextMessage, ival time.Duration) int {
b.mu.Lock()
defer b.mu.Unlock()
cstate := b.cs[msg.Context]
if cstate == nil {
if ival < 0 {
ival = defaultInterval
}
cstate = &ctxstate{pending: list.New(), interval: ival}
b.cs[msg.Context] = cstate
} else if ival < 0 {
ival = cstate.interval
} else {
cstate.interval = ival
}
// Schedule this message immediately if it is the only message for this
// channel.
// TODO: add an optimization when outstanding > 0 but interval was 0s to
// skip queueing
if cstate.outstanding == 0 {
var t time.Time
if ival > 0 {
t = cstate.last.Add(ival)
}
cstate.outstanding++
b.schedulePending(&pendingMsg{
ContextMessage: msg,
scheduled: t,
})
signal.Wakeup(b.wakeup)
} else {
cstate.outstanding++
cstate.pending.PushBack(msg)
}
return cstate.outstanding
}
// Output should be used as a background routine that writes message according
// to the rate limit of each context.
func (b *Limiter) Output(ctx context.Context) {
// TODO: memory management: periodically clean up cstates that haven't been
// used
empty := false
var wakeUpIn <-chan time.Time
for {
if empty {
select {
case <-ctx.Done():
return
case <-b.wakeup:
}
} else if wakeUpIn != nil {
select {
case <-ctx.Done():
return
case <-b.wakeup:
case <-wakeUpIn:
}
wakeUpIn = nil
}
wakeUpIn, empty = b.outputOne()
}
}
func (b *Limiter) outputOne() (wakeUpIn <-chan time.Time, empty bool) {
f := b.nextPending()
if f == nil {
return nil, true
}
pm := f.Value.(*pendingMsg)
delay := time.Until(pm.scheduled)
if delay > 0 {
wakeUpIn = time.After(delay)
return
}
// Should block until sent/buffered
pm.Publisher.Publish()
b.completed(f, pm)
return
}
func (b *Limiter) schedulePending(pm *pendingMsg) {
// TODO: consider using a skiplist
for e := b.pending.Back(); e != nil; e = e.Prev() {
if pm.scheduled.After(e.Value.(*pendingMsg).scheduled) {
b.pending.InsertAfter(pm, e)
return
}
}
b.pending.PushBack(pm)
}
func (b *Limiter) nextPending() *list.Element {
b.mu.Lock()
defer b.mu.Unlock()
return b.pending.Front()
}
func (b *Limiter) completed(e *list.Element, pm *pendingMsg) {
b.mu.Lock()
defer b.mu.Unlock()
b.pending.Remove(e)
// TODO: cache cstate
if cstate := b.cs[pm.Context]; cstate != nil {
b.scheduleNextForContext(cstate)
}
}
// scheduleNextForContext schedules the next messages that's queued for a
// context, if any.
func (b *Limiter) scheduleNextForContext(cstate *ctxstate) {
if cstate.outstanding > 0 {
cstate.outstanding--
} else {
log.Error().Msg("scheduleNextForContext bug in cstate.outstanding")
}
now := time.Now()
cstate.last = now
f := cstate.pending.Front()
if f != nil {
cm := f.Value.(ContextMessage)
pm := &pendingMsg{
ContextMessage: cm,
scheduled: now.Add(cstate.interval),
}
cstate.pending.Remove(f)
b.schedulePending(pm)
signal.Wakeup(b.wakeup)
}
}