"""Module with classes for record pair classification. Creating and using a classifier normally consists of the following steps: - initialise The classifier is initialised and trained if training data is provided (i.e. if a weight vector dictionary and a match and non-match set is given, the training method will be called). - train Train the classifier using training data. - test Testing the trained classifier using test data (with known match andnon-match status). - classify Use the trained classifier to classify weight vectors with unknown match status. """ # ============================================================================= # Import necessary modules (Python standard modules first, then Febrl modules) import auxiliary import mymath import heapq import logging import math import os import random try: import Numeric imp_numeric = True except: imp_numeric = False if (imp_numeric == True): try: import PyML import PyML.datafunc import PyML.svm imp_pyml = True except: imp_pyml = False else: imp_pyml = False try: import svm imp_svm = True except: imp_svm = False if (imp_pyml == False): logging.warn('Cannot import Numeric and PyML modules') if (imp_svm == False): logging.warn('Cannot import svm module') # ============================================================================= class Classifier: """Base class for classifiers. All classifiers have the following instance variables, which can be set when a classifier is initialised: description A string describing the classifier. train_w_vec_dict A weight vector dictionary that will be used for training. train_match_set A set with record identifier pairs which are assumed to be the match training examples. train_non_match_set A set with record identifier pairs which are assumed to be the non-match training examples. If the last three arguments (train_w_vec_dict, train_match_set, and train_non_match_set) are given when a classifier is initialised, it is trained straight away (so the 'train' method does not have to be called). Default is that these three values are set to None, so no training is done when a classifier is initialised. """ # --------------------------------------------------------------------------- def __init__(self, base_kwargs): """Constructor. """ # General attributes for all classifiers. # self.description = '' # A description of the classifier. self.train_w_vec_dict = None # The dictionary containing weight vectors # used for training. self.train_match_set = None # A set with record identifier pairs that # are matches used for training. self.train_non_match_set = None # A set with record identifier pairs that # are non-matches used for training. # Process base keyword arguments (all data set specific keywords were # processed in the derived class constructor) # for (keyword, value) in base_kwargs.items(): if (keyword.startswith('desc')): auxiliary.check_is_string('description', value) self.description = value elif (keyword.startswith('train_w_vec')): auxiliary.check_is_dictionary('train_w_vec_dict', value) self.train_w_vec_dict = value elif (keyword.startswith('train_mat')): auxiliary.check_is_set('train_match_set', value) self.train_match_set = value elif (keyword.startswith('train_non')): auxiliary.check_is_set('train_non_match_set', value) self.train_non_match_set = value else: logging.exception('Illegal constructor argument keyword: '+keyword) raise Exception # --------------------------------------------------------------------------- def train(self, w_vec_dict, match_set, non_match_set): """Method to train a classifier using the given weight vector dictionary and match and non-match sets of record identifier pairs. See implementations in derived classes for details. """ logging.exception('Override abstract method in derived class') raise Exception # --------------------------------------------------------------------------- def test(self, w_vec_dict, match_set, non_match_set): """Method to test a classifier using the given weight vector dictionary and match and non-match sets of record identifier pairs. Will return a confusion matrix as a list of the form: [TP, FN, FP, TN]. See implementations in derived classes for details. """ logging.exception('Override abstract method in derived class') raise Exception # --------------------------------------------------------------------------- def cross_validate(self, w_vec_dict, match_set, non_match_set): """Method to conduct a cross validation using the given weight vector dictionary and match and non-match sets of record identifier pairs. Will return a confusion matrix as a list of the form: [TP, FN, FP, TN]. See implementations in derived classes for details. """ logging.exception('Override abstract method in derived class') raise Exception # --------------------------------------------------------------------------- def classify(self, w_vec_dict): """Method to classify the given weight vector dictionary using the trained classifier. Will return three sets with record identifier pairs: 1) match set 2) non-match set 3) possible match set (this will always be empty for certain classifiers that only classify into matches and non-matches) See implementations in derived classes for details. """ logging.exception('Override abstract method in derived class') raise Exception # --------------------------------------------------------------------------- def log(self, instance_var_list = None): """Write a log message with the basic classifier instance variables plus the instance variable provided in the given input list (assumed to contain pairs of names (strings) and values). """ logging.info('') logging.info('Classifier: "%s"' % (self.description)) if (self.train_w_vec_dict != None): logging.info(' Number of weight vectors provided: %d' % \ (len(self.train_w_vec_dict))) if (self.train_match_set != None): logging.info(' Number of match training examples provided: %d' % \ (len(self.train_match_set))) if (self.train_non_match_set != None): logging.info(' Number of non-match training examples provided: %d' % \ (len(self.train_non_match_set))) if (instance_var_list != None): logging.info(' Classifier specific variables:') max_name_len = 0 for (name, value) in instance_var_list: max_name_len = max(max_name_len, len(name)) for (name, value) in instance_var_list: pad_spaces = (max_name_len-len(name))*' ' logging.info(' %s %s' % (name+':'+pad_spaces, str(value))) # ============================================================================= # ============================================================================= class SuppVecMachine(Classifier): """Implements a classifier based on a support vector machine (SVM). The 'libsvm' library and its Python interface (module svm.py) will be used. For more details and downloads please see: http://www.csie.ntu.edu.tw/~cjlin/libsvm If this module is not implemented this classifier will be be usable. Note that the cross_validation() method only provides performance measures but no trained SVM model that can be used for classifying weight vectors later on. The train() method needs to be used to get a trained SVM. It is possible to do random sampling of training data from all weight vectors using the 'sample' argument. The arguments that have to be set when this classifier is initialised are: (the kernel type will be mapped to corresponding svm.py argument). kernel_type The kernel type from from libsvm. Default value LINEAR, other possibilities are: POLY, RBF, SIGMOID. C The 'C' parameter from libsvm. Default value is 10. sample A number between 0 and 100 that gives the percentage of weight vectors that will be randomly selected and used for clustering in the training process. If set to 100 (the default) then all given weight vectors will be used. """ # --------------------------------------------------------------------------- def __init__(self, **kwargs): """Constructor. Process the 'svmlib' and 'sample' arguments first, then call the base class constructor. """ # Check if svm module is installed or not # if (imp_svm == False): logging.exception('Module "svm.py" not installed, cannot use ' + \ 'SuppVectorMach classifier') raise Exception self.svm_type = svm.C_SVC self.kernel_type = 'LINEAR' self.C = 10 self.svm_model = None # Will be set in train() method self.sample = 100.0 base_kwargs = {} # Dictionary, will contain unprocessed arguments for base # class constructor for (keyword, value) in kwargs.items(): if (keyword.startswith('kernel')): auxiliary.check_is_string('kernel_type', value) if (value not in ['LINEAR', 'POLY', 'RBF', 'SIGMOID']): logging.exception('Illegal value for kernel type: %s ' % (value) + \ '(possible are: LINEAR, POLY, RBF, SIGMOID)') raise Exception self.kernel_type = value elif (keyword == 'C'): auxiliary.check_is_number('C', value) auxiliary.check_is_not_negative('C', value) self.C = value elif (keyword.startswith('samp')): if (value != None): auxiliary.check_is_percentage('sample', value) self.sample = value else: base_kwargs[keyword] = value Classifier.__init__(self, base_kwargs) # Initialise base class self.log([('SVM kernel type', self.kernel_type), ('C', self.C), ('Sampling rate', self.sample)]) # Log a message # If the weight vector dictionary and both match and non-match sets - - - - # are given start the training process # if ((self.train_w_vec_dict != None) and (self.train_match_set != None) \ and (self.train_non_match_set != None)): self.train(self.train_w_vec_dict, self.train_match_set, (self.train_non_match_set)) # --------------------------------------------------------------------------- def train(self, w_vec_dict, match_set, non_match_set): """Method to train a classifier using the given weight vector dictionary and match and non-match sets of record identifier pairs. Note that all weight vectors must either be in the match or the non-match training sets. """ auxiliary.check_is_dictionary('w_vec_dict', w_vec_dict) auxiliary.check_is_set('match_set', match_set) auxiliary.check_is_set('non_match_set', non_match_set) # Check that match and non-match sets are separate and do cover all weight # vectors given # if (len(match_set.intersection(non_match_set)) > 0): logging.exception('Intersection of match and non-match set not empty') raise Exception if ((len(match_set)+len(non_match_set)) != len(w_vec_dict)): logging.exception('Weight vector dictionary of different length than' + \ ' summed lengths of match and non-match sets: ' + \ '%d / %d+%d=%d' % (len(w_vec_dict), len(match_set), len(non_match_set), len(match_set)+len(non_match_set))) raise Exception self.train_w_vec_dict = w_vec_dict # Save self.train_match_set = match_set self.train_non_match_set = non_match_set logging.info('Train SVM classifier using %d weight vectors' % \ (len(w_vec_dict))) logging.info(' Match and non-match sets with %d and %d entries' % \ (len(match_set), len(non_match_set))) # Sample the weight vectors - - - - - - - - - - - - - - - - - - - - - - - - # if (self.sample == 100.0): use_w_vec_dict = w_vec_dict else: num_w_vec_sample = max(2, int(len(w_vec_dict)*self.sample/100.0)) use_w_vec_dict = {} # Create a new weight vector dictionary with samples rec_id_tuple_sample = random.sample(w_vec_dict.keys(),num_w_vec_sample) assert len(rec_id_tuple_sample) == num_w_vec_sample for rec_id_tuple in rec_id_tuple_sample: use_w_vec_dict[rec_id_tuple] = w_vec_dict[rec_id_tuple] logging.info(' Number of weight vectors to be used for SVM ' + \ 'classification: %d' % (len(use_w_vec_dict))) train_data = [] train_labels = [] for (rec_id_tuple, w_vec) in use_w_vec_dict.iteritems(): train_data.append(w_vec) if (rec_id_tuple in match_set): train_labels.append(1.0) # Match class else: train_labels.append(-1.0) # Non-match class assert len(train_data) == len(train_labels) assert len(train_data) == len(use_w_vec_dict) # Initialise and train the SVM - - - - - - - - - - - - - - - - - - - - - - # if (self.kernel_type == 'LINEAR'): svm_kernel = svm.LINEAR elif (self.kernel_type == 'POLY'): svm_kernel = svm.POLY elif (self.kernel_type == 'RBF'): svm_kernel = svm.RBF elif (self.kernel_type == 'SIGMOID'): svm_kernel = svm.SIGMOID svm_param = svm.svm_parameter(svm_type = svm.C_SVC, C=self.C, kernel_type=svm_kernel) svm_prob = svm.svm_problem(train_labels, train_data) self.svm_model = svm.svm_model(svm_prob, svm_param) logging.info('Trained SVM with %d training examples' % \ (len(use_w_vec_dict))) # --------------------------------------------------------------------------- def test(self, w_vec_dict, match_set, non_match_set): """Method to test a classifier using the given weight vector dictionary and match and non-match sets of record identifier pairs. Weight vectors will be assigned to matches or non-matches according to the SVM classification. No weight vector will be assigned to the possible match set. Will return a confusion matrix as a list of the form: [TP, FN, FP, TN]. """ if (self.svm_model == None): logging.warn('SVM has not been trained, testing not possible') return [0,0,0,0] auxiliary.check_is_dictionary('w_vec_dict', w_vec_dict) auxiliary.check_is_set('match_set', match_set) auxiliary.check_is_set('non_match_set', non_match_set) # Check that match and non-match sets are separate and do cover all weight # vectors given # if (len(match_set.intersection(non_match_set)) > 0): logging.exception('Intersection of match and non-match set not empty') raise Exception if ((len(match_set)+len(non_match_set)) != len(w_vec_dict)): logging.exception('Weight vector dictionary of different length than' + \ ' summed lengths of match and non-match sets: ' + \ '%d / %d+%d=%d' % (len(w_vec_dict), len(match_set), len(non_match_set), len(match_set)+len(non_match_set))) raise Exception num_true_m = 0 num_false_m = 0 num_true_nm = 0 num_false_nm = 0 for (rec_id_tuple, w_vec) in w_vec_dict.iteritems(): if (self.svm_model.predict(w_vec) == 1.0): # Match prediction if (rec_id_tuple in match_set): num_true_m += 1 else: num_false_m += 1 else: # Non-match prediction if (rec_id_tuple in non_match_set): num_true_nm += 1 else: num_false_nm += 1 assert (num_true_m+num_false_nm+num_false_m+num_true_nm) == len(w_vec_dict) logging.info(' Results: TP = %d, FN = %d, FP = %d, TN = %d' % \ (num_true_m,num_false_nm,num_false_m,num_true_nm)) return [num_true_m, num_false_nm, num_false_m, num_true_nm] # -------------------------------------------------------------------------- def cross_validate(self, w_vec_dict, match_set, non_match_set, n=10): """Method to conduct a cross validation using the given weight vector dictionary and match and non-match sets of record identifier pairs. Will return a confusion matrix as a list of the form: [TP, FN, FP, TN]. The cross validation approach is being conducted in svmlib. The complete weight vector dictionary and corresponding labels are given to the svm.py cross validation procedure. All weight vectors are then classified on the generated SVM model and the resulting performance is returned. Note that this method only provides performance measures, but no trained SVM model that can be used for classifying weight vectors later on. The train() method needs to be used to get a trained SVM. """ auxiliary.check_is_integer('n', n) auxiliary.check_is_positive('n', n) auxiliary.check_is_dictionary('w_vec_dict', w_vec_dict) auxiliary.check_is_set('match_set', match_set) auxiliary.check_is_set('non_match_set', non_match_set) # Check that match and non-match sets are separate and do cover all weight # vectors given # if (len(match_set.intersection(non_match_set)) > 0): logging.exception('Intersection of match and non-match set not empty') raise Exception if ((len(match_set)+len(non_match_set)) != len(w_vec_dict)): logging.exception('Weight vector dictionary of different length than' + \ ' summed lengths of match and non-match sets: ' + \ '%d / %d+%d=%d' % (len(w_vec_dict), len(match_set), len(non_match_set), len(match_set)+len(non_match_set))) raise Exception logging.info('') logging.info('Conduct %d-fold cross validation on SVM classifier ' % (n) \ + 'using %d weight vectors' % (len(w_vec_dict))) logging.info(' Match and non-match sets with %d and %d entries' % \ (len(match_set), len(non_match_set))) # Generate the training data and labels for the SVM - - - - - - - - - - - - # train_data = [] train_labels = [] for (rec_id_tuple, w_vec) in w_vec_dict.iteritems(): train_data.append(w_vec) if (rec_id_tuple in match_set): train_labels.append(1.0) # Match class else: train_labels.append(-1.0) # Non-match class # Initialise the SVM - - - - - - - - - - - - - - - - - - - - - - - - - - - # if (self.kernel_type == 'LINEAR'): svm_kernel = svm.LINEAR elif (self.kernel_type == 'POLY'): svm_kernel = svm.POLY elif (self.kernel_type == 'RBF'): svm_kernel = svm.RBF elif (self.kernel_type == 'SIGMOID'): svm_kernel = svm.SIGMOID svm_param = svm.svm_parameter(svm_type = svm.C_SVC, C=self.C, kernel_type=svm_kernel) svm_prob = svm.svm_problem(train_labels, train_data) target_list = svm.cross_validation(svm_prob, svm_param, n) assert len(target_list) == len(w_vec_dict) num_true_m = 0 num_false_m = 0 num_true_nm = 0 num_false_nm = 0 for i in range(len(target_list)): if (target_list[i] == 1.0): # Match prediction if (train_labels[i] == 1.0): # True match num_true_m += 1 else: num_false_m += 1 else: # Non-match prediction if (train_labels[i] == -1.0): # True non-match num_true_nm += 1 else: num_false_nm += 1 assert (num_true_m+num_false_nm+num_false_m+num_true_nm) == len(w_vec_dict) logging.info(' Results: TP = %d, FN = %d, FP = %d, TN = %d' % \ (num_true_m,num_false_nm,num_false_m,num_true_nm)) return [num_true_m, num_false_nm, num_false_m, num_true_nm] # --------------------------------------------------------------------------- def classify(self, w_vec_dict): """Method to classify the given weight vector dictionary using the trained classifier. Will return three sets with record identifier pairs: 1) match set, 2) non-match set, and 3) possible match set. The possible match set will be empty, as this classifier classifies all weight vectors as either matches or non-matches. """ if (self.svm_model == None): logging.warn('SVM has not been trained, classification not possible') return set(), set(), set() auxiliary.check_is_dictionary('w_vec_dict', w_vec_dict) match_set = set() non_match_set = set() poss_match_set = set() for (rec_id_tuple, w_vec) in w_vec_dict.iteritems(): if (self.svm_model.predict(w_vec) == 1.0): # Match prediction match_set.add(rec_id_tuple) else: # Non-match prediction non_match_set.add(rec_id_tuple) assert (len(match_set) + len(non_match_set) + len(poss_match_set)) == \ len(w_vec_dict) logging.info('Classified %d weight vectors: %d as matches, %d as ' % \ (len(w_vec_dict), len(match_set), len(non_match_set)) + \ 'non-matches, and %d as possible matches' % \ (len(poss_match_set))) return match_set, non_match_set, poss_match_set