Skip to content

Commit

Permalink
Improve AHashSet for better shift handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mikera committed Oct 8, 2024
1 parent 2303452 commit 48d2d05
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 41 deletions.
8 changes: 3 additions & 5 deletions convex-core/src/main/java/convex/core/data/AHashSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ public ASet<T> disjAll(ACollection<T> b) {
AHashSet<T> result=this;
long n=b.count();
for (long i=0; i<n; i++) {
result=result.excludeRef(b.getElementRef(i));
result=result.excludeHash(b.getElementRef(i).getHash());
}
return result;
}

public abstract AHashSet<T> excludeRef(Ref<?> valueRef);
public abstract AHashSet<T> excludeHash(Hash hash);

public abstract AHashSet<T> includeRef(Ref<T> ref) ;

Expand All @@ -98,7 +98,7 @@ public AHashSet<T> conj(ACell a) {

@Override
public ASet<T> exclude(ACell a) {
return excludeRef(Ref.get(a));
return excludeHash(Cells.getHash(a));
}

@Override
Expand Down Expand Up @@ -142,8 +142,6 @@ public T getByHash(Hash hash) {
if (ref==null) return null;
return ref.getValue();
}

protected abstract AHashSet<T> includeRef(Ref<T> e, int i);

/**
* Tests if this Set contains a given hash
Expand Down
6 changes: 6 additions & 0 deletions convex-core/src/main/java/convex/core/data/ASet.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ public final byte getTag() {
*/
public abstract ASet<T> exclude(ACell a) ;

/**
* Gets the Hash of teh first element in this set
* @return
*/
protected abstract Hash getFirstHash();

/**
* Updates the set to include all the given elements.
* Can be used to implement union of sets
Expand Down
2 changes: 1 addition & 1 deletion convex-core/src/main/java/convex/core/data/Format.java
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ private static <T extends ACell> T read(byte tag, Blob blob, int offset) throws
} catch (BadFormatException e) {
throw e;
} catch (IndexOutOfBoundsException e) {
throw new BadFormatException("Read out of bounds when decoding with tag 0x"+Utils.toHexString(tag));
throw new BadFormatException("Read out of bounds when decoding with tag 0x"+Utils.toHexString(tag),e);
} catch (MissingDataException e) {
throw e;
} catch (Exception e) {
Expand Down
1 change: 1 addition & 0 deletions convex-core/src/main/java/convex/core/data/MapLeaf.java
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ public MapLeaf<K, V> slice(long start, long end) {

@Override
protected Hash getFirstHash() {
if (count==0) return null;
return entries[0].getKeyHash();
}

Expand Down
2 changes: 1 addition & 1 deletion convex-core/src/main/java/convex/core/data/MapTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ static <K extends ACell, V extends ACell> MapTree<K, V> create(int shift, AHashM
int digit=child.getFirstHash().getHexDigit(shift);
mask|=(1<<digit);
}
if (Integer.bitCount(shift)!=n) {
if (Integer.bitCount(mask&0xFFFF)!=n) {
throw new IllegalArgumentException("Children do not differ at specified digit");
}
return new MapTree<>(rs,shift,mask,count);
Expand Down
4 changes: 4 additions & 0 deletions convex-core/src/main/java/convex/core/data/Maps.java
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ public static <K extends ACell, V extends ACell> AHashMap<K, V> read(Blob b, int

public static int MAX_ENCODING_SIZE = Math.max(MapTree.MAX_ENCODING_LENGTH, MapLeaf.MAX_ENCODING_LENGTH);

public static <K extends ACell, V extends ACell> Hash getFirstHash(AHashMap<K, V> map) {
return map.getFirstHash();
}




Expand Down
21 changes: 11 additions & 10 deletions convex-core/src/main/java/convex/core/data/SetLeaf.java
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ public SetLeaf<T> exclude(ACell key) {
}

@Override
public SetLeaf<T> excludeRef(Ref<?> key) {
int i = seekKeyRef(key.getHash());
public SetLeaf<T> excludeHash(Hash hash) {
int i = seekKeyRef(hash);
if (i < 0) return this; // not found
return excludeAt(i);
}
Expand Down Expand Up @@ -544,12 +544,7 @@ protected boolean containsAll(SetLeaf<?> b) {
}

@Override
public AHashSet<T> includeRef(Ref<T> ref) {
return includeRef(ref,0);
}

@Override
protected AHashSet<T> includeRef(Ref<T> e, int shift) {
public AHashSet<T> includeRef(Ref<T> e) {
int n=elements.length;
Hash h=e.getHash();
int pos=0;
Expand All @@ -573,7 +568,7 @@ protected AHashSet<T> includeRef(Ref<T> e, int shift) {
} else {
// Maximum size exceeded, so need to expand to tree.
// Shift required since this might not be the tree root!
return SetTree.create(newEntries, shift);
return SetTree.create(newEntries);
}
}

Expand All @@ -593,7 +588,7 @@ public T get(long index) {
@Override
public AHashSet<T> toCanonical() {
if (count<=MAX_ELEMENTS) return this;
return SetTree.create(elements, 0);
return SetTree.create(elements);
}

@SuppressWarnings("unchecked")
Expand All @@ -609,6 +604,12 @@ public ASet<T> slice(long start, long end) {
return new SetLeaf<T>(nrefs);
}

@Override
protected Hash getFirstHash() {
if (count==0) return null;
return elements[0].getHash();
}




Expand Down
126 changes: 103 additions & 23 deletions convex-core/src/main/java/convex/core/data/SetTree.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package convex.core.data;

import java.util.Arrays;
import java.util.Comparator;

import convex.core.exceptions.BadFormatException;
import convex.core.exceptions.InvalidDataException;
import convex.core.exceptions.Panic;
Expand Down Expand Up @@ -72,13 +75,14 @@ private static <T extends ACell> long computeCount(Ref<AHashSet<T>>[] children)
* @return New SetTree node
*/
@SuppressWarnings("unchecked")
public static <V extends ACell> SetTree<V> create(Ref<V>[] elementRefs, int shift) {
public static <V extends ACell> SetTree<V> create(Ref<V>[] elementRefs) {
int n = elementRefs.length;
if (n <= SetLeaf.MAX_ELEMENTS) {
throw new IllegalArgumentException(
"Insufficient distinct entries for TreeMap construction: " + elementRefs.length);
}

int shift=computeShift(elementRefs);
// construct full child array
Ref<AHashSet<V>>[] children = new Ref[16];
for (int i = 0; i < n; i++) {
Expand All @@ -88,12 +92,72 @@ public static <V extends ACell> SetTree<V> create(Ref<V>[] elementRefs, int shif
if (ref == null) {
children[ix] = SetLeaf.create(e).getRef();
} else {
AHashSet<V> newChild=ref.getValue().includeRef(e,shift+1);
AHashSet<V> newChild=ref.getValue().includeRef(e);
children[ix] = newChild.getRef();
}
}
return (SetTree<V>) createFull(children, shift);
}

/**
* Create MapTree with specific children at specified shift level
* Children must branch at the given shift level
*/
@SafeVarargs
static <V extends ACell> SetTree<V> create(int shift, AHashSet<V> ... children) {
int n=children.length;
Arrays.sort(children,shiftComparator(shift));
@SuppressWarnings("unchecked")
Ref<AHashSet<V>>[] rs=new Ref[n];
long count=0;
short mask=0;
for (int i=0; i<n; i++) {
AHashSet<V> child=children[i];
rs[i]=Ref.get(child);
count+=child.count;
int digit=child.getFirstHash().getHexDigit(shift);
mask|=(1<<digit);
}
if (Integer.bitCount(mask&0xFFFF)!=n) {
throw new IllegalArgumentException("Children do not differ at specified digit");
}
return new SetTree<>(rs,shift,mask,count);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
static Comparator<AHashSet>[] COMPARATORS=new Comparator[64];

@SuppressWarnings("rawtypes")
private static Comparator<AHashSet> shiftComparator(int shift) {
if (COMPARATORS[shift]==null) {
COMPARATORS[shift]=new Comparator<AHashSet>() {
@Override
public int compare(AHashSet o1, AHashSet o2) {
int d1= o1.getFirstHash().getHexDigit(shift);
int d2= o2.getFirstHash().getHexDigit(shift);
return d1-d2;
}
};
};
return COMPARATORS[shift];
}


/**
* Computes the common shift for a vector of entries.
* This is the shift at which the first split occurs, i.e length of common prefix
* @param es Entries
* @return
*/
protected static <V extends ACell> int computeShift(Ref<V>[] es) {
int shift=63; // max possible
Hash h=es[0].getHash();
int n=es.length;
for (int i=1; i<n; i++) {
shift=Math.min(shift, h.commonHexPrefixLength(es[i].getHash(),shift));
}
return shift;
}

/**
* Creates a SetTree given child Refs for each digit
Expand Down Expand Up @@ -200,21 +264,20 @@ protected Ref<T> getRefByHash(Hash hash) {
return children[i].getValue().getRefByHash(hash);
}

@SuppressWarnings("unchecked")
@Override
public AHashSet<T> exclude(ACell key) {
return excludeRef((Ref<T>) Ref.get(key));
return excludeHash(Cells.getHash(key));
}

@Override
public AHashSet<T> excludeRef(Ref<?> keyRef) {
int digit = keyRef.getHash().getHexDigit(shift);
public AHashSet<T> excludeHash(Hash hash) {
int digit =hash.getHexDigit(shift);
int i = Bits.indexForDigit(digit, mask);
if (i < 0) return this; // not present

// dissoc entry from child
AHashSet<T> child = children[i].getValue();
AHashSet<T> newChild = child.excludeRef(keyRef);
AHashSet<T> newChild = child.excludeHash(hash);
if (child == newChild) return this; // no removal, no change

AHashSet<T> result=(newChild.isEmpty())?dissocChild(i):replaceChild(i, newChild.getRef());
Expand Down Expand Up @@ -242,6 +305,10 @@ public AHashSet<T> toCanonical() {
@SuppressWarnings("unchecked")
private AHashSet<T> dissocChild(int i) {
int bsize = children.length;
if (bsize==2) {
// can just return the remaining child
return children[1-i].getValue();
}
AHashSet<T> child = children[i].getValue();
Ref<AHashSet<T>>[] newBlocks = (Ref<AHashSet<T>>[]) new Ref<?>[bsize - 1];
System.arraycopy(children, 0, newBlocks, 0, i);
Expand Down Expand Up @@ -297,14 +364,21 @@ public static int digitForIndex(int index, short mask) {
@Override
public SetTree<T> include(ACell value) {
Ref<T> keyRef = (Ref<T>) Ref.get(value);
return includeRef(keyRef, shift);
return includeRef(keyRef);
}


@Override
protected SetTree<T> includeRef(Ref<T> e, int shift) {
if (this.shift != shift) {
throw new Error("Invalid shift!");
public SetTree<T> includeRef(Ref<T> e) {
Hash kh= e.getHash();
int cshift= kh.commonHexPrefixLength(getFirstHash(), Hash.HEX_LENGTH);

if (cshift<shift) {
// branch at an earlier position
SetLeaf<T> newLeaf=SetLeaf.create(e);
return create(cshift,newLeaf,this);
}

Ref<T> keyRef = e;
int digit = keyRef.getHash().getHexDigit(shift);
int i = Bits.indexForDigit(digit, mask);
Expand All @@ -315,16 +389,11 @@ protected SetTree<T> includeRef(Ref<T> e, int shift) {
} else {
// location needs update
AHashSet<T> child = children[i].getValue();
AHashSet<T> newChild = child.includeRef(e, shift + 1);
AHashSet<T> newChild = child.includeRef(e);
if (child == newChild) return this;
return (SetTree<T>) replaceChild(i, newChild.getRef());
}
}

@Override
public AHashSet<T> includeRef(Ref<T> ref) {
return includeRef(ref,shift);
}

@Override
public int encode(byte[] bs, int pos) {
Expand Down Expand Up @@ -514,7 +583,7 @@ private AHashSet<T> mergeWith(SetLeaf<T> b, int setOp, int shift) {
if (newE != null) {
// include only new keys where function result is not null. Re-use existing
// entry if possible.
result = result.includeRef(newE, shift);
result = result.includeRef(newE);
}
}
return result;
Expand Down Expand Up @@ -575,7 +644,7 @@ protected void validateWithPrefix(Hash base, int digit, int position) throws Inv

Hash firstHash;
try {
firstHash=getElementRef(0).getHash();
firstHash=getFirstHash();
} catch (ClassCastException e) {
throw new InvalidDataException("Bad child type:" +e.getMessage(), this);
}
Expand All @@ -600,9 +669,8 @@ protected void validateWithPrefix(Hash base, int digit, int position) throws Inv

if (child instanceof SetTree) {
SetTree<T> childTree=(SetTree<T>) child;
int expectedShift=shift+1;
if (childTree.shift!=expectedShift) {
throw new InvalidDataException("Wrong child shift ["+childTree.shift+"], expected ["+expectedShift+"]",this);
if (childTree.shift<=shift) {
throw new InvalidDataException("Wrong child shift ["+childTree.shift+"], expected greater than ["+shift+"]",this);
}
}

Expand All @@ -624,7 +692,10 @@ protected void validateWithPrefix(Hash base, int digit, int position) throws Inv

private boolean isValidStructure() {
if (count <= SetLeaf.MAX_ELEMENTS) return false;
if (children.length != Integer.bitCount(mask & 0xFFFF)) return false;
int n=children.length;
if (n<2) return false;

if (n != Integer.bitCount(mask & 0xFFFF)) return false;
for (int i = 0; i < children.length; i++) {
Ref<AHashSet<T>> child = children[i];
if (child == null) return false;
Expand Down Expand Up @@ -731,4 +802,13 @@ public ASet<T> slice(long start, long end) {
return result;
}

// Cache of first hash, we don't want to descend tree repeatedly to find this
private Hash firstHash;

@Override
protected Hash getFirstHash() {
if (firstHash==null) firstHash=children[0].getValue().getFirstHash();
return firstHash;
}

}
2 changes: 1 addition & 1 deletion convex-core/src/main/java/convex/core/data/Sets.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public static <T extends ACell> ASet<T> read(Blob b, int pos) throws BadFormatEx
public static <T extends ACell> AHashSet<T> createWithShift(int shift, ArrayList<Ref<T>> values) {
AHashSet<T> result=Sets.empty();
for (Ref<T> v: values) {
result=result.includeRef(v, shift);
result=result.includeRef(v);
}
return result;
}
Expand Down
Loading

0 comments on commit 48d2d05

Please sign in to comment.