From d903d6846868198cfec06bb5ccdf79f4a9856f2c Mon Sep 17 00:00:00 2001 From: Yasuyuki Takeo Date: Fri, 9 Aug 2024 17:04:03 +0900 Subject: [PATCH] Fix assosiation errors --- .github/workflows/backend.yml | 9 +- backend/Makefile | 3 +- backend/cmd/gqlgenerate/main_test.go | 5 +- backend/graph/services/card.go | 70 +++++++++----- backend/graph/services/cardgroup.go | 38 ++++---- backend/graph/services/cardgroup_test.go | 15 ++- backend/graph/services/swiperecord.go | 33 ++++--- backend/graph/services/swiperecord_test.go | 103 ++++++++++++++++----- backend/graph/services/user_test.go | 1 + backend/testutils/database.go | 5 +- 10 files changed, 192 insertions(+), 90 deletions(-) diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 1a67e8a..61cf0c7 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -24,12 +24,17 @@ jobs: - name: Install Goose run: go install github.com/pressly/goose/v3/cmd/goose@latest + # Alternatively, install using go install + - name: Set up gotestfmt + run: go install github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest + - name: Fetch dependant Go modules run: |- go get -v -t -d ./... working-directory: ./backend - name: Test code - run: |- - go test -v ./... + run: | + set -euo pipefail + go test -json -v ./... 2>&1 | tee /tmp/gotest.log | gotestfmt working-directory: ./backend diff --git a/backend/Makefile b/backend/Makefile index e01ceb5..d61180f 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -39,7 +39,8 @@ fmt: ## Format code test: ## Run tests printf "${GREEN}Run all tests\n\n${WHITE}"; \ go clean -cache -testcache -i -r; \ - go test -race -run=./... -bench=./... ./...; \ + go install github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest; \ + go test -race -json -v -coverprofile=coverage.txt ./... 2>&1 | tee /tmp/gotest.log | gotestfmt; \ printf "${GREEN}Done\n"; \ .PHONY: init diff --git a/backend/cmd/gqlgenerate/main_test.go b/backend/cmd/gqlgenerate/main_test.go index 986a22c..e805961 100644 --- a/backend/cmd/gqlgenerate/main_test.go +++ b/backend/cmd/gqlgenerate/main_test.go @@ -1,8 +1,6 @@ package main import ( - "log" - "os" "testing" "github.com/99designs/gqlgen/plugin/modelgen" @@ -17,13 +15,12 @@ func TestLoadGraphQLConfig(t *testing.T) { } func TestGenerateGraphQLCode(t *testing.T) { - logger := log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile) cfg, err := loadGraphQLConfig() if err != nil { t.Fatalf("Failed to load config: %v", err) } - err = generateGraphQLCode(cfg, logger) + err = generateGraphQLCode(cfg) if err != nil { t.Fatalf("Expected no error, got %v", err) } diff --git a/backend/graph/services/card.go b/backend/graph/services/card.go index 4e99a99..838f0ec 100644 --- a/backend/graph/services/card.go +++ b/backend/graph/services/card.go @@ -48,6 +48,29 @@ func convertToCard(card repository.Card) *model.Card { } } +func convertCardConnection(cards []repository.Card, hasPrevPage, hasNextPage bool) *model.CardConnection { + var result model.CardConnection + + for _, dbc := range cards { + card := convertToCard(dbc) + + // Use the ID directly as it is already of type int64 + result.Edges = append(result.Edges, &model.CardEdge{Cursor: card.ID, Node: card}) + result.Nodes = append(result.Nodes, card) + } + result.TotalCount = len(cards) + + result.PageInfo = &model.PageInfo{} + if result.TotalCount != 0 { + result.PageInfo.StartCursor = &result.Nodes[0].ID + result.PageInfo.EndCursor = &result.Nodes[result.TotalCount-1].ID + } + result.PageInfo.HasPreviousPage = hasPrevPage + result.PageInfo.HasNextPage = hasNextPage + + return &result +} + func (s *cardService) GetCardByID(ctx context.Context, id int64) (*model.Card, error) { var card repository.Card if err := s.db.WithContext(ctx).First(&card, id).Error; err != nil { @@ -167,33 +190,32 @@ func (s *cardService) PaginatedCardsByCardGroup(ctx context.Context, cardGroupID return nil, err } - var edges []*model.CardEdge - var nodes []*model.Card - for _, card := range cards { - node := convertToCard(card) - edges = append(edges, &model.CardEdge{ - Cursor: card.ID, - Node: node, - }) - nodes = append(nodes, node) - } - - pageInfo := &model.PageInfo{} - if len(cards) > 0 { - pageInfo.HasNextPage = len(cards) == s.defaultLimit - pageInfo.HasPreviousPage = len(cards) == s.defaultLimit - if len(edges) > 0 { - pageInfo.StartCursor = &edges[0].Cursor - pageInfo.EndCursor = &edges[len(edges)-1].Cursor + var hasNextPage, hasPrevPage bool + var count int64 + + if len(cards) != 0 { + startCursor, endCursor := cards[0].ID, cards[len(cards)-1].ID + + err := s.db.WithContext(ctx).Model(&repository.Card{}). + Where("cardgroup_id = ?", cardGroupID). + Where("id < ?", startCursor). + Count(&count).Error + if err != nil { + return nil, err + } + hasPrevPage = count > 0 + + err = s.db.WithContext(ctx).Model(&repository.Card{}). + Where("cardgroup_id = ?", cardGroupID). + Where("id > ?", endCursor). + Count(&count).Error + if err != nil { + return nil, err } + hasNextPage = count > 0 } - return &model.CardConnection{ - Edges: edges, - Nodes: nodes, - PageInfo: pageInfo, - TotalCount: len(cards), - }, nil + return convertCardConnection(cards, hasPrevPage, hasNextPage), nil } func (s *cardService) GetCardsByIDs(ctx context.Context, ids []int64) ([]*model.Card, error) { diff --git a/backend/graph/services/cardgroup.go b/backend/graph/services/cardgroup.go index fc2a432..bc6277a 100644 --- a/backend/graph/services/cardgroup.go +++ b/backend/graph/services/cardgroup.go @@ -36,7 +36,7 @@ func convertToCardGroup(cardGroup repository.Cardgroup) *model.CardGroup { func (s *cardGroupService) GetCardGroupByID(ctx context.Context, id int64) (*model.CardGroup, error) { var cardGroup repository.Cardgroup if err := s.db.First(&cardGroup, id).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to get card group by ID", err) + logger.Logger.ErrorContext(ctx, "Failed to get cardgroup by ID", err) return nil, err } return convertToCardGroup(cardGroup), nil @@ -44,9 +44,9 @@ func (s *cardGroupService) GetCardGroupByID(ctx context.Context, id int64) (*mod func (s *cardGroupService) CreateCardGroup(ctx context.Context, input model.NewCardGroup) (*model.CardGroup, error) { gormCardGroup := convertToGormCardGroup(input) - result := s.db.WithContext(ctx).Create(gormCardGroup) + result := s.db.WithContext(ctx).Create(&gormCardGroup) if result.Error != nil { - logger.Logger.ErrorContext(ctx, "Failed to create card group", result.Error) + logger.Logger.ErrorContext(ctx, "Failed to create cardgroup", result.Error) return nil, result.Error } return convertToCardGroup(*gormCardGroup), nil @@ -55,7 +55,7 @@ func (s *cardGroupService) CreateCardGroup(ctx context.Context, input model.NewC func (s *cardGroupService) CardGroups(ctx context.Context) ([]*model.CardGroup, error) { var cardGroups []repository.Cardgroup if err := s.db.WithContext(ctx).Find(&cardGroups).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to retrieve card groups", err) + logger.Logger.ErrorContext(ctx, "Failed to retrieve cardgroups", err) return nil, err } var gqlCardGroups []*model.CardGroup @@ -68,13 +68,13 @@ func (s *cardGroupService) CardGroups(ctx context.Context) ([]*model.CardGroup, func (s *cardGroupService) UpdateCardGroup(ctx context.Context, id int64, input model.NewCardGroup) (*model.CardGroup, error) { var cardGroup repository.Cardgroup if err := s.db.WithContext(ctx).First(&cardGroup, id).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to find card group for update", err) + logger.Logger.ErrorContext(ctx, "Failed to find cardgroup for update", err) return nil, err } cardGroup.Name = input.Name cardGroup.Updated = time.Now() if err := s.db.WithContext(ctx).Save(&cardGroup).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to update card group", err) + logger.Logger.ErrorContext(ctx, "Failed to update cardgroup", err) return nil, err } return convertToCardGroup(cardGroup), nil @@ -84,12 +84,12 @@ func (s *cardGroupService) DeleteCardGroup(ctx context.Context, id int64) (*bool success := false result := s.db.WithContext(ctx).Delete(&repository.Cardgroup{}, id) if result.Error != nil { - logger.Logger.ErrorContext(ctx, "Failed to delete card group", result.Error) + logger.Logger.ErrorContext(ctx, "Failed to delete cardgroup", result.Error) return &success, result.Error } if result.RowsAffected == 0 { err := fmt.Errorf("record not found") - logger.Logger.ErrorContext(ctx, "Card group not found for deletion", err) + logger.Logger.ErrorContext(ctx, "Cardgroup not found for deletion", err) return &success, err } success = true @@ -100,15 +100,15 @@ func (s *cardGroupService) AddUserToCardGroup(ctx context.Context, userID int64, var user repository.User var cardGroup repository.Cardgroup if err := s.db.WithContext(ctx).First(&user, userID).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to find user for adding to card group", err) + logger.Logger.ErrorContext(ctx, "Failed to find user for adding to cardgroup", err) return nil, err } if err := s.db.WithContext(ctx).First(&cardGroup, cardGroupID).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to find card group for adding user", err) + logger.Logger.ErrorContext(ctx, "Failed to find cardgroup for adding user", err) return nil, err } if err := s.db.Model(&cardGroup).Association("Users").Append(&user); err != nil { - logger.Logger.ErrorContext(ctx, "Failed to add user to card group", err) + logger.Logger.ErrorContext(ctx, "Failed to add user to cardgroup", err) return nil, err } return convertToCardGroup(cardGroup), nil @@ -118,15 +118,15 @@ func (s *cardGroupService) RemoveUserFromCardGroup(ctx context.Context, userID i var user repository.User var cardGroup repository.Cardgroup if err := s.db.WithContext(ctx).First(&user, userID).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to find user for removing from card group", err) + logger.Logger.ErrorContext(ctx, "Failed to find user for removing from cardgroup", err) return nil, err } if err := s.db.WithContext(ctx).First(&cardGroup, cardGroupID).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to find card group for removing user", err) + logger.Logger.ErrorContext(ctx, "Failed to find cardgroup for removing user", err) return nil, err } if err := s.db.Model(&cardGroup).Association("Users").Delete(&user); err != nil { - logger.Logger.ErrorContext(ctx, "Failed to remove user from card group", err) + logger.Logger.ErrorContext(ctx, "Failed to remove user from cardgroup", err) return nil, err } return convertToCardGroup(cardGroup), nil @@ -135,7 +135,7 @@ func (s *cardGroupService) RemoveUserFromCardGroup(ctx context.Context, userID i func (s *cardGroupService) GetCardGroupsByUser(ctx context.Context, userID int64) ([]*model.CardGroup, error) { var user repository.User if err := s.db.WithContext(ctx).Preload("CardGroups").First(&user, userID).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to get card groups by user ID", err) + logger.Logger.ErrorContext(ctx, "Failed to get cardgroups by user ID", err) return nil, err } var gqlCardGroups []*model.CardGroup @@ -149,7 +149,7 @@ func (s *cardGroupService) PaginatedCardGroupsByUser(ctx context.Context, userID var user repository.User var cardGroups []repository.Cardgroup - // Fetch the user and preload the card groups with pagination conditions + // Fetch the user and preload the cardgroups with pagination conditions query := s.db.WithContext(ctx).Model(&user).Where("id = ?", userID).Preload("CardGroups", func(db *gorm.DB) *gorm.DB { if after != nil { db = db.Where("cardgroups.id > ?", *after) @@ -168,8 +168,8 @@ func (s *cardGroupService) PaginatedCardGroupsByUser(ctx context.Context, userID }) if err := query.Find(&user).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to get paginated card groups by user", err) - return nil, fmt.Errorf("error fetching paginated card groups by user: %+v", err) + logger.Logger.ErrorContext(ctx, "Failed to get paginated cardgroups by user", err) + return nil, fmt.Errorf("error fetching paginated cardgroups by user: %+v", err) } cardGroups = user.CardGroups @@ -206,7 +206,7 @@ func (s *cardGroupService) PaginatedCardGroupsByUser(ctx context.Context, userID func (s *cardGroupService) GetCardGroupsByIDs(ctx context.Context, ids []int64) ([]*model.CardGroup, error) { var cardGroups []*repository.Cardgroup if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&cardGroups).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to retrieve card groups by IDs", err) + logger.Logger.ErrorContext(ctx, "Failed to retrieve cardgroups by IDs", err) return nil, err } diff --git a/backend/graph/services/cardgroup_test.go b/backend/graph/services/cardgroup_test.go index 755372a..f57efe9 100644 --- a/backend/graph/services/cardgroup_test.go +++ b/backend/graph/services/cardgroup_test.go @@ -63,7 +63,9 @@ func (suite *CardGroupTestSuite) TestCardGroupService() { suite.Run("Normal_CreateCardGroup", func() { - createdGroup, err := testutils.CreateUserAndCardGroup(ctx, userService, cardGroupService, roleService) + input := model.NewCardGroup{Name: "Test Group"} + createdGroup, err := cardGroupService.CreateCardGroup(context.Background(), input) + assert.NoError(t, err) assert.Equal(t, "Test Group", createdGroup.Name) }) @@ -219,12 +221,21 @@ func (suite *CardGroupTestSuite) TestCardGroupService() { suite.Run("Normal_RemoveUserFromCardGroup", func() { + // Create a role + newRole := model.NewRole{ + Name: "Test Role", + } + createdRole, err := roleService.CreateRole(ctx, newRole) + if err != nil { + suite.T().Fatalf("Failed at CreateRole: %+v", err) + } + // Create a user newUser := model.NewUser{ Name: "Test User", Created: time.Now(), Updated: time.Now(), - RoleIds: []int64{}, // Add any required roles here + RoleIds: []int64{createdRole.ID}, // Add any required roles here } createdUser, err := userService.CreateUser(ctx, newUser) assert.NoError(t, err) diff --git a/backend/graph/services/swiperecord.go b/backend/graph/services/swiperecord.go index 2606693..971a47e 100644 --- a/backend/graph/services/swiperecord.go +++ b/backend/graph/services/swiperecord.go @@ -41,11 +41,11 @@ func (s *swipeRecordService) GetSwipeRecordByID(ctx context.Context, id int64) ( var swipeRecord repository.SwipeRecord if err := s.db.WithContext(ctx).First(&swipeRecord, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - err := fmt.Errorf("swipe record not found") - logger.Logger.ErrorContext(ctx, "Swipe record not found:", "id", id) + err := fmt.Errorf("swipe record not found: id=%d", id) + logger.Logger.ErrorContext(ctx, err.Error()) return nil, err } - logger.Logger.ErrorContext(ctx, "Failed to get swipe record by ID", err) + logger.Logger.ErrorContext(ctx, "Failed to get swipe record by ID:", err) return nil, err } return convertToSwipeRecord(swipeRecord), nil @@ -57,10 +57,10 @@ func (s *swipeRecordService) CreateSwipeRecord(ctx context.Context, input model. if result.Error != nil { if strings.Contains(result.Error.Error(), "foreign key constraint") { err := fmt.Errorf("invalid swipe ID or card ID") - logger.Logger.ErrorContext(ctx, "Failed to create swipe record: invalid swipe ID or card ID", err) + logger.Logger.ErrorContext(ctx, "Failed to create swipe record:", err) return nil, err } - logger.Logger.ErrorContext(ctx, "Failed to create swipe record", result.Error) + logger.Logger.ErrorContext(ctx, "Failed to create swipe record:", result.Error) return nil, result.Error } return convertToSwipeRecord(*gormSwipeRecord), nil @@ -69,14 +69,14 @@ func (s *swipeRecordService) CreateSwipeRecord(ctx context.Context, input model. func (s *swipeRecordService) UpdateSwipeRecord(ctx context.Context, id int64, input model.NewSwipeRecord) (*model.SwipeRecord, error) { var swipeRecord repository.SwipeRecord if err := s.db.WithContext(ctx).First(&swipeRecord, id).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Swipe record does not exist: ", "id", id) + logger.Logger.ErrorContext(ctx, fmt.Sprintf("Swipe record does not exist: id=%d", id), err) return nil, err } swipeRecord.Direction = input.Direction swipeRecord.Updated = time.Now() if err := s.db.WithContext(ctx).Save(&swipeRecord).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to save swipe record", err) + logger.Logger.ErrorContext(ctx, "Failed to update swipe record:", err) return nil, err } return convertToSwipeRecord(swipeRecord), nil @@ -85,14 +85,14 @@ func (s *swipeRecordService) UpdateSwipeRecord(ctx context.Context, id int64, in func (s *swipeRecordService) DeleteSwipeRecord(ctx context.Context, id int64) (*bool, error) { result := s.db.WithContext(ctx).Delete(&repository.SwipeRecord{}, id) if result.Error != nil { - logger.Logger.ErrorContext(ctx, "Failed to delete swipe record", result.Error) + logger.Logger.ErrorContext(ctx, "Failed to delete swipe record:", result.Error) return nil, result.Error } success := result.RowsAffected > 0 if !success { - err := fmt.Errorf("record not found") - logger.Logger.ErrorContext(ctx, "Swipe record not found for deletion", err) + err := fmt.Errorf("record not found: id=%d", id) + logger.Logger.ErrorContext(ctx, "Swipe record not found for deletion:", err) return &success, err } @@ -102,7 +102,7 @@ func (s *swipeRecordService) DeleteSwipeRecord(ctx context.Context, id int64) (* func (s *swipeRecordService) SwipeRecords(ctx context.Context) ([]*model.SwipeRecord, error) { var swipeRecords []repository.SwipeRecord if err := s.db.WithContext(ctx).Find(&swipeRecords).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to retrieve swipe records", err) + logger.Logger.ErrorContext(ctx, "Failed to retrieve swipe records:", err) return nil, err } var gqlSwipeRecords []*model.SwipeRecord @@ -113,9 +113,14 @@ func (s *swipeRecordService) SwipeRecords(ctx context.Context) ([]*model.SwipeRe } func (s *swipeRecordService) SwipeRecordsByUser(ctx context.Context, userID int64) ([]*model.SwipeRecord, error) { + if userID <= 0 { + logger.Logger.ErrorContext(ctx, "User ID must be larger than 0", "user_id", userID) + return nil, fmt.Errorf("user ID must be larger than 0. It's %d", userID) + } + var swipeRecords []repository.SwipeRecord if err := s.db.WithContext(ctx).Where("user_id = ?", userID).Find(&swipeRecords).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to retrieve swipe records by user ID", err) + logger.Logger.ErrorContext(ctx, "Failed to retrieve swipe records by user ID:", err) return nil, err } var gqlSwipeRecords []*model.SwipeRecord @@ -144,7 +149,7 @@ func (s *swipeRecordService) PaginatedSwipeRecordsByUser(ctx context.Context, us } if err := query.Find(&swipeRecords).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to retrieve paginated swipe records by user ID", err) + logger.Logger.ErrorContext(ctx, "Failed to retrieve paginated swipe records by user ID:", err) return nil, err } @@ -180,7 +185,7 @@ func (s *swipeRecordService) PaginatedSwipeRecordsByUser(ctx context.Context, us func (s *swipeRecordService) GetSwipeRecordsByIDs(ctx context.Context, ids []int64) ([]*model.SwipeRecord, error) { var swipeRecords []*repository.SwipeRecord if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&swipeRecords).Error; err != nil { - logger.Logger.ErrorContext(ctx, "Failed to retrieve swipe records by IDs", err) + logger.Logger.ErrorContext(ctx, "Failed to retrieve swipe records by IDs:", err) return nil, err } diff --git a/backend/graph/services/swiperecord_test.go b/backend/graph/services/swiperecord_test.go index 3aa38b5..c879f1a 100644 --- a/backend/graph/services/swiperecord_test.go +++ b/backend/graph/services/swiperecord_test.go @@ -5,6 +5,7 @@ import ( "backend/graph/services" "backend/testutils" "context" + "fmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "gorm.io/gorm" @@ -14,11 +15,9 @@ import ( type SwipeRecordTestSuite struct { suite.Suite - db *gorm.DB - sv services.SwipeRecordService - userID int64 - cardGroup *model.CardGroup - cleanup func() + db *gorm.DB + sv services.SwipeRecordService + cleanup func() } func (suite *SwipeRecordTestSuite) SetupSuite() { @@ -43,17 +42,34 @@ func (suite *SwipeRecordTestSuite) SetupSuite() { suite.db = pg.GetDB() suite.sv = services.New(suite.db) - // Create a user and card group +} + +func (suite *SwipeRecordTestSuite) createTestUserAndRole(ctx context.Context) (int64, error) { userService := suite.sv.(services.UserService) - cardGroupService := suite.sv.(services.CardGroupService) roleService := suite.sv.(services.RoleService) - createdGroup, err := testutils.CreateUserAndCardGroup(ctx, userService, cardGroupService, roleService) + // Create a role + newRole := model.NewRole{ + Name: "Test Role", + } + createdRole, err := roleService.CreateRole(ctx, newRole) + if err != nil { + return 0, fmt.Errorf("Failed to create Role: %w", err) + } + + // Create a user + newUser := model.NewUser{ + Name: "Test User", + Created: time.Now(), + Updated: time.Now(), + RoleIds: []int64{createdRole.ID}, // Assign the new role to the user + } + createdUser, err := userService.CreateUser(ctx, newUser) if err != nil { - suite.T().Fatalf("Failed to create user and card group: %+v", err) + return 0, fmt.Errorf("Failed to create User: %w", err) } - suite.userID = createdGroup.Users.Nodes[0].ID - suite.cardGroup = createdGroup + + return createdUser.ID, nil } func (suite *SwipeRecordTestSuite) TearDownSuite() { @@ -73,8 +89,14 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { t.Helper() suite.Run("Normal_CreateSwipeRecord", func() { + // Use the helper function to create the user and role + userID, err := suite.createTestUserAndRole(ctx) + if err != nil { + suite.T().Fatal(err) + } + newSwipeRecord := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "left", Created: time.Now(), Updated: time.Now(), @@ -87,6 +109,7 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { }) suite.Run("Error_CreateSwipeRecord", func() { + newSwipeRecord := model.NewSwipeRecord{ UserID: 0, // Invalid UserID Direction: "", @@ -99,8 +122,14 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { }) suite.Run("Normal_GetSwipeRecordByID", func() { + // Use the helper function to create the user and role + userID, err := suite.createTestUserAndRole(ctx) + if err != nil { + suite.T().Fatal(err) + } + newSwipeRecord := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "left", Created: time.Now(), Updated: time.Now(), @@ -121,8 +150,14 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { }) suite.Run("Normal_UpdateSwipeRecord", func() { + // Use the helper function to create the user and role + userID, err := suite.createTestUserAndRole(ctx) + if err != nil { + suite.T().Fatal(err) + } + newSwipeRecord := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "left", Created: time.Now(), Updated: time.Now(), @@ -130,7 +165,7 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { createdSwipeRecord, _ := swipeRecordService.CreateSwipeRecord(ctx, newSwipeRecord) updateSwipeRecord := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "right", } @@ -141,8 +176,14 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { }) suite.Run("Error_UpdateSwipeRecord", func() { + // Use the helper function to create the user and role + userID, err := suite.createTestUserAndRole(ctx) + if err != nil { + suite.T().Fatal(err) + } + updateSwipeRecord := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "right", } @@ -153,8 +194,14 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { }) suite.Run("Normal_DeleteSwipeRecord", func() { + // Use the helper function to create the user and role + userID, err := suite.createTestUserAndRole(ctx) + if err != nil { + suite.T().Fatal(err) + } + newSwipeRecord := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "left", Created: time.Now(), Updated: time.Now(), @@ -175,14 +222,20 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { }) suite.Run("Normal_ListSwipeRecords", func() { + // Use the helper function to create the user and role + userID, err := suite.createTestUserAndRole(ctx) + if err != nil { + suite.T().Fatal(err) + } + newSwipeRecord1 := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "left", Created: time.Now(), Updated: time.Now(), } newSwipeRecord2 := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "right", Created: time.Now(), Updated: time.Now(), @@ -197,14 +250,20 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { }) suite.Run("Normal_ListSwipeRecordsByUser", func() { + // Use the helper function to create the user and role + userID, err := suite.createTestUserAndRole(ctx) + if err != nil { + suite.T().Fatal(err) + } + newSwipeRecord1 := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "left", Created: time.Now(), Updated: time.Now(), } newSwipeRecord2 := model.NewSwipeRecord{ - UserID: suite.userID, + UserID: userID, Direction: "right", Created: time.Now(), Updated: time.Now(), @@ -212,7 +271,7 @@ func (suite *SwipeRecordTestSuite) TestSwipeRecordService() { swipeRecordService.CreateSwipeRecord(ctx, newSwipeRecord1) swipeRecordService.CreateSwipeRecord(ctx, newSwipeRecord2) - swipeRecords, err := swipeRecordService.SwipeRecordsByUser(ctx, suite.userID) + swipeRecords, err := swipeRecordService.SwipeRecordsByUser(ctx, userID) assert.NoError(t, err) assert.Len(t, swipeRecords, 2) diff --git a/backend/graph/services/user_test.go b/backend/graph/services/user_test.go index 52b00b1..b15d39b 100644 --- a/backend/graph/services/user_test.go +++ b/backend/graph/services/user_test.go @@ -62,6 +62,7 @@ func (suite *UserTestSuite) TestUserService() { t.Helper() suite.Run("Normal_CreateUser", func() { + // Create a role newRole := model.NewRole{ Name: "Test Role", diff --git a/backend/testutils/database.go b/backend/testutils/database.go index dcea3f4..cfb8107 100644 --- a/backend/testutils/database.go +++ b/backend/testutils/database.go @@ -91,10 +91,11 @@ func RunServersTest(t *testing.T, db *gorm.DB, fn func(*testing.T)) { } // Delete records from tables - tx.Where("1 = 1").Delete(&repo.Role{}) - tx.Where("1 = 1").Delete(&repo.User{}) + tx.Where("1 = 1").Delete(&repo.SwipeRecord{}) tx.Where("1 = 1").Delete(&repo.Card{}) tx.Where("1 = 1").Delete(&repo.Cardgroup{}) + tx.Where("1 = 1").Delete(&repo.User{}) + tx.Where("1 = 1").Delete(&repo.Role{}) // Call the provided test function if fn != nil {