Spark从入门到精通 SQL, DataFrame and Dataset

Spark从入门到精通-SQL, DataFrame and Dataset

  • 官方原文链接

  • 本文代码对应的git地址

  • 《Spark从入门到精通 SQL, DataFrame and Dataset》 本文知识点
  • SparkSession

    • 创建SparkSession

      import org.apache.spark.sql.SparkSession
      val spark = SparkSession
        .builder()
        .appName("Spark SQL basic example")
        .config("spark.some.config.option", "some-value")
        .getOrCreate()
      
    • 读取文件

      • json文件

        val df = spark.read.json("examples/src/main/resources/people.json")
        // 返回为Dataframe
        
      • txt文件

        val peopleDF = spark.sparkContext
          .textFile("examples/src/main/resources/people.txt")
        // 返回为rdd
        
  • 结果消除info日志

    import org.apache.log4j.{Level, Logger}
    val rootLogger = Logger.getRootLogger()
    rootLogger.setLevel(Level.ERROR)
    
  • DataFrame

    • 创建DF

      val df = spark.read.json("examples/src/main/resources/people.json")
      
    • DF操作

    • // 操作DF需要导入
      import spark.implicits._
      
      • 显示

        • 显示内容

          df.show()
          
        • 查看结构

          df.printSchema()
          
      • 列选择

        • 选择其中一列

          df.select("name").show()
          
        • 选择其中几列

          df.select("name", "age").show()
          
      • 列操作

        • 利用$

          df.select($"name", $"age" + 1).show()
          
        • 利用map

          • 通过index

            df.map(teenager => {"Name: " + teenager(1)}).show()
            
          • 通过列名

            df.map(teenager => "Name: " + teenager.getAs[String]("name")).show()
            
        • 利用withColumn

          import org.apache.spark.sql.functions.col
          df.withColumn("age+1",col("age")+1).show()
          
      • 行选择

        • filter

          df.filter($"age" > 21).show()
          
        • where

          df.where("age is not null")
          
      • 聚合函数

        • 聚合获得一个结果

          df.groupBy("age").count().show()
          
        • 聚合获得多个结果

          df.groupBy(col("age"))
            .agg(
            max(col("name")).alias("name_max"),
            min(col("name")).alias("name_min") )
          
      • 空值填充

        df.na.fill(0.0)
        
      • 视图View

        • 临时视图(session级别)

          • 创建

            df.createOrReplaceTempView("people")
            
          • 查询

            val sqlDF = spark.sql("SELECT * FROM people")
            
        • 全局视图(应用级别), 存在于系统数据库global_temp中

          • 创建

            df.createGlobalTempView("people")
            
          • 查询

            • session内查询

              spark.sql("SELECT * FROM global_temp.people").show()
              
            • 跨session查询

              spark.newSession().sql("SELECT * FROM global_temp.people").show()
              
    • rdd转DF

      • toDF函数

        case class Person(name: String, age: Long)
        val peopleDF = spark.sparkContext
          .textFile("examples/src/main/resources/people.txt")
          .map(_.split(","))
          .map(attributes => Person(attributes(0), attributes(1).trim.toInt))
          .toDF()
        
      • Schema

        import org.apache.spark.sql.types._
        import org.apache.spark.sql.Row
        val rowRDD = spark.sparkContext.textFile("/Users/didi/打车金/asd/examples/src/main/resources/people.txt")
        .map(_.split(","))
        .map(attributes => Row(attributes(0), attributes(1).trim))
        val schema = StructType(Array(StructField("name", StringType, true), StructField("age", IntegerType, true)))
        val schemaDF = spark.createDataFrame(rowRDD, schema)
        
  • Dataset

    • 创建DS

      • Seq

        val caseClassDS = Seq(Person("Andy", 32)).toDS()
        val primitiveDS = Seq(1, 2, 3).toDS()
        
      • Json

        case class Person(name: String, age: Long)
        val path = "examples/src/main/resources/people.json"
        val peopleDS = spark.read.json(path).as[Person]
        
    • DS操作

      primitiveDS.map(_ + 1).collect() // Returns: Array(2, 3, 4)
      
  • UDF

    • Sql

      // AggregateFunctions.scala
      
      import org.apache.log4j.{Level, Logger}
      import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
      import org.apache.spark.sql.types._
      import org.apache.spark.sql.{Row, SparkSession}
      
      object AggregateFunctions {
      
        object MyAverage extends UserDefinedAggregateFunction {
          // Data types of input arguments of this aggregate function
          def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
          // Data types of values in the aggregation buffer
          def bufferSchema: StructType = {
            StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
          }
          // The data type of the returned value
          def dataType: DataType = DoubleType
          // Whether this function always returns the same output on the identical input
          def deterministic: Boolean = true
          // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
          // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
          // the opportunity to update its values. Note that arrays and maps inside the buffer are still
          // immutable.
          def initialize(buffer: MutableAggregationBuffer): Unit = {
            buffer(0) = 0L
            buffer(1) = 0L
          }
          // Updates the given aggregation buffer `buffer` with new input data from `input`
          def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
            if (!input.isNullAt(0)) {
              buffer(0) = buffer.getLong(0) + input.getLong(0)
              buffer(1) = buffer.getLong(1) + 1
            }
          }
          // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
          def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
            buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
            buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
          }
          // Calculates the final result
          def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
        }
      
        def main(args: Array[String]): Unit = {
      
          val spark = SparkSession
            .builder()
            .master("local")
            .appName("Spark SQL basic example")
            .config("spark.some.config.option", "some-value")
            .getOrCreate()
      
          val rootLogger = Logger.getRootLogger()
          rootLogger.setLevel(Level.ERROR)
      
          spark.udf.register("myAverage", MyAverage)
      
          //    Creating DataFrames
          val df = spark.read.json("/Users/didi/learn/learnSpark/src/main/resources/employees.json")
          df.createOrReplaceTempView("employees")
          df.show()
      
          val result = spark.sql("select myAverage(salary) as average_salary from employees")
          result.show()
      
        }
      
      }
      
    • DataSet

      // AggregateFunctions2.scala
      import org.apache.log4j.{Level, Logger}
      import org.apache.spark.sql.expressions.Aggregator
      import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
      
      object AggregateFunctions2 {
      
        case class Employee(name: String, salary: Long)
        case class Average(var sum: Long, var count: Long)
      
        object MyAverage extends Aggregator[Employee, Average, Double] {
      
          def zero: Average = Average(0L, 0L)
      
          def reduce(buffer: Average, employee: Employee): Average = {
            buffer.sum += employee.salary
            buffer.count += 1
            buffer
          }
      
          def merge(b1: Average, b2: Average): Average ={
            b1.sum += b2.sum
            b1.count += b2.count
            b1
          }
      
          def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
      
          def bufferEncoder: Encoder[Average] = Encoders.product
      
          def outputEncoder: Encoder[Double] = Encoders.scalaDouble
      
        }
      
        def main(args: Array[String]): Unit = {
      
            val spark = SparkSession
              .builder()
              .master("local")
              .appName("Spark SQL basic example")
              .config("spark.some.config.option", "some-value")
              .getOrCreate()
            import spark.implicits._
      
            val rootLogger = Logger.getRootLogger()
            rootLogger.setLevel(Level.ERROR)
      
            //    Creating DataFrames
            val ds = spark.read.json("/Users/didi/learn/learnSpark/src/main/resources/employees.json").as[Employee]
            ds.show()
      
            val averageSalary = MyAverage.toColumn.name("average_salary")
            val result = ds.select(averageSalary)
      
            result.show()
      
      
        }
      
      }
      
      
    原文作者:enjoy_算法工程师
    原文地址: https://www.jianshu.com/p/d689014ac0a8
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