diff --git a/x/emissions/keeper/inference_synthesis/network_inference_builder.go b/x/emissions/keeper/inference_synthesis/network_inference_builder.go index f46fc2a09..6be2ba045 100644 --- a/x/emissions/keeper/inference_synthesis/network_inference_builder.go +++ b/x/emissions/keeper/inference_synthesis/network_inference_builder.go @@ -7,102 +7,68 @@ import ( "cosmossdk.io/log" alloraMath "github.com/allora-network/allora-chain/math" emissions "github.com/allora-network/allora-chain/x/emissions/types" - sdk "github.com/cosmos/cosmos-sdk/types" ) -type NetworkInferenceBuilder struct { - ctx sdk.Context - logger log.Logger - palette SynthPalette - // Network Inferences Properties - inferences []*emissions.WorkerAttributedValue - forecastImpliedInferences []*emissions.WorkerAttributedValue - weights RegretInformedWeights - combinedInference InferenceValue - naiveInference InferenceValue - oneOutInfererInferences []*emissions.WithheldWorkerAttributedValue - oneOutForecasterInferences []*emissions.WithheldWorkerAttributedValue - oneInInferences []*emissions.WorkerAttributedValue -} - -func NewNetworkInferenceBuilderFromSynthRequest( - req SynthRequest, -) (*NetworkInferenceBuilder, error) { - paletteFactory := SynthPaletteFactory{} - palette, err := paletteFactory.BuildPaletteFromRequest(req) - if err != nil { - return nil, errorsmod.Wrapf(err, "Error building palette from request") - } - return &NetworkInferenceBuilder{ - ctx: req.Ctx, - logger: Logger(req.Ctx), - palette: palette, - }, nil -} - // Calculates the network combined inference I_i, Equation 9 -func (b *NetworkInferenceBuilder) SetCombinedValue() *NetworkInferenceBuilder { - b.logger.Debug(fmt.Sprintf("Calculating combined inference for topic %v", b.palette.TopicId)) - palette := b.palette.Clone() +func GetCombinedInference(logger log.Logger, inPalette SynthPalette) ( + weights RegretInformedWeights, combinedInference InferenceValue, err error) { + logger.Debug(fmt.Sprintf("Calculating combined inference for topic %v", inPalette.TopicId)) + palette := inPalette.Clone() - weights, err := palette.CalcWeightsGivenWorkers() + weights, err = palette.CalcWeightsGivenWorkers() if err != nil { - b.logger.Warn(fmt.Sprintf("Error calculating weights for combined inference: %s", err.Error())) - return b + errorsmod.Wrap(err, "Error calculating weights for combined inference") + return RegretInformedWeights{}, InferenceValue{}, err } - combinedInference, err := palette.CalcWeightedInference(weights) + combinedInference, err = palette.CalcWeightedInference(weights) if err != nil { - b.logger.Warn(fmt.Sprintf("Error calculating combined inference: %s", err.Error())) - return b + errorsmod.Wrap(err, "Error calculating combined inference") + return RegretInformedWeights{}, InferenceValue{}, err } - b.logger.Debug(fmt.Sprintf("Combined inference calculated for topic %v is %v", b.palette.TopicId, combinedInference)) - b.combinedInference = combinedInference - b.weights = weights - return b + logger.Debug(fmt.Sprintf("Combined inference calculated for topic %v is %v", inPalette.TopicId, combinedInference)) + return weights, combinedInference, nil } // Map inferences to a WorkerAttributedValue array and set -func (b *NetworkInferenceBuilder) SetInfererValues() *NetworkInferenceBuilder { - infererValues := make([]*emissions.WorkerAttributedValue, 0) - for _, inferer := range b.palette.Inferers { +func GetInferences(palette SynthPalette) (infererValues []*emissions.WorkerAttributedValue) { + infererValues = make([]*emissions.WorkerAttributedValue, 0) + for _, inferer := range palette.Inferers { infererValues = append(infererValues, &emissions.WorkerAttributedValue{ Worker: inferer, - Value: b.palette.InferenceByWorker[inferer].Value, + Value: palette.InferenceByWorker[inferer].Value, }) } - b.inferences = infererValues - return b + return infererValues } // Map forecast-implied inferences to a WorkerAttributedValue array and set -func (b *NetworkInferenceBuilder) SetForecasterValues() *NetworkInferenceBuilder { - forecastImpliedValues := make([]*emissions.WorkerAttributedValue, 0) - for _, forecaster := range b.palette.Forecasters { - if b.palette.ForecastImpliedInferenceByWorker[forecaster] == nil { - b.logger.Warn(fmt.Sprintf("No forecast-implied inference for forecaster %s", forecaster)) +func GetForecastImpliedInferences(logger log.Logger, palette SynthPalette) ( + forecastImpliedInferences []*emissions.WorkerAttributedValue) { + forecastImpliedInferences = make([]*emissions.WorkerAttributedValue, 0) + for _, forecaster := range palette.Forecasters { + if palette.ForecastImpliedInferenceByWorker[forecaster] == nil { + logger.Warn(fmt.Sprintf("No forecast-implied inference for forecaster %s", forecaster)) continue } - forecastImpliedValues = append(forecastImpliedValues, &emissions.WorkerAttributedValue{ - Worker: b.palette.ForecastImpliedInferenceByWorker[forecaster].Inferer, - Value: b.palette.ForecastImpliedInferenceByWorker[forecaster].Value, + forecastImpliedInferences = append(forecastImpliedInferences, &emissions.WorkerAttributedValue{ + Worker: palette.ForecastImpliedInferenceByWorker[forecaster].Inferer, + Value: palette.ForecastImpliedInferenceByWorker[forecaster].Value, }) } - b.forecastImpliedInferences = forecastImpliedValues - return b + return forecastImpliedInferences } // Calculates the network naive inference I^-_i -func (b *NetworkInferenceBuilder) SetNaiveValue() *NetworkInferenceBuilder { - b.logger.Debug(fmt.Sprintf("Calculating naive inference for topic %v", b.palette.TopicId)) - palette := b.palette.Clone() +func GetNaiveInference(logger log.Logger, inPalette SynthPalette) (naiveInference alloraMath.Dec, err error) { + logger.Debug(fmt.Sprintf("Calculating naive inference for topic %v", inPalette.TopicId)) + palette := inPalette.Clone() // Update the forecasters info to exclude all forecasters - err := palette.UpdateForecastersInfo(make([]string, 0)) + err = palette.UpdateForecastersInfo(make([]string, 0)) if err != nil { - b.logger.Warn(fmt.Sprintf("Error updating forecasters info for naive inference: %s", err.Error())) - return b + return alloraMath.Dec{}, errorsmod.Wrap(err, "Error updating forecasters info for naive inference") } // Get inferer naive regrets @@ -110,33 +76,32 @@ func (b *NetworkInferenceBuilder) SetNaiveValue() *NetworkInferenceBuilder { for _, inferer := range palette.Inferers { regret, _, err := palette.K.GetNaiveInfererNetworkRegret(palette.Ctx, palette.TopicId, inferer) if err != nil { - b.logger.Warn(fmt.Sprintf("Error getting naive regret for inferer %s: %s", inferer, err.Error())) - return b + return alloraMath.Dec{}, errorsmod.Wrapf(err, "Error getting naive regret for inferer %s", inferer) } palette.InfererRegrets[inferer] = ®ret.Value } weights, err := palette.CalcWeightsGivenWorkers() if err != nil { - b.logger.Warn(fmt.Sprintf("Error calculating weights for naive inference: %s", err.Error())) - return b + return alloraMath.Dec{}, errorsmod.Wrap(err, "Error calculating weights for naive inference") } - naiveInference, err := palette.CalcWeightedInference(weights) + naiveInference, err = palette.CalcWeightedInference(weights) if err != nil { - b.logger.Warn(fmt.Sprintf("Error calculating naive inference: %s", err.Error())) - return b + return alloraMath.Dec{}, errorsmod.Wrap(err, "Error calculating naive inference") } - b.logger.Debug(fmt.Sprintf("Naive inference calculated for topic %v is %v", b.palette.TopicId, naiveInference)) - b.naiveInference = naiveInference - return b + logger.Debug(fmt.Sprintf("Naive inference calculated for topic %v is %v", inPalette.TopicId, naiveInference)) + return naiveInference, nil } // Calculate the one-out inference given a withheld inferer -func (b *NetworkInferenceBuilder) calcOneOutInfererInference(withheldInferer Worker) (alloraMath.Dec, error) { - b.logger.Debug(fmt.Sprintf("Calculating one-out inference for topic %v withheld inferer %s", b.palette.TopicId, withheldInferer)) - palette := b.palette.Clone() +func calcOneOutInfererInference( + logger log.Logger, inPalette SynthPalette, withheldInferer Worker) ( + oneOutNetworkInferenceWithoutInferer alloraMath.Dec, err error) { + logger.Debug(fmt.Sprintf( + "Calculating one-out inference for topic %v withheld inferer %s", inPalette.TopicId, withheldInferer)) + palette := inPalette.Clone() totalInferers := palette.Inferers // Remove the inferer from the palette's inferers @@ -147,7 +112,7 @@ func (b *NetworkInferenceBuilder) calcOneOutInfererInference(withheldInferer Wor } } - err := palette.UpdateInferersInfo(remainingInferers) + err = palette.UpdateInferersInfo(remainingInferers) if err != nil { return alloraMath.Dec{}, errorsmod.Wrapf(err, "Error updating inferers") } @@ -187,12 +152,12 @@ func (b *NetworkInferenceBuilder) calcOneOutInfererInference(withheldInferer Wor return alloraMath.Dec{}, errorsmod.Wrapf(err, "Error calculating one-out inference for forecaster") } - oneOutNetworkInferenceWithoutInferer, err := paletteCopy.CalcWeightedInference(weights) + oneOutNetworkInferenceWithoutInferer, err = paletteCopy.CalcWeightedInference(weights) if err != nil { return alloraMath.Dec{}, errorsmod.Wrapf(err, "Error calculating one-out inference for inferer") } - b.logger.Debug(fmt.Sprintf("One-out inference calculated for topic %v withheld inferer %s is %v", b.palette.TopicId, withheldInferer, oneOutNetworkInferenceWithoutInferer)) + logger.Debug(fmt.Sprintf("One-out inference calculated for topic %v withheld inferer %s is %v", inPalette.TopicId, withheldInferer, oneOutNetworkInferenceWithoutInferer)) return oneOutNetworkInferenceWithoutInferer, nil } @@ -200,16 +165,15 @@ func (b *NetworkInferenceBuilder) calcOneOutInfererInference(withheldInferer Wor // Assumed that there is at most 1 inference per inferer // Loop over all inferences and withhold one, then calculate the network inference less that withheld inference // This involves recalculating the forecast-implied inferences for each withheld inferer -func (b *NetworkInferenceBuilder) SetOneOutInfererValues() *NetworkInferenceBuilder { - b.logger.Debug(fmt.Sprintf("Calculating one-out inferer inferences for topic %v with %v inferers", b.palette.TopicId, len(b.palette.Inferers))) +func GetOneOutInfererInferences(logger log.Logger, palette SynthPalette) ( + oneOutInfererInferences []*emissions.WithheldWorkerAttributedValue, err error) { + logger.Debug(fmt.Sprintf("Calculating one-out inferer inferences for topic %v with %v inferers", palette.TopicId, len(palette.Inferers))) // Calculate the one-out inferences per inferer oneOutInferences := make([]*emissions.WithheldWorkerAttributedValue, 0) - for _, worker := range b.palette.Inferers { - oneOutInference, err := b.calcOneOutInfererInference(worker) + for _, worker := range palette.Inferers { + oneOutInference, err := calcOneOutInfererInference(logger, palette, worker) if err != nil { - b.logger.Warn(fmt.Sprintf("Error calculating one-out inferer inferences: %s", err.Error())) - b.oneOutInfererInferences = make([]*emissions.WithheldWorkerAttributedValue, 0) - return b + return []*emissions.WithheldWorkerAttributedValue{}, errorsmod.Wrapf(err, "Error calculating one-out inferer inferences") } oneOutInferences = append(oneOutInferences, &emissions.WithheldWorkerAttributedValue{ @@ -218,15 +182,16 @@ func (b *NetworkInferenceBuilder) SetOneOutInfererValues() *NetworkInferenceBuil }) } - b.logger.Debug(fmt.Sprintf("One-out inferer inferences calculated for topic %v", b.palette.TopicId)) - b.oneOutInfererInferences = oneOutInferences - return b + logger.Debug(fmt.Sprintf("One-out inferer inferences calculated for topic %v", palette.TopicId)) + return oneOutInferences, nil } // Calculate the one-out inference given a withheld forecaster -func (b *NetworkInferenceBuilder) calcOneOutForecasterInference(withheldForecaster Worker) (alloraMath.Dec, error) { - b.logger.Debug(fmt.Sprintf("Calculating one-out inference for topic %v withheld forecaster %s", b.palette.TopicId, withheldForecaster)) - palette := b.palette.Clone() +func calcOneOutForecasterInference( + logger log.Logger, inPalette SynthPalette, withheldForecaster Worker) ( + oneOutNetworkInferenceWithoutInferer alloraMath.Dec, err error) { + logger.Debug(fmt.Sprintf("Calculating one-out inference for topic %v withheld forecaster %s", inPalette.TopicId, withheldForecaster)) + palette := inPalette.Clone() totalForecasters := palette.Forecasters // Remove the withheldForecaster from the palette's forecasters @@ -237,7 +202,7 @@ func (b *NetworkInferenceBuilder) calcOneOutForecasterInference(withheldForecast } } - err := palette.UpdateForecastersInfo(remainingForecasters) + err = palette.UpdateForecastersInfo(remainingForecasters) if err != nil { return alloraMath.Dec{}, errorsmod.Wrapf(err, "Error updating forecasters") } @@ -266,45 +231,46 @@ func (b *NetworkInferenceBuilder) calcOneOutForecasterInference(withheldForecast return alloraMath.Dec{}, errorsmod.Wrapf(err, "Error calculating one-out inference for forecaster") } - oneOutNetworkInferenceWithoutInferer, err := palette.CalcWeightedInference(weights) + oneOutNetworkInferenceWithoutInferer, err = palette.CalcWeightedInference(weights) if err != nil { return alloraMath.Dec{}, errorsmod.Wrapf(err, "Error calculating one-out inference for inferer") } - b.logger.Debug(fmt.Sprintf("One-out inference calculated for topic %v withheld forecaster %s is %v", b.palette.TopicId, withheldForecaster, oneOutNetworkInferenceWithoutInferer)) + logger.Debug(fmt.Sprintf("One-out inference calculated for topic %v withheld forecaster %s is %v", inPalette.TopicId, withheldForecaster, oneOutNetworkInferenceWithoutInferer)) return oneOutNetworkInferenceWithoutInferer, nil } // Set all one-out-forecaster inferences that are possible given the provided input // Assume that there is at most 1 forecast-implied inference per forecaster // Loop over all forecast-implied inferences and withhold one, then calculate the network inference less that withheld value -func (b *NetworkInferenceBuilder) SetOneOutForecasterValues() *NetworkInferenceBuilder { - b.logger.Debug(fmt.Sprintf("Calculating one-out forecaster inferences for topic %v with %v forecasters", b.palette.TopicId, len(b.palette.Forecasters))) +func GetOneOutForecasterInferences( + logger log.Logger, palette SynthPalette) ( + oneOutForecasterInferences []*emissions.WithheldWorkerAttributedValue, err error) { + logger.Debug(fmt.Sprintf("Calculating one-out forecaster inferences for topic %v with %v forecasters", palette.TopicId, len(palette.Forecasters))) // Calculate the one-out forecast-implied inferences per forecaster oneOutImpliedInferences := make([]*emissions.WithheldWorkerAttributedValue, 0) - // If there is only one forecaster, thre's no need to calculate one-out inferences - if len(b.palette.Forecasters) > 1 { - for _, worker := range b.palette.Forecasters { - oneOutInference, err := b.calcOneOutForecasterInference(worker) + // If there is only one forecaster, there's no need to calculate one-out inferences + if len(palette.Forecasters) > 1 { + for _, worker := range palette.Forecasters { + oneOutInference, err := calcOneOutForecasterInference(logger, palette, worker) if err != nil { - b.logger.Warn(fmt.Sprintf("Error calculating one-out forecaster inferences: %s", err.Error())) - b.oneOutForecasterInferences = make([]*emissions.WithheldWorkerAttributedValue, 0) - return b + return []*emissions.WithheldWorkerAttributedValue{}, errorsmod.Wrapf(err, "Error calculating one-out forecaster inferences") } oneOutImpliedInferences = append(oneOutImpliedInferences, &emissions.WithheldWorkerAttributedValue{ Worker: worker, Value: oneOutInference, }) } - b.logger.Debug(fmt.Sprintf("One-out forecaster inferences calculated for topic %v", b.palette.TopicId)) + logger.Debug(fmt.Sprintf("One-out forecaster inferences calculated for topic %v", palette.TopicId)) } - b.oneOutForecasterInferences = oneOutImpliedInferences - return b + return oneOutForecasterInferences, nil } -func (b *NetworkInferenceBuilder) calcOneInValue(oneInForecaster Worker) (alloraMath.Dec, error) { - b.logger.Debug(fmt.Sprintf("Calculating one-in inference for forecaster: %s", oneInForecaster)) - palette := b.palette.Clone() +func calcOneInValue( + logger log.Logger, inPalette SynthPalette, oneInForecaster Worker) ( + oneInInference alloraMath.Dec, err error) { + logger.Debug(fmt.Sprintf("Calculating one-in inference for forecaster: %s", oneInForecaster)) + palette := inPalette.Clone() // In each loop, remove all forecast-implied inferences except one forecastImpliedInferencesWithForecaster := make(map[Worker]*emissions.Inference) @@ -345,7 +311,7 @@ func (b *NetworkInferenceBuilder) calcOneInValue(oneInForecaster Worker) (allora return alloraMath.Dec{}, errorsmod.Wrapf(err, "Error calculating weights for one-in inferences") } // Calculate the network inference with just this forecaster's forecast-implied inference - oneInInference, err := palette.CalcWeightedInference(weights) + oneInInference, err = palette.CalcWeightedInference(weights) if err != nil { return alloraMath.Dec{}, errorsmod.Wrapf(err, "Error calculating one-in inference") } @@ -356,17 +322,16 @@ func (b *NetworkInferenceBuilder) calcOneInValue(oneInForecaster Worker) (allora // Set all one-in inferences that are possible given the provided input // Assumed that there is at most 1 inference per worker. // Also assume that there is at most 1 forecast-implied inference per worker. -func (b *NetworkInferenceBuilder) SetOneInValues() *NetworkInferenceBuilder { +func GetOneInForecasterInferences(logger log.Logger, palette SynthPalette) (oneInInferences []*emissions.WorkerAttributedValue, err error) { // Loop over all forecast-implied inferences and set it as the only forecast-implied inference // one at a time, then calculate the network inference given that one held out - oneInInferences := make([]*emissions.WorkerAttributedValue, 0) + oneInInferences = make([]*emissions.WorkerAttributedValue, 0) // If there is only one forecaster, thre's no need to calculate one-in inferences - if len(b.palette.Forecasters) > 1 { - for _, oneInForecaster := range b.palette.Forecasters { - oneInValue, err := b.calcOneInValue(oneInForecaster) + if len(palette.Forecasters) > 1 { + for _, oneInForecaster := range palette.Forecasters { + oneInValue, err := calcOneInValue(logger, palette, oneInForecaster) if err != nil { - b.logger.Warn(fmt.Sprintf("Error calculating one-in inferences: %s", err.Error())) - return b + return []*emissions.WorkerAttributedValue{}, errorsmod.Wrapf(err, "Error calculating one-in inferences") } oneInInferences = append(oneInInferences, &emissions.WorkerAttributedValue{ Worker: oneInForecaster, @@ -374,33 +339,48 @@ func (b *NetworkInferenceBuilder) SetOneInValues() *NetworkInferenceBuilder { }) } } - b.oneInInferences = oneInInferences - return b -} - -func (b *NetworkInferenceBuilder) CalcAndSetNetworkInferences() *NetworkInferenceBuilder { - return b.SetCombinedValue(). - SetInfererValues(). - SetForecasterValues(). - SetNaiveValue(). - SetOneOutInfererValues(). - SetOneOutForecasterValues(). - SetOneInValues() + return oneInInferences, err } // Calculates all network inferences in the set I_i given historical state (e.g. regrets) // and data from workers (e.g. inferences, forecast-implied inferences). // Could improve this with Builder pattern, as for other instances of generated ValueBundles. -func (b *NetworkInferenceBuilder) Build() *emissions.ValueBundle { +func CalcNetworkInferences( + logger log.Logger, + palette SynthPalette, +) (inferenceBundle *emissions.ValueBundle, weights RegretInformedWeights, err error) { + weights, combinedInference, err := GetCombinedInference(logger, palette) + if err != nil { + return &emissions.ValueBundle{}, RegretInformedWeights{}, errorsmod.Wrap(err, "Error calculating combined inference") + } + inferences := GetInferences(palette) + forecastImpliedInferences := GetForecastImpliedInferences(logger, palette) + naiveInference, err := GetNaiveInference(logger, palette) + if err != nil { + return &emissions.ValueBundle{}, RegretInformedWeights{}, errorsmod.Wrap(err, "Error calculating naive inference") + } + oneOutInfererInferences, err := GetOneOutInfererInferences(logger, palette) + if err != nil { + return &emissions.ValueBundle{}, RegretInformedWeights{}, errorsmod.Wrap(err, "Error calculating one-out inferer inferences") + } + oneOutForecasterInferences, err := GetOneOutForecasterInferences(logger, palette) + if err != nil { + return &emissions.ValueBundle{}, RegretInformedWeights{}, errorsmod.Wrap(err, "Error calculating one-out forecaster inferences") + } + oneInForecasterInferences, err := GetOneInForecasterInferences(logger, palette) + if err != nil { + return &emissions.ValueBundle{}, RegretInformedWeights{}, errorsmod.Wrap(err, "Error calculating one-in inferences") + } + // Build value bundle to return all the calculated inferences return &emissions.ValueBundle{ - TopicId: b.palette.TopicId, - CombinedValue: b.combinedInference, - InfererValues: b.inferences, - ForecasterValues: b.forecastImpliedInferences, - NaiveValue: b.naiveInference, - OneOutInfererValues: b.oneOutInfererInferences, - OneOutForecasterValues: b.oneOutForecasterInferences, - OneInForecasterValues: b.oneInInferences, - } + TopicId: palette.TopicId, + CombinedValue: combinedInference, + InfererValues: inferences, + ForecasterValues: forecastImpliedInferences, + NaiveValue: naiveInference, + OneOutInfererValues: oneOutInfererInferences, + OneOutForecasterValues: oneOutForecasterInferences, + OneInForecasterValues: oneInForecasterInferences, + }, weights, err } diff --git a/x/emissions/keeper/inference_synthesis/network_inference_builder_test.go b/x/emissions/keeper/inference_synthesis/network_inference_builder_test.go index cfeef43d1..7e6bb4476 100644 --- a/x/emissions/keeper/inference_synthesis/network_inference_builder_test.go +++ b/x/emissions/keeper/inference_synthesis/network_inference_builder_test.go @@ -149,7 +149,7 @@ func TestModuleTestSuite(t *testing.T) { } func (s *InferenceSynthesisTestSuite) getEpochValueBundleByEpoch(epochNumber int) ( - *inferencesynthesis.NetworkInferenceBuilder, + inferencesynthesis.SynthPalette, map[int]func(header string) alloraMath.Dec, ) { k := s.emissionsKeeper @@ -393,7 +393,8 @@ func (s *InferenceSynthesisTestSuite) getEpochValueBundleByEpoch(epochNumber int }) } - networkInferenceBuilder, err := inferencesynthesis.NewNetworkInferenceBuilderFromSynthRequest( + paletteFactory := inferencesynthesis.SynthPaletteFactory{} + synthPalette, err := paletteFactory.BuildPaletteFromRequest( inferencesynthesis.SynthRequest{ Ctx: ctx, K: k, @@ -409,14 +410,14 @@ func (s *InferenceSynthesisTestSuite) getEpochValueBundleByEpoch(epochNumber int ) require.NoError(s.T(), err) - return networkInferenceBuilder, epochGetters + return synthPalette, epochGetters } func (s *InferenceSynthesisTestSuite) testCorrectCombinedInitialValueForEpoch(epoch int) { - networkInferenceBuilder, epochGet := s.getEpochValueBundleByEpoch(epoch) - valueBundle := networkInferenceBuilder.SetCombinedValue().Build() - s.Require().NotNil(valueBundle.CombinedValue) - alloratestutil.InEpsilon5(s.T(), valueBundle.CombinedValue, epochGet[epoch]("network_inference").String()) + synthPalette, epochGet := s.getEpochValueBundleByEpoch(epoch) + _, combinedValue, err := inferencesynthesis.GetCombinedInference(s.ctx.Logger(), synthPalette) + s.Require().NoError(err) + alloratestutil.InEpsilon5(s.T(), combinedValue, epochGet[epoch]("network_inference").String()) } func (s *InferenceSynthesisTestSuite) TestCorrectCombinedValueEpoch2() { @@ -432,10 +433,10 @@ func (s *InferenceSynthesisTestSuite) TestCorrectCombinedValueEpoch4() { } func (s *InferenceSynthesisTestSuite) testCorrectNaiveValueForEpoch(epoch int) { - networkInferenceBuilder, epochGet := s.getEpochValueBundleByEpoch(epoch) - valueBundle := networkInferenceBuilder.SetNaiveValue().Build() - s.Require().NotNil(valueBundle.NaiveValue) - alloratestutil.InEpsilon5(s.T(), valueBundle.NaiveValue, epochGet[epoch]("network_naive_inference").String()) + synthPalette, epochGet := s.getEpochValueBundleByEpoch(epoch) + naiveValue, err := inferencesynthesis.GetNaiveInference(s.ctx.Logger(), synthPalette) + s.Require().NoError(err) + alloratestutil.InEpsilon5(s.T(), naiveValue, epochGet[epoch]("network_naive_inference").String()) } func (s *InferenceSynthesisTestSuite) TestCorrectNaiveValueEpoch2() { @@ -447,7 +448,7 @@ func (s *InferenceSynthesisTestSuite) TestCorrectNaiveValueEpoch3() { } func (s *InferenceSynthesisTestSuite) testCorrectOneOutInfererValuesForEpoch(epoch int) { - networkInferenceBuilder, epochGet := s.getEpochValueBundleByEpoch(epoch) + synthPalette, epochGet := s.getEpochValueBundleByEpoch(epoch) expectedValues := map[string]alloraMath.Dec{ "worker0": epochGet[epoch]("network_inference_oneout_0"), @@ -457,11 +458,12 @@ func (s *InferenceSynthesisTestSuite) testCorrectOneOutInfererValuesForEpoch(epo "worker4": epochGet[epoch]("network_inference_oneout_4"), } - valueBundle := networkInferenceBuilder.SetOneOutInfererValues().Build() + oneOutInfererValues, err := inferencesynthesis.GetOneOutInfererInferences(s.ctx.Logger(), synthPalette) + s.Require().NoError(err) for worker, expectedValue := range expectedValues { found := false - for _, workerAttributedValue := range valueBundle.OneOutInfererValues { + for _, workerAttributedValue := range oneOutInfererValues { if workerAttributedValue.Worker == worker { found = true alloratestutil.InEpsilon5(s.T(), expectedValue, workerAttributedValue.Value.String()) @@ -480,8 +482,9 @@ func (s *InferenceSynthesisTestSuite) TestCorrectOneOutInfererValuesEpoch3() { } func (s *InferenceSynthesisTestSuite) testCorrectOneOutForecasterValuesForEpoch(epoch int) { - networkInferenceBuilder, epochGet := s.getEpochValueBundleByEpoch(epoch) - valueBundle := networkInferenceBuilder.SetOneOutForecasterValues().Build() + synthPalette, epochGet := s.getEpochValueBundleByEpoch(epoch) + oneOutForecasterValues, err := inferencesynthesis.GetOneOutForecasterInferences(s.ctx.Logger(), synthPalette) + s.Require().NoError(err) expectedValues := map[string]alloraMath.Dec{ "forecaster0": epochGet[epoch]("network_inference_oneout_5"), @@ -491,7 +494,7 @@ func (s *InferenceSynthesisTestSuite) testCorrectOneOutForecasterValuesForEpoch( for worker, expectedValue := range expectedValues { found := false - for _, workerAttributedValue := range valueBundle.OneOutForecasterValues { + for _, workerAttributedValue := range oneOutForecasterValues { if workerAttributedValue.Worker == worker { found = true alloratestutil.InEpsilon5(s.T(), expectedValue, workerAttributedValue.Value.String()) @@ -514,8 +517,9 @@ func (s *InferenceSynthesisTestSuite) TestCorrectOneOutForecasterValuesEpoch4() } func (s *InferenceSynthesisTestSuite) testCorrectOneInForecasterValuesForEpoch(epoch int) { - networkInferenceBuilder, epochGet := s.getEpochValueBundleByEpoch(epoch) - valueBundle := networkInferenceBuilder.SetOneInValues().Build() + synthPalette, epochGet := s.getEpochValueBundleByEpoch(epoch) + oneInForecasterValues, err := inferencesynthesis.GetOneInForecasterInferences(s.ctx.Logger(), synthPalette) + s.Require().NoError(err) expectedValues := map[string]alloraMath.Dec{ "forecaster0": epochGet[epoch]("network_naive_inference_onein_0"), @@ -525,7 +529,7 @@ func (s *InferenceSynthesisTestSuite) testCorrectOneInForecasterValuesForEpoch(e for worker, expectedValue := range expectedValues { found := false - for _, workerAttributedValue := range valueBundle.OneInForecasterValues { + for _, workerAttributedValue := range oneInForecasterValues { if workerAttributedValue.Worker == worker { found = true alloratestutil.InEpsilon5(s.T(), expectedValue, workerAttributedValue.Value.String()) @@ -589,7 +593,8 @@ func (s *InferenceSynthesisTestSuite) TestBuildNetworkInferencesIncompleteData() } // Call the function without setting regrets - networkInferenceBuilder, err := inferencesynthesis.NewNetworkInferenceBuilderFromSynthRequest( + paletteFactory := inferencesynthesis.SynthPaletteFactory{} + synthPalette, err := paletteFactory.BuildPaletteFromRequest( inferencesynthesis.SynthRequest{ Ctx: ctx, K: k, @@ -604,7 +609,8 @@ func (s *InferenceSynthesisTestSuite) TestBuildNetworkInferencesIncompleteData() }, ) s.Require().NoError(err) - valueBundle := networkInferenceBuilder.CalcAndSetNetworkInferences().Build() + valueBundle, _, err := inferencesynthesis.CalcNetworkInferences(s.ctx.Logger(), synthPalette) + s.Require().Error(err) s.Require().NotNil(valueBundle) s.Require().NotNil(valueBundle.CombinedValue) @@ -675,7 +681,8 @@ func (s *InferenceSynthesisTestSuite) TestCalcNetworkInferencesTwoWorkerTwoForec err = k.SetOneInForecasterNetworkRegret(ctx, topicId, worker4, worker2, emissionstypes.TimestampedValue{Value: alloraMath.MustNewDecFromString("0.4")}) s.Require().NoError(err) - networkInferenceBuilder, err := inferencesynthesis.NewNetworkInferenceBuilderFromSynthRequest( + paletteFactory := inferencesynthesis.SynthPaletteFactory{} + synthPalette, err := paletteFactory.BuildPaletteFromRequest( inferencesynthesis.SynthRequest{ Ctx: ctx, K: k, @@ -690,7 +697,8 @@ func (s *InferenceSynthesisTestSuite) TestCalcNetworkInferencesTwoWorkerTwoForec }, ) s.Require().NoError(err) - valueBundle := networkInferenceBuilder.CalcAndSetNetworkInferences().Build() + valueBundle, _, err := inferencesynthesis.CalcNetworkInferences(s.ctx.Logger(), synthPalette) + s.Require().NoError(err) // Check the results s.Require().NotNil(valueBundle) @@ -791,7 +799,8 @@ func (s *InferenceSynthesisTestSuite) TestCalcNetworkInferencesThreeWorkerThreeF err = k.SetOneInForecasterNetworkRegret(ctx, topicId, forecaster3, worker3, emissionstypes.TimestampedValue{Value: alloraMath.MustNewDecFromString("0.006")}) s.Require().NoError(err) - networkInferenceBuilder, err := inferencesynthesis.NewNetworkInferenceBuilderFromSynthRequest( + paletteFactory := inferencesynthesis.SynthPaletteFactory{} + synthPalette, err := paletteFactory.BuildPaletteFromRequest( inferencesynthesis.SynthRequest{ Ctx: ctx, K: k, @@ -806,7 +815,8 @@ func (s *InferenceSynthesisTestSuite) TestCalcNetworkInferencesThreeWorkerThreeF }, ) s.Require().NoError(err) - valueBundle := networkInferenceBuilder.CalcAndSetNetworkInferences().Build() + valueBundle, _, err := inferencesynthesis.CalcNetworkInferences(s.ctx.Logger(), synthPalette) + s.Require().NoError(err) // Check the results s.Require().NotNil(valueBundle) @@ -871,7 +881,8 @@ func (s *InferenceSynthesisTestSuite) TestCalc0neInInferencesTwoForecastersOldTw err = k.SetOneInForecasterNetworkRegret(ctx, topicId, worker2, worker1, emissionstypes.TimestampedValue{Value: alloraMath.MustNewDecFromString("0.008")}) s.Require().NoError(err) - networkInferenceBuilder, err := inferencesynthesis.NewNetworkInferenceBuilderFromSynthRequest( + paletteFactory := inferencesynthesis.SynthPaletteFactory{} + synthPalette, err := paletteFactory.BuildPaletteFromRequest( inferencesynthesis.SynthRequest{ Ctx: ctx, K: k, @@ -886,7 +897,8 @@ func (s *InferenceSynthesisTestSuite) TestCalc0neInInferencesTwoForecastersOldTw }, ) s.Require().NoError(err) - valueBundle := networkInferenceBuilder.SetOneInValues().Build() + valueBundle, _, err := inferencesynthesis.CalcNetworkInferences(s.ctx.Logger(), synthPalette) + s.Require().NoError(err) // Check the results s.Require().NotNil(valueBundle) diff --git a/x/emissions/keeper/inference_synthesis/network_inferences.go b/x/emissions/keeper/inference_synthesis/network_inferences.go index cccc6c9cd..99b424dad 100644 --- a/x/emissions/keeper/inference_synthesis/network_inferences.go +++ b/x/emissions/keeper/inference_synthesis/network_inferences.go @@ -115,7 +115,8 @@ func GetNetworkInferences( } else { Logger(ctx).Debug(fmt.Sprintf("Creating network inferences for topic %v with %v inferences and %v forecasts", topicId, len(inferences.Inferences), len(forecasts.Forecasts))) - networkInferenceBuilder, err := NewNetworkInferenceBuilderFromSynthRequest( + paletteFactory := SynthPaletteFactory{} + synthPalette, err := paletteFactory.BuildPaletteFromRequest( SynthRequest{ Ctx: ctx, K: k, @@ -133,10 +134,11 @@ func GetNetworkInferences( Logger(ctx).Warn(fmt.Sprintf("Error constructing network inferences builder topic: %s", err.Error())) return networkInferences, nil, infererWeights, forecasterWeights, inferenceBlockHeight, lossBlockHeight, err } - networkInferences = networkInferenceBuilder.CalcAndSetNetworkInferences().Build() - forecastImpliedInferencesByWorker = networkInferenceBuilder.palette.ForecastImpliedInferenceByWorker - infererWeights = networkInferenceBuilder.weights.inferers - forecasterWeights = networkInferenceBuilder.weights.forecasters + var weights RegretInformedWeights + networkInferences, weights, err = CalcNetworkInferences(Logger(ctx), synthPalette) + forecastImpliedInferencesByWorker = synthPalette.ForecastImpliedInferenceByWorker + infererWeights = weights.inferers + forecasterWeights = weights.forecasters } } else { // Single valid inference case