Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving and loading trained classifiers #43

Closed
anayurg opened this issue Aug 29, 2023 · 3 comments
Closed

Saving and loading trained classifiers #43

anayurg opened this issue Aug 29, 2023 · 3 comments

Comments

@anayurg
Copy link

anayurg commented Aug 29, 2023

Hey there,

I'm playing around with the Tsetlin classifier and I'm having some trouble figuring out how to save a trained classifier so that it can be loaded later. Do you have a go-to method that you can recommend? I couldn't do it with pickle and joblib. So I tried pickling the clause banks and weight banks separately but that seems to work only if I pickle them before the classifier is used for any prediction which in my case is not an option.

I'm attaching a minimal example just in case but I imagine there must be another way to do it that I haven't figured out.

saving.py

import pickle
from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.data import TMUDatasetSource

data = TMUDatasetSource().get_dataset(
        "XOR_biased",
        cache=True,
        cache_max_age=1,
        features=["X", "Y"],
        labels=["xor"],
        shuffle=True,
        train_ratio=1000,
        test_ratio=1000,
        return_type=dict
    )

tm = TMClassifier(20,10,2)
tm.fit(data["x_train"], data["y_train"])

# This method works only if these two lines are commented out:
# y_pred = tm.predict(data["x_test"])
# print((y_pred == data["y_test"]).mean())

for i in range(len(tm.weight_banks)):
    with open('wb%s.pkl' % i, 'wb') as f:
        pickle.dump(tm.weight_banks[i], f)
for i in range(len(tm.clause_banks)):
    with open('cb%s.pkl' % i, 'wb') as f:
        pickle.dump(tm.clause_banks[i], f)

loading.py

from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.data import TMUDatasetSource
import os
import pickle

data = TMUDatasetSource().get_dataset(
        "XOR_biased",
        cache=True,
        cache_max_age=1,
        features=["X", "Y"],
        labels=["xor"],
        shuffle=True,
        train_ratio=1000,
        test_ratio=1000,
        return_type=dict
    )

tm_new = TMClassifier(20,10,2)
tm_new.clause_banks = []
tm_new.number_of_classes = 2
tm_new.weight_banks = []

files = os.listdir('/home/folder/')
cb_count = sum(1 for file in files if file.startswith('cb'))
wb_count = sum(1 for file in files if file.startswith('wb'))

for i in range(wb_count):
    with open('wb%s.pkl' % i, 'rb') as f:
        wb = pickle.load(f)
        tm_new.weight_banks.append(wb)
for i in range(cb_count):
    with open('cb%s.pkl' % i, 'rb') as f:
        cb = pickle.load(f)
        tm_new.clause_banks.append(cb)

y_pred_new = tm_new.predict(data["x_test"])
print((y_pred_new == data["y_test"]).mean())

The error that I get when I first predict something and then save the clause and weight banks is the following:

Traceback (most recent call last):
  File "loading.py", line 36, in <module>
    y_pred_new = tm_new.predict(data["x_test"])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/envs/tmtest/lib/python3.11/site-packages/tmu/models/classification/vanilla_classifier.py", line 311, in predict
    self.clause_banks[i].calculate_clause_outputs_predict(self.encoded_X_test,
  File "/home/envs/tmtest/lib/python3.11/site-packages/tmu/clause_bank/clause_bank.py", line 168, in calculate_clause_outputs_predict
    self.lcm_p,
    ^^^^^^^^^^
AttributeError: 'ClauseBank' object has no attribute 'lcm_p'. Did you mean: 'lcc_p'?

Thanks for the help!

@satunheim
Copy link

Please also see Issue #46 which is related to this.

@perara
Copy link
Member

perara commented Sep 2, 2023

Dear @anayurg,

May I suggest just loading the whole TM directly?

saving:

import pickle
from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.data.tmu_datasource import TMUDatasetSource

data = TMUDatasetSource().get_dataset(
    "XOR_biased",
    cache=True,
    cache_max_age=1,
    features=["X", "Y"],
    labels=["xor"],
    shuffle=True,
    train_ratio=1000,
    test_ratio=1000,
    return_type=dict
)

tm = TMClassifier(20,10,2)
tm.fit(data["x_train"], data["y_train"])

with open('tm.pkl', 'wb') as f:
    pickle.dump(tm, f)

loading:

from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.data.tmu_datasource import TMUDatasetSource
import os
import pickle

data = TMUDatasetSource().get_dataset(
    "XOR_biased",
    cache=True,
    cache_max_age=1,
    features=["X", "Y"],
    labels=["xor"],
    shuffle=True,
    train_ratio=1000,
    test_ratio=1000,
    return_type=dict
)


with open('tm.pkl', 'rb') as f:
    tm_new = pickle.load(f)


y_pred_new = tm_new.predict(data["x_test"])
print((y_pred_new == data["y_test"]).mean())

@perara
Copy link
Member

perara commented Sep 2, 2023

I've also corrected
AttributeError: 'ClauseBank' object has no attribute 'lcm_p'. Did you mean: 'lcc_p'?

Thanks for finding this :)

Reopen if this does not satisfy your requirements.

@perara perara closed this as completed Sep 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants