Spark随机森林实现票房预测

 更新时间:2019年08月10日 12:43:01   作者:Quincy1994  
这篇文章主要为大家详细介绍了Spark随机森林实现票房预测,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

前言

最近一段时间都在处理电影领域的数据, 而电影票房预测是电影领域数据建模中的一个重要模块, 所以我们针对电影数据做了票房预测建模.

前期工作

一开始的做法是将这个问题看待成回归的问题, 采用GBDT回归树去做. 训练了不同残差的回归树, 然后做集成学习. 考虑的影响因子分别有电影的类型, 豆瓣评分, 导演的 影响力, 演员的影响力, 电影的出品公司. 不过预测的结果并不是那么理想, 准确率为真实值的0.3+/-区间情况下的80%, 且波动性较大, 不容易解析.

后期的改进

总结之前的失败经验, 主要归纳了以下几点:

1.影响因子不够多, 难以建模
2.票房成绩的区间较大(一百万到10亿不等),分布不均匀, 大多数集中与亿级, 所以不适合采用回归方法解决.
3.数据样本量比较少, 不均匀, 预测百万级的电影较多, 影响预测结果

后期, 我们重新规范了数据的输入格式, 即影响因子, 具体如下:

第一行: 电影名字
第二行: 电影票房(也就是用于预测的, 以万为单位)
第三行: 电影类型
第四行: 片长(以分钟为单位)
第五行:上映时间(按月份)
第六行: 制式( 一般分为2D, 3D, IMAX)
第七行: 制作国家
第八行: 导演影响 (以导演的平均票房成绩为衡量, 以万为单位 )
第九行: 演员影响 ( 以所有演员的平均票房成绩为衡量, 以万为单位 )
第十行:制作公司影响 ( 以所有制作公司的平均票房成绩为衡量, 以万为单位 )
第十一行: 发行公式影响 ( 以所有制作公司的平均票房成绩为衡量,以万为单位 )

收集了05-17年的来自中国,日本,美国,英国的电影, 共1058部电影. 由于处理成为分类问题, 故按将电影票房分为以下等级:


在构建模型之前, 先将数据处理成libsvm格式文件, 然后采用随机森林模型训练.

随机森林由许多的决策树组成, 因为这些决策树的形成采用随机的策略, 每个决策树都随机生成, 相互之间独立.模型最后输出的类别是由每个树输出的类别的众数而定.在构建每个决策树的时候采用的策略是信息熵, 决策树为多元分类决策树.随机森林的流程图如下图所示:

随机森林是采用spark-mllib提供的random forest, 由于超过10亿的电影的数据相对比较少, 为了平衡各数据的分布, 采用了过分抽样的方法, 训练模型的代码如下:

public void predict() throws IOException{
  SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");
  conf.set("spark.testing.memory", "2147480000");
  SparkContext sc = new SparkContext(conf);
  SQLContext sqlContext = new SQLContext(sc);


  // Load and parse the data file, converting it to a DataFrame.
  DataFrame trainData = sqlContext.read().format("libsvm").load(this.trainFile);
  DataFrame testData = sqlContext.read().format("libsvm").load(this.testFile);

  // Index labels, adding metadata to the label column.
  // Fit on whole dataset to include all labels in index.
  StringIndexerModel labelIndexer = new StringIndexer()
   .setInputCol("label")
   .setOutputCol("indexedLabel")
   .fit(trainData);
  // Automatically identify categorical features, and index them.
  // Set maxCategories so features with > 4 distinct values are treated as continuous.
  VectorIndexerModel featureIndexer = new VectorIndexer()
   .setInputCol("features")
   .setOutputCol("indexedFeatures")
   .setMaxCategories(4)
   .fit(trainData);

  // Split the data into training and test sets (30% held out for testing)
//  DataFrame[] splits = trainData.randomSplit(new double[] {0.9, 0.1});
//  trainData = splits[0];
//  testData = splits[1];

  // Train a RandomForest model.
  RandomForestClassifier rf = new RandomForestClassifier()
   .setLabelCol("indexedLabel")
   .setFeaturesCol("indexedFeatures")
   .setNumTrees(20);

  // Convert indexed labels back to original labels.
  IndexToString labelConverter = new IndexToString()
   .setInputCol("prediction")
   .setOutputCol("predictedLabel")
   .setLabels(labelIndexer.labels());

  // Chain indexers and forest in a Pipeline
  Pipeline pipeline = new Pipeline()
   .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter});

  // Train model. This also runs the indexers.
  PipelineModel model = pipeline.fit(trainData);

  // Make predictions.
  DataFrame predictions = model.transform(testData);

  // Select example rows to display.
  predictions.select("predictedLabel", "label", "features").show(200);

  // Select (prediction, true label) and compute test error
  MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
   .setLabelCol("indexedLabel")
   .setPredictionCol("prediction")
   .setMetricName("precision");
  double accuracy = evaluator.evaluate(predictions);
  System.out.println("Test Error = " + (1.0 - accuracy));

  RandomForestClassificationModel rfModel = (RandomForestClassificationModel)(model.stages()[2]);
//  System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());

  DataFrame resultDF = predictions.select("predictedLabel");
  JavaRDD<Row> resultRow = resultDF.toJavaRDD();
  JavaRDD<String> result = resultRow.map(new Result());
  this.resultList = result.collect();
  for(String one: resultList){
   System.out.println(one);
  }
 }

下面为其中一个的决策树情况:

