using Flux
using Random, LinearAlgebra
# Set seed for reproducibility
Random.seed!(42)
# Dimensions
m, n, r = 10, 8, 8 # A is m x n, X is m x k, Y is k x n
# Generate a random matrix A
A = randn(m, n)
# Initialize learnable factors
X = randn(m, r)
Y = randn(r, n)
function custom_loss(A, X, Y)
return norm(A-X*Y)^2
#return norm(A-X*Y)^2 + norm(X)^2 + norm(Y)^2
end
model_loss = (X,Y) -> custom_loss(A,X,Y)
# Training loop
niterations = 10000
for epoch in 1:niterations
gX, gY = gradient(model_loss, X, Y)
@. X = X - 0.01*gX
@. Y = Y - 0.01*gY
# Print loss every 50 epochs
if epoch % 50 == 0
println("Iterations $epoch: Loss = ", model_loss(X, Y))
end
end
opterr = model_loss(X,Y)
# Display learned X and Y
Usvd,Ssvd,Vsvd = svd(A)
svderr = norm(A-Usvd[:,1:r]*Diagonal(Ssvd[1:r])*Vsvd[:,1:r]')^2
@show opterr
@show svderr
Iterations 50: Loss = 2.091547823016724 Iterations 100: Loss = 0.5661487081197589 Iterations 150: Loss = 0.37969498421413045 Iterations 200: Loss = 0.29424961525867044 Iterations 250: Loss = 0.23736222031632534 Iterations 300: Loss = 0.1964354866301492 Iterations 350: Loss = 0.16601387453115432 Iterations 400: Loss = 0.14287411770926364 Iterations 450: Loss = 0.1249276797918005 Iterations 500: Loss = 0.11076102657341395 Iterations 550: Loss = 0.09938956930058786 Iterations 600: Loss = 0.09011258577555409 Iterations 650: Loss = 0.08242286358561295 Iterations 700: Loss = 0.07594817484624594 Iterations 750: Loss = 0.07041219966566804 Iterations 800: Loss = 0.06560780146151443 Iterations 850: Loss = 0.06137839646849488 Iterations 900: Loss = 0.05760476277001347 Iterations 950: Loss = 0.05419557903563848 Iterations 1000: Loss = 0.05108056173010772 Iterations 1050: Loss = 0.048205435866809476 Iterations 1100: Loss = 0.04552821297055821 Iterations 1150: Loss = 0.04301640904246799 Iterations 1200: Loss = 0.040644943537536414 Iterations 1250: Loss = 0.038394535145568016 Iterations 1300: Loss = 0.036250462521420154 Iterations 1350: Loss = 0.034201595160950955 Iterations 1400: Loss = 0.032239626065311194 Iterations 1450: Loss = 0.030358456836600548 Iterations 1500: Loss = 0.02855369955562033 Iterations 1550: Loss = 0.026822269694929714 Iterations 1600: Loss = 0.025162051458471984 Iterations 1650: Loss = 0.023571622051321215 Iterations 1700: Loss = 0.022050025002258982 Iterations 1750: Loss = 0.020596585179147527 Iterations 1800: Loss = 0.01921075984538795 Iterations 1850: Loss = 0.017892021228549455 Iterations 1900: Loss = 0.016639766782187856 Iterations 1950: Loss = 0.015453253753083748 Iterations 2000: Loss = 0.014331554922186804 Iterations 2050: Loss = 0.013273532546789995 Iterations 2100: Loss = 0.012277827650879621 Iterations 2150: Loss = 0.011342861929083665 Iterations 2200: Loss = 0.010466849670845874 Iterations 2250: Loss = 0.009647817286751566 Iterations 2300: Loss = 0.008883628230206997 Iterations 2350: Loss = 0.008172011350082477 Iterations 2400: Loss = 0.00751059097438486 Iterations 2450: Loss = 0.006896917300391609 Iterations 2500: Loss = 0.006328495941524063 Iterations 2550: Loss = 0.0058028157451007 Iterations 2600: Loss = 0.005317374239396081 Iterations 2650: Loss = 0.0048697002867813515 Iterations 2700: Loss = 0.0044573737081773914 Iterations 2750: Loss = 0.004078041800864097 Iterations 2800: Loss = 0.003729432797022102 Iterations 2850: Loss = 0.003409366405908848 Iterations 2900: Loss = 0.003115761651040025 Iterations 2950: Loss = 0.0028466422585831153 Iterations 3000: Loss = 0.0026001398781133342 Iterations 3050: Loss = 0.0023744954257074163 Iterations 3100: Loss = 0.0021680588356420283 Iterations 3150: Loss = 0.001979287493983368 Iterations 3200: Loss = 0.0018067436079380144 Iterations 3250: Loss = 0.0016490907413761518 Iterations 3300: Loss = 0.0015050897213698617 Iterations 3350: Loss = 0.001373594094424699 Iterations 3400: Loss = 0.0012535452854519839 Iterations 3450: Loss = 0.0011439675882342335 Iterations 3500: Loss = 0.0010439630937052696 Iterations 3550: Loss = 0.0009527066421060979 Iterations 3600: Loss = 0.0008694408671250296 Iterations 3650: Loss = 0.0007934713844947298 Iterations 3700: Loss = 0.0007241621641145677 Iterations 3750: Loss = 0.0006609311134525298 Iterations 3800: Loss = 0.000603245890573871 Iterations 3850: Loss = 0.0005506199574433729 Iterations 3900: Loss = 0.0005026088779495058 Iterations 3950: Loss = 0.0004588068602014448 Iterations 4000: Loss = 0.0004188435388616766 Iterations 4050: Loss = 0.00038238099042634487 Iterations 4100: Loss = 0.0003491109722921667 Iterations 4150: Loss = 0.00031875237501530706 Iterations 4200: Loss = 0.0002910488762518231 Iterations 4250: Loss = 0.00026576678436702426 Iterations 4300: Loss = 0.00024269305952397344 Iterations 4350: Loss = 0.00022163350013455508 Iterations 4400: Loss = 0.0002024110828191659 Iterations 4450: Loss = 0.00018486444442131878 Iterations 4500: Loss = 0.00016884649512177784 Iterations 4550: Loss = 0.0001542231522587359 Iterations 4600: Loss = 0.0001408721850611429 Iterations 4650: Loss = 0.0001286821611204695 Iterations 4700: Loss = 0.00011755148604680225 Iterations 4750: Loss = 0.00010738752836612796 Iterations 4800: Loss = 9.810582230878433e-5 Iterations 4850: Loss = 8.962934170755446e-5 Iterations 4900: Loss = 8.188783876490361e-5 Iterations 4950: Loss = 7.481724195830831e-5 Iterations 5000: Loss = 6.835910783088071e-5 Iterations 5050: Loss = 6.246012185993381e-5 Iterations 5100: Loss = 5.707164400986692e-5 Iterations 5150: Loss = 5.214929495828022e-5 Iterations 5200: Loss = 4.765257933709538e-5 Iterations 5250: Loss = 4.3544542654561886e-5 Iterations 5300: Loss = 3.979145886186964e-5 Iterations 5350: Loss = 3.636254580030131e-5 Iterations 5400: Loss = 3.322970601405161e-5 Iterations 5450: Loss = 3.0367290641250534e-5 Iterations 5500: Loss = 2.7751884302955842e-5 Iterations 5550: Loss = 2.5362109098887773e-5 Iterations 5600: Loss = 2.3178445990580994e-5 Iterations 5650: Loss = 2.1183072008975297e-5 Iterations 5700: Loss = 1.9359711865829693e-5 Iterations 5750: Loss = 1.7693502677433463e-5 Iterations 5800: Loss = 1.617087062660398e-5 Iterations 5850: Loss = 1.4779418495651612e-5 Iterations 5900: Loss = 1.3507823099787599e-5 Iterations 5950: Loss = 1.2345741738505414e-5 Iterations 6000: Loss = 1.1283726862300677e-5 Iterations 6050: Loss = 1.0313148224707296e-5 Iterations 6100: Loss = 9.426121855429314e-6 Iterations 6150: Loss = 8.61544525028296e-6 Iterations 6200: Loss = 7.874538227934553e-6 Iterations 6250: Loss = 7.1973889528499404e-6 Iterations 6300: Loss = 6.578504668725269e-6 Iterations 6350: Loss = 6.012866727414332e-6 Iterations 6400: Loss = 5.495889535460493e-6 Iterations 6450: Loss = 5.023383074014882e-6 Iterations 6500: Loss = 4.591518678525342e-6 Iterations 6550: Loss = 4.196797792514818e-6 Iterations 6600: Loss = 3.836023435009318e-6 Iterations 6650: Loss = 3.5062741443110902e-6 Iterations 6700: Loss = 3.2048801817677257e-6 Iterations 6750: Loss = 2.929401798262728e-6 Iterations 6800: Loss = 2.6776093835862707e-6 Iterations 6850: Loss = 2.4474653346242448e-6 Iterations 6900: Loss = 2.2371074927645178e-6 Iterations 6950: Loss = 2.0448340140349103e-6 Iterations 7000: Loss = 1.869089547444632e-6 Iterations 7050: Loss = 1.7084526079274066e-6 Iterations 7100: Loss = 1.5616240401995408e-6 Iterations 7150: Loss = 1.4274164789251081e-6 Iterations 7200: Loss = 1.3047447188194008e-6 Iterations 7250: Loss = 1.1926169158745868e-6 Iterations 7300: Loss = 1.090126547732652e-6 Iterations 7350: Loss = 9.964450675236286e-7 Iterations 7400: Loss = 9.10815191176191e-7 Iterations 7450: Loss = 8.325447634320019e-7 Iterations 7500: Loss = 7.610011525613419e-7 Iterations 7550: Loss = 6.956061280851128e-7 Iterations 7600: Loss = 6.358311798147705e-7 Iterations 7650: Loss = 5.811932400970531e-7 Iterations 7700: Loss = 5.312507744708986e-7 Iterations 7750: Loss = 4.856002089583236e-7 Iterations 7800: Loss = 4.4387266494429344e-7 Iterations 7850: Loss = 4.057309751334595e-7 Iterations 7900: Loss = 3.708669563458725e-7 Iterations 7950: Loss = 3.3899891702201004e-7 Iterations 8000: Loss = 3.0986937920724994e-7 Iterations 8050: Loss = 2.8324299654339726e-7 Iterations 8100: Loss = 2.589046513808985e-7 Iterations 8150: Loss = 2.3665771558951264e-7 Iterations 8200: Loss = 2.163224609684961e-7 Iterations 8250: Loss = 1.9773460638062864e-7 Iterations 8300: Loss = 1.807439898387137e-7 Iterations 8350: Loss = 1.6521335479008007e-7 Iterations 8400: Loss = 1.51017240774682e-7 Iterations 8450: Loss = 1.3804096947034427e-7 Iterations 8500: Loss = 1.2617971792249396e-7 Iterations 8550: Loss = 1.1533767145725407e-7 Iterations 8600: Loss = 1.0542724942100497e-7 Iterations 8650: Loss = 9.636839748385193e-8 Iterations 8700: Loss = 8.80879407835405e-8 Iterations 8750: Loss = 8.051899267420461e-8 Iterations 8800: Loss = 7.360041429922142e-8 Iterations 8850: Loss = 6.727632062050114e-8 Iterations 8900: Loss = 6.149562890441936e-8 Iterations 8950: Loss = 5.621164601953549e-8 Iterations 9000: Loss = 5.1381691203743314e-8 Iterations 9050: Loss = 4.69667512569366e-8 Iterations 9100: Loss = 4.2931165366594105e-8 Iterations 9150: Loss = 3.9242337021973384e-8 Iterations 9200: Loss = 3.587047068731033e-8 Iterations 9250: Loss = 3.278833110679159e-8 Iterations 9300: Loss = 2.9971023292656866e-8 Iterations 9350: Loss = 2.7395791423329467e-8 Iterations 9400: Loss = 2.504183502269875e-8 Iterations 9450: Loss = 2.2890140936637582e-8 Iterations 9500: Loss = 2.092332974850959e-8 Iterations 9550: Loss = 1.9125515393841845e-8 Iterations 9600: Loss = 1.748217683727073e-8 Iterations 9650: Loss = 1.5980040778726607e-8 Iterations 9700: Loss = 1.4606974437213959e-8 Iterations 9750: Loss = 1.3351887548844835e-8 Iterations 9800: Loss = 1.2204642786063424e-8 Iterations 9850: Loss = 1.1155973874523015e-8 Iterations 9900: Loss = 1.0197410746274112e-8 Iterations 9950: Loss = 9.321211124555337e-9 Iterations 10000: Loss = 8.520297987205108e-9 opterr = 8.520297987205108e-9 svderr = 1.2555705456793983e-28
1.2555705456793983e-28
# This will evaluate the gradient of norm(A-XY^2) at a point X, Y
# Let's check against my "analytical" gradient...
X = ones(m,r)
Y = ones(r,n)
gX, gY = gradient(model_loss, X, Y)
gX
gX + (2*(A - X*Y)*Y')
10×8 Matrix{Float64}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
norm(X), norm(Y)
(3.8030981067374814, 3.803098106737481)
norm(A-X*Y)
2.6586124701944938
# Dimensions
m, n, r = 10, 8, 8 # A is m x n, X is m x k, Y is k x n
# Generate a random matrix A
A = randn(m, n)
# Initialize learnable factors
X = randn(m, r)
Y = randn(r, n)
function full_function(A, X, Y)
#return norm(A-X*Y)^2
return norm(A-X*Y)^2 + norm(X)^2 + norm(Y)^2
end
optimization_function = (X,Y) -> full_function(A,X,Y)
opt = Optimisers.Adam(0.01)
state = Optimisers.setup(opt, (X,Y))
# Training loop (Flux calls this epochs, but I don't like confusing them...)
niteration = 5000
for iteration in 1:niteration
#@show X, Y
grads = gradient(optimization_function, X, Y)
state, (X,Y) = Optimisers.update(state, (X,Y), grads)
# Print loss every 50 epochs
if iteration % 50 == 0
println("Iteration $iteration: Loss = ", optimization_function(X, Y))
end
end
@show opterr = optimization_function(X,Y)
@show norm(A-X*Y)^2
Iteration 50: Loss = 220.904522671741 Iteration 100: Loss = 114.74149638440048 Iteration 150: Loss = 82.61458301083897 Iteration 200: Loss = 66.22401683960052 Iteration 250: Loss = 55.52050523200967 Iteration 300: Loss = 48.703229900422 Iteration 350: Loss = 44.8066391176972 Iteration 400: Loss = 42.486276770455035 Iteration 450: Loss = 41.04934383794148 Iteration 500: Loss = 40.162565316476744 Iteration 550: Loss = 39.62032106042476 Iteration 600: Loss = 39.290744995720395 Iteration 650: Loss = 39.09141976721878 Iteration 700: Loss = 38.97136555408673 Iteration 750: Loss = 38.89915824094566 Iteration 800: Loss = 38.85560367418039 Iteration 850: Loss = 38.82913678944861 Iteration 900: Loss = 38.81287419558346 Iteration 950: Loss = 38.802746043355576 Iteration 1000: Loss = 38.796345197509225 Iteration 1050: Loss = 38.79223869079882 Iteration 1100: Loss = 38.7895642781357 Iteration 1150: Loss = 38.78779636818757 Iteration 1200: Loss = 38.78661019243293 Iteration 1250: Loss = 38.78580237375574 Iteration 1300: Loss = 38.785243916077846 Iteration 1350: Loss = 38.784851999650215 Iteration 1400: Loss = 38.78457282622818 Iteration 1450: Loss = 38.7843710482461 Iteration 1500: Loss = 38.78422317465724 Iteration 1550: Loss = 38.78411340988787 Iteration 1600: Loss = 38.78403100019623 Iteration 1650: Loss = 38.783968525312844 Iteration 1700: Loss = 38.783920789658325 Iteration 1750: Loss = 38.783884097698376 Iteration 1800: Loss = 38.78385577739258 Iteration 1850: Loss = 38.7838338648007 Iteration 1900: Loss = 38.78381689377257 Iteration 1950: Loss = 38.78380375434026 Iteration 2000: Loss = 38.783793596150886 Iteration 2050: Loss = 38.783785761554384 Iteration 2100: Loss = 38.783779738357424 Iteration 2150: Loss = 38.78377512576706 Iteration 2200: Loss = 38.78377160931973 Iteration 2250: Loss = 38.783768942054685 Iteration 2300: Loss = 38.78376693012621 Iteration 2350: Loss = 38.783765421647516 Iteration 2400: Loss = 38.783764297942014 Iteration 2450: Loss = 38.78376346662529 Iteration 2500: Loss = 38.78376285610403 Iteration 2550: Loss = 38.783762411186814 Iteration 2600: Loss = 38.78376208957735 Iteration 2650: Loss = 38.7837618590728 Iteration 2700: Loss = 38.78376169532962 Iteration 2750: Loss = 38.78376158008687 Iteration 2800: Loss = 38.783761499759095 Iteration 2850: Loss = 38.783761444327666 Iteration 2900: Loss = 38.783761406472856 Iteration 2950: Loss = 38.78376138089887 Iteration 3000: Loss = 38.78376136381347 Iteration 3050: Loss = 38.78376135253012 Iteration 3100: Loss = 38.78376134516678 Iteration 3150: Loss = 38.78376134042032 Iteration 3200: Loss = 38.7837613373993 Iteration 3250: Loss = 38.783761335501424 Iteration 3300: Loss = 38.783761334325106 Iteration 3350: Loss = 38.78376133360603 Iteration 3400: Loss = 38.783761333172706 Iteration 3450: Loss = 38.78376133291539 Iteration 3500: Loss = 38.78376133276486 Iteration 3550: Loss = 38.78376133267817 Iteration 3600: Loss = 38.783761332629055 Iteration 3650: Loss = 38.78376133260166 Iteration 3700: Loss = 38.78376133258663 Iteration 3750: Loss = 38.78376133257855 Iteration 3800: Loss = 38.78376133257426 Iteration 3850: Loss = 38.78376133257205 Iteration 3900: Loss = 38.7837613325709 Iteration 3950: Loss = 38.783761332570336 Iteration 4000: Loss = 38.78376133257007 Iteration 4050: Loss = 38.783761332569924 Iteration 4100: Loss = 38.78376133256985 Iteration 4150: Loss = 38.783761332569846 Iteration 4200: Loss = 38.783761332569824 Iteration 4250: Loss = 38.78376133256981 Iteration 4300: Loss = 38.783761332569824 Iteration 4350: Loss = 38.78376133256981 Iteration 4400: Loss = 38.78376133256982 Iteration 4450: Loss = 38.78376133256982 Iteration 4500: Loss = 38.7837613325698 Iteration 4550: Loss = 38.78376133256981 Iteration 4600: Loss = 38.78376133256983 Iteration 4650: Loss = 38.78376133256981 Iteration 4700: Loss = 38.78376133256982 Iteration 4750: Loss = 38.78376133256981 Iteration 4800: Loss = 38.7837613325698 Iteration 4850: Loss = 38.78376133256981 Iteration 4900: Loss = 38.78376133256981 Iteration 4950: Loss = 38.78376133256981 Iteration 5000: Loss = 38.78376133256981 opterr = optimization_function(X, Y) = 38.78376133256981 norm(A - X * Y) ^ 2 = 7.192773609225721
7.192773609225721
using Optim
m = 10
n = 8
A = randn(m,n)
#A = Matrix(1.0I,m,n)
r = 2
myf = x -> matrix_approx_function(x, A, r)
myg! = (x, storage) -> matrix_approx_gradient!(x, storage, A, r)
soln = optimize(myf, myg!, ones(m*r+n*r), BFGS(), Optim.Options(f_tol = 1e-8))
#soln = optimize(myf, myg!, randn(m*r+n*r), BFGS(), Optim.Options(f_tol = 1e-8))
x = Optim.minimizer(soln)
@show soln
Uopt = reshape(x[(1:m*r)],m,r)
Vopt = reshape(x[(m*r+1):end],n,r)
objval = 2*myf(x)
opterr = norm(A-Uopt*Vopt')^2
Usvd,Ssvd,Vsvd = svd(A)
svderr = norm(A-Usvd[:,1:r]*Diagonal(Ssvd[1:r])*Vsvd[:,1:r]')^2
@show objval
@show opterr
@show svderr
; # hide final output in JuliaBox
UndefVarError: `matrix_approx_gradient!` not defined in `Main`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
[1] (::var"#15#16")(x::Vector{Float64}, storage::Vector{Float64})
@ Main ./In[8]:7
[2] (::NLSolversBase.var"#fg!#8"{var"#13#14", var"#15#16"})(gx::Vector{Float64}, x::Vector{Float64})
@ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/objective_types/abstract.jl:13
[3] value_gradient!!(obj::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, x::Vector{Float64})
@ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/interface.jl:82
[4] initial_state(method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Flat}, options::Optim.Options{Float64, Nothing}, d::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, initial_x::Vector{Float64})
@ Optim ~/.julia/packages/Optim/HvjCd/src/multivariate/solvers/first_order/bfgs.jl:94
[5] optimize
@ ~/.julia/packages/Optim/HvjCd/src/multivariate/optimize/optimize.jl:36 [inlined]
[6] optimize(f::Function, g::Function, initial_x::Vector{Float64}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Flat}, options::Optim.Options{Float64, Nothing}; inplace::Bool, autodiff::Symbol)
@ Optim ~/.julia/packages/Optim/HvjCd/src/multivariate/optimize/interface.jl:156
[7] optimize(f::Function, g::Function, initial_x::Vector{Float64}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Flat}, options::Optim.Options{Float64, Nothing})
@ Optim ~/.julia/packages/Optim/HvjCd/src/multivariate/optimize/interface.jl:151
[8] top-level scope
@ In[8]:9