Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handling of join queries without WHERE/ON clauses #1145

Merged
merged 10 commits into from
Dec 23, 2024
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.evomaster.client.java.distance.heuristics;

import java.util.Arrays;

public class TruthnessUtils {

/**
Expand All @@ -13,7 +15,7 @@ public static double normalizeValue(double v) {
throw new IllegalArgumentException("Negative value: " + v);
}

if(Double.isInfinite(v) || v == Double.MAX_VALUE){
if (Double.isInfinite(v) || v == Double.MAX_VALUE) {
return 1d;
}

Expand All @@ -26,7 +28,6 @@ public static double normalizeValue(double v) {
}



public static Truthness getEqualityTruthness(int a, int b) {
double distance = DistanceHelper.getDistanceToEquality(a, b);
double normalizedDistance = normalizeValue(distance);
Expand Down Expand Up @@ -79,4 +80,52 @@ public static Truthness getTruthnessToEmpty(int len) {
}
return t;
}

public static Truthness andAggregation(Truthness... truthnesses) {
double averageOfTrue = averageOfTrue(truthnesses);
double falseOrAverageFalse = falseOrAverageFalse(truthnesses);
return new Truthness(averageOfTrue, falseOrAverageFalse);
}

private static double averageOfTrue(Truthness... truthnesses) {
checkValidTruthnesses(truthnesses);
double[] getOfTrueValues = Arrays.stream(truthnesses).mapToDouble(Truthness::getOfTrue)
.toArray();
return average(getOfTrueValues);
}

private static void checkValidTruthnesses(Truthness[] truthnesses) {
if (truthnesses == null || truthnesses.length == 0 || Arrays.stream(truthnesses).anyMatch(e -> e == null)) {
throw new IllegalArgumentException("null or empty Truthness instance");
}
}

private static double average(double... values) {
if (values == null || values.length == 0) {
throw new IllegalArgumentException("null or empty values");
}
double total = 0.0;
for (double v : values) {
total += v;
}
return total / values.length;
}

private static double averageOfFalse(Truthness... truthnesses) {
checkValidTruthnesses(truthnesses);
double[] getOfFalseValues = Arrays.stream(truthnesses).mapToDouble(Truthness::getOfFalse)
.toArray();
return average(getOfFalseValues);
}

private static double falseOrAverageFalse(Truthness... truthnesses) {
checkValidTruthnesses(truthnesses);
if (Arrays.stream(truthnesses).anyMatch(t -> t.isFalse())) {
return 1.0d;
} else {
return averageOfFalse(truthnesses);
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public QueryResult(List<VariableDescriptor> variableDescriptorList) {
variableDescriptors.addAll(variableDescriptorList);
}



/**
* WARNING: Constructor only needed for testing
*
Expand Down Expand Up @@ -173,4 +175,17 @@ public QueryResultDto toDto(){

return dto;
}

/**
* Retrieves the table name of this queryResult.
*
* @return the table name of the first {@code VariableDescriptor} in the {@code variableDescriptors} list.
* @throws IllegalStateException if the {@code variableDescriptors} list is empty.
*/
public String getTableName() {
if (variableDescriptors.isEmpty()) {
throw new IllegalStateException("No variable descriptors found");
}
return variableDescriptors.get(0).getTableName();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package org.evomaster.client.java.sql;

import java.util.TreeMap;

public class QueryResultSet {
jgaleotti marked this conversation as resolved.
Show resolved Hide resolved

private final TreeMap<String, QueryResult> queryResults;
jgaleotti marked this conversation as resolved.
Show resolved Hide resolved
private QueryResult queryResultForVirtualTable;

public QueryResultSet() {
this(true);
}

public QueryResultSet(boolean isCaseSensitive) {
queryResults = new TreeMap<>(isCaseSensitive ? null : String.CASE_INSENSITIVE_ORDER);
}

public boolean isCaseInsensitive() {
return queryResults.comparator().equals(String.CASE_INSENSITIVE_ORDER);
}

public void addQueryResult(QueryResult queryResult) {
String tableName = queryResult.seeVariableDescriptors()
.stream()
.findFirst()
.map(VariableDescriptor::getTableName)
.orElse(null);

if (tableName == null) {
handleVirtualTable(queryResult);
} else {
handleNamedTable(tableName, queryResult);
}
}

private void handleNamedTable(String tableName, QueryResult queryResult) {
if (queryResults.containsKey(tableName)) {
throw new IllegalArgumentException("Duplicate table in QueryResultSet: " + tableName);
}
queryResults.put(tableName, queryResult);
}

private void handleVirtualTable(QueryResult queryResult) {
if (queryResultForVirtualTable != null) {
throw new IllegalArgumentException("Duplicate values for virtual table");
}
queryResultForVirtualTable = queryResult;
}

public QueryResult getQueryResultForNamedTable(String tableName) {
return queryResults.get(tableName);
}

public QueryResult getQueryResultForVirtualTable() {
return queryResultForVirtualTable;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,96 @@
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.Join;
import org.evomaster.client.java.controller.api.dto.database.schema.DbInfoDto;
import org.evomaster.client.java.distance.heuristics.Truthness;
import org.evomaster.client.java.distance.heuristics.TruthnessUtils;
import org.evomaster.client.java.sql.QueryResult;
import org.evomaster.client.java.sql.QueryResultSet;

import static org.evomaster.client.java.sql.internal.SqlParserUtils.getFrom;
import static org.evomaster.client.java.sql.internal.SqlParserUtils.getWhere;
import java.util.Collection;
import java.util.List;

import static org.evomaster.client.java.sql.internal.SqlParserUtils.*;

public class SqlHeuristicsCalculator {

public static double C = 0.1;
public static double C_BETTER = C + C / 2;
jgaleotti marked this conversation as resolved.
Show resolved Hide resolved
public static Truthness TRUE_TRUTHNESS = new Truthness(1, C);

private QueryResultSet queryResultSet;

private SqlHeuristicsCalculator(QueryResult[] data) {
final boolean isCaseSensitive = false;
jgaleotti marked this conversation as resolved.
Show resolved Hide resolved
this.queryResultSet = new QueryResultSet(isCaseSensitive);
for (QueryResult queryResult : data) {
queryResultSet.addQueryResult(queryResult);
}
}

public static SqlDistanceWithMetrics computeDistance(String sqlCommand,
DbInfoDto schema,
TaintHandler taintHandler,
QueryResult... data) {

Statement parsedSqlCommand = SqlParserUtils.parseSqlCommand(sqlCommand);
Expression whereClause= getWhere(parsedSqlCommand);
FromItem fromItem = getFrom(parsedSqlCommand);

if (fromItem == null && whereClause == null) {
return new SqlDistanceWithMetrics(0.0,0,false);
}
SqlHeuristicsCalculator calculator = new SqlHeuristicsCalculator(data);
Truthness t = calculator.computeCommand(parsedSqlCommand);
double distanceToTrue = 1 - t.getOfTrue();
return new SqlDistanceWithMetrics(distanceToTrue, 0, false);
}

private Truthness computeCommand(Statement parsedSqlCommand) {
final Expression whereClause = getWhere(parsedSqlCommand);
final FromItem fromItem = getFrom(parsedSqlCommand);
final List<Join> joins = getJoins(parsedSqlCommand);
if (fromItem == null && joins == null) {
/**
* Result will depend on the contents of the virtual table
*/
return getTruthnessForTable(null);
} else if (fromItem != null && joins == null && whereClause == null) {
return getTruthnessForTable(fromItem);
} else if (fromItem != null && joins != null && whereClause == null) {
final Join join = joins.get(0);
final FromItem leftFromItem = fromItem;
final FromItem rightFromItem = join.getRightItem();
final Collection<Expression> onExpressions = join.getOnExpressions();
if (join.isLeft()) {
return getTruthnessForTable(leftFromItem);
} else if (join.isRight()) {
return getTruthnessForTable(rightFromItem);
} else if (join.isCross()) {
Truthness truthnessLeftTable = getTruthnessForTable(leftFromItem);
Truthness truthnessRightTable = getTruthnessForTable(rightFromItem);
return TruthnessUtils.andAggregation(truthnessLeftTable, truthnessRightTable);
} else {
// inner join?
}


}

return null;
}

private Truthness getTruthnessForTable(FromItem fromItem) {
final QueryResult tableData;
if (fromItem == null) {
tableData = queryResultSet.getQueryResultForVirtualTable();
} else {
if (!SqlParserUtils.isTable(fromItem)) {
throw new IllegalArgumentException("Cannot compute Truthness for form item that it is not a table " + fromItem);
}
String tableName = SqlParserUtils.getTableName(fromItem);
tableData = queryResultSet.getQueryResultForNamedTable(tableName);
}
final int len = tableData.size();
final Truthness t = TruthnessUtils.getTruthnessToEmpty(len).invert();
return t;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;

import java.util.List;

public class SqlParserUtils {

/**
Expand Down Expand Up @@ -95,6 +99,16 @@ public static FromItem getFrom(Statement parsedStatement) {
}
}

public static List<Join> getJoins(Statement parsedStatement) {
if (parsedStatement instanceof Select) {
Select select = (Select) parsedStatement;
PlainSelect plainSelect = select.getPlainSelect();
return plainSelect.getJoins();
} else {
throw new IllegalArgumentException("Cannot get Joins From: " + parsedStatement.toString());
}
}

/**
* This method assumes that the SQL command can be successfully parsed.
*
Expand All @@ -118,4 +132,36 @@ public static boolean canParseSqlStatement(String sqlCommand){
return false;
}
}

/**
* Checks if the given FromItem is a Table.
*
* @param fromItem the FromItem to check
* @return true if the FromItem is a Table, false otherwise
*/
public static boolean isTable(FromItem fromItem) {
return fromItem instanceof Table;
}

/**
* Retrieves the fully qualified name of a table from the provided {@link FromItem}.
* <p>
* This method checks if the given {@code fromItem} is an instance of {@link Table}.
* If it is, the method extracts and returns the fully qualified name of the table.
* Otherwise, it throws an {@link IllegalArgumentException}.
* </p>
*
* @param fromItem the {@link FromItem} instance to extract the table name from.
* @return the fully qualified name of the table as a {@link String}.
* @throws IllegalArgumentException if the provided {@code fromItem} is not an instance of {@link Table}.
* @see net.sf.jsqlparser.schema.Table#getFullyQualifiedName()
*/
public static String getTableName(FromItem fromItem) {
if (fromItem instanceof Table) {
Table table = (Table) fromItem;
return table.getFullyQualifiedName();
} else {
throw new IllegalArgumentException("From item " + fromItem + " is not a table");
}
}
}
Loading
Loading