Skip to content

Commit

Permalink
Added support for the 'nbins_cats' hyperparameter. HT: Eric Blood. Fixes
Browse files Browse the repository at this point in the history
 #4, fixes #16
  • Loading branch information
vruusmann committed Aug 22, 2024
1 parent 698ffb2 commit 4f50e83
Show file tree
Hide file tree
Showing 8 changed files with 1,218 additions and 1,196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,23 +201,9 @@ public Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Inte
} else

{
if(equal != 0){
if(feature instanceof CategoricalFeature){
CategoricalFeature categoricalFeature = (CategoricalFeature)feature;

GenmodelBitSet bitSet = new GenmodelBitSet(0);

if(equal == 8){
bitSet.fill2(compressedTree, byteBuffer);
} else

if(equal == 12){
bitSet.fill3(compressedTree, byteBuffer);
} else

{
throw new IllegalArgumentException("Node type " + equal + " is not supported");
}

String name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();

Expand All @@ -226,19 +212,55 @@ public Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Inte
List<Object> leftValues = new ArrayList<>();
List<Object> rightValues = new ArrayList<>();

for(int i = 0; i < values.size(); i++){
Object value = values.get(i);
if(equal != 0){
GenmodelBitSet bitSet = new GenmodelBitSet(0);

if(!valueFilter.test(value)){
continue;
} // End if
if(equal == 8){
bitSet.fill2(compressedTree, byteBuffer);
} else

if(!bitSet.contains(i)){
leftValues.add(value);
if(equal == 12){
bitSet.fill3(compressedTree, byteBuffer);
} else

{
rightValues.add(value);
throw new IllegalArgumentException("Node type " + equal + " is not supported");
}

for(int i = 0; i < values.size(); i++){
Object value = values.get(i);

if(!valueFilter.test(value)){
continue;
} // End if

if(!bitSet.contains(i)){
leftValues.add(value);
} else

{
rightValues.add(value);
}
}
} else

{
Double splitVal = (double)byteBuffer.get4f();

for(int i = 0; i < values.size(); i++){
Object value = values.get(i);

if(!valueFilter.test(value)){
continue;
} // End if

if(i < splitVal){
leftValues.add(value);
} else

{
rightValues.add(value);
}
}
}

Expand Down
Loading

0 comments on commit 4f50e83

Please sign in to comment.