Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save kittisak-phetrungnapha/c8e3bf3e508146a43898e80010a28247 to your computer and use it in GitHub Desktop.
Save kittisak-phetrungnapha/c8e3bf3e508146a43898e80010a28247 to your computer and use it in GitHub Desktop.
import pandas as pd
import numpy as np
import json
import spacy
import nltk
import string
import re
# Prepare actual similarity data
data_xlsx = pd.read_excel('./similarity_manually_label.xlsx', 'Sheet1', index_col=0)
actual_matrix = np.array(data_xlsx.values)
# Import base and test data
with open('./text_similarity_base.json') as data_file:    
    text_similarity_base = json.load(data_file)
    
with open('./text_similarity_test.json') as data_file:    
    text_similarity_test = json.load(data_file)
# Create base and test data frames
base_df = pd.DataFrame.from_dict(text_similarity_base, orient='columns')
test_df = pd.DataFrame.from_dict(text_similarity_test, orient='columns')
nltk.download('punkt')
[nltk_data] Downloading package punkt to /Users/kittisakp/nltk_data...
[nltk_data]   Package punkt is already up-to-date!





True
# Text pre-processing functions
stemmer = nltk.stem.porter.PorterStemmer()
remove_punctuation_map = dict((ord(char), None) for char in string.punctuation)
stopwords = nltk.corpus.stopwords.words('english')

def tokenize(text):
    return nltk.word_tokenize(text)

def stem_tokens(tokens):
    return [stemmer.stem(item) for item in tokens]

def remove_stopwords(tokens):
    return [item for item in tokens if item not in stopwords]

def keep_alphabetic(tokens):
    return [item for item in tokens if item.isalpha()]

def reduce_lengthening(tokens):
    pattern = re.compile(r"(.)\1{2,}")
    return [pattern.sub(r"\1\1", item) for item in tokens]

'''lowercase, punctuation, remove stopwords, only alphabetic, reduce lengthening, stem'''
def normalize(text):
    lower_text_without_punctuation = text.lower().translate(remove_punctuation_map)
    return ' '.join(
                stem_tokens(
                reduce_lengthening(
                keep_alphabetic(
                remove_stopwords(
                tokenize(
                lower_text_without_punctuation))))))
# Text cleansing
base_df['normalized_text'] = base_df['text'].apply(lambda text: normalize(text))
test_df['normalized_text'] = test_df['text'].apply(lambda text: normalize(text))
nlp = spacy.load('en_vectors_web_lg')
# Define constants
thresholds = [
    0,
    0.125,
    0.25,
    0.375,
    0.5,
    0.625,
    0.75,
    0.8,
    0.825,
    0.85,
    0.875,
    0.9,
    0.91,
    0.92,
    0.93,
    0.94,
    0.95,
    0.96,
    0.97,
    0.98,
    0.99,
    1
]

base_sentences = base_df['normalized_text'].values
test_sentences = test_df['normalized_text'].values
base_count = len(base_sentences)
test_count = len(test_sentences)
def calculate_similarity(threshold):    
    predict_matrix = np.array([[None for j in range(test_count)] for i in range(base_count)])
    tp_count = 0
    tn_count = 0
    fp_count = 0
    fn_count = 0
    
    # Prepare predict data
    for base_index, base_value in enumerate(base_sentences):
        base_doc = nlp(base_value)

        for test_index, test_value in enumerate(test_sentences):
            test_doc = nlp(test_value)
            similarity = test_doc.similarity(base_doc)

            if similarity >= threshold:
                predict_matrix[base_index][test_index] = 1 # 1 means duplicate
            else:
                predict_matrix[base_index][test_index] = 0 # 0 means non-duplicate
    
    # Calculate result
    for i in range(base_count):
        for j in range(test_count):
            actual = actual_matrix[i][j]
            predict = predict_matrix[i][j]

            if actual == 0 and predict == 0: # true negative
                tn_count += 1
            elif actual == 1 and predict == 1: # true position
                tp_count += 1
            elif actual == 1 and predict == 0: # false negative 
                fn_count += 1
            elif actual == 0 and predict == 1: # false positive
                fp_count += 1

    accuracy = (tn_count + tp_count) / (tn_count + tp_count + fn_count + fp_count)
                
    print("threshold:", threshold)
    print("true negative:", tn_count)
    print("true position:", tp_count)
    print("false negative:", fn_count)
    print("false positive:", fp_count)
    print("accuracy:", accuracy)
    print("\n======================================\n")
