ForeTiS.model._torch_model

Module Contents

Classes

TorchModel

Parent class based on BaseModel for all PyTorch models to share functionalities.

class ForeTiS.model._torch_model.TorchModel(optuna_trial, datasets, featureset_name, optimize_featureset, pca_transform=None, current_model_name=None, batch_size=None, n_epochs=None, target_column=None)

Bases: ForeTiS.model._base_model.BaseModel, abc.ABC

Parent class based on BaseModel for all PyTorch models to share functionalities. See BaseModel for more information.

Attributes

Inherited attributes

See BaseModel.

Additional attributes

  • batch_size (int): Batch size for batch-based training

  • n_epochs (int): Number of epochs for optimization

  • num_monte_carlo (int): Number of monte carlo iteration for the bayesian neural networks

  • optimizer (torch.optim.optimizer.Optimizer): optimizer for model fitting

  • loss_fn: loss function for model fitting

  • early_stopping_patience (int): epochs without improvement before early stopping

  • early_stopping_point (int): epoch at which early stopping occured

  • device (torch.device): device to use, e.g. GPU

  • X_scaler (sklearn.preprocessing.StandardScaler): Standard scaler for the X data

Parameters:
  • optuna_trial (optuna.trial.Trial) – Trial of optuna for optimization

  • datasets (list) – all datasets that are available

  • current_model_name (str) – name of the current model according to naming of .py file in package model

  • batch_size (int) – batch size for neural network models

  • n_epochs (int) – number of epochs for neural network models

  • target_column (str) – the target column for the prediction

  • featureset_name (str) –

  • optimize_featureset (bool) –

  • pca_transform (bool) –

train_val_loop(train, val)

Implementation of a train and validation loop for PyTorch models. See BaseModel for more information

Parameters:
  • train (pandas.DataFrame) –

  • val (pandas.DataFrame) –

Return type:

numpy.array

train_val_loader(train, val)

Get the Dataloader with training and validation data

Poram train:

training data

Parameters:
  • val (pandas.DataFrame) – validation data

  • train (pandas.DataFrame) –

Returns:

train_loader, val_loader, val

train_one_epoch(train_loader, scaler)

Train one epoch

Parameters:

train_loader (torch.utils.data.DataLoader) – DataLoader with training data

validate_one_epoch(val_loader)

Validate one epoch

Parameters:

val_loader (torch.utils.data.DataLoader) – DataLoader with validation data

Returns:

loss based on loss-criterion

Return type:

float

retrain(retrain)

Implementation of the retraining for PyTorch models. See BaseModel for more information

Parameters:

retrain (pandas.DataFrame) –

update(update, period)

Implementation of the retraining for PyTorch models. See BaseModel for more information

Parameters:
  • update (pandas.DataFrame) –

  • period (int) –

predict(X_in)

Implementation of a prediction based on input features for PyTorch models. See BaseModel for more information

Parameters:

X_in (pandas.DataFrame) –

Return type:

numpy.array

get_loss(outputs, targets)

Calculate the loss based on the outputs and targets

Parameters:
  • outputs (torch.Tensor) – outputs of the model

  • targets (torch.Tensor) – targets of the dataset

Returns:

loss

Return type:

torch.Tensor

get_dataloader(X, y=None, only_transform=None, predict=False, shuffle=False)

Get a Pytorch DataLoader using the specified data and batch size

Parameters:
  • X (numpy.array) – feature matrix to use

  • y (numpy.array) – optional target vector to use

  • only_transform (bool) – whether to only transform or not

  • predict (bool) – weather to use the data for predictions or not

  • shuffle (bool) – shuffle parameter for DataLoader

Returns:

Pytorch DataLoader

Return type:

torch.utils.data.DataLoader

static common_hyperparams()

Add hyperparameters that are common for PyTorch models. Do not need to be included in optimization for every child model. Also See BaseModel for more information

static get_torch_object_for_string(string_to_get)

Get the torch object for a specific string, e.g. when suggesting to optuna as hyperparameter

Parameters:

string_to_get (str) – string to retrieve the torch object

Returns:

torch object