Skip to content

Commit

Permalink
Merge pull request RedHatInsights#2633 from lzap/mem-cache-updatetrans
Browse files Browse the repository at this point in the history
Trivial memory cache for update trans
  • Loading branch information
loadtheaccumulator authored Jul 8, 2024
2 parents b2360f0 + ff2a234 commit af4846c
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 17 deletions.
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ func main() {
jobs.Worker().Start(ctx)
defer jobs.Worker().Stop(ctx)

defer routes.UpdateTransCache.Stop()

consumers := []services.ConsumerService{
services.NewKafkaConsumerService(cfg.KafkaConfig, kafkacommon.TopicPlaybookDispatcherRuns),
services.NewKafkaConsumerService(cfg.KafkaConfig, kafkacommon.TopicInventoryEvents),
Expand Down
129 changes: 129 additions & 0 deletions pkg/cache/mem_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// A simple memory-only thread-safe cache with TTL. Taken from the provisioning package:
// https://github.com/RHEnVision/provisioning-backend/commit/decfb331a2e5642e904bed0dcd3ac41b319eb732
package cache

import (
"runtime"
"sync"
"time"
)

// Cache stores arbitrary data with expiration time.
type Cache[K comparable, V any] struct {
items map[K]*item[V]
mu sync.Mutex
done chan any
clean chan bool
once sync.Once
cleanWG sync.WaitGroup
}

// An item represents arbitrary data with expiration time.
type item[V any] struct {
data V
expires int64
}

// New creates a new cache that asynchronously cleans
// expired entries after the given time passes. If cleaningInterval
// is zero, no background cleanup goroutine is scheduled.
func NewMemoryCache[K comparable, V any](cleaningInterval time.Duration) *Cache[K, V] {
cache := &Cache[K, V]{
items: make(map[K]*item[V]),
clean: make(chan bool),
done: make(chan any),
}

if cleaningInterval != 0 {
go func() {
ticker := time.NewTicker(cleaningInterval)
defer ticker.Stop()

for {
select {
case <-ticker.C:
cache.cleanup()
case <-cache.clean:
cache.cleanup()
cache.cleanWG.Done()
case <-cache.done:
return
}
}
}()
}

// Shutdown the goroutine when GC wants to clean this up
runtime.SetFinalizer(cache, func(c *Cache[K, V]) {
c.Stop()
})

return cache
}

// cleanup function is called from the background goroutine
func (cache *Cache[K, V]) cleanup() {
cache.mu.Lock()
defer cache.mu.Unlock()

now := time.Now().UnixNano()
for key, item := range cache.items {
if item.expires > 0 && now > item.expires {
delete(cache.items, key)
}
}
}

// Get gets the value for the given key.
func (cache *Cache[K, V]) Get(key K) (V, bool) {
cache.mu.Lock()
defer cache.mu.Unlock()

item, exists := cache.items[key]
if !exists || (item.expires > 0 && time.Now().UnixNano() > item.expires) {
var nothing V
return nothing, false
}

return item.data, true
}

// Set sets a value for the given key with an expiration duration.
// If the duration is 0 or less, it will be stored forever.
func (cache *Cache[K, V]) Set(key K, value V, duration time.Duration) {
cache.mu.Lock()
defer cache.mu.Unlock()

var expires int64
if duration > 0 {
expires = time.Now().Add(duration).UnixNano()
}
cache.items[key] = &item[V]{
data: value,
expires: expires,
}
}

// Count contains count of cached items.
func (cache *Cache[K, V]) Count() int {
cache.mu.Lock()
defer cache.mu.Unlock()

return len(cache.items)
}

// ExpireNow schedules immediate expiration cycle. It blocks, until cleanup is completed.
// If cleanup interval is zero, this will block forever.
func (cache *Cache[K, V]) ExpireNow() {
cache.cleanWG.Add(1)
cache.clean <- true
cache.cleanWG.Wait()
}

// Stop frees up resources and stops the cleanup goroutine
func (cache *Cache[K, V]) Stop() {
cache.once.Do(func() {
cache.items = make(map[K]*item[V])
close(cache.done)
})
}
82 changes: 82 additions & 0 deletions pkg/cache/mem_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package cache

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestSetAndGetFound(t *testing.T) {
c := NewMemoryCache[string, string](0)
c.Set("hello", "Hello", 0)
hello, found := c.Get("hello")
assert.True(t, found)
assert.Equal(t, "Hello", hello)
}

func TestSetAndGetNotFound(t *testing.T) {
c := NewMemoryCache[string, string](0)
_, found := c.Get("does not exist")
assert.False(t, found)
}

func TestManualExpiration(t *testing.T) {
c := NewMemoryCache[string, string](time.Minute)
c.Set("short", "expiration", time.Nanosecond)
c.ExpireNow()

_, found := c.Get("short")
assert.False(t, found)
}

func TestExpiration(t *testing.T) {
c := NewMemoryCache[string, string](10 * time.Millisecond)
c.Set("short", "expiration", time.Nanosecond)
defer c.Stop()

// hope for the best
time.Sleep(100 * time.Millisecond)

_, found := c.Get("short")
assert.False(t, found)
}

func BenchmarkNew(b *testing.B) {
b.ReportAllocs()

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
NewMemoryCache[string, string](5 * time.Second).Stop()
}
})
}

