Skip to content

Commit

Permalink
ORA1294 unit test misc other functions (#233)
Browse files Browse the repository at this point in the history
Update test TestGetAllReputersOutput to make sure that it checks that
the values get closer to the correct value with more iterations of
gradient descent. This way the test doesn't time out.

---------

Co-authored-by: Kenny P <17100641+kpeluso@users.noreply.github.com>
  • Loading branch information
br4e and kpeluso authored May 11, 2024
1 parent afd8feb commit f4c69ff
Showing 1 changed file with 93 additions and 19 deletions.
112 changes: 93 additions & 19 deletions x/emissions/module/rewards/rewards_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,10 +644,12 @@ func TestGetAllConsensusScores(t *testing.T) {
}
}

// something about this test takes too long and hangs
// must investigate further
/*
func TestGetAllReputersOutput(t *testing.T) {
func (s *RewardsTestSuite) TestGetAllReputersOutput() {
require := s.Require()

params, err := s.emissionsKeeper.GetParams(s.ctx)
require.NoError(err)

allLosses := [][]alloraMath.Dec{
{alloraMath.MustNewDecFromString("0.0112"), alloraMath.MustNewDecFromString("0.00231"), alloraMath.MustNewDecFromString("0.02274"), alloraMath.MustNewDecFromString("0.01299"), alloraMath.MustNewDecFromString("0.02515"), alloraMath.MustNewDecFromString("0.0185"), alloraMath.MustNewDecFromString("0.01018"), alloraMath.MustNewDecFromString("0.02105"), alloraMath.MustNewDecFromString("0.01041"), alloraMath.MustNewDecFromString("0.0183"), alloraMath.MustNewDecFromString("0.01022"), alloraMath.MustNewDecFromString("0.01333"), alloraMath.MustNewDecFromString("0.01298"), alloraMath.MustNewDecFromString("0.01023"), alloraMath.MustNewDecFromString("0.01268"), alloraMath.MustNewDecFromString("0.01381"), alloraMath.MustNewDecFromString("0.01731"), alloraMath.MustNewDecFromString("0.01238"), alloraMath.MustNewDecFromString("0.01168"), alloraMath.MustNewDecFromString("0.00929"), alloraMath.MustNewDecFromString("0.01212"), alloraMath.MustNewDecFromString("0.01806"), alloraMath.MustNewDecFromString("0.01901"), alloraMath.MustNewDecFromString("0.01828"), alloraMath.MustNewDecFromString("0.01522"), alloraMath.MustNewDecFromString("0.01833"), alloraMath.MustNewDecFromString("0.0101"), alloraMath.MustNewDecFromString("0.01224"), alloraMath.MustNewDecFromString("0.01226"), alloraMath.MustNewDecFromString("0.01474"), alloraMath.MustNewDecFromString("0.01218"), alloraMath.MustNewDecFromString("0.01604"), alloraMath.MustNewDecFromString("0.01149"), alloraMath.MustNewDecFromString("0.02075"), alloraMath.MustNewDecFromString("0.00818"), alloraMath.MustNewDecFromString("0.0116"), alloraMath.MustNewDecFromString("0.01127"), alloraMath.MustNewDecFromString("0.01495"), alloraMath.MustNewDecFromString("0.00689"), alloraMath.MustNewDecFromString("0.0108"), alloraMath.MustNewDecFromString("0.01417"), alloraMath.MustNewDecFromString("0.0124"), alloraMath.MustNewDecFromString("0.01588"), alloraMath.MustNewDecFromString("0.01012"), alloraMath.MustNewDecFromString("0.01467"), alloraMath.MustNewDecFromString("0.0128"), alloraMath.MustNewDecFromString("0.01234"), alloraMath.MustNewDecFromString("0.0148"), alloraMath.MustNewDecFromString("0.01046"), alloraMath.MustNewDecFromString("0.01192"), alloraMath.MustNewDecFromString("0.01381"), alloraMath.MustNewDecFromString("0.01687"), alloraMath.MustNewDecFromString("0.01136"), alloraMath.MustNewDecFromString("0.01185"), alloraMath.MustNewDecFromString("0.01568"), alloraMath.MustNewDecFromString("0.00949"), alloraMath.MustNewDecFromString("0.01339")},
{alloraMath.MustNewDecFromString("0.01635"), alloraMath.MustNewDecFromString("0.00179"), alloraMath.MustNewDecFromString("0.03396"), alloraMath.MustNewDecFromString("0.0153"), alloraMath.MustNewDecFromString("0.01988"), alloraMath.MustNewDecFromString("0.00962"), alloraMath.MustNewDecFromString("0.01191"), alloraMath.MustNewDecFromString("0.01616"), alloraMath.MustNewDecFromString("0.01417"), alloraMath.MustNewDecFromString("0.01216"), alloraMath.MustNewDecFromString("0.01292"), alloraMath.MustNewDecFromString("0.01564"), alloraMath.MustNewDecFromString("0.01323"), alloraMath.MustNewDecFromString("0.01261"), alloraMath.MustNewDecFromString("0.01145"), alloraMath.MustNewDecFromString("0.0163"), alloraMath.MustNewDecFromString("0.014"), alloraMath.MustNewDecFromString("0.01373"), alloraMath.MustNewDecFromString("0.01453"), alloraMath.MustNewDecFromString("0.01207"), alloraMath.MustNewDecFromString("0.01641"), alloraMath.MustNewDecFromString("0.01601"), alloraMath.MustNewDecFromString("0.01114"), alloraMath.MustNewDecFromString("0.01259"), alloraMath.MustNewDecFromString("0.01589"), alloraMath.MustNewDecFromString("0.01229"), alloraMath.MustNewDecFromString("0.01309"), alloraMath.MustNewDecFromString("0.0138"), alloraMath.MustNewDecFromString("0.01162"), alloraMath.MustNewDecFromString("0.01145"), alloraMath.MustNewDecFromString("0.01013"), alloraMath.MustNewDecFromString("0.01208"), alloraMath.MustNewDecFromString("0.0111"), alloraMath.MustNewDecFromString("0.0118"), alloraMath.MustNewDecFromString("0.01374"), alloraMath.MustNewDecFromString("0.01428"), alloraMath.MustNewDecFromString("0.01791"), alloraMath.MustNewDecFromString("0.01288"), alloraMath.MustNewDecFromString("0.01161"), alloraMath.MustNewDecFromString("0.01151"), alloraMath.MustNewDecFromString("0.01148"), alloraMath.MustNewDecFromString("0.01284"), alloraMath.MustNewDecFromString("0.01239"), alloraMath.MustNewDecFromString("0.01023"), alloraMath.MustNewDecFromString("0.01712"), alloraMath.MustNewDecFromString("0.0116"), alloraMath.MustNewDecFromString("0.01639"), alloraMath.MustNewDecFromString("0.01043"), alloraMath.MustNewDecFromString("0.01308"), alloraMath.MustNewDecFromString("0.01455"), alloraMath.MustNewDecFromString("0.01607"), alloraMath.MustNewDecFromString("0.01205"), alloraMath.MustNewDecFromString("0.01357"), alloraMath.MustNewDecFromString("0.01108"), alloraMath.MustNewDecFromString("0.01633"), alloraMath.MustNewDecFromString("0.01208"), alloraMath.MustNewDecFromString("0.01278")},
Expand All @@ -671,28 +673,100 @@ func TestGetAllReputersOutput(t *testing.T) {
}
var numReputers int64 = 5
wantScores := []alloraMath.Dec{
alloraMath.MustNewDecFromString("17.536755245164326"),
alloraMath.MustNewDecFromString("20.302662649273707"),
alloraMath.MustNewDecFromString("24.278413872561256"),
alloraMath.MustNewDecFromString("11.365030585937692"),
alloraMath.MustNewDecFromString("15.211816727558011"),
alloraMath.MustNewDecFromString("17.46894"),
alloraMath.MustNewDecFromString("20.19617"),
alloraMath.MustNewDecFromString("24.15073"),
alloraMath.MustNewDecFromString("11.39661"),
alloraMath.MustNewDecFromString("15.29052"),
}
wantCoefficients := []alloraMath.Dec{
alloraMath.MustNewDecFromString("0.99942"),
alloraMath.MustNewDecFromString("0.99987"),
alloraMath.OneDec(),
alloraMath.OneDec(),
alloraMath.MustNewDecFromString("0.96574"),
alloraMath.MustNewDecFromString("0.95346"),
alloraMath.MustNewDecFromString("0.98634"),
alloraMath.MustNewDecFromString("0.98154"),
}
gotScores, gotCoefficients, err := rewards.GetAllReputersOutput(allLosses, stakes, initialCoefficients, numReputers)
require.NoError(t, err, "GetAllReputersOutput() error = %v, wantErr %v", err, false)
gotScores1, gotCoefficients1, err := rewards.GetAllReputersOutput(
allLosses,
stakes,
initialCoefficients,
numReputers,
params.LearningRate,
params.Sharpness,
1,
)
require.NoError(err)

gotScores2, gotCoefficients2, err := rewards.GetAllReputersOutput(
allLosses,
stakes,
initialCoefficients,
numReputers,
params.LearningRate,
params.Sharpness,
2,
)
require.NoError(err)

if !alloraMath.SlicesInDelta(gotScores, wantScores, alloraMath.MustNewDecFromString("0.00001")) {
t.Errorf("GetAllReputersOutput() gotScores = %v, want %v", gotScores, wantScores)
gotScores3, gotCoefficients3, err := rewards.GetAllReputersOutput(
allLosses,
stakes,
initialCoefficients,
numReputers,
params.LearningRate,
params.Sharpness,
3,
)
require.NoError(err)

getAbsoluteDifferences := func(gotScores []alloraMath.Dec, wantScores []alloraMath.Dec) ([]alloraMath.Dec, error) {
differences := []alloraMath.Dec{}
for i, score := range gotScores {
diff, err := score.Sub(wantScores[i])
if err != nil {
return nil, err
}
if diff.IsNegative() {
diff, err = diff.Mul(alloraMath.MustNewDecFromString("-1"))
if err != nil {
return nil, err
}
}

differences = append(differences, diff)
}
return differences, nil
}

if !alloraMath.SlicesInDelta(gotCoefficients, wantCoefficients, alloraMath.MustNewDecFromString("0.0001")) {
t.Errorf("GetAllReputersOutput() gotCoefficients = %v, want %v", gotCoefficients, wantCoefficients)
require.True(len(gotScores1) == len(wantScores))
require.True(len(gotScores2) == len(wantScores))
require.True(len(gotScores3) == len(wantScores))

scores1DifferenceAbs, err := getAbsoluteDifferences(gotScores1, wantScores)
require.NoError(err)
scores2DifferenceAbs, err := getAbsoluteDifferences(gotScores2, wantScores)
require.NoError(err)
scores3DifferenceAbs, err := getAbsoluteDifferences(gotScores3, wantScores)
require.NoError(err)

for i := 0; i < len(wantScores); i++ {
require.True(scores2DifferenceAbs[i].Lt(scores1DifferenceAbs[i]))
require.True(scores3DifferenceAbs[i].Lt(scores2DifferenceAbs[i]))
}

require.True(len(gotCoefficients1) == len(wantCoefficients))
require.True(len(gotCoefficients2) == len(wantCoefficients))
require.True(len(gotCoefficients3) == len(wantCoefficients))

coefficients1DifferenceAbs, err := getAbsoluteDifferences(gotCoefficients1, wantCoefficients)
require.NoError(err)
coefficients2DifferenceAbs, err := getAbsoluteDifferences(gotCoefficients2, wantCoefficients)
require.NoError(err)
coefficients3DifferenceAbs, err := getAbsoluteDifferences(gotCoefficients3, wantCoefficients)
require.NoError(err)

for i := 0; i < len(wantCoefficients); i++ {
require.True(coefficients2DifferenceAbs[i].Lte(coefficients1DifferenceAbs[i]))
require.True(coefficients3DifferenceAbs[i].Lte(coefficients2DifferenceAbs[i]))
}
}
*/

0 comments on commit f4c69ff

Please sign in to comment.