diff --git a/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/AbstractArithmeticRule.java b/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/AbstractArithmeticRule.java index 87039492d..e185755af 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/AbstractArithmeticRule.java +++ b/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/AbstractArithmeticRule.java @@ -17,6 +17,7 @@ */ package org.linqs.psl.model.rule.arithmetic; +import org.linqs.psl.config.Options; import org.linqs.psl.database.DatabaseQuery; import org.linqs.psl.database.ResultList; import org.linqs.psl.database.atom.AtomManager; @@ -289,8 +290,10 @@ private void groundForNonSummation(Constant[] constants, Map groundSingleNonSummationRule(constants, variableMap, atomManager, resources); results.addAll(resources.groundRules); - for (GroundRule groundRule : resources.groundRules) { - VizDataCollection.addGroundRule(this, groundRule, variableMap, constants); + if (Options.CLI_VIZ.getBoolean()) { + for (GroundRule groundRule : resources.groundRules) { + VizDataCollection.addGroundRule(this, groundRule, variableMap, constants); + } } resources.groundRules.clear(); resources.accessExceptionAtoms.clear(); @@ -313,8 +316,10 @@ private void groundForSummation(Constant[] constants, Map var groundSingleSummationRule(constants, variableMap, atomManager, resources); results.addAll(resources.groundRules); - for (GroundRule groundRule : resources.groundRules) { - VizDataCollection.addGroundRule(this, groundRule, variableMap, constants); + if (Options.CLI_VIZ.getBoolean()) { + for (GroundRule groundRule : resources.groundRules) { + VizDataCollection.addGroundRule(this, groundRule, variableMap, constants); + } } resources.groundRules.clear(); resources.accessExceptionAtoms.clear(); @@ -346,7 +351,9 @@ private int groundAllNonSummationRule(AtomManager atomManager, GroundRuleStore g for (int groundingIndex = 0; groundingIndex < results.size(); groundingIndex++) { groundSingleNonSummationRule(results.get(groundingIndex), variableMap, atomManager, resources); GroundRule groundRule = resources.groundRules.get(groundingIndex); - VizDataCollection.addGroundRule(this, groundRule, variableMap, results.get(groundingIndex)); + if (Options.CLI_VIZ.getBoolean()) { + VizDataCollection.addGroundRule(this, groundRule, variableMap, results.get(groundingIndex)); + } } int count = resources.groundRules.size(); @@ -422,7 +429,9 @@ private int groundAllSummationRule(AtomManager atomManager, GroundRuleStore grou for (int groundingIndex = 0; groundingIndex < results.size(); groundingIndex++) { groundSingleSummationRule(results.get(groundingIndex), variableMap, atomManager, resources); GroundRule groundRule = resources.groundRules.get(groundingIndex); - VizDataCollection.addGroundRule(this, groundRule, variableMap, results.get(groundingIndex)); + if (Options.CLI_VIZ.getBoolean()) { + VizDataCollection.addGroundRule(this, groundRule, variableMap, results.get(groundingIndex)); + } } int count = resources.groundRules.size(); diff --git a/psl-core/src/main/java/org/linqs/psl/util/VizDataCollection.java b/psl-core/src/main/java/org/linqs/psl/util/VizDataCollection.java index 6b661fa7a..9b689a3c9 100644 --- a/psl-core/src/main/java/org/linqs/psl/util/VizDataCollection.java +++ b/psl-core/src/main/java/org/linqs/psl/util/VizDataCollection.java @@ -16,10 +16,12 @@ import org.slf4j.LoggerFactory; import java.io.IOException; +import java.io.FilterOutputStream; import java.io.PrintStream; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.zip.GZIPOutputStream; @@ -48,9 +50,6 @@ private static synchronized void init() { } public static void outputJSON() { - String[] keyNames = {"truthMap", "rules", "groundRules", "groundAtoms"}; - JSONObject fullJson = new JSONObject(vizData, keyNames); - PrintStream stream = System.out; if (outputPath != null) { @@ -58,18 +57,76 @@ public static void outputJSON() { stream = new PrintStream(outputPath); if (outputPath.endsWith(".gz")) { GZIPOutputStream gzipStream = new GZIPOutputStream(stream, true); - byte[] jsonByteArray = fullJson.toString().getBytes(); - gzipStream.write(jsonByteArray, 0, jsonByteArray.length); + writeToStream(gzipStream); gzipStream.close(); } else { - stream.println(fullJson.toString()); + writeToStream(stream); } stream.close(); } catch (IOException ex) { throw new RuntimeException(); } } else { - stream.println(fullJson.toString()); + writeToStream(stream); + } + } + // Write to stream with JSON formatting. + private static void writeToStream(FilterOutputStream stream) { + // JSON format reference: https://www.json.org/json-en.html. + try { + stream.write("{".getBytes()); + // Write each map as a JSON object, each JSON object is comma delimited. + writeMap(vizData.truthMap, stream, "truthMap"); + stream.write(",".getBytes()); + writeMap(stream, vizData.rules, "rules"); + stream.write(",".getBytes()); + writeMap(stream, vizData.groundRules, "groundRules"); + stream.write(",".getBytes()); + writeMap(stream, vizData.groundAtoms, "groundAtoms"); + + stream.write("}".getBytes()); + } catch (IOException ex) { + throw new RuntimeException(); + } + } + // Write map to stream with JSON formatting. + private static void writeMap(FilterOutputStream stream, Map> map, String key) { + try { + // Each key must be string formatted. + stream.write((" \"" + key + "\" :{").getBytes()); + + Iterator>> iterator = map.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry> entry = iterator.next(); + JSONObject jsonObject = new JSONObject(entry.getValue()); + stream.write((" \"" + entry.getKey() + "\" :" + jsonObject.toString()).getBytes()); + if (iterator.hasNext()) { + stream.write(",".getBytes()); + } + } + stream.write("}".getBytes()); + } catch (IOException ex) { + throw new RuntimeException(); + } + } + // Write map to stream with JSON formatting. + private static void writeMap(Map map, FilterOutputStream stream, String key) { + try { + // Each key must be string formatted. + stream.write((" \"" + key + "\" :{").getBytes()); + + Iterator> iterator = map.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + stream.write((" \"" + entry.getKey() + "\" :" + entry.getValue()).getBytes()); + if (iterator.hasNext()) { + stream.write(",".getBytes()); + } + } + + stream.write("}".getBytes()); + } catch (IOException ex) { + throw new RuntimeException(); } } @@ -116,7 +173,7 @@ public static void dissatisfactionPerGroundRule(GroundRuleStore groundRuleStore) } } } - // TODO: Arithmetic Ground Rules Collection + public static synchronized void addGroundRule(AbstractRule parentRule, GroundRule groundRule, Map variableMap, Constant[] constantsList) { if (groundRule == null) { @@ -136,7 +193,7 @@ public static synchronized void addGroundRule(AbstractRule parentRule, atomCount++; } - // Adds a rule element to RuleMap + // Adds a rule element to RuleMap. String ruleStringID = Integer.toString(System.identityHashCode(parentRule)); Map rulesElementItem = new HashMap(); rulesElementItem.put("text", parentRule.getName());