Model
info
NeuralPy Model Class is mostly stable and can be used for any project. In the future, any chance of breaking changes is very low.
The Model class on NeuralPy is a wrapper class that wraps a PyTorch model and provides some simple methods to train, predict, evaluate, etc. In NeuralPy every model is based on this class and they inherits the Model class.
The Model class can be used for training any PyTorch model.
Supported Arguments
force_cpu=False
: (Boolean) If True, then uses CPU even if CUDA is availabletraining_device=None
: (NeuralPy device class) Device that will be used for training predictions. If you usetraining_device
then it ignores theforce_cpu
parameterrandom_state
: (Integer) Random state for the device
Supported Methods
.compile()
method:
In the NeuralPy Model, the compile method is responsible for attaching a loss function and optimizer with the model and this method needs to be called before training. This method can also be used for setting metrics for the model so that NeuralPy Model can evaluate those during training.
Supported Arguments:
optimizer
: (NeuralPy Optimizer class) Adds an optimizer to the modelloss_function
: (NeuralPy Loss Function class) Adds a loss function to the modelmetrics
: (String[]) Metrics that will be evaluated by the model. Currently only supportsaccuracy
.
Example Code
.fit()
Method
The .fit()
method is used for training the NeuralPy model.
Supported Arguments
train_data
: (Tuple(NumPy Array, NumPy Array | Generator)) Pass the training data as a tuple like(X, y)
whereX
is training data andy
is the labels for the training the model. The function accepts both numpy array and generator functions that return numpy array.validation_data=None
:(Tuple(NumPy Array, NumPy Array) | Generator) Pass the validation data as a tuple like(X, y)
whereX
is test data andy
is the labels for the validating the model. This field is optional. The function accepts both numpy array and generator functions that return numpy array.epochs=10
: (Integer) Number of epochsbatch_size=32
: (Integer) Batch size for training.steps_per_epoch=None
: (Integer) No of steps in each epoch. Used only for generator.validation_steps=None
: (Integer) No of validation steps in each epoch. Used on for generators.callbacks=None
: (Array) Array of callbacks
Example Code
.predict()
Method
The .predict()
method is used for predicting outputs from the model.
Supported Arguments
predict_data
: (NumPy Array | Generator) Data to be predictedpredict_steps=None
: No of steps in the generatorbatch_size=None
: (Integer) Batch size for predicting. If not provided, then the entire data is predicted once.
Example Code
.predict_class()
Method
The .predict_class()
method is used for predicting classes using the trained model. This method works only if accuracy
is passed in the metrics
parameter on the .compile()
method.
Supported Arguments
predict_data
: (NumPy Array | Generator) Data to be predictedpredict_steps=None
: No of steps in the generatorbatch_size=None
: (Integer) Batch size for predicting. If not provided, then the entire data is predicted once.
Example Code
.evaluate()
Method
The .evaluate()
method is used for evaluating models using the test dataset.
Supported Arguments
X
: (NumPy Array | Generator) Data to be predictedy
: (NumPy Array | Generator) Original labels ofX
batch_size=None
: (Integer) Batch size for predicting. If not provided, then the entire data is predicted once.
Example Code
.summary()
Method
The .summary()
method is for getting a summary of the model
Supported Arguments
- None
Example Code
.get_model()
Method
The .get_model()
method is used for getting the PyTorch model from the NeuralPy model. After extracting the model, the model can be treated just like a regular PyTorch model.
Supported Arguments
- None
Example Code
.set_model()
Method
The .set_model()
method is used for converting a PyTorch model to a NeuralPy model. After this conversion, the model can be trained using NeuralPy optimizer and loss_functions.
Supported Arguments
model
: (PyTorch model) A valid class based on the Sequential PyTorch model.