-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathDiffJPEG.py
35 lines (31 loc) · 1.16 KB
/
DiffJPEG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Pytorch
import torch
import torch.nn as nn
# Local
from modules import compress_jpeg, decompress_jpeg
from utils import diff_round, quality_to_factor
class DiffJPEG(nn.Module):
def __init__(self, height, width, differentiable=True, quality=80):
''' Initialize the DiffJPEG layer
Inputs:
height(int): Original image hieght
width(int): Original image width
differentiable(bool): If true uses custom differentiable
rounding function, if false uses standrard torch.round
quality(float): Quality factor for jpeg compression scheme.
'''
super(DiffJPEG, self).__init__()
if differentiable:
rounding = diff_round
else:
rounding = torch.round
factor = quality_to_factor(quality)
self.compress = compress_jpeg(rounding=rounding, factor=factor)
self.decompress = decompress_jpeg(height, width, rounding=rounding,
factor=factor)
def forward(self, x):
'''
'''
y, cb, cr = self.compress(x)
recovered = self.decompress(y, cb, cr)
return recovered