在PyTorch中自定义fit()函数中的操作代码

 更新时间:2024年05月17日 09:33:02   作者:MUKAMO  
当在进行有监督学习时,我们可以使用fit()函数对模型进行训练,通过迭代优化模型的参数,使其能够更好地拟合训练数据,本文给大家介绍了在PyTorch中自定义fit()函数中的操作代码,感兴趣的同学可以跟着小编一起来看看

1、绪论

当在进行有监督学习时,我们可以使用fit()函数对模型进行训练,通过迭代优化模型的参数,使其能够更好地拟合训练数据。

但当我们希望控制每一个小细节时,就可以完全从头开始编写自己的训练循环。此时就需要一个自定义的训练算法,但是如果我们同时又想受益于fit()的便捷功能,如回调、内置分布支持或步骤融合,该怎么办呢?

Keras的一个核心原则是复杂性的渐进披露。我们总是能够逐渐进入更底层的工作流程。如果高级功能不完全符合我们的要求,我们就能够通过自定义fit()在保留相应数量高级便利性的同时,对小细节获得更多的控制。

当我们需要自定义fit()的行为时,你应该重写Model类的训练步骤函数。这是fit()函数为每一批数据调用的函数。然后,你就可以像平常一样调用fit()——而它将会运行你自己的学习算法。

2、运行准备

2.1 设置

运行前请按照如下进行设置

import os

# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras
from keras import layers
import numpy as np

2.2 示例代码

一下我们从一个简单的例子开始感受在PyTorch中自定义fit()函数中的操作的方法。

首先需要创建一个新的类,它继承自keras.Model。
建立这个新类后,只需要重写train_step(self, data)方法。
运行上述方法将返回一个字典,该字典将指标名称(包括损失)映射到它们的当前值。
输入参数data是传递给fit作为训练数据的内容:

  • 如果通过调用fit(x, y, ...)传递NumPy数组,那么data将是元组(x, y)
  • 如果通过调用fit(dataset, ...)传递一个torch.utils.data.DataLoadertf.data.Dataset,那么data将是数据集在每个批次中生成的内容。

train_step()方法的主体中,我们实现了一个常规的训练更新。重要的是,我们通过self.compute_loss()计算损失,该方法封装了在compile()方法中传递的损失函数。

类似地,我们对self.metrics中的指标调用metric.update_state(y, y_pred),以更新在compile()方法中传递的指标的状态,并在最后查询self.metrics以检索它们的当前值。

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.compute_loss(y=y, y_pred=y_pred)

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}

运行代码,输出如下所示

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 551us/step - mae: 0.6533 - loss: 0.6036
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 522us/step - mae: 0.4013 - loss: 0.2522
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 516us/step - mae: 0.3813 - loss: 0.2256

<keras.src.callbacks.history.History at 0x299b7baf0>

3、底层操作

当然,实际操作过程中也可以在compile()方法中不传递损失函数,而是在train_step中手动处理所有事情。同样地,对于指标也是如此。

下面是一个更底层级别操作的例子,它仅使用compile()来配置优化器:

我们首先创建Metric实例来跟踪我们的损失和MAE分数(在__init__()方法中)。
通过一个自定义的train_step(),更新这些指标的状态(通过在其上调用update_state()),然后查询它们(通过result())以返回它们的当前平均值,这些值将由进度条显示并传递给任何回调。
请注意,运行过程中需要在每个epoch之间调用reset_states()来重置指标!否则,调用result()将返回从训练开始以来的平均值,而通常是使用每个epoch的平均值。框架可以为我们做这件事:只需将你想要重置的任何指标列在模型的metrics属性中即可。模型将在每个fit() epoch的开始或evaluate()调用的开始时调用reset_states()来重置这些对象的状态。

class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.loss_fn = keras.losses.MeanSquaredError()

    def train_step(self, data):
        x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.loss_fn(y, y_pred)

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(y, y_pred)
        return {
            "loss": self.loss_tracker.result(),
            "mae": self.mae_metric.result(),
        }

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        return [self.loss_tracker, self.mae_metric]


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# We don't pass a loss or metrics here.
model.compile(optimizer="adam")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 461us/step - loss: 0.2470 - mae: 0.3953
Epoch 2/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 456us/step - loss: 0.2386 - mae: 0.3910
Epoch 3/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 456us/step - loss: 0.2359 - mae: 0.3901
Epoch 4/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 480us/step - loss: 0.2013 - mae: 0.3572
Epoch 5/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 463us/step - loss: 0.1903 - mae: 0.3480

