theanets.recurrent.Classifier

class theanets.recurrent.Classifier(layers, weighted=False, sparse_input=False)

A classifier attempts to match a 1-hot target output.

Unlike a feedforward classifier, where the target labels are provided as a single vector, a recurrent classifier requires a vector of target labels for each time step in the input data. So a recurrent classifier model requires the following inputs for training:

  • x: A three-dimensional array of input data. Each element of axis 0 of x is expected to be one moment in time. Each element of axis 1 of x holds a single sample in a batch of data. Each element of axis 2 of x represents the measurements of a particular input variable across all times and all data items in a batch.
  • labels: A two-dimensional array of integer target labels. Each element of labels is expected to be the class index for a single batch item. Axis 0 of this array represents time, and axis 1 represents data samples in a batch.
__init__(layers, weighted=False, sparse_input=False)

Methods

error(outputs) Build a theano expression for computing the network error.
predict_sequence(seed, steps[, streams, rng]) Draw a sequential sample of classes from this network.

Attributes

DEFAULT_OUTPUT_ACTIVATION
num_params Number of parameters in the entire network model.
params A list of the learnable theano parameters for this network.
error(outputs)

Build a theano expression for computing the network error.

Parameters:

outputs : dict mapping str to theano expression

A dictionary of all outputs generated by the layers in this network.

Returns:

error : theano expression

A theano expression representing the network error.

predict_sequence(seed, steps, streams=1, rng=None)

Draw a sequential sample of classes from this network.

Parameters:

seed : list of int

A list of integer class labels to “seed” the classifier.

steps : int

The number of time steps to sample.

streams : int, optional

Number of parallel streams to sample from the model. Defaults to 1.

rng : numpy.random.RandomState or int, optional

A random number generator, or an integer seed for a random number generator. If not provided, the random number generator will be created with an automatically chosen seed.