Skip to content

Commit

Permalink
Add tests for MarkovChainNeuralNetwork
Browse files Browse the repository at this point in the history
  • Loading branch information
aadya940 committed Feb 14, 2024
1 parent a34f4a7 commit 5a4b2a2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 6 deletions.
6 changes: 3 additions & 3 deletions chainopy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .markov_chain import MarkovChain
from .nn import MarkovChainNeuralNetwork, analysis
from .nn import MarkovChainNeuralNetwork, divergance_analysis

__all__ = [
"MarkovChain",
"MarkovChainNeuralNetwork",
"analysis",
]
"divergance_analysis",
]
6 changes: 3 additions & 3 deletions chainopy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MarkovChainNeuralNetwork(nn.Module):
num_layers : int
Number of layers in the neural network.
Raises:
ValueError: If markov_chain is not of type MarkovChain.
Expand Down Expand Up @@ -243,7 +243,7 @@ def simulate_random_walk(self, start_state, steps):
return markov_walk


def analysis(mc: MarkovChain, nn: MarkovChainNeuralNetwork) -> float:
def divergance_analysis(mc: MarkovChain, nn: MarkovChainNeuralNetwork) -> float:
"""
KL Divergance between `MarkovChain.tpm` and
`MarkovChain().fit(MarkovChainNeuralNetwork.simulate_random_walk).tpm`.
Expand All @@ -269,7 +269,7 @@ def _generate_fit_string():
_observed_seq_list = nn.simulate_random_walk(
random.choice(mc.states), len(mc.states) * 200
)
_estimated_tpm = _learn_matrix(_observed_seq_list, epsilon=_epsilon)
_estimated_tpm = _learn_matrix.learn_matrix_cython(_observed_seq_list, epsilon=_epsilon)
return _estimated_tpm

_est_tpm = _generate_fit_string().flatten()
Expand Down
47 changes: 47 additions & 0 deletions chainopy/test_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import numpy as np
import torch

from .nn import MarkovChainNeuralNetwork, divergance_analysis
from .markov_chain import MarkovChain


@pytest.fixture
def mock_markov_chain():
tpm = np.array([[0.5, 0.5], [0.3, 0.7]])
states = ['Rain', 'No-Rain']
mc = MarkovChain(tpm, states)
return mc


def test_markov_chain_neural_network_init(mock_markov_chain):
with pytest.raises(ValueError):
MarkovChainNeuralNetwork("invalid_type", 2)

with pytest.raises(ValueError):
mc = MarkovChain(None, None)
MarkovChainNeuralNetwork(mc, 2)


def test_markov_chain_neural_network_forward(mock_markov_chain):
mc_nn = MarkovChainNeuralNetwork(mock_markov_chain, 2)
input_data = torch.tensor([[torch.rand(1), 0.1, 0.9]])
output = mc_nn(input_data)
assert output.shape == (1, 2)


def test_markov_chain_neural_network_training(mock_markov_chain):
mc_nn = MarkovChainNeuralNetwork(mock_markov_chain, 2)
mc_nn.train_model(1000, 10, 0.01, verbose=False)
assert mc_nn.optimizer is not None
assert mc_nn.scheduler is not None
assert mc_nn.loss_function is not None
assert mc_nn.input_data is not None
assert mc_nn.output_data is not None


def test_divergance_analysis(mock_markov_chain):
mc_nn = MarkovChainNeuralNetwork(mock_markov_chain, 2)
mc_nn.train_model(1000, 10, 0.01, verbose=False)
kl_divergence = divergance_analysis(mock_markov_chain, mc_nn)
assert isinstance(kl_divergence, float)

0 comments on commit 5a4b2a2

Please sign in to comment.