”工欲善其事,必先利其器。“—孔子《论语.录灵公》
首页 > 编程 > 构建常规等变 CNN 的原则

构建常规等变 CNN 的原则

发布于2024-07-31
浏览:496

The one principle is simply stated as 'Let the kernel rotate' and we will focus in this article on how you can apply it in your architectures.

Equivariant architectures allow us to train models which are indifferent to certain group actions.

To understand what this exactly means, let us train this simple CNN model on the MNIST dataset (a dataset of handwritten digits from 0-9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Table 1: Test accuracy of the SimpleCNN model

As expected, we get over 95% accuracy on the testing dataset, but what if we rotate the image by 90 degrees? Without any countermeasures applied, the results drop to just slightly better than guessing. This model would be useless for general applications.

In contrast, let us train a similar equivariant architecture with the same number of parameters, where the group actions are exactly the 90-degree rotations.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Table 2: Test accuracy of the EqCNN model with the same amount of parameters as the SimpleCNN model

The accuracy remains the same, and we did not even opt for data augmentation.

These models become even more impressive with 3D data, but we will stick with this example to explore the core idea.

In case you want to test it out for yourself, you can access all code written in both PyTorch and JAX for free under Github-Repo, and training with Docker or Podman is possible with just two commands.

Have fun!

So What is Equivariance?

Equivariant architectures guarantee stability of features under certain group actions. Groups are simple structures where group elements can be combined, reversed, or do nothing.

You can look up the formal definition on Wikipedia if you are interested.

For our purposes, you can think of a group of 90-degree rotations acting on square images. We can rotate an image by 90, 180, 270, or 360 degrees. To reverse the action, we apply a 270, 180, 90, or 0-degree rotation respectively. It is straightforward to see that we can combine, reverse, or do nothing with the group denoted as C4C_4C4 . The image visualizes all actions on an image.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Figure 3: Feature maps for all four rotations after the input image was rotated

I color-coded the corresponding maps. Each feature map is shifted by one. As the final max operator computes the same result for these shifted feature maps, we obtain the same results.

In my code, I did not rotate back after the final convolution, since my kernels condense the image to a one-dimensional array. If you want to expand on this example, you would need to account for this fact.

Accounting for group actions or "kernel rotations" plays a vital role in the design of more sophisticated architectures.

Is it a Free Lunch?

No, we pay in computational speed, inductive bias, and a more complex implementation.

The latter point is somewhat solved with libraries such as E3NN, where most of the heavy math is abstracted away. Nevertheless, one needs to account for a lot during architecture design.

One superficial weakness is the 4x computational cost for computing all rotated feature layers. However, modern hardware with mass parallelization can easily counteract this load. In contrast, training a simple CNN with data augmentation would easily exceed 10x in training time. This gets even worse for 3D rotations where data augmentation would require about 500x the training amount to compensate for all possible rotations.

Overall, equivariance model design is more often than not a price worth paying if one wants stable features.

What is Next?

Equivariant model designs have exploded in recent years, and in this article, we barely scratched the surface. In fact, we did not even exploit the full C4C_4C4 group yet. We could have used full 3D kernels. However, our model already achieves over 95% accuracy, so there is little reason to go further with this example.

Besides CNNs, researchers have successfully translated these principles to continuous groups, including SO(2)SO(2)SO(2) (the group of all rotations in the plane) and SE(3)SE(3)SE(3) (the group of all translations and rotations in 3D space).

In my experience, these models are absolutely mind-blowing and achieve performance, when trained from scratch, comparable to the performance of foundation models trained on multiple times larger datasets.

Let me know if you want me to write more on this topic.

Further References

In case you want a formal introduction to this topic, here is an excellent compilation of papers, covering the complete history of equivariance in Machine Learning.
AEN

I actually plan to create a deep-dive, hands-on tutorial on this topic. You can already sign up for my mailing list, and I will provide you with free versions over time, along with a direct channel for feedback and Q&A.

See you around :)

版本声明 本文转载于:https://dev.to/freiberg-roman/the-1-principle-to-build-regular-equivariant-cnns-338b?1如有侵犯,请联系[email protected]删除
最新教程 更多>
  • 如何从Python中的字符串中删除表情符号:固定常见错误的初学者指南?
    如何从Python中的字符串中删除表情符号:固定常见错误的初学者指南?
    从python import codecs import codecs import codecs 导入 text = codecs.decode('这狗\ u0001f602'.encode('utf-8'),'utf-8') 印刷(文字)#带有...
    编程 发布于2025-06-28
  • 如何为PostgreSQL中的每个唯一标识符有效地检索最后一行?
    如何为PostgreSQL中的每个唯一标识符有效地检索最后一行?
    postgresql:为每个唯一标识符在postgresql中提取最后一行,您可能需要遇到与数据集合中每个不同标识的信息相关的信息。考虑以下数据:[ 1 2014-02-01 kjkj 在数据集中的每个唯一ID中检索最后一行的信息,您可以在操作员上使用Postgres的有效效率: id dat...
    编程 发布于2025-06-28
  • 如何在Chrome中居中选择框文本?
    如何在Chrome中居中选择框文本?
    选择框的文本对齐:局部chrome-inly-ly-ly-lyly solument 您可能希望将文本中心集中在选择框中,以获取优化的原因或提高可访问性。但是,在CSS中的选择元素中手动添加一个文本 - 对属性可能无法正常工作。初始尝试 state)</option> < op...
    编程 发布于2025-06-28
  • 如何使用Regex在PHP中有效地提取括号内的文本
    如何使用Regex在PHP中有效地提取括号内的文本
    php:在括号内提取文本在处理括号内的文本时,找到最有效的解决方案是必不可少的。一种方法是利用PHP的字符串操作函数,如下所示: 作为替代 $ text ='忽略除此之外的一切(text)'; preg_match('#((。 &&& [Regex使用模式来搜索特...
    编程 发布于2025-06-28
  • JavaScript计算两个日期之间天数的方法
    JavaScript计算两个日期之间天数的方法
    How to Calculate the Difference Between Dates in JavascriptAs you attempt to determine the difference between two dates in Javascript, consider this s...
    编程 发布于2025-06-28
  • CSS可以根据任何属性值来定位HTML元素吗?
    CSS可以根据任何属性值来定位HTML元素吗?
    靶向html元素,在CSS 中使用任何属性值,在CSS中,可以基于特定属性(如下所示)基于特定属性的基于特定属性的emants目标元素: 字体家庭:康斯拉斯(Consolas); } 但是,出现一个常见的问题:元素可以根据任何属性值而定位吗?本文探讨了此主题。的目标元素有任何任何属性值,属...
    编程 发布于2025-06-28
  • PHP未来:适应与创新
    PHP未来:适应与创新
    PHP的未来将通过适应新技术趋势和引入创新特性来实现:1)适应云计算、容器化和微服务架构,支持Docker和Kubernetes;2)引入JIT编译器和枚举类型,提升性能和数据处理效率;3)持续优化性能和推广最佳实践。 引言在编程世界中,PHP一直是网页开发的中流砥柱。作为一个从1994年就开始发展...
    编程 发布于2025-06-28
  • 如何检查对象是否具有Python中的特定属性?
    如何检查对象是否具有Python中的特定属性?
    方法来确定对象属性存在寻求一种方法来验证对象中特定属性的存在。考虑以下示例,其中尝试访问不确定属性会引起错误: >>> a = someClass() >>> A.property Trackback(最近的最新电话): 文件“ ”,第1行, AttributeError: SomeClass...
    编程 发布于2025-06-28
  • 如何高效地在一个事务中插入数据到多个MySQL表?
    如何高效地在一个事务中插入数据到多个MySQL表?
    mySQL插入到多个表中,该数据可能会产生意外的结果。虽然似乎有多个查询可以解决问题,但将从用户表的自动信息ID与配置文件表的手动用户ID相关联提出了挑战。使用Transactions和last_insert_id() 插入用户(用户名,密码)值('test','test...
    编程 发布于2025-06-28
  • 表单刷新后如何防止重复提交?
    表单刷新后如何防止重复提交?
    在Web开发中预防重复提交 在表格提交后刷新页面时,遇到重复提交的问题是常见的。要解决这个问题,请考虑以下方法: 想象一下具有这样的代码段,看起来像这样的代码段:)){ //数据库操作... 回声“操作完成”; 死(); } ?> ...
    编程 发布于2025-06-28
  • 您如何在Laravel Blade模板中定义变量?
    您如何在Laravel Blade模板中定义变量?
    在Laravel Blade模板中使用Elegance 在blade模板中如何分配变量对于存储以后使用的数据至关重要。在使用“ {{}}”分配变量的同时,它可能并不总是最优雅的解决方案。幸运的是,Blade通过@php Directive提供了更优雅的方法: $ old_section =“...
    编程 发布于2025-06-28
  • Java字符串非空且非null的有效检查方法
    Java字符串非空且非null的有效检查方法
    检查字符串是否不是null而不是空的 if(str!= null && str.isementy())二手: if(str!= null && str.length()== 0) option 3:trim()。isement(Isement() trim whitespace whitesp...
    编程 发布于2025-06-28
  • 找到最大计数时,如何解决mySQL中的“组函数\”错误的“无效使用”?
    找到最大计数时,如何解决mySQL中的“组函数\”错误的“无效使用”?
    如何在mySQL中使用mySql 检索最大计数,您可能会遇到一个问题,您可能会在尝试使用以下命令:理解错误正确找到由名称列分组的值的最大计数,请使用以下修改后的查询: 计数(*)为c 来自EMP1 按名称组 c desc订购 限制1 查询说明 select语句提取名称列和每个名称...
    编程 发布于2025-06-28
  • 为什么不使用CSS`content'属性显示图像?
    为什么不使用CSS`content'属性显示图像?
    在Firefox extemers属性为某些图像很大,&& && && &&华倍华倍[华氏华倍华氏度]很少见,却是某些浏览属性很少,尤其是特定于Firefox的某些浏览器未能在使用内容属性引用时未能显示图像的情况。这可以在提供的CSS类中看到:。googlepic { 内容:url(&#...
    编程 发布于2025-06-28

免责声明: 提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发到邮箱:[email protected] 我们会第一时间内为您处理。

Copyright© 2022 湘ICP备2022001581号-3