基于spark的时间序列预测包Sparkts._的使用

最近研究了一下时间序列预测的使用,网上找了大部分的资源,都是使用python来实现的,使用python来实现虽然能满足大部分的需求,但是python有一点缺点按就是只能使用一台计算资源进行计算,如果数据量大的时候,就有可能不能胜任,虽然这种情况很少,但是还是有可能会发生,因此就查了一下spark有没有这方面的资料,没想到还真的有,使用spark集群进行计算速度方面提升明显。

项目接地址:https://github.com/sryza/spark-timeseries

首先非常感谢这位博主,我是在学习了他的代码之下才能更好的理解spark-timeseries的使用。

博客链接:http://blog.csdn.net/qq_30232405/article/details/70622400

下面是我对代码的改进,主要是调整的是时间类型的通用性与arima模型能自定义pdq参数等,能通用大部分类型的时间。

TimeFormatUtils.java

import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.HashMap;
import java.util.regex.Pattern;

public class TimeFormatUtils {


    /**
     * 获取时间类型格式
     *
     * @param timeStr
     * @return
     */
    public static String getDateType(String timeStr) {
        HashMap<String, String> dateRegFormat = new HashMap<String, String>();
        dateRegFormat.put("^\\d{4}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D*$", "yyyy-MM-dd HH:mm:ss");//2014年3月12日 13时5分34秒,2014-03-12 12:05:34,2014/3/12 12:5:34
        dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH:mm");//2014-03-12 12:05
        dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH");//2014-03-12 12
        dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd");//2014-03-12
        dateRegFormat.put("^\\d{4}\\D+\\d{2}$", "yyyy-MM");//2014-03
        dateRegFormat.put("^\\d{4}$", "yyyy");//2014
        dateRegFormat.put("^\\d{14}$", "yyyyMMddHHmmss");//20140312120534
        dateRegFormat.put("^\\d{12}$", "yyyyMMddHHmm");//201403121205
        dateRegFormat.put("^\\d{10}$", "yyyyMMddHH");//2014031212
        dateRegFormat.put("^\\d{8}$", "yyyyMMdd");//20140312
        dateRegFormat.put("^\\d{6}$", "yyyyMM");//201403

        try {
            for (String key : dateRegFormat.keySet()) {
                if (Pattern.compile(key).matcher(timeStr).matches()) {
                    String formater = "";
                    if (timeStr.contains("/"))
                        return dateRegFormat.get(key).replaceAll("-", "/");
                    else
                        return dateRegFormat.get(key);

                }
            }
        } catch (Exception e) {
            System.err.println("-----------------日期格式无效:" + timeStr);
            e.printStackTrace();
        }
        return null;
    }

    public static String fromatData(String time, SimpleDateFormat format) {
        try {
            SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
            return formatter.format(format.parse(time));
        } catch (ParseException e) {
            e.printStackTrace();
        }
        return null;
    }
}

TimeSeriesTrain.scala


import java.sql.Timestamp
import java.text.SimpleDateFormat
import java.time.{ZoneId, ZonedDateTime}

import com.cloudera.sparkts._
import com.sendi.TimeSeries.Util.TimeFormatUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

/**
  * 时间序列模型time-series的建立
  */
object TimeSeriesTrain {

  /**
    * 总方法调用
    */
  def timeSeries(args: Array[String]) {
    args.foreach(println)

    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

    /**
      * 1、初始化spark环境
      */
    val sparkSession = SparkSession.builder
      .master("local[4]").appName("SparkTest")
      .enableHiveSupport() //创建支持HiveContext;
      .getOrCreate()

    /**
      * 2、初始化参数
      */
    //hive中的数据库名字
    val databaseTableName = args(0)
    //输入的列名必须是time data
    val hiveColumnName = List(args(1).toString.split(","): _*)
    //开始与结束时间
    val startTime = args(2)
    val endTime = args(3)
    //获取时间类型
    val sdf = new SimpleDateFormat(TimeFormatUtils.getDateType(startTime))
    //时间跨度
    val timeSpanType = args(4)
    val timeSpan = args(5).toInt

    //预测后面N个值
    val predictedN = args(6).toInt
    //存放的表名字
    val outputTableName = args(7)

    var listPDQ: List[String] = List("")
    var period = 0
    var holtWintersModelType = ""

    //选择模型(holtwinters或者是arima)
    val modelName = args(8)

    //根据不同的类型赋值不同的参数
    if (modelName.equals("arima")) {
      listPDQ = List(args(9).toString.split(","): _*)
    } else {
      //季节性参数(12或者4)
      period = args(9).toInt
      //holtWinters选择模型:additive(加法模型)、Multiplicative(乘法模型)
      holtWintersModelType = args(10)
    }

    /**
      * 3、 读取数据源,最终转换成 {time key data} 这种类型的RDD格式
      */
    val timeDataKeyDf = readHiveData(sparkSession, databaseTableName, hiveColumnName)
    val zonedDateDataDf = timeChangeToDate(sparkSession, timeDataKeyDf, hiveColumnName, startTime, sdf)

    /**
      * 4、创建数据中时间的跨度(Create an daily DateTimeIndex):开始日期+结束日期+递增数
      * 日期的格式要与数据库中time数据的格式一样
      */
    val dtIndex = getTimeSpan(startTime, endTime, timeSpanType, timeSpan, sdf)

    /**
      * 5、创建训练数据
      */
    val trainTsrdd = TimeSeriesRDD.timeSeriesRDDFromObservations(dtIndex, zonedDateDataDf,
      hiveColumnName(0), hiveColumnName(0) + "Key", hiveColumnName(1))
    trainTsrdd.cache()
    //填充缺失值
    val filledTrainTsrdd = trainTsrdd.fill("linear")

    /**
      * 6、建立模型对象,并使用训练数据进行训练
      */
    val timeSeriesKeyModel = new TimeSeriesKeyModel(predictedN, outputTableName)
    var forecastValue: RDD[(String, Vector)] = sparkSession.sparkContext.parallelize(Seq(("", Vectors.dense(1))))
    //选择模型
    modelName match {
      case "arima" => {
        //创建和训练arima模型
        val (forecast, coefficients) = timeSeriesKeyModel.arimaModelTrainKey(filledTrainTsrdd, listPDQ)
        //Arima模型评估参数的保存
        forecastValue = forecast
        timeSeriesKeyModel.arimaModelKeyEvaluationSave(sparkSession, coefficients, forecast)
      }
      case "holtwinters" => {
        //创建和训练HoltWinters模型(季节性模型)
        val (forecast, sse) = timeSeriesKeyModel.holtWintersModelTrainKey(filledTrainTsrdd, period, holtWintersModelType)
        //HoltWinters模型评估参数的保存
        forecastValue = forecast
        timeSeriesKeyModel.holtWintersModelKeyEvaluationSave(sparkSession, sse, forecast)
      }
      case _ => throw new UnsupportedOperationException("Currently only supports 'ariam' and 'holtwinters")
    }

    /**
      * 7、合并实际值和预测值,并加上日期,形成dataframe(Date,Data),并保存
      */
    timeSeriesKeyModel.actualForcastDateKeySaveInHive(sparkSession, filledTrainTsrdd, forecastValue, predictedN, startTime,
      endTime, timeSpanType, timeSpan, sdf, hiveColumnName)
  }

  /**
    * 读取hive中的数据,并对其进行处理操作,返回 time data key
    *
    * @param sparkSession
    * @param databaseTableName
    * @param hiveColumnName
    */
  def readHiveData(sparkSession: SparkSession, databaseTableName: String, hiveColumnName: List[String]): DataFrame = {
    //read the data form the hive  where的作用是取出字段为time的列
    var hiveDataDf = sparkSession.sql("select * from " + databaseTableName + " where " + hiveColumnName(0) + " !='" + hiveColumnName(0) + "'")
      .select(hiveColumnName.head, hiveColumnName.tail: _*)

    //去除空值
    hiveDataDf = hiveDataDf.filter(hiveColumnName(1) + " != ''")

    //In hiveDataDF:increase a new column.This column's name is hiveColumnName(0)+"Key",it's value is 0.
    //timeDataKeyDf : time data timeKey column
    val timeDataKeyDf = hiveDataDf.withColumn(hiveColumnName(0) + "Key", hiveDataDf(hiveColumnName(1)) * 0.toString)
      .select(hiveColumnName(0), hiveColumnName(1), hiveColumnName(0) + "Key")
    timeDataKeyDf
  }


