”工欲善其事,必先利其器。“—孔子《论语.录灵公》
首页 > 编程 > 构建常规等变 CNN 的原则

构建常规等变 CNN 的原则

发布于2024-07-31
浏览:612

The one principle is simply stated as 'Let the kernel rotate' and we will focus in this article on how you can apply it in your architectures.

Equivariant architectures allow us to train models which are indifferent to certain group actions.

To understand what this exactly means, let us train this simple CNN model on the MNIST dataset (a dataset of handwritten digits from 0-9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Table 1: Test accuracy of the SimpleCNN model

As expected, we get over 95% accuracy on the testing dataset, but what if we rotate the image by 90 degrees? Without any countermeasures applied, the results drop to just slightly better than guessing. This model would be useless for general applications.

In contrast, let us train a similar equivariant architecture with the same number of parameters, where the group actions are exactly the 90-degree rotations.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Table 2: Test accuracy of the EqCNN model with the same amount of parameters as the SimpleCNN model

The accuracy remains the same, and we did not even opt for data augmentation.

These models become even more impressive with 3D data, but we will stick with this example to explore the core idea.

In case you want to test it out for yourself, you can access all code written in both PyTorch and JAX for free under Github-Repo, and training with Docker or Podman is possible with just two commands.

Have fun!

So What is Equivariance?

Equivariant architectures guarantee stability of features under certain group actions. Groups are simple structures where group elements can be combined, reversed, or do nothing.

You can look up the formal definition on Wikipedia if you are interested.

For our purposes, you can think of a group of 90-degree rotations acting on square images. We can rotate an image by 90, 180, 270, or 360 degrees. To reverse the action, we apply a 270, 180, 90, or 0-degree rotation respectively. It is straightforward to see that we can combine, reverse, or do nothing with the group denoted as C4C_4C4 . The image visualizes all actions on an image.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Figure 3: Feature maps for all four rotations after the input image was rotated

I color-coded the corresponding maps. Each feature map is shifted by one. As the final max operator computes the same result for these shifted feature maps, we obtain the same results.

In my code, I did not rotate back after the final convolution, since my kernels condense the image to a one-dimensional array. If you want to expand on this example, you would need to account for this fact.

Accounting for group actions or "kernel rotations" plays a vital role in the design of more sophisticated architectures.

Is it a Free Lunch?

No, we pay in computational speed, inductive bias, and a more complex implementation.

The latter point is somewhat solved with libraries such as E3NN, where most of the heavy math is abstracted away. Nevertheless, one needs to account for a lot during architecture design.

One superficial weakness is the 4x computational cost for computing all rotated feature layers. However, modern hardware with mass parallelization can easily counteract this load. In contrast, training a simple CNN with data augmentation would easily exceed 10x in training time. This gets even worse for 3D rotations where data augmentation would require about 500x the training amount to compensate for all possible rotations.

Overall, equivariance model design is more often than not a price worth paying if one wants stable features.

What is Next?

Equivariant model designs have exploded in recent years, and in this article, we barely scratched the surface. In fact, we did not even exploit the full C4C_4C4 group yet. We could have used full 3D kernels. However, our model already achieves over 95% accuracy, so there is little reason to go further with this example.

Besides CNNs, researchers have successfully translated these principles to continuous groups, including SO(2)SO(2)SO(2) (the group of all rotations in the plane) and SE(3)SE(3)SE(3) (the group of all translations and rotations in 3D space).

In my experience, these models are absolutely mind-blowing and achieve performance, when trained from scratch, comparable to the performance of foundation models trained on multiple times larger datasets.

Let me know if you want me to write more on this topic.

Further References

In case you want a formal introduction to this topic, here is an excellent compilation of papers, covering the complete history of equivariance in Machine Learning.
AEN

I actually plan to create a deep-dive, hands-on tutorial on this topic. You can already sign up for my mailing list, and I will provide you with free versions over time, along with a direct channel for feedback and Q&A.

See you around :)

