In [10]:
using Flux
using Random, LinearAlgebra

# Set seed for reproducibility
Random.seed!(42)

# Dimensions
m, n, r = 10, 8, 8  # A is m x n, X is m x k, Y is k x n

# Generate a random matrix A
A = randn(m, n)

# Initialize learnable factors
X = randn(m, r)
Y = randn(r, n)

function custom_loss(A, X, Y)
    return norm(A-X*Y)^2
    #return norm(A-X*Y)^2 + norm(X)^2 + norm(Y)^2    
end 

model_loss = (X,Y) -> custom_loss(A,X,Y)

# Training loop
niterations = 10000
for epoch in 1:niterations

    gX, gY = gradient(model_loss, X, Y)
    
    @. X = X - 0.01*gX
    @. Y = Y - 0.01*gY
    
    # Print loss every 50 epochs
    if epoch % 50 == 0
        println("Iterations $epoch: Loss = ", model_loss(X, Y))
    end
end

opterr = model_loss(X,Y)

# Display learned X and Y
Usvd,Ssvd,Vsvd = svd(A)
svderr = norm(A-Usvd[:,1:r]*Diagonal(Ssvd[1:r])*Vsvd[:,1:r]')^2

@show opterr
@show svderr
Iterations 50: Loss = 2.091547823016724
Iterations 100: Loss = 0.5661487081197589
Iterations 150: Loss = 0.37969498421413045
Iterations 200: Loss = 0.29424961525867044
Iterations 250: Loss = 0.23736222031632534
Iterations 300: Loss = 0.1964354866301492
Iterations 350: Loss = 0.16601387453115432
Iterations 400: Loss = 0.14287411770926364
Iterations 450: Loss = 0.1249276797918005
Iterations 500: Loss = 0.11076102657341395
Iterations 550: Loss = 0.09938956930058786
Iterations 600: Loss = 0.09011258577555409
Iterations 650: Loss = 0.08242286358561295
Iterations 700: Loss = 0.07594817484624594
Iterations 750: Loss = 0.07041219966566804
Iterations 800: Loss = 0.06560780146151443
Iterations 850: Loss = 0.06137839646849488
Iterations 900: Loss = 0.05760476277001347
Iterations 950: Loss = 0.05419557903563848
Iterations 1000: Loss = 0.05108056173010772
Iterations 1050: Loss = 0.048205435866809476
Iterations 1100: Loss = 0.04552821297055821
Iterations 1150: Loss = 0.04301640904246799
Iterations 1200: Loss = 0.040644943537536414
Iterations 1250: Loss = 0.038394535145568016
Iterations 1300: Loss = 0.036250462521420154
Iterations 1350: Loss = 0.034201595160950955
Iterations 1400: Loss = 0.032239626065311194
Iterations 1450: Loss = 0.030358456836600548
Iterations 1500: Loss = 0.02855369955562033
Iterations 1550: Loss = 0.026822269694929714
Iterations 1600: Loss = 0.025162051458471984
Iterations 1650: Loss = 0.023571622051321215
Iterations 1700: Loss = 0.022050025002258982
Iterations 1750: Loss = 0.020596585179147527
Iterations 1800: Loss = 0.01921075984538795
Iterations 1850: Loss = 0.017892021228549455
Iterations 1900: Loss = 0.016639766782187856
Iterations 1950: Loss = 0.015453253753083748
Iterations 2000: Loss = 0.014331554922186804
Iterations 2050: Loss = 0.013273532546789995
Iterations 2100: Loss = 0.012277827650879621
Iterations 2150: Loss = 0.011342861929083665
Iterations 2200: Loss = 0.010466849670845874
Iterations 2250: Loss = 0.009647817286751566
Iterations 2300: Loss = 0.008883628230206997
Iterations 2350: Loss = 0.008172011350082477
Iterations 2400: Loss = 0.00751059097438486
Iterations 2450: Loss = 0.006896917300391609
Iterations 2500: Loss = 0.006328495941524063
Iterations 2550: Loss = 0.0058028157451007
Iterations 2600: Loss = 0.005317374239396081
Iterations 2650: Loss = 0.0048697002867813515
Iterations 2700: Loss = 0.0044573737081773914
Iterations 2750: Loss = 0.004078041800864097
Iterations 2800: Loss = 0.003729432797022102
Iterations 2850: Loss = 0.003409366405908848
Iterations 2900: Loss = 0.003115761651040025
Iterations 2950: Loss = 0.0028466422585831153
Iterations 3000: Loss = 0.0026001398781133342
Iterations 3050: Loss = 0.0023744954257074163
Iterations 3100: Loss = 0.0021680588356420283
Iterations 3150: Loss = 0.001979287493983368
Iterations 3200: Loss = 0.0018067436079380144
Iterations 3250: Loss = 0.0016490907413761518
Iterations 3300: Loss = 0.0015050897213698617
Iterations 3350: Loss = 0.001373594094424699
Iterations 3400: Loss = 0.0012535452854519839
Iterations 3450: Loss = 0.0011439675882342335
Iterations 3500: Loss = 0.0010439630937052696
Iterations 3550: Loss = 0.0009527066421060979
Iterations 3600: Loss = 0.0008694408671250296
Iterations 3650: Loss = 0.0007934713844947298
Iterations 3700: Loss = 0.0007241621641145677
Iterations 3750: Loss = 0.0006609311134525298
Iterations 3800: Loss = 0.000603245890573871
Iterations 3850: Loss = 0.0005506199574433729
Iterations 3900: Loss = 0.0005026088779495058
Iterations 3950: Loss = 0.0004588068602014448
Iterations 4000: Loss = 0.0004188435388616766
Iterations 4050: Loss = 0.00038238099042634487
Iterations 4100: Loss = 0.0003491109722921667
Iterations 4150: Loss = 0.00031875237501530706
Iterations 4200: Loss = 0.0002910488762518231
Iterations 4250: Loss = 0.00026576678436702426
Iterations 4300: Loss = 0.00024269305952397344
Iterations 4350: Loss = 0.00022163350013455508
Iterations 4400: Loss = 0.0002024110828191659
Iterations 4450: Loss = 0.00018486444442131878
Iterations 4500: Loss = 0.00016884649512177784
Iterations 4550: Loss = 0.0001542231522587359
Iterations 4600: Loss = 0.0001408721850611429
Iterations 4650: Loss = 0.0001286821611204695
Iterations 4700: Loss = 0.00011755148604680225
Iterations 4750: Loss = 0.00010738752836612796
Iterations 4800: Loss = 9.810582230878433e-5
Iterations 4850: Loss = 8.962934170755446e-5
Iterations 4900: Loss = 8.188783876490361e-5
Iterations 4950: Loss = 7.481724195830831e-5
Iterations 5000: Loss = 6.835910783088071e-5
Iterations 5050: Loss = 6.246012185993381e-5
Iterations 5100: Loss = 5.707164400986692e-5
Iterations 5150: Loss = 5.214929495828022e-5
Iterations 5200: Loss = 4.765257933709538e-5
Iterations 5250: Loss = 4.3544542654561886e-5
Iterations 5300: Loss = 3.979145886186964e-5
Iterations 5350: Loss = 3.636254580030131e-5
Iterations 5400: Loss = 3.322970601405161e-5
Iterations 5450: Loss = 3.0367290641250534e-5
Iterations 5500: Loss = 2.7751884302955842e-5
Iterations 5550: Loss = 2.5362109098887773e-5
Iterations 5600: Loss = 2.3178445990580994e-5
Iterations 5650: Loss = 2.1183072008975297e-5
Iterations 5700: Loss = 1.9359711865829693e-5
Iterations 5750: Loss = 1.7693502677433463e-5
Iterations 5800: Loss = 1.617087062660398e-5
Iterations 5850: Loss = 1.4779418495651612e-5
Iterations 5900: Loss = 1.3507823099787599e-5
Iterations 5950: Loss = 1.2345741738505414e-5
Iterations 6000: Loss = 1.1283726862300677e-5
Iterations 6050: Loss = 1.0313148224707296e-5
Iterations 6100: Loss = 9.426121855429314e-6
Iterations 6150: Loss = 8.61544525028296e-6
Iterations 6200: Loss = 7.874538227934553e-6
Iterations 6250: Loss = 7.1973889528499404e-6
Iterations 6300: Loss = 6.578504668725269e-6
Iterations 6350: Loss = 6.012866727414332e-6
Iterations 6400: Loss = 5.495889535460493e-6
Iterations 6450: Loss = 5.023383074014882e-6
Iterations 6500: Loss = 4.591518678525342e-6
Iterations 6550: Loss = 4.196797792514818e-6
Iterations 6600: Loss = 3.836023435009318e-6
Iterations 6650: Loss = 3.5062741443110902e-6
Iterations 6700: Loss = 3.2048801817677257e-6
Iterations 6750: Loss = 2.929401798262728e-6
Iterations 6800: Loss = 2.6776093835862707e-6
Iterations 6850: Loss = 2.4474653346242448e-6
Iterations 6900: Loss = 2.2371074927645178e-6
Iterations 6950: Loss = 2.0448340140349103e-6
Iterations 7000: Loss = 1.869089547444632e-6
Iterations 7050: Loss = 1.7084526079274066e-6
Iterations 7100: Loss = 1.5616240401995408e-6
Iterations 7150: Loss = 1.4274164789251081e-6
Iterations 7200: Loss = 1.3047447188194008e-6
Iterations 7250: Loss = 1.1926169158745868e-6
Iterations 7300: Loss = 1.090126547732652e-6
Iterations 7350: Loss = 9.964450675236286e-7
Iterations 7400: Loss = 9.10815191176191e-7
Iterations 7450: Loss = 8.325447634320019e-7
Iterations 7500: Loss = 7.610011525613419e-7
Iterations 7550: Loss = 6.956061280851128e-7
Iterations 7600: Loss = 6.358311798147705e-7
Iterations 7650: Loss = 5.811932400970531e-7
Iterations 7700: Loss = 5.312507744708986e-7
Iterations 7750: Loss = 4.856002089583236e-7
Iterations 7800: Loss = 4.4387266494429344e-7
Iterations 7850: Loss = 4.057309751334595e-7
Iterations 7900: Loss = 3.708669563458725e-7
Iterations 7950: Loss = 3.3899891702201004e-7
Iterations 8000: Loss = 3.0986937920724994e-7
Iterations 8050: Loss = 2.8324299654339726e-7
Iterations 8100: Loss = 2.589046513808985e-7
Iterations 8150: Loss = 2.3665771558951264e-7
Iterations 8200: Loss = 2.163224609684961e-7
Iterations 8250: Loss = 1.9773460638062864e-7
Iterations 8300: Loss = 1.807439898387137e-7
Iterations 8350: Loss = 1.6521335479008007e-7
Iterations 8400: Loss = 1.51017240774682e-7
Iterations 8450: Loss = 1.3804096947034427e-7
Iterations 8500: Loss = 1.2617971792249396e-7
Iterations 8550: Loss = 1.1533767145725407e-7
Iterations 8600: Loss = 1.0542724942100497e-7
Iterations 8650: Loss = 9.636839748385193e-8
Iterations 8700: Loss = 8.80879407835405e-8
Iterations 8750: Loss = 8.051899267420461e-8
Iterations 8800: Loss = 7.360041429922142e-8
Iterations 8850: Loss = 6.727632062050114e-8
Iterations 8900: Loss = 6.149562890441936e-8
Iterations 8950: Loss = 5.621164601953549e-8
Iterations 9000: Loss = 5.1381691203743314e-8
Iterations 9050: Loss = 4.69667512569366e-8
Iterations 9100: Loss = 4.2931165366594105e-8
Iterations 9150: Loss = 3.9242337021973384e-8
Iterations 9200: Loss = 3.587047068731033e-8
Iterations 9250: Loss = 3.278833110679159e-8
Iterations 9300: Loss = 2.9971023292656866e-8
Iterations 9350: Loss = 2.7395791423329467e-8
Iterations 9400: Loss = 2.504183502269875e-8
Iterations 9450: Loss = 2.2890140936637582e-8
Iterations 9500: Loss = 2.092332974850959e-8
Iterations 9550: Loss = 1.9125515393841845e-8
Iterations 9600: Loss = 1.748217683727073e-8
Iterations 9650: Loss = 1.5980040778726607e-8
Iterations 9700: Loss = 1.4606974437213959e-8
Iterations 9750: Loss = 1.3351887548844835e-8
Iterations 9800: Loss = 1.2204642786063424e-8
Iterations 9850: Loss = 1.1155973874523015e-8
Iterations 9900: Loss = 1.0197410746274112e-8
Iterations 9950: Loss = 9.321211124555337e-9
Iterations 10000: Loss = 8.520297987205108e-9
opterr = 8.520297987205108e-9
svderr = 1.2555705456793983e-28
Out[10]:
1.2555705456793983e-28
In [3]:
# This will evaluate the gradient of norm(A-XY^2) at a point X, Y
# Let's check against my "analytical" gradient...
X = ones(m,r)
Y = ones(r,n)
gX, gY = gradient(model_loss, X, Y)
gX
gX + (2*(A - X*Y)*Y')
Out[3]:
10×8 Matrix{Float64}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
In [29]:
norm(X), norm(Y)
Out[29]:
(3.8030981067374814, 3.803098106737481)
In [30]:
norm(A-X*Y)
Out[30]:
2.6586124701944938
In [5]:
# Dimensions
m, n, r = 10, 8, 8  # A is m x n, X is m x k, Y is k x n

# Generate a random matrix A
A = randn(m, n)

# Initialize learnable factors
X = randn(m, r)
Y = randn(r, n)

function full_function(A, X, Y)
    #return norm(A-X*Y)^2
    return norm(A-X*Y)^2 + norm(X)^2 + norm(Y)^2    
end 

optimization_function = (X,Y) -> full_function(A,X,Y)

opt = Optimisers.Adam(0.01)
state = Optimisers.setup(opt, (X,Y))

# Training loop (Flux calls this epochs, but I don't like confusing them...)
niteration = 5000
for iteration in 1:niteration
    #@show X, Y

    grads = gradient(optimization_function, X, Y)
    
    state, (X,Y) = Optimisers.update(state, (X,Y), grads)
    
    # Print loss every 50 epochs
    if iteration % 50 == 0
        println("Iteration $iteration: Loss = ", optimization_function(X, Y))
    end
end

@show opterr = optimization_function(X,Y)
@show norm(A-X*Y)^2
Iteration 50: Loss = 220.904522671741
Iteration 100: Loss = 114.74149638440048
Iteration 150: Loss = 82.61458301083897
Iteration 200: Loss = 66.22401683960052
Iteration 250: Loss = 55.52050523200967
Iteration 300: Loss = 48.703229900422
Iteration 350: Loss = 44.8066391176972
Iteration 400: Loss = 42.486276770455035
Iteration 450: Loss = 41.04934383794148
Iteration 500: Loss = 40.162565316476744
Iteration 550: Loss = 39.62032106042476
Iteration 600: Loss = 39.290744995720395
Iteration 650: Loss = 39.09141976721878
Iteration 700: Loss = 38.97136555408673
Iteration 750: Loss = 38.89915824094566
Iteration 800: Loss = 38.85560367418039
Iteration 850: Loss = 38.82913678944861
Iteration 900: Loss = 38.81287419558346
Iteration 950: Loss = 38.802746043355576
Iteration 1000: Loss = 38.796345197509225
Iteration 1050: Loss = 38.79223869079882
Iteration 1100: Loss = 38.7895642781357
Iteration 1150: Loss = 38.78779636818757
Iteration 1200: Loss = 38.78661019243293
Iteration 1250: Loss = 38.78580237375574
Iteration 1300: Loss = 38.785243916077846
Iteration 1350: Loss = 38.784851999650215
Iteration 1400: Loss = 38.78457282622818
Iteration 1450: Loss = 38.7843710482461
Iteration 1500: Loss = 38.78422317465724
Iteration 1550: Loss = 38.78411340988787
Iteration 1600: Loss = 38.78403100019623
Iteration 1650: Loss = 38.783968525312844
Iteration 1700: Loss = 38.783920789658325
Iteration 1750: Loss = 38.783884097698376
Iteration 1800: Loss = 38.78385577739258
Iteration 1850: Loss = 38.7838338648007
Iteration 1900: Loss = 38.78381689377257
Iteration 1950: Loss = 38.78380375434026
Iteration 2000: Loss = 38.783793596150886
Iteration 2050: Loss = 38.783785761554384
Iteration 2100: Loss = 38.783779738357424
Iteration 2150: Loss = 38.78377512576706
Iteration 2200: Loss = 38.78377160931973
Iteration 2250: Loss = 38.783768942054685
Iteration 2300: Loss = 38.78376693012621
Iteration 2350: Loss = 38.783765421647516
Iteration 2400: Loss = 38.783764297942014
Iteration 2450: Loss = 38.78376346662529
Iteration 2500: Loss = 38.78376285610403
Iteration 2550: Loss = 38.783762411186814
Iteration 2600: Loss = 38.78376208957735
Iteration 2650: Loss = 38.7837618590728
Iteration 2700: Loss = 38.78376169532962
Iteration 2750: Loss = 38.78376158008687
Iteration 2800: Loss = 38.783761499759095
Iteration 2850: Loss = 38.783761444327666
Iteration 2900: Loss = 38.783761406472856
Iteration 2950: Loss = 38.78376138089887
Iteration 3000: Loss = 38.78376136381347
Iteration 3050: Loss = 38.78376135253012
Iteration 3100: Loss = 38.78376134516678
Iteration 3150: Loss = 38.78376134042032
Iteration 3200: Loss = 38.7837613373993
Iteration 3250: Loss = 38.783761335501424
Iteration 3300: Loss = 38.783761334325106
Iteration 3350: Loss = 38.78376133360603
Iteration 3400: Loss = 38.783761333172706
Iteration 3450: Loss = 38.78376133291539
Iteration 3500: Loss = 38.78376133276486
Iteration 3550: Loss = 38.78376133267817
Iteration 3600: Loss = 38.783761332629055
Iteration 3650: Loss = 38.78376133260166
Iteration 3700: Loss = 38.78376133258663
Iteration 3750: Loss = 38.78376133257855
Iteration 3800: Loss = 38.78376133257426
Iteration 3850: Loss = 38.78376133257205
Iteration 3900: Loss = 38.7837613325709
Iteration 3950: Loss = 38.783761332570336
Iteration 4000: Loss = 38.78376133257007
Iteration 4050: Loss = 38.783761332569924
Iteration 4100: Loss = 38.78376133256985
Iteration 4150: Loss = 38.783761332569846
Iteration 4200: Loss = 38.783761332569824
Iteration 4250: Loss = 38.78376133256981
Iteration 4300: Loss = 38.783761332569824
Iteration 4350: Loss = 38.78376133256981
Iteration 4400: Loss = 38.78376133256982
Iteration 4450: Loss = 38.78376133256982
Iteration 4500: Loss = 38.7837613325698
Iteration 4550: Loss = 38.78376133256981
Iteration 4600: Loss = 38.78376133256983
Iteration 4650: Loss = 38.78376133256981
Iteration 4700: Loss = 38.78376133256982
Iteration 4750: Loss = 38.78376133256981
Iteration 4800: Loss = 38.7837613325698
Iteration 4850: Loss = 38.78376133256981
Iteration 4900: Loss = 38.78376133256981
Iteration 4950: Loss = 38.78376133256981
Iteration 5000: Loss = 38.78376133256981
opterr = optimization_function(X, Y) = 38.78376133256981
norm(A - X * Y) ^ 2 = 7.192773609225721
Out[5]:
7.192773609225721
In [6]:
using Optim
In [8]:
m = 10
n = 8
A = randn(m,n)
#A = Matrix(1.0I,m,n)
r = 2
myf = x -> matrix_approx_function(x, A, r)
myg! = (x, storage) -> matrix_approx_gradient!(x, storage, A, r)

soln = optimize(myf, myg!, ones(m*r+n*r), BFGS(), Optim.Options(f_tol = 1e-8))
#soln = optimize(myf, myg!, randn(m*r+n*r), BFGS(), Optim.Options(f_tol = 1e-8))
x = Optim.minimizer(soln)
@show soln
Uopt = reshape(x[(1:m*r)],m,r)
Vopt = reshape(x[(m*r+1):end],n,r)
objval = 2*myf(x)
opterr = norm(A-Uopt*Vopt')^2

Usvd,Ssvd,Vsvd = svd(A)
svderr = norm(A-Usvd[:,1:r]*Diagonal(Ssvd[1:r])*Vsvd[:,1:r]')^2
@show objval
@show opterr
@show svderr
; # hide final output in JuliaBox
UndefVarError: `matrix_approx_gradient!` not defined in `Main`
Suggestion: check for spelling errors or missing imports.

Stacktrace:
 [1] (::var"#15#16")(x::Vector{Float64}, storage::Vector{Float64})
   @ Main ./In[8]:7
 [2] (::NLSolversBase.var"#fg!#8"{var"#13#14", var"#15#16"})(gx::Vector{Float64}, x::Vector{Float64})
   @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/objective_types/abstract.jl:13
 [3] value_gradient!!(obj::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, x::Vector{Float64})
   @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/interface.jl:82
 [4] initial_state(method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Flat}, options::Optim.Options{Float64, Nothing}, d::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, initial_x::Vector{Float64})
   @ Optim ~/.julia/packages/Optim/HvjCd/src/multivariate/solvers/first_order/bfgs.jl:94
 [5] optimize
   @ ~/.julia/packages/Optim/HvjCd/src/multivariate/optimize/optimize.jl:36 [inlined]
 [6] optimize(f::Function, g::Function, initial_x::Vector{Float64}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Flat}, options::Optim.Options{Float64, Nothing}; inplace::Bool, autodiff::Symbol)
   @ Optim ~/.julia/packages/Optim/HvjCd/src/multivariate/optimize/interface.jl:156
 [7] optimize(f::Function, g::Function, initial_x::Vector{Float64}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Flat}, options::Optim.Options{Float64, Nothing})
   @ Optim ~/.julia/packages/Optim/HvjCd/src/multivariate/optimize/interface.jl:151
 [8] top-level scope
   @ In[8]:9