# [DC_01] Introduction to PySpark 2022/07 [Datacamp course - Introduction to PySpark](https://app.datacamp.com/learn/courses/introduction-to-pyspark) [my_notes(hackmd.ver)](https://hackmd.io/VBqBDscUTXSe48TN8SRnAQ?view) :::info [TOC] ::: ## 1. Getting to know PySpark In this chapter, you'll learn how Spark manages data and how can you read and write tables from Python. ### What is Spark, anyway? Spark is a platform for cluster computing. Spark lets you spread data and computations over **clusters** with multiple **nodes** (think of each node as a separate computer). Splitting up your data makes it easier to work with very large datasets because each node only works with a small amount of data. As each node works on its own subset of the total data, it also carries out a part of the total calculations required, so that both data processing and computation are performed in parallel over the nodes in the cluster. It is a fact that parallel computation can make certain types of programming tasks much faster. However, with greater computing power comes greater complexity. Deciding whether or not Spark is the best solution for your problem takes some experience, but you can consider questions like: - Is my data too big to work with on a single machine? - Can my calculations be easily parallelized? ### Using Spark in Python The first step in using Spark is connecting to a cluster. In practice, the cluster will be hosted on a remote machine that's connected to all other nodes. There will be one computer, called the **master** that manages splitting up the data and the computations. The master is connected to the rest of the computers in the cluster, which are called **worker**. The master sends the workers data and calculations to run, and they send their results back to the master. When you're just getting started with Spark it's simpler to just run a cluster locally. Thus, for this course, instead of connecting to another computer, all computations will be run on DataCamp's servers in a simulated cluster. Creating the connection is as simple as creating an instance of the `SparkContext` class. The class constructor takes a few optional arguments that allow you to specify the attributes of the cluster you're connecting to. An object holding all these attributes can be created with the `SparkConf()` constructor. Take a look at the documentation for all the details! For the rest of this course you'll have a `SparkContext` called `sc` already available in your workspace. ### Examining The SparkContext ```python # Verify SparkContext print(sc) # >>> <SparkContext master=local[*] appName=pyspark-shell> # Print Spark version print(sc.version) # >>>3.2.0 ``` ### Using DataFrames Spark's core data structure is the Resilient Distributed Dataset (RDD). This is a low level object that lets Spark work its magic by splitting data across multiple nodes in the cluster. However, RDDs are hard to work with directly, so in this course you'll be using the Spark DataFrame abstraction built on top of RDDs. The Spark DataFrame was designed to behave a lot like a SQL table (a table with variables in the columns and observations in the rows). Not only are they easier to understand, DataFrames are also more optimized for complicated operations than RDDs. When you start modifying and combining columns and rows of data, there are many ways to arrive at the same result, but some often take much longer than others. When using RDDs, it's up to the data scientist to figure out the right way to optimize the query, but the DataFrame implementation has much of this optimization built in! To start working with Spark DataFrames, you first have to create a `SparkSession` object from your `SparkContext`. You can think of the `SparkContext` as your connection to the cluster and the `SparkSession` as your interface with that connection. Remember, for the rest of this course you'll have a `SparkSession` called `spark` available in your workspace! ### Creating a SparkSession ```python # Import SparkSession from pyspark.sql from pyspark.sql import SparkSession # Create my_spark my_spark = SparkSession.builder.getOrCreate() # Print my_spark print(my_spark) # >>> <pyspark.sql.session.SparkSession object at 0x7f6f5bc62910> ``` ### Viewing tables ```python # Print the tables in the catalog print(spark.catalog.listTables()) # >>> [Table(name='flights', database=None, description=None, tableType='TEMPORARY', isTemporary=True)] ``` ### Are you query-ious? ```python # Don't change this query query = "FROM flights SELECT * LIMIT 10" # Get the first 10 rows of flights flights10 = spark.sql(query) # Show the results flights10.show() ``` output ``` +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ |year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute| +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ |2014| 12| 8| 658| -7| 935| -5| VX| N846VA| 1780| SEA| LAX| 132| 954| 6| 58| |2014| 1| 22| 1040| 5| 1505| 5| AS| N559AS| 851| SEA| HNL| 360| 2677| 10| 40| |2014| 3| 9| 1443| -2| 1652| 2| VX| N847VA| 755| SEA| SFO| 111| 679| 14| 43| |2014| 4| 9| 1705| 45| 1839| 34| WN| N360SW| 344| PDX| SJC| 83| 569| 17| 5| |2014| 3| 9| 754| -1| 1015| 1| AS| N612AS| 522| SEA| BUR| 127| 937| 7| 54| |2014| 1| 15| 1037| 7| 1352| 2| WN| N646SW| 48| PDX| DEN| 121| 991| 10| 37| |2014| 7| 2| 847| 42| 1041| 51| WN| N422WN| 1520| PDX| OAK| 90| 543| 8| 47| |2014| 5| 12| 1655| -5| 1842| -18| VX| N361VA| 755| SEA| SFO| 98| 679| 16| 55| |2014| 4| 19| 1236| -4| 1508| -7| AS| N309AS| 490| SEA| SAN| 135| 1050| 12| 36| |2014| 11| 19| 1812| -3| 2352| -4| AS| N564AS| 26| SEA| ORD| 198| 1721| 18| 12| +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ ``` ### Pandafy a Spark DataFrame ```python # Don't change this query query = "SELECT origin, dest, COUNT(*) as N FROM flights GROUP BY origin, dest" # Run the query flight_counts = spark.sql(query) # Convert the results to a pandas DataFrame pd_counts = flight_counts.toPandas() # Print the head of pd_counts print(pd_counts.head()) # origin dest N # 0 SEA RNO 8 # 1 SEA DTW 98 # 2 SEA CLE 2 # 3 SEA LAX 450 # 4 PDX SEA 144 ``` ### Put some Spark in your data ![](https://i.imgur.com/krrmYru.png) ```python # Create pd_temp pd_temp = pd.DataFrame(np.random.random(10)) # Create spark_temp from pd_temp spark_temp = spark.createDataFrame(pd_temp) #---session # Examine the tables in the catalog print(spark.catalog.listTables()) # Add spark_temp to the catalog spark_temp.createOrReplaceTempView("temp") #---tablename # Examine the tables in the catalog again print(spark.catalog.listTables()) # <script.py> output: # [Table(name='temp', database=None, description=None, tableType='TEMPORARY', isTemporary=True)] # [Table(name='temp', database=None, description=None, tableType='TEMPORARY', isTemporary=True)] ``` ### Dropping the middle man ```python # Don't change this file path file_path = "/usr/local/share/datasets/airports.csv" # Read in the airports data airports = spark.read.csv(file_path, header=True) # Show the data airports.show() ``` output ``` +---+--------------------+----------------+-----------------+----+---+---+ |faa| name| lat| lon| alt| tz|dst| +---+--------------------+----------------+-----------------+----+---+---+ |04G| Lansdowne Airport| 41.1304722| -80.6195833|1044| -5| A| |06A|Moton Field Munic...| 32.4605722| -85.6800278| 264| -5| A| |06C| Schaumburg Regional| 41.9893408| -88.1012428| 801| -6| A| |06N| Randall Airport| 41.431912| -74.3915611| 523| -5| A| |09J|Jekyll Island Air...| 31.0744722| -81.4277778| 11| -4| A| |0A9|Elizabethton Muni...| 36.3712222| -82.1734167|1593| -4| A| |0G6|Williams County A...| 41.4673056| -84.5067778| 730| -5| A| |0G7|Finger Lakes Regi...| 42.8835647| -76.7812318| 492| -5| A| |0P2|Shoestring Aviati...| 39.7948244| -76.6471914|1000| -5| U| |0S9|Jefferson County ...| 48.0538086| -122.8106436| 108| -8| A| |0W3|Harford County Ai...| 39.5668378| -76.2024028| 409| -5| A| |10C| Galt Field Airport| 42.4028889| -88.3751111| 875| -6| U| |17G|Port Bucyrus-Craw...| 40.7815556| -82.9748056|1003| -5| A| |19A|Jackson County Ai...| 34.1758638| -83.5615972| 951| -4| U| |1A3|Martin Campbell F...| 35.0158056| -84.3468333|1789| -4| A| |1B9| Mansfield Municipal| 42.0001331| -71.1967714| 122| -5| A| |1C9|Frazier Lake Airpark|54.0133333333333|-124.768333333333| 152| -8| A| |1CS|Clow Internationa...| 41.6959744| -88.1292306| 670| -6| U| |1G3| Kent State Airport| 41.1513889| -81.4151111|1134| -4| A| |1OH| Fortman Airport| 40.5553253| -84.3866186| 885| -5| U| +---+--------------------+----------------+-----------------+----+---+---+ only showing top 20 rows ``` ## 2. Manipulating data In this chapter, you'll learn about the pyspark.sql module, which provides optimized data queries to your Spark session. ### Creating columns ```python # Create the DataFrame flights flights = spark.table("flights") # Show the head flights.show() # Add duration_hrs flights = flights.withColumn("duration_hrs",flights.air_time/60) ``` output ``` +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ |year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute| +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ |2014| 12| 8| 658| -7| 935| -5| VX| N846VA| 1780| SEA| LAX| 132| 954| 6| 58| |2014| 1| 22| 1040| 5| 1505| 5| AS| N559AS| 851| SEA| HNL| 360| 2677| 10| 40| |2014| 3| 9| 1443| -2| 1652| 2| VX| N847VA| 755| SEA| SFO| 111| 679| 14| 43| ... |2014| 6| 5| 1733| -12| 1945| -10| OO| N215AG| 3488| PDX| BUR| 111| 817| 17| 33| +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ only showing top 20 rows ``` ### SQL in a nutshell As you move forward, it will help to have a basic understanding of SQL. A more in depth look can be found [here](https://www.datacamp.com/courses/intro-to-sql-for-data-science). A SQL query returns a table derived from one or more tables contained in a database. Every SQL query is made up of commands that tell the database what you want to do with the data. The two commands that every query has to contain are `SELECT` and `FROM`. The `SELECT` command is followed by the columns you want in the resulting table. The `FROM` command is followed by the name of the table that contains those columns. The minimal SQL query is: ```sql SELECT * FROM my_table; ``` The * selects all columns, so this returns the entire table named my_table. Similar to `.withColumn()`, you can do column-wise computations within a `SELECT` statement. For example, ```sql SELECT origin, dest, air_time / 60 FROM flights; ``` returns a table with the origin, destination, and duration in hours for each flight. Another commonly used command is `WHERE`. This command filters the rows of the table based on some logical condition you specify. The resulting table contains the rows where your condition is true. For example, if you had a table of students and grades you could do: ```sql SELECT * FROM students WHERE grade = 'A'; ``` to select all the columns and the rows containing information about students who got As. ### SQL in a nutshell (2) Another common database task is aggregation. That is, reducing your data by breaking it into chunks and summarizing each chunk. This is done in SQL using the `GROUP BY` command. This command breaks your data into groups and applies a function from your `SELECT` statement to each group. For example, if you wanted to count the number of flights from each of two origin destinations, you could use the query ```sql SELECT COUNT(*) FROM flights GROUP BY origin; ``` `GROUP BY origin` tells SQL that you want the output to have a row for each unique value of the `origin` column. The `SELECT` statement selects the values you want to populate each of the columns. Here, we want to `COUNT()` every row in each of the groups. It's possible to `GROUP BY` more than one column. When you do this, the resulting table has a row for every combination of the unique values in each column. The following query counts the number of flights from SEA and PDX to every destination airport: ```sql SELECT origin, dest, COUNT(*) FROM flights GROUP BY origin, dest; ``` The output will have a row for every combination of the values in `origin` and `dest` (i.e. a row listing each origin and destination that a flight flew to). There will also be a column with the `COUNT()` of all the rows in each group. ### Filtering Data ```python # Filter flights by passing a string long_flights1 = flights.filter("distance > 1000") # Filter flights by passing a column of boolean values long_flights2 = flights.filter(flights.distance > 1000) # Print the data to check they're equal long_flights1.show() long_flights2.show() ``` output ``` +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ |year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute| +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ |2014| 1| 22| 1040| 5| 1505| 5| AS| N559AS| 851| SEA| HNL| 360| 2677| 10| 40| |2014| 4| 19| 1236| -4| 1508| -7| AS| N309AS| 490| SEA| SAN| 135| 1050| 12| 36| |2014| 11| 19| 1812| -3| 2352| -4| AS| N564AS| 26| SEA| ORD| 198| 1721| 18| 12| ... |2014| 11| 19| 1319| -6| 1821| -14| DL| N309US| 2164| PDX| MSP| 169| 1426| 13| 19| |2014| 5| 21| 515| 0| 757| 0| US| N172US| 593| SEA| PHX| 143| 1107| 5| 15| +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ only showing top 20 rows +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ |year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute| +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ |2014| 1| 22| 1040| 5| 1505| 5| AS| N559AS| 851| SEA| HNL| 360| 2677| 10| 40| |2014| 4| 19| 1236| -4| 1508| -7| AS| N309AS| 490| SEA| SAN| 135| 1050| 12| 36| |2014| 11| 19| 1812| -3| 2352| -4| AS| N564AS| 26| SEA| ORD| 198| 1721| 18| 12| ... |2014| 11| 19| 1319| -6| 1821| -14| DL| N309US| 2164| PDX| MSP| 169| 1426| 13| 19| |2014| 5| 21| 515| 0| 757| 0| US| N172US| 593| SEA| PHX| 143| 1107| 5| 15| +----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+ only showing top 20 rows ``` ### Selecting ```python # Select the first set of columns selected1 = flights.select("tailnum", "origin", "dest") # Select the second set of columns temp = flights.select(flights.origin, flights.dest, flights.carrier) # Define first filter filterA = flights.origin == "SEA" # Define second filter filterB = flights.dest == "PDX" # Filter the data, first by filterA then by filterB selected2 = temp.filter(filterA).filter(filterB) ``` ### Selecting II ```python # Define avg_speed avg_speed = (flights.distance/(flights.air_time/60)).alias("avg_speed") # Select the correct columns speed1 = flights.select("origin", "dest", "tailnum", avg_speed) # Create the same table using a SQL expression speed2 = flights.selectExpr("origin", "dest", "tailnum", "distance/(air_time/60) as avg_speed") ``` ### Aggregating ```python # Find the shortest flight from PDX in terms of distance flights.filter(flights.origin == "PDX").groupBy().min("distance").show() # Find the longest flight from SEA in terms of air time flights.filter(flights.origin == "SEA").groupBy().max("air_time").show() ``` output ``` +-------------+ |min(distance)| +-------------+ | 106| +-------------+ +-------------+ |max(air_time)| +-------------+ | 409| +-------------+ ``` ### Aggregating II ```python # Average duration of Delta flights flights.filter(flights.origin == "SEA").filter(flights.carrier == "DL").groupBy().avg("air_time").show() # Total hours in the air flights.withColumn("duration_hrs", flights.air_time/60).groupBy().sum("duration_hrs").show() ``` output ``` +------------------+ | avg(air_time)| +------------------+ |188.20689655172413| +------------------+ +------------------+ | sum(duration_hrs)| +------------------+ |25289.600000000126| +------------------+ ``` ### Grouping and Aggregating I ```python # Group by tailnum by_plane = flights.groupBy("tailnum") # Number of flights each plane made by_plane.count().show() # Group by origin by_origin = flights.groupBy("origin") # Average duration of flights from PDX and SEA by_origin.avg("air_time").show() ``` output ``` +-------+-----+ |tailnum|count| +-------+-----+ | N442AS| 38| | N102UW| 2| | N36472| 4| ... | N654AW| 2| | N336NW| 1| +-------+-----+ only showing top 20 rows +------+------------------+ |origin| avg(air_time)| +------+------------------+ | SEA| 160.4361496051259| | PDX|137.11543248288737| +------+------------------+ ``` ### Grouping and Aggregating II ```python # Import pyspark.sql.functions as F import pyspark.sql.functions as F # Group by month and dest by_month_dest = flights.groupBy("month", "dest") # Average departure delay by month and destination by_month_dest.avg("dep_delay").show() # Standard deviation of departure delay by_month_dest.agg(F.stddev("dep_delay")).show() ``` output ``` +-----+----+--------------------+ |month|dest| avg(dep_delay)| +-----+----+--------------------+ | 11| TUS| -2.3333333333333335| | 11| ANC| 7.529411764705882| | 1| BUR| -1.45| | 1| PDX| -5.6923076923076925| | 6| SBA| -2.5| | 5| LAX|-0.15789473684210525| | 10| DTW| 2.6| | 6| SIT| -1.0| | 10| DFW| 18.176470588235293| | 3| FAI| -2.2| | 10| SEA| -0.8| | 2| TUS| -0.6666666666666666| | 12| OGG| 25.181818181818183| | 9| DFW| 4.066666666666666| | 5| EWR| 14.25| | 3| RDM| -6.2| | 8| DCA| 2.6| | 7| ATL| 4.675675675675675| | 4| JFK| 0.07142857142857142| | 10| SNA| -1.1333333333333333| +-----+----+--------------------+ only showing top 20 rows +-----+----+----------------------+ |month|dest|stddev_samp(dep_delay)| +-----+----+----------------------+ | 11| TUS| 3.0550504633038935| | 11| ANC| 18.604716401245316| | 1| BUR| 15.22627576540667| ... | 4| JFK| 8.156774303176903| | 10| SNA| 13.726234873756304| +-----+----+----------------------+ only showing top 20 rows ``` ### Joining Another very common data operation is the join. Joins are a whole topic unto themselves, so in this course we'll just look at simple joins. If you'd like to learn more about joins, you can take a look [here](https://www.datacamp.com/courses/joining-data-with-pandas). A join will combine two different tables along a column that they share. This column is called the key. Examples of keys here include the `tailnum` and `carrier` columns from the `flights` table. For example, suppose that you want to know more information about the plane that flew a flight than just the tail number. This information isn't in the `flights` table because the same plane flies many different flights over the course of two years, so including this information in every row would result in a lot of duplication. To avoid this, you'd have a second table that has only one row for each plane and whose columns list all the information about the plane, including its tail number. You could call this table `planes` When you join the `flights` table to this table of airplane information, you're adding all the columns from the `planes` table to the `flights` table. To fill these columns with information, you'll look at the tail number from the `flights` table and find the matching one in the `planes `table, and then use that row to fill out all the new columns. Now you'll have a much bigger table than before, but now every row has all information about the plane that flew that flight! ### Joining II ```python # Examine the data airports.show() # Rename the faa column airports = airports.withColumnRenamed("faa", "dest") # Join the DataFrames flights_with_airports = flights.join(airports, on="dest", how="leftouter") # Examine the new DataFrame flights_with_airports.show() ``` output ``` +---+--------------------+----------------+-----------------+----+---+---+ |faa| name| lat| lon| alt| tz|dst| +---+--------------------+----------------+-----------------+----+---+---+ |04G| Lansdowne Airport| 41.1304722| -80.6195833|1044| -5| A| |06A|Moton Field Munic...| 32.4605722| -85.6800278| 264| -5| A| |06C| Schaumburg Regional| 41.9893408| -88.1012428| 801| -6| A| ... |1CS|Clow Internationa...| 41.6959744| -88.1292306| 670| -6| U| |1G3| Kent State Airport| 41.1513889| -81.4151111|1134| -4| A| |1OH| Fortman Airport| 40.5553253| -84.3866186| 885| -5| U| +---+--------------------+----------------+-----------------+----+---+---+ only showing top 20 rows +----+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+--------+--------+----+------+--------------------+---------+-----------+----+---+---+ |dest|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|air_time|distance|hour|minute| name| lat| lon| alt| tz|dst| +----+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+--------+--------+----+------+--------------------+---------+-----------+----+---+---+ | LAX|2014| 12| 8| 658| -7| 935| -5| VX| N846VA| 1780| SEA| 132| 954| 6| 58| Los Angeles Intl|33.942536|-118.408075| 126| -8| A| | HNL|2014| 1| 22| 1040| 5| 1505| 5| AS| N559AS| 851| SEA| 360| 2677| 10| 40| Honolulu Intl|21.318681|-157.922428| 13|-10| N| | SFO|2014| 3| 9| 1443| -2| 1652| 2| VX| N847VA| 755| SEA| 111| 679| 14| 43| San Francisco Intl|37.618972|-122.374889| 13| -8| A| ... | MDW|2014| 8| 11| 1017| -3| 1613| -7| WN| N8634A| 827| SEA| 216| 1733| 10| 17| Chicago Midway Intl|41.785972| -87.752417| 620| -6| A| | BOS|2014| 1| 13| 2156| -9| 607| -15| AS| N597AS| 24| SEA| 290| 2496| 21| 56|General Edward La...|42.364347| -71.005181| 19| -5| A| | BUR|2014| 6| 5| 1733| -12| 1945| -10| OO| N215AG| 3488| PDX| 111| 817| 17| 33| Bob Hope|34.200667|-118.358667| 778| -8| A| +----+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+--------+--------+----+------+--------------------+---------+-----------+----+---+---+ only showing top 20 rows ``` ## 3. Getting started with machine learning pipelines PySpark has built-in, cutting-edge machine learning routines, along with utilities to create full machine learning pipelines. You'll learn about them in this chapter. ### Machine Learning Pipelines In the next two chapters you'll step through every stage of the machine learning pipeline, from data intake to model evaluation. Let's get to it! At the core of the `pyspark.ml` module are the `Transformer` and `Estimator` classes. Almost every other class in the module behaves similarly to these two basic classes. `Transformer` classes have a `.transform()` method that takes a DataFrame and returns a new DataFrame; usually the original one with a new column appended. For example, you might use the class `Bucketizer` to create discrete bins from a continuous feature or the class `PCA` to reduce the dimensionality of your dataset using principal component analysis. `Estimator` classes all implement a `.fit()` method. These methods also take a DataFrame, but instead of returning another DataFrame they return a model object. This can be something like a `StringIndexerModel` for including categorical data saved as strings in your models, or a `RandomForestModel` that uses the random forest algorithm for classification or regression. ### Join the DataFrames ```python # Rename year column planes = planes.withColumnRenamed("year", "plane_year") # Join the DataFrames model_data = flights.join(planes, on="tailnum", how="leftouter") ``` ### Data types Good work! Before you get started modeling, it's important to know that Spark only handles numeric data. That means all of the columns in your DataFrame must be either integers or decimals (called 'doubles' in Spark). When we imported our data, we let Spark guess what kind of information each column held. Unfortunately, Spark doesn't always guess right and you can see that some of the columns in our DataFrame are strings containing numbers as opposed to actual numeric values. To remedy this, you can use the `.cast()` method in combination with the `.withColumn()` method. It's important to note that `.cast()` works on columns, while `.withColumn()` works on DataFrames. The only argument you need to pass to `.cast()` is the kind of value you want to create, in string form. For example, to create integers, you'll pass the argument `"integer"` and for decimal numbers you'll use `"double"`. You can put this call to `.cast()` inside a call to `.withColumn()` to overwrite the already existing column, just like you did in the previous chapter! ### String to integer ```python # Cast the columns to integers model_data = model_data.withColumn("arr_delay", model_data.arr_delay.cast("integer")) model_data = model_data.withColumn("air_time", model_data.air_time.cast("integer")) model_data = model_data.withColumn("month", model_data.month.cast("integer")) model_data = model_data.withColumn("plane_year", model_data.plane_year.cast("integer")) ``` ### Create a new column ```python # Create the column plane_age model_data = model_data.withColumn("plane_age", model_data.year - model_data.plane_year) ``` ### Making a Boolean ```python # Create is_late model_data = model_data.withColumn("is_late", model_data.arr_delay > 0) # Convert to an integer model_data = model_data.withColumn("label", model_data.is_late.cast("integer")) # Remove missing values model_data = model_data.filter("arr_delay is not NULL and dep_delay is not NULL and air_time is not NULL and plane_year is not NULL") ``` ### Strings and factors As you know, Spark requires numeric data for modeling. So far this hasn't been an issue; even boolean columns can easily be converted to integers without any trouble. But you'll also be using the airline and the plane's destination as features in your model. These are coded as strings and there isn't any obvious way to convert them to a numeric data type. Fortunately, PySpark has functions for handling this built into the `pyspark.ml.features` submodule. You can create what are called 'one-hot vectors' to represent the carrier and the destination of each flight. A one-hot vector is a way of representing a categorical feature where every observation has a vector in which all elements are zero except for at most one element, which has a value of one (1). Each element in the vector corresponds to a level of the feature, so it's possible to tell what the right level is by seeing which element of the vector is equal to one (1). The first step to encoding your categorical feature is to create a `StringIndexer`. Members of this class are `Estimators` that take a DataFrame with a column of strings and map each unique string to a number. Then, the `Estimator` returns a `Transformer` that takes a DataFrame, attaches the mapping to it as metadata, and returns a new DataFrame with a numeric column corresponding to the string column. The second step is to encode this numeric column as a one-hot vector using a `OneHotEncoder`. This works exactly the same way as the `StringIndexer` by creating an`Estimator` and then a `Transformer`. The end result is a column that encodes your categorical feature as a vector that's suitable for machine learning routines! This may seem complicated, but don't worry! All you have to remember is that you need to create a `StringIndexer` and a `OneHotEncoder`, and the `Pipeline` will take care of the rest. ### Carrier ```python # Create a StringIndexer carr_indexer = StringIndexer(inputCol="carrier", outputCol="carrier_index") # Create a OneHotEncoder carr_encoder = OneHotEncoder(inputCol="carrier_index", outputCol="carrier_fact") ``` ### Destination ```python # Create a StringIndexer dest_indexer = StringIndexer(inputCol="dest", outputCol="dest_index") # Create a OneHotEncoder dest_encoder = OneHotEncoder(inputCol="dest_index", outputCol="dest_fact") ``` ### Assemble a vector ```python # Make a VectorAssembler vec_assembler = VectorAssembler(inputCols=["month", "air_time", "carrier_fact", "dest_fact", "plane_age"], outputCol="features") ``` ### Create the pipeline ```python # Import Pipeline from pyspark.ml import Pipeline # Make the pipeline flights_pipe = Pipeline(stages=[dest_indexer, dest_encoder, carr_indexer, carr_encoder, vec_assembler]) ``` ### Test vs. Train After you've cleaned your data and gotten it ready for modeling, one of the most important steps is to split the data into a test set and a train set. After that, don't touch your test data until you think you have a good model! As you're building models and forming hypotheses, you can test them on your training data to get an idea of their performance. Once you've got your favorite model, you can see how well it predicts the new data in your test set. This never-before-seen data will give you a much more realistic idea of your model's performance in the real world when you're trying to predict or classify new data. In Spark it's important to make sure you split the data **after** all the transformations. This is because operations like `StringIndexer` don't always produce the same index even when given the same list of strings. ### Transform the data ```python # Fit and transform the data piped_data = flights_pipe.fit(model_data).transform(model_data) ``` ### Split the data ```python # Split the data into training and test sets training, test = piped_data.randomSplit([0.6, 0.4]) ``` ## 4. Model tuning and selection In this last chapter, you'll apply what you've learned to create a model that predicts which flights will be delayed. ### What is logistic regression? The model you'll be fitting in this chapter is called a **logistic regression**. This model is very similar to a linear regression, but instead of predicting a numeric variable, it predicts the probability (between 0 and 1) of an event. To use this as a classification algorithm, all you have to do is assign a cutoff point to these probabilities. If the predicted probability is above the cutoff point, you classify that observation as a 'yes' (in this case, the flight being late), if it's below, you classify it as a 'no'! You'll tune this model by testing different values for several **hyperparameters**. A **hyperparameter** is just a value in the model that's not estimated from the data, but rather is supplied by the user to maximize performance. For this course it's not necessary to understand the mathematics behind all of these values - what's important is that you'll try out a few different choices and pick the best one. ### Create the modeler ```python # Import LogisticRegression from pyspark.ml.classification import LogisticRegression # Create a LogisticRegression Estimator lr = LogisticRegression() ``` ### Cross validation In the next few exercises you'll be tuning your logistic regression model using a procedure called **k-fold cross validation**. This is a method of estimating the model's performance on unseen data (like your `test` DataFrame). It works by splitting the training data into a few different partitions. The exact number is up to you, but in this course you'll be using PySpark's default value of three. Once the data is split up, one of the partitions is set aside, and the model is fit to the others. Then the error is measured against the held out partition. This is repeated for each of the partitions, so that every block of data is held out and used as a test set exactly once. Then the error on each of the partitions is averaged. This is called the cross validation error of the model, and is a good estimate of the actual error on the held out data. You'll be using cross validation to choose the hyperparameters by creating a grid of the possible pairs of values for the two hyperparameters, `elasticNetParam` and `regParam`, and using the cross validation error to compare all the different models so you can choose the best one! ### Create the evaluator ```python # Import the evaluation submodule import pyspark.ml.evaluation as evals # Create a BinaryClassificationEvaluator evaluator = evals.BinaryClassificationEvaluator(metricName="areaUnderROC") ``` ### Make a grid ```python # Import the tuning submodule import pyspark.ml.tuning as tune # Create the parameter grid grid = tune.ParamGridBuilder() # Add the hyperparameter grid = grid.addGrid(lr.regParam, np.arange(0, .1, .01)) grid = grid.addGrid(lr.elasticNetParam, [0, 1]) # Build the grid grid = grid.build() ``` ### Make the validator ```python # Create the CrossValidator cv = tune.CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator ) ``` ### Fit the model(s) ```python # Call lr.fit() best_lr = lr.fit(training) # Print best_lr print(best_lr) # >>> LogisticRegressionModel: uid=LogisticRegression_74bab76f5f45, numClasses=2, numFeatures=83 ``` ### Evaluating binary classifiers For this course we'll be using a common metric for binary classification algorithms call the **AUC**, or area under the curve. In this case, the curve is the ROC, or receiver operating curve. The details of what these things actually measure isn't important for this course. All you need to know is that for our purposes, the closer the AUC is to one (1), the better the model is! ### Evaluate the model ```python # Use the model to predict the test set test_results = best_lr.transform(test) # Evaluate the predictions print(evaluator.evaluate(test_results)) # >>> 0.7123313100891033 ```