<keras.src.callbacks.history.History at 0x299c5eec0>

3.1 支持样本权重和分类权重

在文章开始的基本示例没有提到样本权重,那么如果想要支持fit()方法的sample_weight和class_weight参数,可以按照以下步骤进行:

从data参数中解包sample_weight

将其传递给compute_loss和update_state(当然,如果你不是依赖于compile()方法来设置损失和指标,也可以手动应用它)

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.compute_loss(
            y=y,
            y_pred=y_pred,
            sample_weight=sample_weight,
        )

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)
Epoch 1/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 499us/step - mae: 1.4332 - loss: 1.0769
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 520us/step - mae: 0.9250 - loss: 0.5614
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 502us/step - mae: 0.6069 - loss: 0.2653

<keras.src.callbacks.history.History at 0x299c82bf0>

3.2 提供自定义的评估步骤

如果想要在调用model.evaluate()时自定义评估步骤,我们怎么做呢?那么我们将以完全相同的方式重写test_step

class CustomModel(keras.Model):
    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_pred = self(x, training=False)
        # Updates the metrics tracking the loss
        loss = self.compute_loss(y=y, y_pred=y_pred)
        # Update the metrics.
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 325us/step - mae: 0.4427 - loss: 0.2993

[0.2726495862007141, 0.42286917567253113]

4、完整的应用示例

为了整合前面所学的知识,我们将通过一个端到端的GAN(生成对抗网络)示例来演示在PyTorch中自定义fit()函数中的操作。

在这个例子中,我们将考虑:

  • 一个用于生成28x28x1图像的生成器网络。
  • 一个用于将28x28x1图像分类为两个类别(“假”和“真”)的判别器网络。
  • 每个网络都有一个优化器。
  • 一个用于训练判别器的损失函数。

首先,我们需要定义生成器和判别器的网络结构。这里为了简洁,我们不会详细写出每个层的定义,但你可以想象生成器网络将噪声作为输入并输出图像,而判别器网络将图像作为输入并输出一个概率值,该值表示输入图像是真实的(来自训练集)还是假的(由生成器生成)。

下面是GAN训练的大致流程:

  1. 初始化生成器和判别器网络

    • 定义生成器和判别器的模型结构。
    • 编译判别器网络,并指定一个损失函数(如二元交叉熵)和优化器(如Adam)。
  2. 训练判别器

    • 对于一批真实图像,计算判别器的损失(使用真实标签1)。
    • 通过生成器生成一批假图像,并计算判别器对假图像的损失(使用假标签0)。
    • 将两个损失相加,并对判别器执行一次梯度下降更新。
  3. 训练生成器

    • 生成一批假图像。
    • 使用判别器对这些假图像进行预测,得到概率值。
    • 使用判别器的预测作为标签(我们想要生成器生成的图像被判别器认为是真实的),计算生成器的损失(这通常是通过将判别器的预测传递给某种损失函数,如二元交叉熵或均方误差,来实现的)。
    • 使用计算出的损失对生成器执行一次梯度下降更新。
    • 注意:在训练生成器时,我们需要将判别器的权重设置为不可训练(因为我们只希望更新生成器的权重)。这可以通过在训练生成器之前调用discriminator.trainable = False来实现。
  4. 循环迭代

    • 重复步骤2和3多次,以训练GAN。
  5. 在测试集上评估GAN

    • 使用训练好的生成器生成图像,并可视化这些图像以评估GAN的性能。
# Create the discriminator
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

# Create the generator
latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

