ALS是alternating least squares的缩写 , 意为交替最小二乘法;而ALS-WR是alternating-least-squares with weighted-λ -regularization的缩写,意为加权正则化交替最小二乘法。该方法常用于基于矩阵分解的推荐系统中。例如:将用户(user)对商品(item)的评分矩阵分解为两个矩阵:一个是用户对商品隐含特征的偏好矩阵,另一个是商品所包含的隐含特征的矩阵。在这个矩阵分解的过程中,评分缺失项得到了填充,也就是说我们可以基于这个填充的评分来给用户最商品推荐了。
(数据格式为:userId, itemId, rating, timestamp ) 1 对每个userId随机初始化N(10)个factor值,由这些值影响userId的权重。 2 对每个itemId也随机初始化N(10)个factor值。 3 固定userId,从userFactors矩阵和rating矩阵中分解出itemFactors矩阵。即[Item Factors Matrix] = [User Factors Matrix]^-1 * [Rating Matrix]. 4 固定itemId,从itemFactors矩阵和rating矩阵中分解出userFactors矩阵。即[User Factors Matrix] = [Item Factors Matrix]^-1 * [Rating Matrix]. 5 重复迭代第3,第4步,最后可以收敛到稳定的userFactors和itemFactors。 6 对itemId进行推断就为userFactors * itemId = rating value;对userId进行推断就为itemFactors * userId = rating value。Spark支持ML和MLLIB两种机器学习库,官方推荐的是ML, 因为ML功能更全面更灵活,未来会主要支持ML。
/** * @author huangyueran * @category ALS-WR */public class JavaALSExampleByMl { private static final Logger log = LoggerFactory.getLogger(JavaALSExampleByMl.class); public static class Rating implements Serializable { // 0::2::3::1424380312 private int userId; // 0 private int movieId; // 2 private float rating; // 3 private long timestamp; // 1424380312 public Rating() { } public Rating(int userId, int movieId, float rating, long timestamp) { this.userId = userId; this.movieId = movieId; this.rating = rating; this.timestamp = timestamp; } public int getUserId() { return userId; } public int getMovieId() { return movieId; } public float getRating() { return rating; } public long getTimestamp() { return timestamp; } public static Rating parseRating(String str) { String[] fields = str.split("::"); if (fields.length != 4) { throw new IllegalArgumentException("Each line must contain 4 fields"); } int userId = Integer.parseInt(fields[0]); int movieId = Integer.parseInt(fields[1]); float rating = Float.parseFloat(fields[2]); long timestamp = Long.parseLong(fields[3]); return new Rating(userId, movieId, rating, timestamp); } } public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(jsc); JavaRDDratingsRDD = jsc.textFile("data/sample_movielens_ratings.txt") .map(new Function () { public Rating call(String str) { return Rating.parseRating(str); } }); Dataset ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class); Dataset
[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); // //对数据进行分割,80%为训练样例,剩下的为测试样例。 Dataset
training = splits[0]; Dataset
test = splits[1]; // Build the recommendation model using ALS on the training data ALS als = new ALS().setMaxIter(5) // 设置迭代次数 .setRegParam(0.01) // //正则化参数,使每次迭代平滑一些,此数据集取0.1好像错误率低一些。 .setUserCol("userId").setItemCol("movieId") .setRatingCol("rating"); ALSModel model =; // //调用算法开始训练 Dataset
itemFactors = model.itemFactors();; Dataset
userFactors = model.userFactors();; // Evaluate the model by computing the RMSE on the test data Dataset
rawPredictions = model.transform(test); //对测试数据进行预测 Dataset
predictions = rawPredictions .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType)) .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType)); RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating") .setPredictionCol("prediction"); Double rmse = evaluator.evaluate(predictions);"Root-mean-square error = {} ", rmse); jsc.stop(); }}
/** * @category ALS * @author huangyueran * */public class JavaALSExampleByMlLib { private static final Logger log = LoggerFactory.getLogger(JavaALSExampleByMlLib.class); public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local[4]"); JavaSparkContext jsc = new JavaSparkContext(conf); JavaRDDdata = jsc.textFile("data/sample_movielens_ratings.txt"); JavaRDD ratings = Function () { public Rating call(String s) { String[] sarray = StringUtils.split(StringUtils.trim(s), "::"); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } }); // Build the recommendation model using ALS int rank = 10; int numIterations = 6; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data JavaRDD
/** * @author huangyueran * @category 基于Spark-streaming、kafka的实时推荐模板DEMO 原系统中包含商城项目、logback、flume、hadoop * The real time recommendation template DEMO based on Spark-streaming and Kafka contains the mall project, logback, flume and Hadoop in the original system */public final class SparkALSByStreaming { private static final Logger log = LoggerFactory.getLogger(SparkALSByStreaming.class); private static final String KAFKA_ADDR = "middleware:9092"; private static final String TOPIC = "RECOMMEND_TOPIC"; private static final String HDFS_ADDR = "hdfs://middleware:9000"; private static final String MODEL_PATH = "/spark-als/model"; // 基于Hadoop、Flume、Kafka、spark-streaming、logback、商城系统的实时推荐系统DEMO // Real time recommendation system DEMO based on Hadoop, Flume, Kafka, spark-streaming, logback and mall system // 商城系统采集的数据集格式 Data Format: // 用户ID,商品ID,用户行为评分,时间戳 // UserID,ItemId,Rating,TimeStamp // 53,1286513,9,1508221762 // 53,1172348420,9,1508221762 // 53,1179495514,12,1508221762 // 53,1184890730,3,1508221762 // 53,1210793742,159,1508221762 // 53,1215837445,9,1508221762 public static void main(String[] args) { System.setProperty("HADOOP_USER_NAME", "root"); // 设置权限用户 SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaDirectWordCount").setMaster("local[1]"); final JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(6)); MapkafkaParams = new HashMap (); // key是topic名称,value是线程数量 kafkaParams.put("", KAFKA_ADDR); // 指定broker在哪 HashSet topicsSet = new HashSet (); topicsSet.add(TOPIC); // 指定操作的topic // Create direct kafka stream with brokers and topics // createDirectStream() JavaPairInputDStream messages = KafkaUtils.createDirectStream(jssc, String.class, String.class, StringDecoder.class, StringDecoder.class, kafkaParams, topicsSet); JavaDStream lines = Function , String>() { public String call(Tuple2 tuple2) { return tuple2._2(); } }); JavaDStream ratingsStream = Function () { public Rating call(String s) { String[] sarray = StringUtils.split(StringUtils.trim(s), ","); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } }); // 进行流推荐计算 ratingsStream.foreachRDD(new VoidFunction >() { public void call(JavaRDD ratings) throws Exception { // 获取到原始的数据集 SparkContext sc = ratings.context(); RDD textFileRDD = sc.textFile(HDFS_ADDR + "/flume/logs", 3); // 读取原始数据集文件 JavaRDD originalTextFile = textFileRDD.toJavaRDD(); final JavaRDD originaldatas = Function () { public Rating call(String s) { String[] sarray = StringUtils.split(StringUtils.trim(s), ","); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } });"========================================");"Original TextFile Count:{}", originalTextFile.count()); // HDFS中已经存储的原始用户行为日志数据"========================================"); // 将原始数据集和新的用户行为数据进行合并 JavaRDD calculations = originaldatas.union(ratings);"Calc Count:{}", calculations.count()); // Build the recommendation model using ALS int rank = 10; // 模型中隐语义因子的个数 int numIterations = 6; // 训练次数 // 得到训练模型 if (!ratings.isEmpty()) { // 如果有用户行为数据 MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(calculations), rank, numIterations, 0.01); // 判断文件是否存在,如果存在 删除文件目录 Configuration hadoopConfiguration = sc.hadoopConfiguration(); hadoopConfiguration.set("fs.defaultFS", HDFS_ADDR); FileSystem fs = FileSystem.get(hadoopConfiguration); Path outpath = new Path(MODEL_PATH); if (fs.exists(outpath)) {"########### 删除" + outpath.getName() + " ###########"); fs.delete(outpath, true); } // 保存model, HDFS_ADDR + MODEL_PATH); // 读取model MatrixFactorizationModel modelLoad = MatrixFactorizationModel.load(sc, HDFS_ADDR + MODEL_PATH); // 为指定用户推荐10个商品(电影) for(int userId=0;userId<30;userId++){ // streaming_sample_movielens_ratings.txt Rating[] recommendProducts = modelLoad.recommendProducts(userId, 10);"get recommend result:{}", Arrays.toString(recommendProducts)); } } } }); // ========================================================================================== jssc.start(); try { jssc.awaitTermination(); } catch (InterruptedException e) { e.printStackTrace(); } // Local Model try { Thread.sleep(10000000); } catch (InterruptedException e) { e.printStackTrace(); } // jssc.stop(); // jssc.close(); }}
用户ID,商品ID,用户行为评分,时间戳
UserID,ItemId,Rating,TimeStamp
53,1286513,9,1508221762
53,1172348420,9,1508221762
53,1179495514,12,1508221762
53,1184890730,3,1508221762
53,1210793742,159,1508221762
53,1215837445,9,1508221762
org.apache.spark spark-core_2.10 2.2.0 org.apache.spark spark-mllib_2.10 2.2.0 org.apache.spark spark-sql_2.10 2.2.0 org.apache.spark spark-streaming_2.10 2.2.0 org.apache.spark spark-streaming-kafka_2.10 1.6.3 log4j log4j 1.2.17 org.slf4j slf4j-api 1.7.12 org.slf4j slf4j-log4j12 1.7.12