Skip to content

Dataset Loading

Wanqian Yang edited this page Sep 1, 2019 · 2 revisions

A dataset is loaded into the BNN object by calling self.load(**datafunc()). The only argument for this method is a function datafunc(), and the use of double asterisks ** is Python syntax to pass the output of datafunc() into the method as keyworded arguments (similar to `**kwargs).

A valid loading function datafunc() should accept no arguments, and return a dictionary with the following key/value pairs:

  • "dataset_name": String. Identifier for the dataset.
  • "X_train": torch.Tensor. Input features of train set. A (Ntr, Dx) tensor where Ntr is the size of the train set, and Dx is the input dimensionality.
  • "Y_train": torch.Tensor. Output/target of train set. For regression, this is a (Ntr, Dy) tensor where Ntr is the size of the train set, and Dy is the output dimensionality. For classification, this is a (Ntr, ) tensor where each value in the array is the target class for that corresponding data point.
  • [Optional] "X_test": torch.Tensor. Input features of test set. A (Nte, Dx) tensor where Nte is the size of the test set, and Dx is the input dimensionality. Define only if you want to run evaluation metrics over the test set.
  • [Optional] "Y_test": torch.Tensor. Output/target of test set. For regression, this is a (Nte, Dy) tensor where Nte is the size of the test set, and Dy is the output dimensionality. For classification, this is a (Nte, ) tensor where each value in the array is the target class for that corresponding data point. Define only if you want to run evaluation metrics over the test set.

Note that for classification tasks, classes should be 0-indexed.

Clone this wiki locally