Skip to content

Commit

Permalink
Added Train Model Script
Browse files Browse the repository at this point in the history
  • Loading branch information
silverlightning926 committed Jun 12, 2024
1 parent 5efcbb2 commit 8428729
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/train_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import tensorflow as tf
from keras.api.models import Sequential


def fit_model(model, train_data: tf.data.Dataset, epochs=10, batch_size=32, verbose=1):
model.fit(train_data, epochs=epochs,
batch_size=batch_size, verbose=verbose)

return model


def evaluate_model(model, test_data: tf.data.Dataset):
return model.evaluate(test_data)


def save_model(model: Sequential, path: str):
model.save(path)

0 comments on commit 8428729

Please sign in to comment.