博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
通过Spark进行ALS离线和Stream实时推荐
阅读量:6323 次
发布时间:2019-06-22

本文共 14076 字,大约阅读时间需要 46 分钟。

hot3.png

ALS简介

ALS是alternating least squares的缩写 , 意为交替最小二乘法;而ALS-WR是alternating-least-squares with weighted-λ -regularization的缩写,意为加权正则化交替最小二乘法。该方法常用于基于矩阵分解的推荐系统中。例如:将用户(user)对商品(item)的评分矩阵分解为两个矩阵:一个是用户对商品隐含特征的偏好矩阵,另一个是商品所包含的隐含特征的矩阵。在这个矩阵分解的过程中,评分缺失项得到了填充,也就是说我们可以基于这个填充的评分来给用户最商品推荐了。

ALS is the abbreviation of squares alternating least, meaning the alternating least squares method; and the ALS-WR is alternating-least-squares with weighted- lambda -regularization acronym, meaning weighted regularized alternating least squares method. This method is often used in recommender systems based on matrix factorization. For example, the user (user) score matrix of item is decomposed into two matrices: one is the user preference matrix for the implicit features of the commodity, and the other is the matrix of the implied features of the commodity. In the process of decomposing the matrix, the score missing is filled, that is, we can give the user the most recommended commodity based on the filled score.

ALS-WR算法,简单地说就是:

(数据格式为:userId, itemId, rating, timestamp )
1 对每个userId随机初始化N(10)个factor值,由这些值影响userId的权重。
2 对每个itemId也随机初始化N(10)个factor值。
3 固定userId,从userFactors矩阵和rating矩阵中分解出itemFactors矩阵。即[Item Factors Matrix] = [User Factors Matrix]^-1 * [Rating Matrix].
4 固定itemId,从itemFactors矩阵和rating矩阵中分解出userFactors矩阵。即[User Factors Matrix] = [Item Factors Matrix]^-1 * [Rating Matrix].
5 重复迭代第3,第4步,最后可以收敛到稳定的userFactors和itemFactors。
6 对itemId进行推断就为userFactors * itemId = rating value;对userId进行推断就为itemFactors * userId = rating value。

Spark支持ML和MLLIB两种机器学习库,官方推荐的是ML, 因为ML功能更全面更灵活,未来会主要支持ML。

 

ML实现ALS推荐:

/** * @author huangyueran * @category ALS-WR */public class JavaALSExampleByMl {    private static final Logger log = LoggerFactory.getLogger(JavaALSExampleByMl.class);    public static class Rating implements Serializable {        // 0::2::3::1424380312        private int userId; // 0        private int movieId; // 2        private float rating; // 3        private long timestamp; // 1424380312        public Rating() {        }        public Rating(int userId, int movieId, float rating, long timestamp) {            this.userId = userId;            this.movieId = movieId;            this.rating = rating;            this.timestamp = timestamp;        }        public int getUserId() {            return userId;        }        public int getMovieId() {            return movieId;        }        public float getRating() {            return rating;        }        public long getTimestamp() {            return timestamp;        }        public static Rating parseRating(String str) {            String[] fields = str.split("::");            if (fields.length != 4) {                throw new IllegalArgumentException("Each line must contain 4 fields");            }            int userId = Integer.parseInt(fields[0]);            int movieId = Integer.parseInt(fields[1]);            float rating = Float.parseFloat(fields[2]);            long timestamp = Long.parseLong(fields[3]);            return new Rating(userId, movieId, rating, timestamp);        }    }    public static void main(String[] args) {        SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local");        JavaSparkContext jsc = new JavaSparkContext(conf);        SQLContext sqlContext = new SQLContext(jsc);        JavaRDD
ratingsRDD = jsc.textFile("data/sample_movielens_ratings.txt") .map(new Function
() { public Rating call(String str) { return Rating.parseRating(str); } }); Dataset
ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class); Dataset
[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); // //对数据进行分割,80%为训练样例,剩下的为测试样例。 Dataset
training = splits[0]; Dataset
test = splits[1]; // Build the recommendation model using ALS on the training data ALS als = new ALS().setMaxIter(5) // 设置迭代次数 .setRegParam(0.01) // //正则化参数,使每次迭代平滑一些,此数据集取0.1好像错误率低一些。 .setUserCol("userId").setItemCol("movieId") .setRatingCol("rating"); ALSModel model = als.fit(training); // //调用算法开始训练 Dataset
itemFactors = model.itemFactors(); itemFactors.show(1500); Dataset
userFactors = model.userFactors(); userFactors.show(); // Evaluate the model by computing the RMSE on the test data Dataset
rawPredictions = model.transform(test); //对测试数据进行预测 Dataset
predictions = rawPredictions .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType)) .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType)); RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating") .setPredictionCol("prediction"); Double rmse = evaluator.evaluate(predictions); log.info("Root-mean-square error = {} ", rmse); jsc.stop(); }}

