#WARNING: This is a naive implementation of Memory-based
#Collaborative Filtering.  It only returns the prediction
#for user 22 movie 377.  Tested under pyspark 2.4.3
#Copyright 2019 Christopher W. Clifton

from __future__ import print_function

import sys
import numpy as np
from operator import add
from pyspark.sql import SparkSession

def parseRating(line):
    # Parses a rating record in MovieLens format
    # userId <tab> movieId <tab> rating <tab> timestamp.
    fields = line.strip().split()
    return np.array([float(x) for x in line.split('	')])


def userTotal(rating1, rating2):
    return (rating1[0]+rating2[0], rating1[1]+rating2[1], rating1[2]+rating2[2])
    
if __name__ == "__main__":

    if len(sys.argv) != 2:
        print("Usage: mbcf <file>", file=sys.stderr)
        sys.exit(-1)

    spark = SparkSession\
            .builder\
            .appName("PythonMemCF")\
            .getOrCreate()

    ratingsin =  spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
    data = ratingsin.map(parseRating).cache()

    # Get the sum of ratings, sum of squared ratings, and total ratings per user

    userTotals = data.map(lambda u: (u[0], (u[2], u[2]**2, 1))) \
                     .reduceByKey(userTotal)

    # Extract the User 22 totals and movie ratings from this list
    user22Totals = userTotals.lookup(22)

    user22Movies = data.filter(lambda u : u[0] == 22)

    # Compute vector space user weights (numerator first, the join with totals)

    userWeightNum = data.filter(lambda u : u[0] != 22) \
                         .map(lambda m: (m[1], (m[0], m[2]))) \
                         .join(user22Movies.map(lambda m: (m[1], m[2]))) \
                         .map(lambda u: (u[1][0][0], u[1][0][1]*u[1][1])) \
                         .foldByKey(0, add)

    userWeight = userWeightNum.join(userTotals) \
                 .map(lambda u: (u[0], u[1][0]\
                                 /(np.sqrt(user22Totals[0][1])
                                   *np.sqrt(u[1][1][1]))))

    # Get ratings for Movie 377, along with user weights, totals

    ratings377 = data.filter(lambda x : x[1]==377) \
                    .map(lambda u: (u[0], u[2])) \
                    .join(userWeight) \
                    .join(userTotals)

    # Calculate the numerator and denominator sums

    ratingsList377 = ratings377.map(lambda u: (u[1][0][1]*
                                               (u[1][0][0]- \
                                                u[1][1][0]/u[1][1][2]), \
                                               abs(u[1][0][1]))) \
                               .fold((0,0), (lambda x,y: (x[0]+y[0], \
                                                          x[1]+y[1])))

    # Finally, the rating, as offset from user 22's average
    
    rating377 = user22Totals[0][0]/user22Totals[0][2] + ratingsList377[0]/ratingsList377[1]
                                        
    print("Predicted Rating for User 22 Movie 377: " + str(rating377))

    spark.stop()