Tree 16 (weight 1.0):
 If (feature 10 in {0.0})
  If (feature 48 <= 110.0)
  If (feature 86 <= 13698.87)
  If (feature 21 in {0.0})
  If (feature 54 in {0.0})
   Predict: 0.0
  Else (feature 54 not in {0.0})
   Predict: 1.0
  Else (feature 21 not in {0.0})
  Predict: 0.0
  Else (feature 86 > 13698.87)
  If (feature 21 in {0.0})
  If (feature 85 <= 39646.9)
   Predict: 2.0
  Else (feature 85 > 39646.9)
   Predict: 3.0
  Else (feature 21 not in {0.0})
  Predict: 3.0
  Else (feature 48 > 110.0)
  If (feature 85 <= 15003.3)
  If (feature 9 in {0.0})
  If (feature 54 in {0.0})
   Predict: 0.0
  Else (feature 54 not in {0.0})
   Predict: 2.0
  Else (feature 9 not in {0.0})
  Predict: 2.0
  Else (feature 85 > 15003.3)
  If (feature 65 in {0.0})
  If (feature 85 <= 66065.0)
   Predict: 3.0
  Else (feature 85 > 66065.0)
   Predict: 2.0
  Else (feature 65 not in {0.0})
  Predict: 3.0
 Else (feature 10 not in {0.0})
  If (feature 51 in {0.0})
  If (feature 85 <= 6958.4)
  If (feature 11 in {0.0})
  If (feature 50 <= 1.0)
   Predict: 1.0
  Else (feature 50 > 1.0)
   Predict: 0.0
  Else (feature 11 not in {0.0})
  Predict: 0.0
  Else (feature 85 > 6958.4)
  If (feature 5 in {0.0})
  If (feature 4 in {0.0})
   Predict: 3.0
  Else (feature 4 not in {0.0})
   Predict: 1.0
  Else (feature 5 not in {0.0})
  Predict: 2.0
  Else (feature 51 not in {0.0})
  If (feature 48 <= 148.0)
  If (feature 0 in {0.0})
  If (feature 6 in {0.0})
   Predict: 2.0
  Else (feature 6 not in {0.0})
   Predict: 0.0
  Else (feature 0 not in {0.0})
  If (feature 50 <= 4.0)
   Predict: 2.0
  Else (feature 50 > 4.0)
   Predict: 3.0
  Else (feature 48 > 148.0)
  If (feature 9 in {0.0})
  If (feature 49 <= 3.0)
   Predict: 2.0
  Else (feature 49 > 3.0)
   Predict: 0.0
  Else (feature 9 not in {0.0})
  If (feature 36 in {0.0})
   Predict: 3.0
  Else (feature 36 not in {0.0})
   Predict: 2.0

后记

该模型预测的平均准确率为80%, 但相对之前的做法规范了很多, 对结果的解析也更加的合理, 不过如何增强预测的效果, 可以考虑更多的因子, 形如:电影是否有前续;电影网站的口碑指数;预告片的播放量;相关微博的阅读数;百度指数等;

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • 深入解析Apache Kafka实时流处理平台

    深入解析Apache Kafka实时流处理平台

    这篇文章主要为大家介绍了Apache Kafka实时流处理平台深入解析,从基本概念到实战操作详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01
  • mybatis快速上手并运行程序

    mybatis快速上手并运行程序

    MyBatis 是一款优秀的持久层框架,它支持自定义 SQL、存储过程以及高级映射。MyBatis 免除了几乎所有的 JDBC 代码以及设置参数和获取结果集的工作。MyBatis 可以通过简单的 XML 或注解来配置和映射原始类型、接口和 Java POJO为数据库中的记录
    2022-01-01
  • 在Spring Boot中使用swagger-bootstrap-ui的方法

    在Spring Boot中使用swagger-bootstrap-ui的方法

    这篇文章主要介绍了在Spring Boot中使用swagger-bootstrap-ui的方法,需要的朋友可以参考下
    2018-01-01
  • idea导入springboot项目没有maven的解决

    idea导入springboot项目没有maven的解决

    这篇文章主要介绍了idea导入springboot项目没有maven的解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-04-04
  • springboot整合freemarker详解

    springboot整合freemarker详解

    本篇文章主要介绍了springboot整合freemarker详解,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-05-05
  • Java多态(动力节点Java学院整理)

    Java多态(动力节点Java学院整理)

    多态是指允许不同类的对象对同一消息做出响应。即同一消息可以根据发送对象的不同而采用多种不同的行为方式。接下来通过本文给大家介绍java多态相关知识,感兴趣的朋友一起学习吧
    2017-04-04
  • MyBatis 的 XML 配置文件和缓存使用步骤

    MyBatis 的 XML 配置文件和缓存使用步骤

    MyBatis 包含一个非常强大的查询缓存特性,它可以非常方便地配置和定制,这篇文章主要介绍了MyBatis的XML配置文件和缓存,需要的朋友可以参考下
    2022-01-01
  • Spring注解驱动之@EventListener注解使用方式

    Spring注解驱动之@EventListener注解使用方式

    这篇文章主要介绍了Spring注解驱动之@EventListener注解使用方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-09-09
  • 使用Java进行Json数据的解析(对象数组的相互嵌套)

    使用Java进行Json数据的解析(对象数组的相互嵌套)

    下面小编就为大家带来一篇使用Java进行Json数据的解析(对象数组的相互嵌套)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-08-08
  • Mybatis-Plus开发提速器generator的使用

    Mybatis-Plus开发提速器generator的使用

    本文就介绍这款基于Mybatis-Plus的代码自助生成器,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-07-07

最新评论