theanets.main.Experiment

class theanets.main.Experiment(network, *args, **kwargs)

This class encapsulates tasks for training and evaluating a network.

Parameters:

model : Network or str

A specification for obtaining a model. If a string is given, it is assumed to name a file containing a pickled model; this file will be loaded and used. If a network instance is provided, it will be used as the model. If a callable (such as a subclass) is provided, it will be invoked using the provided keyword arguments to create a network instance.

__init__(network, *args, **kwargs)

Methods

__init__(network, *args, **kwargs)
create_dataset(data, **kwargs) Create a dataset for this experiment.
create_trainer(train[, algo]) Create a trainer.
itertrain(train[, valid, algorithm]) Train our network, one batch at a time.
load(path) Load a saved network from a pickle file on disk.
save(path) Save the current network to a pickle file on disk.
train(*args, **kwargs) Train the network until the trainer converges.
create_dataset(data, **kwargs)

Create a dataset for this experiment.

Parameters:

data : sequence of ndarray or callable

The values that you provide for data will be encapsulated inside a Dataset instance; see that class for documentation on the types of things it needs. In particular, you can currently pass in either a list/array/etc. of data, or a callable that generates data dynamically.

Returns:

data : Dataset

A dataset capable of providing mini-batches of data to a training algorithm.

create_trainer(train, algo='rmsprop')

Create a trainer.

Additional keyword arguments are passed directly to the trainer.

Parameters:

train : str

A string describing a trainer to use.

algo : str

A string describing an optimization algorithm.

Returns:

trainer : Trainer

A trainer instance to alter the parameters of our network.

itertrain(train, valid=None, algorithm='rmsprop', **kwargs)

Train our network, one batch at a time.

This method yields a series of (train, valid) monitor pairs. The train value is a dictionary mapping names to monitor values evaluated on the training dataset. The valid value is also a dictionary mapping names to values, but these values are evaluated on the validation dataset.

Because validation might not occur every training iteration, the validation monitors might be repeated for multiple training iterations. It is probably most helpful to think of the validation monitors as being the “most recent” values that have been computed.

After training completes, the network attribute of this class will contain the trained network parameters.

Parameters:

train : sequence of ndarray or downhill.Dataset

A dataset to use when training the network. If this is a downhill.Dataset instance, it will be used directly as the training datset. If it is another type, like a numpy array, it will be converted to a downhill.Dataset and then used as the training set.

valid : sequence of ndarray or downhill.Dataset, optional

If this is provided, it will be used as a validation dataset. If not provided, the training set will be used for validation. (This is not recommended!)

algorithm : str or list of str, optional

One or more optimization algorithms to use for training our network. If not provided, RMSProp will be used.

load(path)

Load a saved network from a pickle file on disk.

This method sets the network attribute of the experiment to the loaded network model.

Parameters:

filename : str

Load the keyword arguments and parameters of a network from a pickle file at the named path. If this name ends in ”.gz” then the input will automatically be gunzipped; otherwise the input will be treated as a “raw” pickle.

Returns:

network : Network

A newly-constructed network, with topology and parameters loaded from the given pickle file.

save(path)

Save the current network to a pickle file on disk.

Parameters:

path : str

Location of the file to save the network.

train(*args, **kwargs)

Train the network until the trainer converges.

All arguments are passed to itertrain().

Returns:

training : dict

A dictionary of monitor values computed using the training dataset, at the conclusion of training. This dictionary will at least contain a ‘loss’ key that indicates the value of the loss function. Other keys may be available depending on the trainer being used.

validation : dict

A dictionary of monitor values computed using the validation dataset, at the conclusion of training.