版本声明 本文转载于:https://dev.to/freiberg-roman/the-1-principle-to-build-regular-equivariant-cnns-338b?1如有侵犯,请联系[email protected]删除
最新教程 更多>
  • Python中嵌套函数与闭包的区别是什么
    Python中嵌套函数与闭包的区别是什么
    嵌套函数与python 在python中的嵌套函数不被考虑闭合,因为它们不符合以下要求:不访问局部范围scliables to incling scliables在封装范围外执行范围的局部范围。 make_printer(msg): DEF打印机(): 打印(味精) ...
    编程 发布于2025-05-10
  • 在UTF8 MySQL表中正确将Latin1字符转换为UTF8的方法
    在UTF8 MySQL表中正确将Latin1字符转换为UTF8的方法
    在UTF8表中将latin1字符转换为utf8 ,您遇到了一个问题,其中含义的字符(例如,“jáuòiñe”)在utf8 table tabled tablesset中被extect(例如,“致电。为了解决此问题,您正在尝试使用“ mb_convert_encoding”和“ iconv”转换受...
    编程 发布于2025-05-10
  • 为什么HTML无法打印页码及解决方案
    为什么HTML无法打印页码及解决方案
    无法在html页面上打印页码? @page规则在@Media内部和外部都无济于事。 HTML:Customization:@page { margin: 10%; @top-center { font-family: sans-serif; font-weight: bo...
    编程 发布于2025-05-10
  • 为什么不使用CSS`content'属性显示图像?
    为什么不使用CSS`content'属性显示图像?
    在Firefox extemers属性为某些图像很大,&& && && &&华倍华倍[华氏华倍华氏度]很少见,却是某些浏览属性很少,尤其是特定于Firefox的某些浏览器未能在使用内容属性引用时未能显示图像的情况。这可以在提供的CSS类中看到:。googlepic { 内容:url(&#...
    编程 发布于2025-05-10
  • `console.log`显示修改后对象值异常的原因
    `console.log`显示修改后对象值异常的原因
    foo = [{id:1},{id:2},{id:3},{id:4},{id:id:5},],]; console.log('foo1',foo,foo.length); foo.splice(2,1); console.log('foo2', foo, foo....
    编程 发布于2025-05-10
  • 如何使用“ JSON”软件包解析JSON阵列?
    如何使用“ JSON”软件包解析JSON阵列?
    parsing JSON与JSON软件包 QUALDALS:考虑以下go代码:字符串 } func main(){ datajson:=`[“ 1”,“ 2”,“ 3”]`` arr:= jsontype {} 摘要:= = json.unmarshal([] byte(...
    编程 发布于2025-05-10
  • eval()vs. ast.literal_eval():对于用户输入,哪个Python函数更安全?
    eval()vs. ast.literal_eval():对于用户输入,哪个Python函数更安全?
    称量()和ast.literal_eval()中的Python Security 在使用用户输入时,必须优先确保安全性。强大的Python功能Eval()通常是作为潜在解决方案而出现的,但担心其潜在风险。 This article delves into the differences betwee...
    编程 发布于2025-05-10
  • Java为何无法创建泛型数组?
    Java为何无法创建泛型数组?
    通用阵列创建错误 arrayList [2]; JAVA报告了“通用数组创建”错误。为什么不允许这样做?答案:Create an Auxiliary Class:public static ArrayList<myObject>[] a = new ArrayList<myO...
    编程 发布于2025-05-10
  • 哪种方法更有效地用于点 - 填点检测:射线跟踪或matplotlib \的路径contains_points?
    哪种方法更有效地用于点 - 填点检测:射线跟踪或matplotlib \的路径contains_points?
    在Python Matplotlib's path.contains_points FunctionMatplotlib's path.contains_points function employs a path object to represent the polygon.它...
    编程 发布于2025-05-10
  • 解决Spring Security 4.1及以上版本CORS问题指南
    解决Spring Security 4.1及以上版本CORS问题指南
    弹簧安全性cors filter:故障排除常见问题 在将Spring Security集成到现有项目中时,您可能会遇到与CORS相关的错误,如果像“访问Control-allo-allow-Origin”之类的标头,则无法设置在响应中。为了解决此问题,您可以实现自定义过滤器,例如代码段中的MyFi...
    编程 发布于2025-05-10
  • 如何从Python中的字符串中删除表情符号:固定常见错误的初学者指南?
    如何从Python中的字符串中删除表情符号:固定常见错误的初学者指南?
    从python import codecs import codecs import codecs 导入 text = codecs.decode('这狗\ u0001f602'.encode('utf-8'),'utf-8') 印刷(文字)#带有...
    编程 发布于2025-05-10
  • 如何实时捕获和流媒体以进行聊天机器人命令执行?
    如何实时捕获和流媒体以进行聊天机器人命令执行?
    在开发能够执行命令的chatbots的领域中,实时从命令执行实时捕获Stdout,一个常见的需求是能够检索和显示标准输出(stdout)在cath cath cant cant cant cant cant cant cant cant interfaces in Chate cant inter...
    编程 发布于2025-05-10
  • 如何使用组在MySQL中旋转数据?
    如何使用组在MySQL中旋转数据?
    在关系数据库中使用mySQL组使用mySQL组进行查询结果,在关系数据库中使用MySQL组,转移数据的数据是指重新排列的行和列的重排以增强数据可视化。在这里,我们面对一个共同的挑战:使用组的组将数据从基于行的基于列的转换为基于列。 Let's consider the following ...
    编程 发布于2025-05-10
  • 如何在Chrome中居中选择框文本?
    如何在Chrome中居中选择框文本?
    选择框的文本对齐:局部chrome-inly-ly-ly-lyly solument 您可能希望将文本中心集中在选择框中,以获取优化的原因或提高可访问性。但是,在CSS中的选择元素中手动添加一个文本 - 对属性可能无法正常工作。初始尝试 state)</option> < op...
    编程 发布于2025-05-10
  • 如何在Java字符串中有效替换多个子字符串?
    如何在Java字符串中有效替换多个子字符串?
    在java 中有效地替换多个substring,需要在需要替换一个字符串中的多个substring的情况下,很容易求助于重复应用字符串的刺激力量。 However, this can be inefficient for large strings or when working with nu...
    编程 发布于2025-05-10

免责声明: 提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发到邮箱:[email protected] 我们会第一时间内为您处理。

Copyright© 2022 湘ICP备2022001581号-3