func BenchmarkGet(b *testing.B) {
c := NewMemoryCache[string, string](5 * time.Second)
defer c.Stop()
c.Set("Hello", "World", 0)

b.ReportAllocs()
b.ResetTimer()

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c.Get("Hello")
}
})
}

func BenchmarkSet(b *testing.B) {
c := NewMemoryCache[string, string](5 * time.Second)
defer c.Stop()

b.ResetTimer()
b.ReportAllocs()

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c.Set("Hello", "World", 0)
}
})
}
44 changes: 27 additions & 17 deletions pkg/routes/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
url2 "net/url"
"strconv"
"strings"
"time"

"github.com/redhatinsights/edge-api/pkg/cache"
"github.com/redhatinsights/edge-api/pkg/db"
"github.com/redhatinsights/edge-api/pkg/models"
"github.com/redhatinsights/edge-api/pkg/services"
Expand Down Expand Up @@ -250,6 +252,8 @@ func GetInstallerIsoStorageContent(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, signedURL, http.StatusSeeOther)
}

var UpdateTransCache = cache.NewMemoryCache[string, UpdateRepo](time.Duration(15 * time.Minute))

// UpdateTransactionCtx is a handler for Update transaction requests
func UpdateTransactionCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -259,33 +263,39 @@ func UpdateTransactionCtx(next http.Handler) http.Handler {
// readOrgID handle response and logging on failure
return
}

updateIDString := chi.URLParam(r, "updateID")
if updateIDString == "" {
logger.Debug("Update transaction ID was not passed to the request or it was empty")
respondWithAPIError(w, logger, errors.NewBadRequest("update transaction ID required"))
return
}
updateTransactionID, err := strconv.Atoi(updateIDString)
if err != nil {
respondWithAPIError(w, logger, errors.NewBadRequest("update transaction id must be an integer"))
return
}
cacheKey := orgID + "-" + updateIDString

var updateRepo UpdateRepo
dbQuery := db.Org(orgID, "").
Model(models.UpdateTransaction{}).
Joins("LEFT JOIN repos AS r ON r.id = update_transactions.repo_id").
Select("update_transactions.id as id, update_transactions.org_id as org_id, r.url as repo_url")
if result := dbQuery.First(&updateRepo, updateTransactionID); result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
logger.WithField("error", result.Error.Error()).Error("device update transaction not found")
respondWithAPIError(w, logger, errors.NewNotFound("device update transaction not found"))
var ok bool
if updateRepo, ok = UpdateTransCache.Get(cacheKey); !ok {
updateTransactionID, err := strconv.Atoi(updateIDString)
if err != nil {
respondWithAPIError(w, logger, errors.NewBadRequest("update transaction id must be an integer"))
return
}
logger.WithField("error", result.Error.Error()).Error("failed to retrieve update transaction")
respondWithAPIError(w, logger, errors.NewInternalServerError())
return

dbQuery := db.Org(orgID, "").
Model(models.UpdateTransaction{}).
Joins("LEFT JOIN repos AS r ON r.id = update_transactions.repo_id").
Select("update_transactions.id as id, update_transactions.org_id as org_id, r.url as repo_url")
if result := dbQuery.First(&updateRepo, updateTransactionID); result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
logger.WithField("error", result.Error.Error()).Error("device update transaction not found")
respondWithAPIError(w, logger, errors.NewNotFound("device update transaction not found"))
return
}
logger.WithField("error", result.Error.Error()).Error("failed to retrieve update transaction")
respondWithAPIError(w, logger, errors.NewInternalServerError())
return
}

UpdateTransCache.Set(cacheKey, updateRepo, time.Duration(2*time.Hour))
}

ctx := setContextUpdateTransaction(r.Context(), &updateRepo)
Expand Down

0 comments on commit af4846c

Please sign in to comment.