diff --git a/x/emissions/keeper/inference_synthesis/weights_test.go b/x/emissions/keeper/inference_synthesis/weights_test.go index 708bdc12a..9f2bd5485 100644 --- a/x/emissions/keeper/inference_synthesis/weights_test.go +++ b/x/emissions/keeper/inference_synthesis/weights_test.go @@ -498,3 +498,167 @@ func decPtr(s string) *alloraMath.Dec { dec := alloraMath.MustNewDecFromString(s) return &dec } + +func (s *WeightsTestSuite) TestCalcWeightsGivenWorkers() { + testCases := []struct { + name string + args synth.CalcWeightsGivenWorkersArgs + expectedError bool + checkResult func(result synth.RegretInformedWeights) + }{ + { + name: "basic calculation with single inferer and forecaster", + args: synth.CalcWeightsGivenWorkersArgs{ + Logger: s.ctx.Logger(), + Inferers: []string{s.addrsStr[0]}, + Forecasters: []string{s.addrsStr[1]}, + InfererToRegret: map[string]*alloraMath.Dec{ + s.addrsStr[0]: decPtr("1.0"), + }, + ForecasterToRegret: map[string]*alloraMath.Dec{ + s.addrsStr[1]: decPtr("2.0"), + }, + EpsilonTopic: alloraMath.MustNewDecFromString("0.01"), + PNorm: alloraMath.MustNewDecFromString("3.0"), + CNorm: alloraMath.MustNewDecFromString("0.75"), + StdDevPlusEpsilon: alloraMath.MustNewDecFromString("1.0"), + }, + expectedError: false, + checkResult: func(result synth.RegretInformedWeights) { + s.Require().Equal(1, len(result.Inferers)) + s.Require().Equal(1, len(result.Forecasters)) + s.Require().True(result.Inferers[s.addrsStr[0]].Lt(result.Forecasters[s.addrsStr[1]])) + }, + }, + { + name: "basic calculation with negative inferer and positive forecaster", + args: synth.CalcWeightsGivenWorkersArgs{ + Logger: s.ctx.Logger(), + Inferers: []string{s.addrsStr[0]}, + Forecasters: []string{s.addrsStr[1]}, + InfererToRegret: map[string]*alloraMath.Dec{ + s.addrsStr[0]: decPtr("-1.0"), + }, + ForecasterToRegret: map[string]*alloraMath.Dec{ + s.addrsStr[1]: decPtr("2.0"), + }, + EpsilonTopic: alloraMath.MustNewDecFromString("0.01"), + PNorm: alloraMath.MustNewDecFromString("3.0"), + CNorm: alloraMath.MustNewDecFromString("0.75"), + StdDevPlusEpsilon: alloraMath.MustNewDecFromString("1.0"), + }, + expectedError: false, + checkResult: func(result synth.RegretInformedWeights) { + s.T().Logf("Single worker test results:") + s.Require().Equal(1, len(result.Inferers)) + s.Require().Equal(1, len(result.Forecasters)) + s.Require().True(result.Inferers[s.addrsStr[0]].Lt(result.Forecasters[s.addrsStr[1]])) + }, + }, + { + name: "basic calculation with positive inferer and negative forecaster", + args: synth.CalcWeightsGivenWorkersArgs{ + Logger: s.ctx.Logger(), + Inferers: []string{s.addrsStr[0]}, + Forecasters: []string{s.addrsStr[1]}, + InfererToRegret: map[string]*alloraMath.Dec{ + s.addrsStr[0]: decPtr("1.0"), + }, + ForecasterToRegret: map[string]*alloraMath.Dec{ + s.addrsStr[1]: decPtr("-2.0"), + }, + EpsilonTopic: alloraMath.MustNewDecFromString("0.01"), + PNorm: alloraMath.MustNewDecFromString("3.0"), + CNorm: alloraMath.MustNewDecFromString("0.75"), + StdDevPlusEpsilon: alloraMath.MustNewDecFromString("1.0"), + }, + expectedError: false, + checkResult: func(result synth.RegretInformedWeights) { + s.Require().Equal(1, len(result.Inferers)) + s.Require().Equal(1, len(result.Forecasters)) + s.Require().True(result.Inferers[s.addrsStr[0]].Gt(result.Forecasters[s.addrsStr[1]])) + }, + }, + { + name: "calculation with multiple workers and mixed positive and negative regrets", + args: synth.CalcWeightsGivenWorkersArgs{ + Logger: s.ctx.Logger(), + Inferers: []string{s.addrsStr[0], s.addrsStr[1]}, + Forecasters: []string{s.addrsStr[2], s.addrsStr[3]}, + InfererToRegret: map[string]*alloraMath.Dec{ + s.addrsStr[0]: decPtr("-1.0"), + s.addrsStr[1]: decPtr("2.0"), + }, + ForecasterToRegret: map[string]*alloraMath.Dec{ + s.addrsStr[2]: decPtr("1.5"), + s.addrsStr[3]: decPtr("-0.5"), + }, + EpsilonTopic: alloraMath.MustNewDecFromString("0.01"), + PNorm: alloraMath.MustNewDecFromString("3.0"), + CNorm: alloraMath.MustNewDecFromString("0.75"), + StdDevPlusEpsilon: alloraMath.MustNewDecFromString("1.0"), + }, + expectedError: false, + checkResult: func(result synth.RegretInformedWeights) { + s.Require().Equal(2, len(result.Inferers)) + s.Require().Equal(2, len(result.Forecasters)) + + // Check that worker with higher regret has a higher weight + s.Require().True(result.Inferers[s.addrsStr[0]].Lt(result.Inferers[s.addrsStr[1]])) + s.Require().True(result.Forecasters[s.addrsStr[2]].Gt(result.Forecasters[s.addrsStr[3]])) + // compare mixed + s.Require().True(result.Forecasters[s.addrsStr[3]].Gt(result.Inferers[s.addrsStr[0]])) + s.Require().True(result.Forecasters[s.addrsStr[2]].Lt(result.Inferers[s.addrsStr[1]])) + + }, + }, + { + name: "empty workers should error", + args: synth.CalcWeightsGivenWorkersArgs{ + Logger: s.ctx.Logger(), + Inferers: []string{}, + Forecasters: []string{}, + InfererToRegret: map[string]*alloraMath.Dec{}, + ForecasterToRegret: map[string]*alloraMath.Dec{}, + EpsilonTopic: alloraMath.MustNewDecFromString("0.01"), + PNorm: alloraMath.MustNewDecFromString("3.0"), + CNorm: alloraMath.MustNewDecFromString("0.75"), + StdDevPlusEpsilon: alloraMath.MustNewDecFromString("1.0"), + }, + expectedError: true, + }, + { + name: "missing regret values should error", + args: synth.CalcWeightsGivenWorkersArgs{ + Logger: s.ctx.Logger(), + Inferers: []string{s.addrsStr[0]}, + Forecasters: []string{s.addrsStr[1]}, + InfererToRegret: map[string]*alloraMath.Dec{}, + ForecasterToRegret: map[string]*alloraMath.Dec{}, + EpsilonTopic: alloraMath.MustNewDecFromString("0.01"), + PNorm: alloraMath.MustNewDecFromString("3.0"), + CNorm: alloraMath.MustNewDecFromString("0.75"), + StdDevPlusEpsilon: alloraMath.MustNewDecFromString("1.0"), + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + result, err := synth.CalcWeightsGivenWorkers(tc.args) + + if tc.expectedError { + s.Require().Error(err) + return + } + + s.Require().NoError(err) + s.Require().NotNil(result) + + if tc.checkResult != nil { + tc.checkResult(result) + } + }) + } +}