-
Notifications
You must be signed in to change notification settings - Fork 5
Dataset Loading
Wanqian Yang edited this page Sep 1, 2019
·
2 revisions
A dataset is loaded into the BNN object by calling self.load(**datafunc())
. The only argument for this method is a function datafunc()
, and the use of double asterisks **
is Python syntax to pass the output of datafunc()
into the method as keyworded arguments (similar to `**kwargs).
A valid loading function datafunc()
should accept no arguments, and return a dictionary with the following key/value pairs:
-
"dataset_name"
: String. Identifier for the dataset. -
"X_train"
: torch.Tensor. Input features of train set. A(Ntr, Dx)
tensor whereNtr
is the size of the train set, andDx
is the input dimensionality. -
"Y_train"
: torch.Tensor. Output/target of train set. For regression, this is a(Ntr, Dy)
tensor whereNtr
is the size of the train set, andDy
is the output dimensionality. For classification, this is a(Ntr, )
tensor where each value in the array is the target class for that corresponding data point. -
[Optional]
"X_test"
: torch.Tensor. Input features of test set. A(Nte, Dx)
tensor whereNte
is the size of the test set, andDx
is the input dimensionality. Define only if you want to run evaluation metrics over the test set. -
[Optional]
"Y_test"
: torch.Tensor. Output/target of test set. For regression, this is a(Nte, Dy)
tensor whereNte
is the size of the test set, andDy
is the output dimensionality. For classification, this is a(Nte, )
tensor where each value in the array is the target class for that corresponding data point. Define only if you want to run evaluation metrics over the test set.
Note that for classification tasks, classes should be 0-indexed.