-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path009NNInterpret.jl
78 lines (57 loc) · 2.32 KB
/
009NNInterpret.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
using MLJ # for fit/predict
using SymbolicRegression # for SRRegressor
using Zygote # For `enable_autodiff=true`
using SymbolicUtils
using DelimitedFiles
using LinearAlgebra # for normalize
cd(@__DIR__)
myarrayX=readdlm("circle_X.out", ',', Float64)
myarrayGradients=readdlm("circle_gradients.out", ',', Float64)
myinputsize=size(myarrayX)[2]
X_concat=cat(myarrayX,myarrayGradients,dims=2) # smuggle in the labels through data X
y
# fake labels
f(x) = 0
y = f.(myarrayX[:, 1])
function derivative_loss(tree, dataset::Dataset{T,L}, options, idx) where {T,L}
# Return infinite loss for any violated assumptions, does not seem necessary, but sometimes better
tree.degree != 2 && return L(Inf)
#tree.l.degree != 2 && return L(Inf)
# Prevent nodes corresponding to invalid features that arise from smuggling in the target y through x
is_invalid_feature(node) = node.degree == 0 && !node.constant && (node.feature > myinputsize)
any(is_invalid_feature, tree) && return L(Inf)
# Select from the batch indices, if given and extract normalized gradients from X
X = idx === nothing ? dataset.X : view(dataset.X, :, idx)
∂y = X[myinputsize+1:2*myinputsize,:]
# Evaluate both f(x) and f'(x), where f is defined by `tree`
ŷ, ∂ŷ, completed = eval_grad_tree_array(tree, X, options; variable=true)
!completed && return L(Inf)
# Only use gradients wrt to real features
∂ŷ=∂ŷ[1:myinputsize,:]
# Normalize gradients with Euclidean norm
normalize!.(eachcol(∂ŷ))
# Calculate mean square error loss function on normalized gradients
mse_grad = sum(i -> (∂ŷ[i] - ∂y[i])^2, eachindex(∂y)) / length(∂y)
return mse_grad
end
model = SRRegressor(;
binary_operators=[+, -, *, /],
unary_operators=[sin, exp],
complexity_of_constants=3,
complexity_of_operators=[exp => 5, sin => 5],
loss_function=derivative_loss,
should_simplify = true,
should_optimize_constants = true,
enable_autodiff=true,
batching=true,
batch_size=25,
niterations=200,#100
early_stop_condition=1e-10,
maxsize=30 # for gravity force eq
)
mach = machine(model, X_concat, y)
fit!(mach)
r = report(mach)
eq = r.equations[r.best_idx]
symbolic_eq = node_to_symbolic(eq, model)
r