print("Base count: %d, Test count: %d, Total = %d\n" % (base_count, test_count, base_count * test_count))

for threshold in thresholds:
    calculate_similarity(threshold)
Base count: 94, Test count: 20, Total = 1880

threshold: 0
true negative: 0
true position: 25
false negative: 0
false positive: 1855
accuracy: 0.013297872340425532

======================================

threshold: 0.125
true negative: 20
true position: 25
false negative: 0
false positive: 1835
accuracy: 0.023936170212765957

======================================

threshold: 0.25
true negative: 22
true position: 25
false negative: 0
false positive: 1833
accuracy: 0.025

======================================

threshold: 0.375
true negative: 44
true position: 24
false negative: 1
false positive: 1811
accuracy: 0.036170212765957444

======================================

threshold: 0.5
true negative: 138
true position: 23
false negative: 2
false positive: 1717
accuracy: 0.08563829787234042

======================================

threshold: 0.625
true negative: 452
true position: 20
false negative: 5
false positive: 1403
accuracy: 0.251063829787234

======================================

threshold: 0.75
true negative: 1198
true position: 14
false negative: 11
false positive: 657
accuracy: 0.6446808510638298

======================================

threshold: 0.8
true negative: 1514
true position: 10
false negative: 15
false positive: 341
accuracy: 0.8106382978723404

======================================

threshold: 0.825
true negative: 1633
true position: 6
false negative: 19
false positive: 222
accuracy: 0.8718085106382979

======================================

threshold: 0.85
true negative: 1735
true position: 3
false negative: 22
false positive: 120
accuracy: 0.924468085106383

======================================

threshold: 0.875
true negative: 1806
true position: 0
false negative: 25
false positive: 49
accuracy: 0.9606382978723405

======================================

threshold: 0.9
true negative: 1843
true position: 0
false negative: 25
false positive: 12
accuracy: 0.9803191489361702

======================================

threshold: 0.91
true negative: 1850
true position: 0
false negative: 25
false positive: 5
accuracy: 0.9840425531914894

======================================

threshold: 0.92
true negative: 1851
true position: 0
false negative: 25
false positive: 4
accuracy: 0.9845744680851064

======================================

threshold: 0.93
true negative: 1855
true position: 0
false negative: 25
false positive: 0
accuracy: 0.9867021276595744

======================================

threshold: 0.94
true negative: 1855
true position: 0
false negative: 25
false positive: 0
accuracy: 0.9867021276595744

======================================

threshold: 0.95
true negative: 1855
true position: 0
false negative: 25
false positive: 0
accuracy: 0.9867021276595744

======================================

threshold: 0.96
true negative: 1855
true position: 0
false negative: 25
false positive: 0
accuracy: 0.9867021276595744

======================================

threshold: 0.97
true negative: 1855
true position: 0
false negative: 25
false positive: 0
accuracy: 0.9867021276595744

======================================

threshold: 0.98
true negative: 1855
true position: 0
false negative: 25
false positive: 0
accuracy: 0.9867021276595744

======================================

threshold: 0.99
true negative: 1855
true position: 0
false negative: 25
false positive: 0
accuracy: 0.9867021276595744

======================================

threshold: 1
true negative: 1855
true position: 0
false negative: 25
false positive: 0
accuracy: 0.9867021276595744

======================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment