Skip to content

Commit

Permalink
Merge pull request #59 from overlabs/master
Browse files Browse the repository at this point in the history
minor fix
  • Loading branch information
visualDust authored Nov 20, 2023
2 parents 4e35d8d + b3665d4 commit 91efb3b
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 13 deletions.
31 changes: 31 additions & 0 deletions doc/docs/guide/logging/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,37 @@ output:
2023-03-18-13:56:03 > test.py > hello world
```

## Using a decorator

`@logger.mention` will mention the decorated function on each call.

```python
@logger.mention
def function_a():
print('message from a')

def function_b():
function_a()
print('message from b')

def function_c():
function_a()
function_b()
print('message from c')

function_c()
```
output:
```html
2023-11-18-06:26:04 > main/function_c > Currently running: function_a
message from a
2023-11-18-06:26:04 > main/function_b > Currently running: function_a
message from a
message from b
message from c
```


## Auto tracing

The default `logger` automatically trace back the caller of logger instance.
Expand Down
1 change: 1 addition & 0 deletions neetbox/daemon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __attach_daemon(daemon_config):
"ipython, try to set 'allowIpython' to True."
)
return False # ignore if debugging in ipython
pkg.is_installed("flask", try_install_if_not=True)
_online_status = connect_daemon(daemon_config) # try to connect daemon
logger.log("daemon connection status: " + str(_online_status))
if not _online_status: # if no daemon online
Expand Down
4 changes: 4 additions & 0 deletions neetbox/daemon/_local_http_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import httpx
import logging

httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.ERROR)

__no_proxy = {
"http://": None,
"https://": None,
Expand Down
26 changes: 13 additions & 13 deletions neetbox/logging/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,25 @@
# URL: https://gong.host
# Date: 20230318

from dataclasses import dataclass
import os
import warnings
from random import random
from typing import Optional


@dataclass
class LogStyle:
color: Optional[str] = None
prefix: str = ""
text_style: Optional[str] = None
datetime_format: str = "%Y-%m-%d-%H:%M:%S"
with_identifier: bool = True
trace_level = 3
with_datetime: bool = True
split_char_cmd = " > "
split_char_identity = "/"
split_char_txt = " | "

@classmethod
def get_supported_colors(cls):
return ["red", "green", "blue", "cyan", "yellow", "magenta"]
Expand All @@ -19,18 +31,6 @@ def get_supported_colors(cls):
def get_supported_text_style(cls):
return ["bold", "italic", "blink"]

def __init__(self) -> None:
self.color: Optional[str] = None
self.prefix: str = ""
self.text_style: Optional[str] = None
self.datetime_format: str = "%Y-%m-%d-%H:%M:%S"
self.with_identifier: bool = True
self.trace_level = 3
self.with_datetime: bool = True
self.split_char_cmd = " > "
self.split_char_identity = "/"
self.split_char_txt = " | "

def parse(self, pattern: str):
# todo
pass
Expand Down
38 changes: 38 additions & 0 deletions neetbox/torch/nn/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from typing import Union

def one_hot(
tensor: torch.Tensor,
num_classes: int,
ignored_label: Union[str, int] = "negative",
):
"""An advanced version of F.one_hot with ignore label support. Convert the mask to a one-hot encoded representation by @visualDust
Args:
tensor (torch.Tensor): indexed label image. Should types int
num_classes (int): number of classes
ignored_label (Union[str|int], optional): specify labels to ignore, or ignore by pattern. Defaults to "negative".
Returns:
torch.Tensor: one hot encoded tensor
"""
original_shape = tensor.shape
for _ in range(4 - len(tensor.shape)):
tensor = tensor.unsqueeze(0) # H W -> C H W -> B C H W, if applicable
# start to handle ignored label
# convert ignored label into positive index bigger than num_classes
if type(ignored_label) is int:
tensor[tensor == ignored_label] = num_classes
elif ignored_label == "negative":
tensor[tensor < 0] = num_classes

# check if mask image is valid
if torch.max(tensor) > num_classes:
raise RuntimeError("class values must be smaller than num_classes.")
B, _, H, W = tensor.shape
one_hot = torch.zeros(B, num_classes + 1, H, W)
one_hot.scatter_(1, tensor, 1) # mark 1 on channel(dim=1) with index of mask
one_hot = one_hot[:, :num_classes] # remove ignored label(s)
for _ in range(len(one_hot.shape) - len(original_shape)):
one_hot.squeeze_(0) # B C H W -> H W -> C H W, if applicable
return one_hot

1 comment on commit 91efb3b

@vercel
Copy link

@vercel vercel bot commented on 91efb3b Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.