  /**
    * 把数据中的“time”列转换成固定时间格式:ZonedDateTime(such as 2007-12-03T10:15:30+01:00 Europe/Paris.)
    *
    * @param sparkSession
    * @param timeDataKeyDf
    * @param hiveColumnName
    * @param startTime
    * @param sdf
    * @return
    */
  def timeChangeToDate(sparkSession: SparkSession, timeDataKeyDf: DataFrame, hiveColumnName: List[String], startTime: String,
                       sdf: SimpleDateFormat): DataFrame = {
    var rowRDD: RDD[Row] = sparkSession.sparkContext.parallelize(Seq(Row(""), Row("")))
    rowRDD = timeDataKeyDf.rdd.map { row =>
      row match {
        case Row(time, data, key) => {
          val date = sdf.parse(time.toString)
          val timestamp = new Timestamp(date.getTime)
          Row(timestamp, key.toString, data.toString.toDouble)
        }
      }
    }

    //根据模式字符串生成模式,转化成dataframe格式
    var field = Seq(
      StructField(hiveColumnName(0), TimestampType, true),
      StructField(hiveColumnName(0) + "Key", StringType, true),
      StructField(hiveColumnName(1), DoubleType, true))
    val schema = StructType(field)
    val zonedDateDataDf = sparkSession.createDataFrame(rowRDD, schema)
    return zonedDateDataDf
  }

  /**
    * 获取时间区间与时间跨度
    *
    * @param timeSpanType
    * @param timeSpan
    * @param sdf
    * @param startTime
    * @param endTime
    */
  def getTimeSpan(startTime: String, endTime: String, timeSpanType: String, timeSpan: Int, sdf: SimpleDateFormat): UniformDateTimeIndex = {
    val start = TimeFormatUtils.fromatData(startTime, sdf)
    val end = TimeFormatUtils.fromatData(endTime, sdf)

    val zone = ZoneId.systemDefault()
    val frequency = timeSpanType match {
      case "year" => new YearFrequency(timeSpan);
      case "month" => new MonthFrequency(timeSpan);
      case "day" => new DayFrequency(timeSpan);
      case "hour" => new HourFrequency(timeSpan);
      case "minute" => new MinuteFrequency(timeSpan);
    }

    val dtIndex: UniformDateTimeIndex = DateTimeIndex.uniformFromInterval(
      ZonedDateTime.of(start.substring(0, 4).toInt, start.substring(5, 7).toInt, start.substring(8, 10).toInt,
        start.substring(11, 13).toInt, start.substring(14, 16).toInt, 0, 0, zone),
      ZonedDateTime.of(end.substring(0, 4).toInt, end.substring(5, 7).toInt, end.substring(8, 10).toInt,
        end.substring(11, 13).toInt, end.substring(14, 16).toInt, 0, 0, zone),
      frequency)
    return dtIndex
  }
}

TimeSeriesKeyModel.scala

import java.text.SimpleDateFormat
import java.util.Calendar

