Skip to content

Documentation for project/models/mnist.py¤

Source Code Documentation

The source codedocumentation is generated from Python docstrings using MkDocs and mkdocstrings.

Classes:

Name Description
MNISTLitModule

Example of a LightningModule for MNIST classification.

MNISTLitModule ¤

MNISTLitModule(optimizer: Optimizer, scheduler: _LRScheduler, input_size: int = 784, lin1_size: int = 256, lin2_size: int = 256, lin3_size: int = 256, output_size: int = 10, compile: bool = True)

Bases: LightningModule

Example of a LightningModule for MNIST classification.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer to use.

required
scheduler _LRScheduler

The learning rate scheduler to use.

required
input_size int

The size of the input layer. Defaults to 784.

784
lin1_size int

The size of the first linear layer. Defaults to 256.

256
lin2_size int

The size of the second linear layer. Defaults to 256.

256
lin3_size int

The size of the third linear layer. Defaults to 256.

256
output_size int

The size of the output layer. Defaults to 10.

10
compile bool

Whether to compile the module or not. Defaults to True.

True

Returns:

Type Description
None

None

Methods:

Name Description
configure_optimizers

Configure optimizers and learning-rate schedulers.

forward

Performs a forward pass through the model self.net.

model_step

Performs a single model step on a batch of data.

on_test_epoch_end

Lightning hook that is called when a test epoch ends.

on_train_epoch_end

Lightning hook that is called when a training epoch ends.

on_train_start

Lightning hook that is called when training begins.

on_validation_epoch_end

Lightning hook that is called when a validation epoch ends.

setup

Called at the beginning of fit (train + validate), validate, test, or predict.

test_step

Performs a single test step on a batch of data from the test set.

training_step

Perform a single training step on a batch of data from the training set.

validation_step

Perform a single validation step on a batch of data from the validation set.

Source code in src/project/models/mnist.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(
    self,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    input_size: int = 784,
    lin1_size: int = 256,
    lin2_size: int = 256,
    lin3_size: int = 256,
    output_size: int = 10,
    compile: bool = True,
) -> None:
    """
    Initializes the MNISTLitModule.

    Args:
        optimizer (torch.optim.Optimizer): The optimizer to use.
        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler to use.
        input_size (int, optional): The size of the input layer. Defaults to 784.
        lin1_size (int, optional): The size of the first linear layer. Defaults to 256.
        lin2_size (int, optional): The size of the second linear layer. Defaults to 256.
        lin3_size (int, optional): The size of the third linear layer. Defaults to 256.
        output_size (int, optional): The size of the output layer. Defaults to 10.
        compile (bool, optional): Whether to compile the module or not. Defaults to True.

    Returns:
        None
    """
    super().__init__()

    # this line allows to access init params with 'self.hparams' attribute
    # also ensures init params will be stored in ckpt
    self.save_hyperparameters(logger=False, ignore=["net"])

    self.net = nn.Sequential(
        nn.Linear(input_size, lin1_size),
        nn.BatchNorm1d(lin1_size),
        nn.ReLU(),
        nn.Linear(lin1_size, lin2_size),
        nn.BatchNorm1d(lin2_size),
        nn.ReLU(),
        nn.Linear(lin2_size, lin3_size),
        nn.BatchNorm1d(lin3_size),
        nn.ReLU(),
        nn.Linear(lin3_size, output_size),
    )

    # loss function
    self.criterion = torch.nn.CrossEntropyLoss()

    # metric objects for calculating and averaging accuracy across batches
    self.train_acc = Accuracy(task="multiclass", num_classes=10)
    self.val_acc = Accuracy(task="multiclass", num_classes=10)
    self.test_acc = Accuracy(task="multiclass", num_classes=10)

    # for averaging loss across batches
    self.train_loss = MeanMetric()
    self.val_loss = MeanMetric()
    self.test_loss = MeanMetric()

    # for tracking best so far validation accuracy
    self.val_acc_best = MaxMetric()

configure_optimizers ¤

configure_optimizers() -> Dict[str, Any]

Configure optimizers and learning-rate schedulers.

Returns:

Type Description
Dict[str, Any]

A dict containing the configured optimizers and learning-rate schedulers to be used for training.

Source code in src/project/models/mnist.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def configure_optimizers(self) -> Dict[str, Any]:
    """Configure optimizers and learning-rate schedulers.

    Returns:
        A dict containing the configured optimizers and learning-rate schedulers to be used for training.
    """
    optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
    if self.hparams.scheduler is not None:
        scheduler = self.hparams.scheduler(optimizer=optimizer)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val/loss",
                "interval": "epoch",
                "frequency": 1,
            },
        }
    return {"optimizer": optimizer}

forward ¤

forward(x: Tensor) -> Tensor

Performs a forward pass through the model self.net.

Parameters:

Name Type Description Default
x Tensor

A tensor of images.

required

Returns:

Type Description
Tensor

A tensor of logits.

