In [37]:
using Flux
using Random, LinearAlgebra

# Set seed for reproducibility
Random.seed!(42)

# Dimensions
m, n, r = 10, 8, 8  # A is m x n, X is m x k, Y is k x n

# Generate a random matrix A
A = randn(m, n)

# Initialize learnable factors
X = randn(m, r)
Y = randn(r, n)

function custom_loss(A, X, Y)
    return norm(A-X*Y)^2
    #return norm(A-X*Y)^2 + norm(X)^2 + norm(Y)^2    
end 

model_loss = (X,Y) -> custom_loss(A,X,Y)

# Training loop
niterations = 5000
for epoch in 1:niterations

    gX, gY = gradient(model_loss, X, Y)
    
    @. X = X - 0.01*gX
    @. Y = Y - 0.01*gY
    
    # Print loss every 50 epochs
    if epoch % 50 == 0
        println("Epoch $epoch: Loss = ", model_loss(X, Y))
    end
end

opterr = model_loss(X,Y)

# Display learned X and Y
Usvd,Ssvd,Vsvd = svd(A)
svderr = norm(A-Usvd[:,1:r]*Diagonal(Ssvd[1:r])*Vsvd[:,1:r]')^2

@show opterr
@show svderr
Epoch 50: Loss = 2.091547823016724
Epoch 100: Loss = 0.5661487081197589
Epoch 150: Loss = 0.37969498421413045
Epoch 200: Loss = 0.29424961525867044
Epoch 250: Loss = 0.23736222031632534
Epoch 300: Loss = 0.1964354866301492
Epoch 350: Loss = 0.16601387453115432
Epoch 400: Loss = 0.14287411770926364
Epoch 450: Loss = 0.1249276797918005
Epoch 500: Loss = 0.11076102657341395
Epoch 550: Loss = 0.09938956930058786
Epoch 600: Loss = 0.09011258577555409
Epoch 650: Loss = 0.08242286358561295
Epoch 700: Loss = 0.07594817484624594
Epoch 750: Loss = 0.07041219966566804
Epoch 800: Loss = 0.06560780146151443
Epoch 850: Loss = 0.06137839646849488
Epoch 900: Loss = 0.05760476277001347
Epoch 950: Loss = 0.05419557903563848
Epoch 1000: Loss = 0.05108056173010772
Epoch 1050: Loss = 0.048205435866809476
Epoch 1100: Loss = 0.04552821297055821
Epoch 1150: Loss = 0.04301640904246799
Epoch 1200: Loss = 0.040644943537536414
Epoch 1250: Loss = 0.038394535145568016
Epoch 1300: Loss = 0.036250462521420154
Epoch 1350: Loss = 0.034201595160950955
Epoch 1400: Loss = 0.032239626065311194
Epoch 1450: Loss = 0.030358456836600548
Epoch 1500: Loss = 0.02855369955562033
Epoch 1550: Loss = 0.026822269694929714
Epoch 1600: Loss = 0.025162051458471984
Epoch 1650: Loss = 0.023571622051321215
Epoch 1700: Loss = 0.022050025002258982
Epoch 1750: Loss = 0.020596585179147527
Epoch 1800: Loss = 0.01921075984538795
Epoch 1850: Loss = 0.017892021228549455
Epoch 1900: Loss = 0.016639766782187856
Epoch 1950: Loss = 0.015453253753083748
Epoch 2000: Loss = 0.014331554922186804
Epoch 2050: Loss = 0.013273532546789995
Epoch 2100: Loss = 0.012277827650879621
Epoch 2150: Loss = 0.011342861929083665
Epoch 2200: Loss = 0.010466849670845874
Epoch 2250: Loss = 0.009647817286751566
Epoch 2300: Loss = 0.008883628230206997
Epoch 2350: Loss = 0.008172011350082477
Epoch 2400: Loss = 0.00751059097438486
Epoch 2450: Loss = 0.006896917300391609
Epoch 2500: Loss = 0.006328495941524063
Epoch 2550: Loss = 0.0058028157451007
Epoch 2600: Loss = 0.005317374239396081
Epoch 2650: Loss = 0.0048697002867813515
Epoch 2700: Loss = 0.0044573737081773914
Epoch 2750: Loss = 0.004078041800864097
Epoch 2800: Loss = 0.003729432797022102
Epoch 2850: Loss = 0.003409366405908848
Epoch 2900: Loss = 0.003115761651040025
Epoch 2950: Loss = 0.0028466422585831153
Epoch 3000: Loss = 0.0026001398781133342
Epoch 3050: Loss = 0.0023744954257074163
Epoch 3100: Loss = 0.0021680588356420283
Epoch 3150: Loss = 0.001979287493983368
Epoch 3200: Loss = 0.0018067436079380144
Epoch 3250: Loss = 0.0016490907413761518
Epoch 3300: Loss = 0.0015050897213698617
Epoch 3350: Loss = 0.001373594094424699
Epoch 3400: Loss = 0.0012535452854519839
Epoch 3450: Loss = 0.0011439675882342335
Epoch 3500: Loss = 0.0010439630937052696
Epoch 3550: Loss = 0.0009527066421060979
Epoch 3600: Loss = 0.0008694408671250296
Epoch 3650: Loss = 0.0007934713844947298
Epoch 3700: Loss = 0.0007241621641145677
Epoch 3750: Loss = 0.0006609311134525298
Epoch 3800: Loss = 0.000603245890573871
Epoch 3850: Loss = 0.0005506199574433729
Epoch 3900: Loss = 0.0005026088779495058
Epoch 3950: Loss = 0.0004588068602014448
Epoch 4000: Loss = 0.0004188435388616766
Epoch 4050: Loss = 0.00038238099042634487
Epoch 4100: Loss = 0.0003491109722921667
Epoch 4150: Loss = 0.00031875237501530706
Epoch 4200: Loss = 0.0002910488762518231
Epoch 4250: Loss = 0.00026576678436702426
Epoch 4300: Loss = 0.00024269305952397344
Epoch 4350: Loss = 0.00022163350013455508
Epoch 4400: Loss = 0.0002024110828191659
Epoch 4450: Loss = 0.00018486444442131878
Epoch 4500: Loss = 0.00016884649512177784
Epoch 4550: Loss = 0.0001542231522587359
Epoch 4600: Loss = 0.0001408721850611429
Epoch 4650: Loss = 0.0001286821611204695
Epoch 4700: Loss = 0.00011755148604680225
Epoch 4750: Loss = 0.00010738752836612796
Epoch 4800: Loss = 9.810582230878433e-5
Epoch 4850: Loss = 8.962934170755446e-5
Epoch 4900: Loss = 8.188783876490361e-5
Epoch 4950: Loss = 7.481724195830831e-5
Epoch 5000: Loss = 6.835910783088071e-5
opterr = 6.835910783088071e-5
svderr = 1.2555705456793983e-28
Out[37]:
1.2555705456793983e-28
In [46]:
# This will evaluate the gradient of norm(A-XY^2) at a point X, Y
# Let's check against my "analytical" gradient...
X = ones(m,r)
Y = ones(r,n)
gX, gY = gradient(model_loss, X, Y)
gX
gX + (2*(A - X*Y)*Y')
Out[46]:
10×8 Matrix{Float64}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
In [29]:
norm(X), norm(Y)
Out[29]:
(3.8030981067374814, 3.803098106737481)
In [30]:
norm(A-X*Y)
Out[30]:
2.6586124701944938
In [53]:
# Dimensions
m, n, r = 10, 8, 8  # A is m x n, X is m x k, Y is k x n

# Generate a random matrix A
A = randn(m, n)

# Initialize learnable factors
X = randn(m, r)
Y = randn(r, n)

function full_function(A, X, Y)
    #return norm(A-X*Y)^2
    return norm(A-X*Y)^2 + norm(X)^2 + norm(Y)^2    
end 

optimization_function = (X,Y) -> full_function(A,X,Y)

opt = Optimisers.Adam(0.01)
state = Optimisers.setup(opt, (X,Y))

# Training loop (Flux calls this epochs, but I don't like confusing them...)
niteration = 5000
for iteration in 1:niteration
    #@show X, Y

    grads = gradient(optimization_function, X, Y)
    
    state, (X,Y) = Optimisers.update(state, (X,Y), grads)
    
    # Print loss every 50 epochs
    if iteration % 50 == 0
        println("Iteration $iteration: Loss = ", optimization_function(X, Y))
    end
end

@show opterr = optimization_function(X,Y)
@show norm(A-X*Y)^2
Iteration 50: Loss = 179.60227608291302
Iteration 100: Loss = 95.79193142197144
Iteration 150: Loss = 71.33033035163707
Iteration 200: Loss = 57.27653930501515
Iteration 250: Loss = 47.99503153083923
Iteration 300: Loss = 41.955722482084234
Iteration 350: Loss = 38.00203752449347
Iteration 400: Loss = 35.352331740381516
Iteration 450: Loss = 33.566439779119
Iteration 500: Loss = 32.37618085148305
Iteration 550: Loss = 31.59633605079715
Iteration 600: Loss = 31.093976245108976
Iteration 650: Loss = 30.775308464352666
Iteration 700: Loss = 30.576066285058488
Iteration 750: Loss = 30.453310272875175
Iteration 800: Loss = 30.378869803582386
Iteration 850: Loss = 30.334512046759293
Iteration 900: Loss = 30.308581208014253
Iteration 950: Loss = 30.293729764249335
Iteration 1000: Loss = 30.285404187704813
Iteration 1050: Loss = 30.280838545929953
Iteration 1100: Loss = 30.278390015150972
Iteration 1150: Loss = 30.277105947348623
Iteration 1200: Loss = 30.276447430899296
Iteration 1250: Loss = 30.276117146388895
Iteration 1300: Loss = 30.275955111934994
Iteration 1350: Loss = 30.275877349604485
Iteration 1400: Loss = 30.275840841097125
Iteration 1450: Loss = 30.275824073791345
Iteration 1500: Loss = 30.275816541832693
Iteration 1550: Loss = 30.27581323357476
Iteration 1600: Loss = 30.27581181339558
Iteration 1650: Loss = 30.27581121790305
Iteration 1700: Loss = 30.275810974195856
Iteration 1750: Loss = 30.275810876937964
Iteration 1800: Loss = 30.275810839129065
Iteration 1850: Loss = 30.2758108248278
Iteration 1900: Loss = 30.275810819570978
Iteration 1950: Loss = 30.275810817695717
Iteration 2000: Loss = 30.275810817047393
Iteration 2050: Loss = 30.27581081683048
Iteration 2100: Loss = 30.275810816760345
Iteration 2150: Loss = 30.275810816738463
Iteration 2200: Loss = 30.275810816731894
Iteration 2250: Loss = 30.275810816729994
Iteration 2300: Loss = 30.27581081672946
Iteration 2350: Loss = 30.27581081672932
Iteration 2400: Loss = 30.275810816729283
Iteration 2450: Loss = 30.275810816729273
Iteration 2500: Loss = 30.275810816729276
Iteration 2550: Loss = 30.275810816729273
Iteration 2600: Loss = 30.275810816729283
Iteration 2650: Loss = 30.27581081672927
Iteration 2700: Loss = 30.275810816729287
Iteration 2750: Loss = 30.275810816729276
Iteration 2800: Loss = 30.27581081672927
Iteration 2850: Loss = 30.27581081672927
Iteration 2900: Loss = 30.275810816729283
Iteration 2950: Loss = 30.275810816729273
Iteration 3000: Loss = 30.275810816729276
Iteration 3050: Loss = 30.275810816729276
Iteration 3100: Loss = 30.275810816729276
Iteration 3150: Loss = 30.275810816729276
Iteration 3200: Loss = 30.275810816729273
Iteration 3250: Loss = 30.275810816729273
Iteration 3300: Loss = 30.275810816729276
Iteration 3350: Loss = 30.275810816729276
Iteration 3400: Loss = 30.275810816729276
Iteration 3450: Loss = 30.275810816729276
Iteration 3500: Loss = 30.27581081672927
Iteration 3550: Loss = 30.27581081672927
Iteration 3600: Loss = 30.27581081672927
Iteration 3650: Loss = 30.27581081672927
Iteration 3700: Loss = 30.275810816729276
Iteration 3750: Loss = 30.275810816729276
Iteration 3800: Loss = 30.275810816729276
Iteration 3850: Loss = 30.275810816729276
Iteration 3900: Loss = 30.275810816729276
Iteration 3950: Loss = 30.27581081672927
Iteration 4000: Loss = 30.275810816729276
Iteration 4050: Loss = 30.27581081672927
Iteration 4100: Loss = 30.275810816729276
Iteration 4150: Loss = 30.275810816729276
Iteration 4200: Loss = 30.275810816729276
Iteration 4250: Loss = 30.275810816729276
Iteration 4300: Loss = 30.275810816729276
Iteration 4350: Loss = 30.275810816729276
Iteration 4400: Loss = 30.27581081672928
Iteration 4450: Loss = 30.275810816729276
Iteration 4500: Loss = 30.275810816729276
Iteration 4550: Loss = 30.275810816729276
Iteration 4600: Loss = 30.275810816729283
Iteration 4650: Loss = 30.275810816729276
Iteration 4700: Loss = 30.275810816729283
Iteration 4750: Loss = 30.275810816729283
Iteration 4800: Loss = 30.275810816729283
Iteration 4850: Loss = 30.275810816729276
Iteration 4900: Loss = 30.275810816729283
Iteration 4950: Loss = 30.275810816729283
Iteration 5000: Loss = 30.275810816729283
opterr = optimization_function(X, Y) = 30.275810816729283
norm(A - X * Y) ^ 2 = 7.141659826689626
Out[53]:
7.141659826689626
In [23]:
using Optim
WARNING: using Optim.Adam in module Main conflicts with an existing identifier.
In [ ]:
m = 10
n = 8
A = randn(m,n)
#A = Matrix(1.0I,m,n)
r = 2
myf = x -> matrix_approx_function(x, A, r)
myg! = (x, storage) -> matrix_approx_gradient!(x, storage, A, r)

soln = optimize(myf, myg!, ones(m*r+n*r), BFGS(), Optim.Options(f_tol = 1e-8))
#soln = optimize(myf, myg!, randn(m*r+n*r), BFGS(), Optim.Options(f_tol = 1e-8))
x = Optim.minimizer(soln)
@show soln
Uopt = reshape(x[(1:m*r)],m,r)
Vopt = reshape(x[(m*r+1):end],n,r)
objval = 2*myf(x)
opterr = norm(A-Uopt*Vopt')^2

Usvd,Ssvd,Vsvd = svd(A)
svderr = norm(A-Usvd[:,1:r]*Diagonal(Ssvd[1:r])*Vsvd[:,1:r]')^2
@show objval
@show opterr
@show svderr
; # hide final output in JuliaBox