下面是一个功能完整的GAN类,它重写了compile()方法以使用自己的签名,并在train_step中以17行代码实现了整个GAN算法:

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
        self.seed_generator = keras.random.SeedGenerator(1337)
        self.built = True

    @property
    def metrics(self):
        return [self.d_loss_tracker, self.g_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = real_images.shape[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        real_images = torch.tensor(real_images, device=device)
        combined_images = torch.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = torch.concat(
            [
                torch.ones((batch_size, 1), device=device),
                torch.zeros((batch_size, 1), device=device),
            ],
            axis=0,
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * keras.random.uniform(labels.shape, seed=self.seed_generator)

        # Train the discriminator
        self.zero_grad()
        predictions = self.discriminator(combined_images)
        d_loss = self.loss_fn(labels, predictions)
        d_loss.backward()
        grads = [v.value.grad for v in self.discriminator.trainable_weights]
        with torch.no_grad():
            self.d_optimizer.apply(grads, self.discriminator.trainable_weights)

        # Sample random points in the latent space
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Assemble labels that say "all real images"
        misleading_labels = torch.zeros((batch_size, 1), device=device)

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        self.zero_grad()
        predictions = self.discriminator(self.generator(random_latent_vectors))
        g_loss = self.loss_fn(misleading_labels, predictions)
        grads = g_loss.backward()
        grads = [v.value.grad for v in self.generator.trainable_weights]
        with torch.no_grad():
            self.g_optimizer.apply(grads, self.generator.trainable_weights)

        # Update metrics and return their value.
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
        }

以下是运行结果

# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))

# Create a TensorDataset
dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(all_digits), torch.from_numpy(all_digits)
)
# Create a DataLoader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

gan.fit(dataloader, epochs=1)
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 1582s 1s/step - d_loss: 0.3581 - g_loss: 2.0571

<keras.src.callbacks.history.History at 0x299ce1840>

以上就是在PyTorch中自定义fit()函数中的操作代码的详细内容,更多关于PyTorch自定义fit()的资料请关注脚本之家其它相关文章!

相关文章

  • 设置python3为默认python的方法

    设置python3为默认python的方法

    我们知道在Windows下多版本共存的配置方法就是改可执行文件的名字,配置环境变量。接下来通过本文给大家介绍设置python3为默认python的方法,一起看看吧
    2018-10-10
  • 深入了解Django中间件及其方法

    深入了解Django中间件及其方法

    这篇文章主要介绍了简单了解Django中间件及其方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • 如何利用python检测图片是否包含二维码

    如何利用python检测图片是否包含二维码

    这篇文章主要介绍了如何利用python检测图片是否包含二维码,帮助大家更好的利用python处理图片,感兴趣的朋友可以了解下
    2020-10-10
  • 对python函数签名的方法详解

    对python函数签名的方法详解

    今天小编就为大家分享一篇对python函数签名的方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • python目标检测非极大抑制NMS与Soft-NMS

    python目标检测非极大抑制NMS与Soft-NMS

    这篇文章主要weidajia 介绍了python目标检测非极大抑制NMS与Soft-NMS实现过程,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05
  • 用pandas划分数据集实现训练集和测试集

    用pandas划分数据集实现训练集和测试集

    这篇文章主要介绍了用pandas划分数据集实现训练集和测试集,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-07-07
  • 详解Python用三种方式统计词频的方法

    详解Python用三种方式统计词频的方法

    这篇文章主要介绍了Python用三种方式统计词频,每种方法给大家介绍的非常详细,需要的朋友可以参考下
    2019-07-07
  • python+django+selenium搭建简易自动化测试

    python+django+selenium搭建简易自动化测试

    这篇文章主要介绍了python+django+selenium搭建简易自动化测试,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-08-08
  • Python实现针对给定单链表删除指定节点的方法

    Python实现针对给定单链表删除指定节点的方法

    这篇文章主要介绍了Python实现针对给定单链表删除指定节点的方法,结合实例形式分析了Python单链表的定义、节点添加、删除、打印等相关操作技巧,需要的朋友可以参考下
    2018-04-04
  • python针对mysql数据库的连接、查询、更新、删除操作示例

    python针对mysql数据库的连接、查询、更新、删除操作示例

    这篇文章主要介绍了python针对mysql数据库的连接、查询、更新、删除操作,结合实例形式详细分析了Python操作mysql数据库的连接与增删改查相关实现技巧,需要的朋友可以参考下
    2019-09-09

最新评论