Created
April 21, 2020 16:43
-
-
Save cartershanklin/68caa03132bac90fb33199b282b0bef4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import oci | |
import os | |
import pathlib | |
import pandas as pd | |
from pyspark import SparkConf | |
from pyspark.sql import SparkSession | |
from sklearn import svm, preprocessing | |
from sklearn.model_selection import GridSearchCV as GridSearchCVNative | |
from spark_sklearn import GridSearchCV as GridSearchCVSpark | |
def main(): | |
use_spark = True | |
oci_path = "oci://sample-data@paasdevssstest/agaricus-lepiota.csv" | |
local_path = os.path.join( | |
pathlib.Path(__file__).parent.absolute(), "agaricus-lepiota.csv" | |
) | |
# Set up Spark. | |
conf = SparkConf() | |
# Check to see if we're in Data Flow or not. | |
if os.environ.get("HOME") == "/home/dataflow": | |
mode = "cluster" | |
path = oci_path | |
print("Running in cluster mode") | |
else: | |
mode = "local" | |
path = oci_path | |
oci_config = oci.config.from_file() | |
conf.set("fs.oci.client.auth.tenantId", oci_config["tenancy"]) | |
conf.set("fs.oci.client.auth.userId", oci_config["user"]) | |
conf.set("fs.oci.client.auth.fingerprint", oci_config["fingerprint"]) | |
conf.set("fs.oci.client.auth.pemfilepath", oci_config["key_file"]) | |
conf.set( | |
"fs.oci.client.hostname", | |
"https://objectstorage.{0}.oraclecloud.com".format(oci_config["region"]), | |
) | |
spark_session = ( | |
SparkSession.builder.appName("svc_mushroom").config(conf=conf).getOrCreate() | |
) | |
print("Running in {} mode".format(mode)) | |
spark_context = spark_session.sparkContext | |
# Handle arguments. | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-p", "--path", help="File Path", default=path) | |
args = parser.parse_args() | |
assert args.path is not None, "Need -p / --path" | |
X, y = load_mushroom_dataframes(args.path, spark_session) | |
svr = svm.SVC(gamma="auto") | |
parameters = {"kernel": ("linear", "rbf"), "C": range(1, 3), "shrinking": [False]} | |
parameters = {"kernel": ("linear", "rbf"), "C": range(1, 10), "shrinking": [False, True]} | |
if use_spark: | |
clf = GridSearchCVSpark(spark_context, svr, parameters) | |
else: | |
clf = GridSearchCVNative(svr, parameters) | |
clf.fit(X, y) | |
new_dataframe = pd.DataFrame(clf.cv_results_) | |
print(pd.DataFrame(new_dataframe).to_csv()) | |
def load_mushroom_csv(path, spark_context): | |
print("Reading data from " + path) | |
if path.startswith("/"): | |
with open(path, "rt") as fd: | |
return pd.read_csv(fd) | |
else: | |
spark_df = spark_context.read.csv(path, header=True) | |
return spark_df.toPandas() | |
def load_mushroom_dataframes(path, spark_session): | |
unencoded_data = load_mushroom_csv(path, spark_session) | |
everything = set( | |
[ | |
"class", | |
"cap-shape", | |
"cap-surface", | |
"cap-color", | |
"bruises", | |
"odor", | |
"gill-attachment", | |
"gill-spacing", | |
"gill-size", | |
"gill-color", | |
"stalk-shape", | |
"stalk-root", | |
"stalk-surface-above-ring", | |
"stalk-surface-below-ring", | |
"stalk-color-above-ring", | |
"stalk-color-below-ring", | |
"veil-type", | |
"veil-color", | |
"ring-number", | |
"ring-type", | |
"spore-print-color", | |
"population", | |
"habitat", | |
] | |
) | |
# Convert categories to numbers within the dataframe. | |
data = unencoded_data.copy() | |
le = preprocessing.LabelEncoder() | |
for i in range(data.shape[1]): | |
data.iloc[:, i] = data.iloc[:, i].fillna("") | |
data.iloc[:, i] = le.fit_transform(data.iloc[:, i]) | |
# Set up the data to predict. | |
predict_attribute = "class" | |
independent_vars = everything | |
independent_vars.remove(predict_attribute) | |
X = data[list(independent_vars)] | |
y = data[predict_attribute] | |
return X, y | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment