ALS实现电影推荐

xiaoxiao2021-02-27  320

package com.ys.scala import org.apache.log4j.Logger import org.apache.log4j.Level import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.mllib.recommendation.Rating import scala.util.Random import org.apache.spark.mllib.recommendation.MatrixFactorizationModel import org.apache.spark.mllib.recommendation.ALS import org.apache.spark.rdd.RDD object ScalaMovieLensALS  {   def main(args: Array[String]): Unit = {          //屏蔽不必要的打印信息     Logger.getLogger("org.apache.spark").setLevel(Level.WARN)     Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)          val conf = new SparkConf().setAppName("ScalaMovieLensALS").setMaster("local")     val sc = new SparkContext(conf)          // load ratings and movie titles     val ratings = sc.textFile("ratings.dat").map { line =>        val fields = line.split("::")       // format: (timestamp % 10, Rating(userId, movieId, rating))         (fields(3).toLong % 10, Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble))     }          val movies = sc.textFile("movies.dat").map { line =>        val fields = line.split("::")       // format: (movieId, movieName) key value格式       (fields(0).toInt, fields(1))     }.collect().toMap          val numRatings = ratings.count();     val numUsers = ratings.map(_._2.user).distinct().count()     val numMovies = ratings.map(_._2.product).distinct().count()          println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")          //get ratings of user on top 50 popular movies     val mostRatedMovieIds = ratings.map(_._2.product) //extract movieId       .countByValue() //count ratings per movie       .toSeq //convert map to seq       .sortBy(-_._2) //sort by rating count in decreasing order       .take(50) //take 50 most rated       .map(_._1) //get movie ids            val random = new Random(0)     val selectedMovies = mostRatedMovieIds.filter { x => random.nextDouble() < 0.2 }       .map { x => (x, movies(x)) }       .toSeq     val myRatings = getRatings(selectedMovies)     //convert received ratings to RDD[Rating], now this can be worked in parallel     val myRatingsRDD = sc.parallelize(myRatings)          // split ratings into train (60%), validation (20%), and test (20%) based on the     // last digit of the timestamp, add myRatings to train, and cache them          val numPartitions = 4     //Rating(3329,953,5.0)     val training = ratings.filter(x => x._1 < 6).values.union(myRatingsRDD).repartition(numPartitions).cache()     val validation = ratings.filter(x => x._1 >= 6 && x._1 < 8).values.repartition(numPartitions).cache()     val test = ratings.filter(x => x._1 >= 8).values.cache()          val numTraining = training.count()     val numValidation = validation.count()     val numTest = test.count()          println(s"Training: $numTraining, validation: $numValidation, test: $numTest")          // train models and evaluate them on the validation set     val ranks = List(8, 10, 12) //模型中的隐藏因子数目     val lambdas = List(0.1, 1.0, 10.0) //ALS正则化参数     val numIterations = List(10, 20) //算法迭代次数     var bestModel: Option[MatrixFactorizationModel] = None //矩阵分解     var bestValidationRmse = Double.MaxValue     var bestRank = 0     var bestLambda = -1.0     var bestNumIter = -1          for(rank <- ranks; lambda <- lambdas; numIter <- numIterations) {       //learn model for these parameter       val model = ALS.train(training, rank, numIter, lambda)       val validationRmse = computeRmse(model, validation)       println(s"RMSE (validation) = $validationRmse for the model trained with rank = $rank , lambda = $lambda ,and numIter = $numIter .")                  if(validationRmse < bestValidationRmse) {         bestModel = Some(model)         bestValidationRmse = validationRmse         bestRank = rank         bestLambda = lambda         bestNumIter = numIter       }     }            // evaluate the best model on the test set     val testRmse = computeRmse(bestModel.get, test)     println(s"The best model was trained with rank = $bestRank and lambda = $bestLambda , and numIter = $bestNumIter , and its RMSE on the test set is $testRmse .")          //find best movies for the user     val myRatedMovieIds = myRatings.map(_.product).toSet     //generate candidates after taking out already rated movies     val candidates = sc.parallelize(movies.keys.filter(!myRatedMovieIds.contains(_)).toSeq)     val recommendations = bestModel.get.predict(candidates.map((0, _))).collect.sortBy(-_.rating).take(50)     var i = 1     println("Movies recommendation for you: ")     recommendations.foreach { r => println("-".format(i) + ": " + movies(r.product))       i += 1     }          // create a naive baseline and compare it with the best model     val meanRating = training.union(validation).map(_.rating).mean     val baselineRmse = math.sqrt(test.map(x => (meanRating - x.rating) * (meanRating - x.rating)).mean)     val improvement = (baselineRmse - testRmse) / baselineRmse * 100     println("The best model improves the baseline by " + "%1.2f".format(improvement) + "%.")     // clean up     sc.stop()   }      /** Get ratings from commandline **/   def getRatings(movies: Seq[(Int, String)]) = {     val prompt = "Please rate following movie (1-5(best), or 0 if not seen):"     println(prompt)          val ratings = movies.flatMap { x =>               var rating: Option[Rating] = None       var valid = false              while (!valid) {         print(x._2 + ":")         try {           val r = Console.readInt()           if(r < 0 || r > 5) {             println(prompt)           } else {             valid = true             if (r > 0) {               rating = Some(Rating(0,x._1,r))             }           }         } catch {           case e: Exception => println(prompt)         }       }              rating match {         case Some(r) => Iterator(r)         case None => Iterator.empty       }     }//end flatMap          if (ratings.isEmpty) {       error("No rating provided")     } else {       ratings     }   }      // Compute RMSE (Root Mean Squared Error).  计算测试集的评分和实际评分之间的均方根误差(RMSE)   def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating]) = {     val usersProducts = data.map { case Rating(user, product, rate) => (user, product) }          val predictions = model.predict(usersProducts).map { case Rating(user, product, rate) => ((user, product), rate) }          val ratesAndPreds = data.map { case Rating(user, product, rate) =>       ((user, product), rate) }.join(predictions).sortByKey()            math.sqrt(ratesAndPreds.map { case ((user, product), (r1, r2)) =>       val err = (r1 - r2)       err * err     }.mean())   }

}

代码中用到的数据movies.dat  ratings.dat可以在http://download.csdn.net/detail/u013147600/8908241下载

movies.dat的数据格式如下:序号,电影名,类型

1::Toy Story (1995)::Animation|Children's|Comedy 2::Jumanji (1995)::Adventure|Children's|Fantasy 3::Grumpier Old Men (1995)::Comedy|Romance 4::Waiting to Exhale (1995)::Comedy|Drama 5::Father of the Bride Part II (1995)::Comedy 6::Heat (1995)::Action|Crime|Thriller 7::Sabrina (1995)::Comedy|Romance 8::Tom and Huck (1995)::Adventure|Children's 9::Sudden Death (1995)::Action 10::GoldenEye (1995)::Action|Adventure|Thriller

ratings.dat的数据格式如下:用户id,电影id,评分,时间戳

1::1193::5::978300760 1::661::3::978302109 1::914::3::978301968 1::3408::4::978300275 1::2355::5::978824291 1::1197::3::978302268 1::1287::5::978302039 1::2804::5::978300719 1::594::4::978302268 1::919::4::978301368

转载请注明原文地址: https://www.6miu.com/read-4256.html

最新回复(0)