import com.cloudera.sparkts.TimeSeriesRDD
import com.cloudera.sparkts.models.{ARIMA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import scala.collection.mutable.ArrayBuffer

/**
  * 时间序列模型(处理的数据多一个key列)
  * Created by llq on 2017/5/3.
  */
class TimeSeriesKeyModel {
  //预测后面N个值
  private var predictedN = 1
  //存放的表名字
  private var outputTableName = "time_series.timeseries_output"

  def this(predictedN: Int, outputTableName: String) {
    this()
    this.predictedN = predictedN
    this.outputTableName = outputTableName
  }

  /**
    * 实现Arima模型,处理数据是多一个key列
    *
    * @param trainTsrdd
    * @return
    */
  def arimaModelTrainKey(trainTsrdd: TimeSeriesRDD[String], listPDQ: List[String]): (RDD[(String, Vector)], RDD[(String, (String, (String, String, String), String, String))]) = {
    /** *参数设置 ******/
    val predictedN = this.predictedN

    /** *创建arima模型 ***/
    //创建和训练arima模型.其RDD格式为(ArimaModel,Vector)
    val arimaAndVectorRdd = trainTsrdd.map { line =>
      line match {
        case (key, denseVector) => {
          if (listPDQ.size >= 3) {
            (key, ARIMA.fitModel(listPDQ(0).toInt, listPDQ(1).toInt, listPDQ(2).toInt, denseVector), denseVector)
          } else {
            (key, ARIMA.autoFit(denseVector), denseVector)
          }
        }
      }
    }

    /** 参数输出:p,d,q的实际值和其系数值、最大似然估计值、aic值 **/
    val coefficients = arimaAndVectorRdd.map { line =>
      line match {
        case (key, arimaModel, denseVector) => {
          (key, (arimaModel.coefficients.mkString(","),
            (arimaModel.p.toString,
              arimaModel.d.toString,
              arimaModel.q.toString),
            arimaModel.logLikelihoodCSS(denseVector).toString,
            arimaModel.approxAIC(denseVector).toString))
        }
      }
    }

    coefficients.collect().map {
      _ match {
        case (key, (coefficients, (p, d, q), logLikelihood, aic)) =>
          println(key + " coefficients:" + coefficients + "=>" + "(p=" + p + ",d=" + d + ",q=" + q + ")")
      }
    }

    /** *预测出后N个的值 *****/
    val forecast = arimaAndVectorRdd.map { row =>
      row match {
        case (key, arimaModel, denseVector) => {
          (key, arimaModel.forecast(denseVector, predictedN))
        }
      }
    }

    //取出预测值
    val forecastValue = forecast.map {
      _ match {
        case (key, value) => {
          val partArray = value.toArray.mkString(",").split(",")
          var forecastArrayBuffer = new ArrayBuffer[Double]()
          var i = partArray.length - predictedN
          while (i < partArray.length) {
            forecastArrayBuffer += partArray(i).toDouble
            i = i + 1
          }
          (key, Vectors.dense(forecastArrayBuffer.toArray))
        }
      }
    }

    println("Arima forecast of next " + predictedN + " observations:")
    forecastValue.foreach(println)
    return (forecastValue, coefficients)
  }


  /**
    * Arima模型评估参数的保存
    * coefficients、(p、d、q)、logLikelihoodCSS、Aic、mean、variance、standard_deviation、max、min、range、count
    *
    * @param sparkSession
    * @param coefficients
    * @param forecastValue
    */
  def arimaModelKeyEvaluationSave(sparkSession: SparkSession, coefficients: RDD[(String, (String, (String, String, String), String, String))], forecastValue: RDD[(String, Vector)]): Unit = {
    /** 把vector转置 **/
    val forecastRdd = forecastValue.map {
      _ match {
        case (key, forecast) => forecast.toArray
      }
    }
    // Split the matrix into one number per line.
    val byColumnAndRow = forecastRdd.zipWithIndex.flatMap {
      case (row, rowIndex) => row.zipWithIndex.map {
        case (number, columnIndex) => columnIndex -> (rowIndex, number)
      }
    }
    // Build up the transposed matrix. Group and sort by column index first.
    val byColumn = byColumnAndRow.groupByKey.sortByKey().values
    // Then sort by row index.
    val transposed = byColumn.map {
      indexedRow => indexedRow.toSeq.sortBy(_._1).map(_._2)
    }
    val summary = Statistics.colStats(transposed.map(value => Vectors.dense(value(0))))

    /** 统计求出预测值的均值、方差、标准差、最大值、最小值、极差、数量等;合并模型评估数据+统计值 **/
    //评估模型的参数+预测出来数据的统计值
    val evaluation = coefficients.join(forecastValue.map {
      _ match {
        case (key, forecast) => {
          (key, (summary.mean.toArray(0).toString,
            summary.variance.toArray(0).toString,
            math.sqrt(summary.variance.toArray(0)).toString,
            summary.max.toArray(0).toString,
            summary.min.toArray(0).toString,
            (summary.max.toArray(0) - summary.min.toArray(0)).toString,
            summary.count.toString))
        }
      }
    })

    val evaluationRddRow = evaluation.map {
      _ match {
        case (key, ((coefficients, pdq, logLikelihoodCSS, aic), (mean, variance, standardDeviation, max, min, range, count))) => {
          Row(coefficients, pdq.toString, logLikelihoodCSS, aic, mean, variance, standardDeviation, max, min, range, count)
        }
      }
    }

    //形成评估dataframe
    val schemaString = "coefficients,pdq,logLikelihoodCSS,aic,mean,variance,standardDeviation,max,min,range,count"
    val schema = StructType(schemaString.split(",").map(fileName => StructField(fileName, StringType, true)))
    val evaluationDf = sparkSession.createDataFrame(evaluationRddRow, schema)

    println("Evaluation in Arima:")
    evaluationDf.show()

    /**
      * 把这份数据保存到hive与db中
      */
    evaluationDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName + "_arima_evaluation")
  }


  /**
    * 实现holtwinters模型,处理的数据多一个key列
    *
    * @param trainTsrdd
    * @param period
    * @param holtWintersModelType
    * @return
    */
  def holtWintersModelTrainKey(trainTsrdd: TimeSeriesRDD[String], period: Int, holtWintersModelType: String): (RDD[(String, Vector)], RDD[(String, Double)]) = {
    /** *参数设置 ******/
    //往后预测多少个值
    val predictedN = this.predictedN

    /** *创建HoltWinters模型 ***/
    //创建和训练HoltWinters模型.其RDD格式为(HoltWintersModel,Vector)
    val holtWintersAndVectorRdd = trainTsrdd.map { line =>
      line match {
        case (key, denseVector) =>
          (key, HoltWinters.fitModel(denseVector, period, holtWintersModelType), denseVector)
      }
    }

    /** *预测出后N个的值 *****/
    //构成N个预测值向量,之后导入到holtWinters的forcast方法中
    val predictedArrayBuffer = new ArrayBuffer[Double]()
    var i = 0
    while (i < predictedN) {
      predictedArrayBuffer += i
      i = i + 1
    }
    val predictedVectors = Vectors.dense(predictedArrayBuffer.toArray)

    //预测
    val forecast = holtWintersAndVectorRdd.map { row =>
      row match {
        case (key, holtWintersModel, denseVector) => {
          (key, holtWintersModel.forecast(denseVector, predictedVectors))
        }
      }
    }
    println("HoltWinters forecast of next " + predictedN + " observations:")
    forecast.foreach(println)

    /** holtWinters模型评估度量:SSE和方差 **/
    val sse = holtWintersAndVectorRdd.map { row =>
      row match {
        case (key, holtWintersModel, denseVector) => {
          (key, holtWintersModel.sse(denseVector))
        }
      }
    }
    return (forecast, sse)
  }

  /**
    * HoltWinters模型评估参数的保存
    * sse、mean、variance、standard_deviation、max、min、range、count
    *
    * @param sparkSession
    * @param sse
    * @param forecastValue
    */
  def holtWintersModelKeyEvaluationSave(sparkSession: SparkSession, sse: RDD[(String, Double)], forecastValue: RDD[(String, Vector)]): Unit = {
    /** 把vector转置 **/
    val forecastRdd = forecastValue.map {
      _ match {
        case (key, forecast) => forecast.toArray
      }
    }
    // Split the matrix into one number per line.
    val byColumnAndRow = forecastRdd.zipWithIndex.flatMap {
      case (row, rowIndex) => row.zipWithIndex.map {
        case (number, columnIndex) => columnIndex -> (rowIndex, number)
      }
    }
    // Build up the transposed matrix. Group and sort by column index first.
    val byColumn = byColumnAndRow.groupByKey.sortByKey().values
    // Then sort by row index.
    val transposed = byColumn.map {
      indexedRow => indexedRow.toSeq.sortBy(_._1).map(_._2)
    }
    val summary = Statistics.colStats(transposed.map(value => Vectors.dense(value(0))))

    /** 统计求出预测值的均值、方差、标准差、最大值、最小值、极差、数量等;合并模型评估数据+统计值 **/
    //评估模型的参数+预测出来数据的统计值
    val evaluation = sse.join(forecastValue.map {
      _ match {
        case (key, forecast) => {
          (key, (summary.mean.toArray(0).toString,
            summary.variance.toArray(0).toString,
            math.sqrt(summary.variance.toArray(0)).toString,
            summary.max.toArray(0).toString,
            summary.min.toArray(0).toString,
            (summary.max.toArray(0) - summary.min.toArray(0)).toString,
            summary.count.toString))
        }
      }
    })

    val evaluationRddRow = evaluation.map {
      _ match {
        case (key, (sse, (mean, variance, standardDeviation, max, min, range, count))) => {
          Row(sse.toString, mean, variance, standardDeviation, max, min, range, count)
        }
      }
    }
    //形成评估dataframe
    val schemaString = "sse,mean,variance,standardDeviation,max,min,range,count"
    val schema = StructType(schemaString.split(",").map(fileName => StructField(fileName, StringType, true)))
    val evaluationDf = sparkSession.createDataFrame(evaluationRddRow, schema)

    println("Evaluation in HoltWinters:")
    evaluationDf.show()

    /**
      * 存入到hive与db中
      */
    evaluationDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName + "_holtwinters_evaluation")
  }

  /**
    * 把信息存储到hive中
    *
    * @param sparkSession
    * @param dateDataRdd
    * @param hiveColumnName
    */
  private def keySaveInHive(sparkSession: SparkSession, dateDataRdd: RDD[Row], hiveColumnName: List[String]): Unit = {
    //把dateData转换成dataframe
    val schemaString = hiveColumnName(0) + " " + hiveColumnName(1)
    val schema = StructType(schemaString.split(" ")
      .map(fieldName => StructField(fieldName, StringType, true)))
    val dateDataDf = sparkSession.createDataFrame(dateDataRdd, schema)

    //dateDataDf存进hive中
    dateDataDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName)
  }


  /**
    * 合并实际值和预测值,并加上日期,形成dataframe(Date,Data)
    *
    * @param sparkSession
    * @param trainTsrdd
    * @param forecastValue
    * @param predictedN
    * @param startTime
    * @param endTime
    * @param timeSpanType
    * @param timeSpan
    * @param sdf
    * @param hiveColumnName
    */
  def actualForcastDateKeySaveInHive(sparkSession: SparkSession, trainTsrdd: TimeSeriesRDD[String], forecastValue: RDD[(String, Vector)],
                                     predictedN: Int, startTime: String, endTime: String, timeSpanType: String, timeSpan: Int,
                                     sdf: SimpleDateFormat, hiveColumnName: List[String]): Unit = {
    //在真实值后面追加预测值
    val actualAndForcastRdd = trainTsrdd.map {
      _ match {
        case (key, actualValue) => (key, actualValue.toArray.mkString(","))
      }
    }.join(forecastValue.map {
      _ match {
        case (key, forecastValue) => (key, forecastValue.toArray.mkString(","))
      }
    })

    //获取从开始预测到预测后的时间,转成RDD形式
    val dateArray = productStartDatePredictDate(predictedN, timeSpanType, timeSpan, sdf, startTime, endTime)

    val dateRdd = sparkSession.sparkContext.parallelize(dateArray.toArray.mkString(",").split(",").map(date => (date)))

    //合并日期和数据值,形成RDD[Row]+keyName
    val actualAndForcastArray = actualAndForcastRdd.collect()
    for (i <- 0 until actualAndForcastArray.length) {
      val dateDataRdd = actualAndForcastArray(i) match {
        case (key, value) => {
          val actualAndForcast = sparkSession.sparkContext.parallelize(value.toString().split(",")
            .map(data => {
              data.replaceAll("\\(", "").replaceAll("\\)", "")
            }))
          dateRdd.zip(actualAndForcast).map {
            _ match {
              case (date, data) => Row(date, data)
            }
          }

        }
      }
      //保存信息
      if (dateDataRdd.collect()(0).toString() != "[1]") {
        keySaveInHive(sparkSession, dateDataRdd, hiveColumnName)
      }
    }
  }

  /**
    * 批量生成日期,时间段为:训练数据的开始到预测的结束
    *
    * @param predictedN
    * @param timeSpanType
    * @param timeSpan
    * @param format
    * @param startTime
    * @param endTime
    * @return
    */
  def productStartDatePredictDate(predictedN: Int, timeSpanType: String, timeSpan: Int,
                                  format: SimpleDateFormat, startTime: String, endTime: String): ArrayBuffer[String] = {
    //形成开始start到预测predicted的日期
    val cal1 = Calendar.getInstance()
    cal1.setTime(format.parse(startTime))
    val cal2 = Calendar.getInstance()
    cal2.setTime(format.parse(endTime))

    /**
      * 获取时间差
      */
    var field = 1
    var diff: Long = 0
    timeSpanType match {
      case "year" => {
        field = Calendar.YEAR
        diff = (cal2.getTime.getYear() - cal1.getTime.getYear()) / timeSpan + predictedN;
      }
      case "month" => {
        field = Calendar.MONTH
        diff = ((cal2.getTime.getYear() - cal1.getTime.getYear()) * 12 + (cal2.getTime.getMonth() - cal1.getTime.getMonth())) / timeSpan + predictedN
      }
      case "day" => {
        field = Calendar.DATE
        diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60 * 60 * 24) / timeSpan + predictedN
      }
      case "hour" => {
        field = Calendar.HOUR
        diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60 * 60) / timeSpan + predictedN
      }
      case "minute" => {
        field = Calendar.MINUTE
        diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60) / timeSpan + predictedN;
      }
    }

    var iDiff = 0L;
    var dateArrayBuffer = new ArrayBuffer[String]()
    while (iDiff <= diff) {
      //保存日期
      dateArrayBuffer += format.format(cal1.getTime)
      cal1.add(field, timeSpan)
      iDiff = iDiff + 1;
    }
    dateArrayBuffer
  }
}
    原文作者:e辉
    原文地址: https://www.jianshu.com/p/1cb74b1adc84
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