MLLIB实现ALS推荐:

/** * @category ALS * @author huangyueran * */public class JavaALSExampleByMlLib {	private static final Logger log = LoggerFactory.getLogger(JavaALSExampleByMlLib.class);	public static void main(String[] args) {		SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local[4]");		JavaSparkContext jsc = new JavaSparkContext(conf);		JavaRDD
data = jsc.textFile("data/sample_movielens_ratings.txt"); JavaRDD
ratings = data.map(new Function
() { public Rating call(String s) { String[] sarray = StringUtils.split(StringUtils.trim(s), "::"); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } }); // Build the recommendation model using ALS int rank = 10; int numIterations = 6; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data JavaRDD
> userProducts = ratings.map(new Function
>() { public Tuple2
call(Rating r) { return new Tuple2
(r.user(), r.product()); } }); // 预测的评分 JavaPairRDD
, Double> predictions = JavaPairRDD .fromJavaRDD(model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD() .map(new Function
, Double>>() { public Tuple2
, Double> call(Rating r) { return new Tuple2
, Double>( new Tuple2
(r.user(), r.product()), r.rating()); } })); JavaPairRDD
, Tuple2
> ratesAndPreds = JavaPairRDD .fromJavaRDD(ratings.map(new Function
, Double>>() { public Tuple2
, Double> call(Rating r) { return new Tuple2
, Double>( new Tuple2
(r.user(), r.product()), r.rating()); } })).join(predictions); // 得到按照用户ID排序后的评分列表 key:用户id JavaPairRDD
> fromJavaRDD = JavaPairRDD.fromJavaRDD(ratesAndPreds.map( new Function
, Tuple2
>, Tuple2
>>() { public Tuple2
> call( Tuple2
, Tuple2
> t) throws Exception { return new Tuple2
>(t._1._1, new Tuple2
(t._1._2, t._2._2)); } })).sortByKey(false); // List
>> list = fromJavaRDD.collect();// for(Tuple2
> t:list){// System.out.println(t._1+":"+t._2._1+"===="+t._2._2);// } JavaRDD
> ratesAndPredsValues = ratesAndPreds.values(); double MSE = JavaDoubleRDD.fromRDD(ratesAndPredsValues.map(new Function
, Object>() { public Object call(Tuple2
pair) { Double err = pair._1() - pair._2(); return err * err; } }).rdd()).mean(); try { FileUtils.deleteDirectory(new File("result")); } catch (IOException e) { e.printStackTrace(); } ratesAndPreds.repartition(1).saveAsTextFile("result/ratesAndPreds"); //为指定用户推荐10个商品(电影) Rating[] recommendProducts = model.recommendProducts(2, 10); log.info("get recommend result:{}",Arrays.toString(recommendProducts)); // 为所有用户推荐TOP N个物品 //model.recommendUsersForProducts(10); // 为所有物品推荐TOP N个用户 //model.recommendProductsForUsers(10) model.userFeatures().saveAsTextFile("result/userFea"); model.productFeatures().saveAsTextFile("result/productFea"); log.info("Mean Squared Error = {}" , MSE); }}

以上两种主要是通过Spark进行离线的ALS推荐。还有一种是通过Spark-Streaming流式计算,对像Kafka消息队列中,缓冲的实时数据进行在线(实时)计算。

 

Spark-Streaming进行ALS实时推荐:

通过Spark-Streaming进行ALS推荐仅仅是其中的一环。真实项目中还涉及了很多其他技术处理。

比如用户行为日志数据的埋点处理,通过flume来进行监控拉取,存储到hdfs中。通过kafka来进行海量行为数据的消费、缓冲。

以及通过Spark机器学习计算后生成的训练模型的离线存储,Web拉取模型进行缓存,对用户进行推荐等等。

/** * @author huangyueran * @category 基于Spark-streaming、kafka的实时推荐模板DEMO 原系统中包含商城项目、logback、flume、hadoop * The real time recommendation template DEMO based on Spark-streaming and Kafka contains the mall project, logback, flume and Hadoop in the original system */public final class SparkALSByStreaming {    private static final Logger log = LoggerFactory.getLogger(SparkALSByStreaming.class);    private static final String KAFKA_ADDR = "middleware:9092";    private static final String TOPIC = "RECOMMEND_TOPIC";    private static final String HDFS_ADDR = "hdfs://middleware:9000";    private static final String MODEL_PATH = "/spark-als/model";    //	基于Hadoop、Flume、Kafka、spark-streaming、logback、商城系统的实时推荐系统DEMO    //	Real time recommendation system DEMO based on Hadoop, Flume, Kafka, spark-streaming, logback and mall system    //	商城系统采集的数据集格式 Data Format:    //	用户ID,商品ID,用户行为评分,时间戳    //	UserID,ItemId,Rating,TimeStamp    //	53,1286513,9,1508221762    //	53,1172348420,9,1508221762    //	53,1179495514,12,1508221762    //	53,1184890730,3,1508221762    //	53,1210793742,159,1508221762    //	53,1215837445,9,1508221762    public static void main(String[] args) {        System.setProperty("HADOOP_USER_NAME", "root"); // 设置权限用户        SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaDirectWordCount").setMaster("local[1]");        final JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(6));        Map
kafkaParams = new HashMap
(); // key是topic名称,value是线程数量 kafkaParams.put("metadata.broker.list", KAFKA_ADDR); // 指定broker在哪 HashSet
topicsSet = new HashSet
(); topicsSet.add(TOPIC); // 指定操作的topic // Create direct kafka stream with brokers and topics // createDirectStream() JavaPairInputDStream
messages = KafkaUtils.createDirectStream(jssc, String.class, String.class, StringDecoder.class, StringDecoder.class, kafkaParams, topicsSet); JavaDStream
lines = messages.map(new Function
, String>() { public String call(Tuple2
tuple2) { return tuple2._2(); } }); JavaDStream
ratingsStream = lines.map(new Function
() { public Rating call(String s) { String[] sarray = StringUtils.split(StringUtils.trim(s), ","); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } }); // 进行流推荐计算 ratingsStream.foreachRDD(new VoidFunction
>() { public void call(JavaRDD
ratings) throws Exception { // 获取到原始的数据集 SparkContext sc = ratings.context(); RDD
textFileRDD = sc.textFile(HDFS_ADDR + "/flume/logs", 3); // 读取原始数据集文件 JavaRDD
originalTextFile = textFileRDD.toJavaRDD(); final JavaRDD
originaldatas = originalTextFile.map(new Function
() { public Rating call(String s) { String[] sarray = StringUtils.split(StringUtils.trim(s), ","); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } }); log.info("========================================"); log.info("Original TextFile Count:{}", originalTextFile.count()); // HDFS中已经存储的原始用户行为日志数据 log.info("========================================"); // 将原始数据集和新的用户行为数据进行合并 JavaRDD
calculations = originaldatas.union(ratings); log.info("Calc Count:{}", calculations.count()); // Build the recommendation model using ALS int rank = 10; // 模型中隐语义因子的个数 int numIterations = 6; // 训练次数 // 得到训练模型 if (!ratings.isEmpty()) { // 如果有用户行为数据 MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(calculations), rank, numIterations, 0.01); // 判断文件是否存在,如果存在 删除文件目录 Configuration hadoopConfiguration = sc.hadoopConfiguration(); hadoopConfiguration.set("fs.defaultFS", HDFS_ADDR); FileSystem fs = FileSystem.get(hadoopConfiguration); Path outpath = new Path(MODEL_PATH); if (fs.exists(outpath)) { log.info("########### 删除" + outpath.getName() + " ###########"); fs.delete(outpath, true); } // 保存model model.save(sc, HDFS_ADDR + MODEL_PATH); // 读取model MatrixFactorizationModel modelLoad = MatrixFactorizationModel.load(sc, HDFS_ADDR + MODEL_PATH); // 为指定用户推荐10个商品(电影) for(int userId=0;userId<30;userId++){ // streaming_sample_movielens_ratings.txt Rating[] recommendProducts = modelLoad.recommendProducts(userId, 10); log.info("get recommend result:{}", Arrays.toString(recommendProducts)); } } } }); // ========================================================================================== jssc.start(); try { jssc.awaitTermination(); } catch (InterruptedException e) { e.printStackTrace(); } // Local Model try { Thread.sleep(10000000); } catch (InterruptedException e) { e.printStackTrace(); } // jssc.stop(); // jssc.close(); }}

用户行为数据集

商城系统采集的数据集格式 Data Format:

用户ID,商品ID,用户行为评分,时间戳
UserID,ItemId,Rating,TimeStamp
53,1286513,9,1508221762
53,1172348420,9,1508221762
53,1179495514,12,1508221762
53,1184890730,3,1508221762
53,1210793742,159,1508221762
53,1215837445,9,1508221762

 

maven依赖

org.apache.spark
spark-core_2.10
2.2.0
org.apache.spark
spark-mllib_2.10
2.2.0
org.apache.spark
spark-sql_2.10
2.2.0
org.apache.spark
spark-streaming_2.10
2.2.0
org.apache.spark
spark-streaming-kafka_2.10
1.6.3
log4j
log4j
1.2.17
org.slf4j
slf4j-api
1.7.12
org.slf4j
slf4j-log4j12
1.7.12

以上代码以及数据集可以去Github上的项目找到

 

 

转载于:https://my.oschina.net/u/4074730/blog/3011783

你可能感兴趣的文章
Uber推出数据湖集成神器DBEvents,支持MySQL、Cassandra等
查看>>
Entity Framework Core 2.0的新特性
查看>>
[deviceone开发]-do_Http组件示例
查看>>
linux yum命令
查看>>
职场中怎样评估系统架构师的成绩?
查看>>
(总结)Nginx/LVS/HAProxy负载均衡软件的优缺点详解
查看>>
centos7 搭建nfs共享文件
查看>>
linux命令
查看>>
我的友情链接
查看>>
Python中fnmatch模块的使用
查看>>
BE镜像还原系统过程
查看>>
Linux中查看所有正在运行的进程
查看>>
H3CTE京东翰林讲师分享实验2 网络设备基本调试
查看>>
汇正进销存
查看>>
近期学习oracle 数据库总结
查看>>
php apc
查看>>
我的友情链接
查看>>
C#学习视频分享与开发技术QQ交流群
查看>>
bootstrap datetimepicker 时间控件的使用
查看>>
sudo 密码超时时间
查看>>