-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathspark_load_pipeline_databricks.py
24 lines (22 loc) · 1.13 KB
/
spark_load_pipeline_databricks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from pyspark.ml import Pipeline
from pyspark.ml import PipelineModel
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
model = PipelineModel.load('/FileStore/lrmodel')
newDF = [
StructField("id", IntegerType(), True),
StructField("text", StringType(), True),
StructField("label", DoubleType(), True)]
finalSchema = StructType(fields=newDF)
dataset = sqlContext.read.format('csv').options(header='true',schema=finalSchema,delimiter='|').load('/FileStore/tables/dataset.csv')
dataset = dataset.withColumn("label", dataset["label"].cast(DoubleType()))
dataset = dataset.withColumn("id", dataset["id"].cast(IntegerType()))
result = model.transform(dataset)\
.select("features", "label", "prediction")
correct = result.where(result["label"] == result["prediction"])
accuracy = correct.count()/dataset.count()
print("Accuracy of model = "+str(accuracy))