Source code in src/project/models/mnist.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Performs a forward pass through the model `self.net`.

    Args:
        x: A tensor of images.

    Returns:
        A tensor of logits.
    """
    batch_size, _, _, _ = x.size()

    x = x.view(batch_size, -1)

    return self.net(x)

model_step ¤

model_step(batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]

Performs a single model step on a batch of data.

Parameters:

Name Type Description Default
batch Tuple[Tensor, Tensor]

A batch of data (a tuple) containing the input tensor of images and target labels.

required

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

A tuple containing (in order): - A tensor of losses. - A tensor of predictions. - A tensor of target labels.

Source code in src/project/models/mnist.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def model_step(
    self, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Performs a single model step on a batch of data.

    Args:
        batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.

    Returns:
        A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
    """
    x, y = batch
    logits = self.forward(x)
    loss = self.criterion(logits, y)
    preds = torch.argmax(logits, dim=1)
    return loss, preds, y

on_test_epoch_end ¤

on_test_epoch_end() -> None

Lightning hook that is called when a test epoch ends.

Source code in src/project/models/mnist.py
198
199
200
def on_test_epoch_end(self) -> None:
    """Lightning hook that is called when a test epoch ends."""
    pass

on_train_epoch_end ¤

on_train_epoch_end() -> None

Lightning hook that is called when a training epoch ends.

Source code in src/project/models/mnist.py
147
148
149
def on_train_epoch_end(self) -> None:
    "Lightning hook that is called when a training epoch ends."
    pass

on_train_start ¤

on_train_start() -> None

Lightning hook that is called when training begins.

Source code in src/project/models/mnist.py
90
91
92
93
94
95
96
def on_train_start(self) -> None:
    """Lightning hook that is called when training begins."""
    # by default lightning executes validation step sanity checks before training starts,
    # so it's worth to make sure validation metrics don't store results from these checks
    self.val_loss.reset()
    self.val_acc.reset()
    self.val_acc_best.reset()

on_validation_epoch_end ¤

on_validation_epoch_end() -> None

Lightning hook that is called when a validation epoch ends.

Source code in src/project/models/mnist.py
168
169
170
171
172
173
174
175
176
def on_validation_epoch_end(self) -> None:
    "Lightning hook that is called when a validation epoch ends."
    acc = self.val_acc.compute()  # get current val acc
    self.val_acc_best(acc)  # update best so far val acc
    # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
    # otherwise metric would be reset by lightning after each epoch
    self.log(
        "val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True
    )

setup ¤

setup(stage: str) -> None

Called at the beginning of fit (train + validate), validate, test, or predict.

This is a good place to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

Name Type Description Default
stage str

One of "fit", "validate", "test", or "predict".

required
Source code in src/project/models/mnist.py
202
203
204
205
206
207
208
209
210
211
212
def setup(self, stage: str) -> None:
    """Called at the beginning of fit (train + validate), validate, test, or predict.

    This is a good place to build models dynamically or adjust something about them. This
    hook is called on every process when using DDP.

    Args:
        stage: One of "fit", "validate", "test", or "predict".
    """
    if self.hparams.compile and stage == "fit":
        self.net = torch.compile(self.net)

test_step ¤

test_step(batch: Tuple[Tensor, Tensor], batch_idx: int) -> None

Performs a single test step on a batch of data from the test set.

Parameters:

Name Type Description Default
batch Tuple[Tensor, Tensor]

A batch of data containing the input tensor of images and target labels.

required
batch_idx int

The index of the current batch.

required
Source code in src/project/models/mnist.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def test_step(
    self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> None:
    """Performs a single test step on a batch of data from the test set.

    Args:
        batch (Tuple[torch.Tensor, torch.Tensor]): A batch of data containing the input tensor
            of images and target labels.
        batch_idx (int): The index of the current batch.
    """
    loss, preds, targets = self.model_step(batch)

    # update and log metrics
    self.test_loss(loss)
    self.test_acc(preds, targets)
    self.log(
        "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True
    )
    self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)

training_step ¤

training_step(batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor

Perform a single training step on a batch of data from the training set.

Parameters:

Name Type Description Default
batch Tuple[Tensor, Tensor]

A batch of data containing the input tensor of images and target labels.

required
batch_idx int

The index of the current batch.

required

Returns:

Type Description
Tensor

torch.Tensor: A tensor of losses between model predictions and targets.

Source code in src/project/models/mnist.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def training_step(
    self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
    """Perform a single training step on a batch of data from the training set.

    Args:
        batch (Tuple[torch.Tensor, torch.Tensor]): A batch of data containing the input tensor
            of images and target labels.
        batch_idx (int): The index of the current batch.

    Returns:
        torch.Tensor: A tensor of losses between model predictions and targets.
    """
    loss, preds, targets = self.model_step(batch)

    # update and log metrics
    self.train_loss(loss)
    self.train_acc(preds, targets)
    self.log(
        "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True
    )
    self.log(
        "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True
    )

    # return loss or backpropagation will fail
    return loss

validation_step ¤

validation_step(batch: Tuple[Tensor, Tensor], batch_idx: int) -> None

Perform a single validation step on a batch of data from the validation set.

Parameters:

Name Type Description Default
batch Tuple[Tensor, Tensor]

A tuple containing the input tensor of images and target labels.

required
batch_idx int

The index of the current batch.

required
Source code in src/project/models/mnist.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def validation_step(
    self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> None:
    """Perform a single validation step on a batch of data from the validation set.

    Args:
        batch: A tuple containing the input tensor of images and target labels.
        batch_idx: The index of the current batch.
    """
    loss, preds, targets = self.model_step(batch)

    # update and log metrics
    self.val_loss(loss)
    self.val_acc(preds, targets)
    self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)