package com.ys.scala
import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.recommendation.Rating
import scala.util.Random
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.rdd.RDD
object ScalaMovieLensALS {
def main(args: Array[String]): Unit = {
//屏蔽不必要的打印信息
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
val conf = new SparkConf().setAppName("ScalaMovieLensALS").setMaster("local")
val sc = new SparkContext(conf)
// load ratings and movie titles
val ratings = sc.textFile("ratings.dat").map { line =>
val fields = line.split("::")
// format: (timestamp % 10, Rating(userId, movieId, rating))
(fields(3).toLong % 10, Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble))
}
val movies = sc.textFile("movies.dat").map { line =>
val fields = line.split("::")
// format: (movieId, movieName) key value格式
(fields(0).toInt, fields(1))
}.collect().toMap
val numRatings = ratings.count();
val numUsers = ratings.map(_._2.user).distinct().count()
val numMovies = ratings.map(_._2.product).distinct().count()
println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
//get ratings of user on top 50 popular movies
val mostRatedMovieIds = ratings.map(_._2.product) //extract movieId
.countByValue() //count ratings per movie
.toSeq //convert map to seq
.sortBy(-_._2) //sort by rating count in decreasing order
.take(50) //take 50 most rated
.map(_._1) //get movie ids
val random = new Random(0)
val selectedMovies = mostRatedMovieIds.filter { x => random.nextDouble() < 0.2 }
.map { x => (x, movies(x)) }
.toSeq
val myRatings = getRatings(selectedMovies)
//convert received ratings to RDD[Rating], now this can be worked in parallel
val myRatingsRDD = sc.parallelize(myRatings)
// split ratings into train (60%), validation (20%), and test (20%) based on the
// last digit of the timestamp, add myRatings to train, and cache them
val numPartitions = 4
//Rating(3329,953,5.0)
val training = ratings.filter(x => x._1 < 6).values.union(myRatingsRDD).repartition(numPartitions).cache()
val validation = ratings.filter(x => x._1 >= 6 && x._1 < 8).values.repartition(numPartitions).cache()
val test = ratings.filter(x => x._1 >= 8).values.cache()
val numTraining = training.count()
val numValidation = validation.count()
val numTest = test.count()
println(s"Training: $numTraining, validation: $numValidation, test: $numTest")
// train models and evaluate them on the validation set
val ranks = List(8, 10, 12) //模型中的隐藏因子数目
val lambdas = List(0.1, 1.0, 10.0) //ALS正则化参数
val numIterations = List(10, 20) //算法迭代次数
var bestModel: Option[MatrixFactorizationModel] = None //矩阵分解
var bestValidationRmse = Double.MaxValue
var bestRank = 0
var bestLambda = -1.0
var bestNumIter = -1
for(rank <- ranks; lambda <- lambdas; numIter <- numIterations) {
//learn model for these parameter
val model = ALS.train(training, rank, numIter, lambda)
val validationRmse = computeRmse(model, validation)
println(s"RMSE (validation) = $validationRmse for the model trained with rank = $rank , lambda = $lambda ,and numIter = $numIter .")
if(validationRmse < bestValidationRmse) {
bestModel = Some(model)
bestValidationRmse = validationRmse
bestRank = rank
bestLambda = lambda
bestNumIter = numIter
}
}
// evaluate the best model on the test set
val testRmse = computeRmse(bestModel.get, test)
println(s"The best model was trained with rank = $bestRank and lambda = $bestLambda , and numIter = $bestNumIter , and its RMSE on the test set is $testRmse .")
//find best movies for the user
val myRatedMovieIds = myRatings.map(_.product).toSet
//generate candidates after taking out already rated movies
val candidates = sc.parallelize(movies.keys.filter(!myRatedMovieIds.contains(_)).toSeq)
val recommendations = bestModel.get.predict(candidates.map((0, _))).collect.sortBy(-_.rating).take(50)
var i = 1
println("Movies recommendation for you: ")
recommendations.foreach { r => println("-".format(i) + ": " + movies(r.product))
i += 1
}
// create a naive baseline and compare it with the best model
val meanRating = training.union(validation).map(_.rating).mean
val baselineRmse = math.sqrt(test.map(x => (meanRating - x.rating) * (meanRating - x.rating)).mean)
val improvement = (baselineRmse - testRmse) / baselineRmse * 100
println("The best model improves the baseline by " + "%1.2f".format(improvement) + "%.")
// clean up
sc.stop()
}
/** Get ratings from commandline **/
def getRatings(movies: Seq[(Int, String)]) = {
val prompt = "Please rate following movie (1-5(best), or 0 if not seen):"
println(prompt)
val ratings = movies.flatMap { x =>
var rating: Option[Rating] = None
var valid = false
while (!valid) {
print(x._2 + ":")
try {
val r = Console.readInt()
if(r < 0 || r > 5) {
println(prompt)
} else {
valid = true
if (r > 0) {
rating = Some(Rating(0,x._1,r))
}
}
} catch {
case e: Exception => println(prompt)
}
}
rating match {
case Some(r) => Iterator(r)
case None => Iterator.empty
}
}//end flatMap
if (ratings.isEmpty) {
error("No rating provided")
} else {
ratings
}
}
// Compute RMSE (Root Mean Squared Error). 计算测试集的评分和实际评分之间的均方根误差(RMSE)
def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating]) = {
val usersProducts = data.map { case Rating(user, product, rate) => (user, product) }
val predictions = model.predict(usersProducts).map { case Rating(user, product, rate) => ((user, product), rate) }
val ratesAndPreds = data.map { case Rating(user, product, rate) =>
((user, product), rate) }.join(predictions).sortByKey()
math.sqrt(ratesAndPreds.map { case ((user, product), (r1, r2)) =>
val err = (r1 - r2)
err * err
}.mean())
}
}
代码中用到的数据movies.dat ratings.dat可以在http://download.csdn.net/detail/u013147600/8908241下载
movies.dat的数据格式如下:序号,电影名,类型
1::Toy Story (1995)::Animation|Children's|Comedy 2::Jumanji (1995)::Adventure|Children's|Fantasy 3::Grumpier Old Men (1995)::Comedy|Romance 4::Waiting to Exhale (1995)::Comedy|Drama 5::Father of the Bride Part II (1995)::Comedy 6::Heat (1995)::Action|Crime|Thriller 7::Sabrina (1995)::Comedy|Romance 8::Tom and Huck (1995)::Adventure|Children's 9::Sudden Death (1995)::Action 10::GoldenEye (1995)::Action|Adventure|Thriller
ratings.dat的数据格式如下:用户id,电影id,评分,时间戳
1::1193::5::978300760 1::661::3::978302109 1::914::3::978301968 1::3408::4::978300275 1::2355::5::978824291 1::1197::3::978302268 1::1287::5::978302039 1::2804::5::978300719 1::594::4::978302268 1::919::4::978301368