source
Metric
def Metric(
name:str, device
):
Base metric class.
source
Mean
def Mean(
name:str, device
):
Mean metric, used for loss.
source
Accuracy
def Accuracy(
name:str, device
):
Accuracy metric.
Example usage:
a = Accuracy("mean", "cpu")
print(a, a.empty)
a.update_state(torch.Tensor([3,2,2,1]), torch.Tensor([1,2,2,1]))
print(a, a.empty)
a.update_state(torch.Tensor([1,2,2,3]), torch.Tensor([1,2,2,3]))
print(a, a.empty)
a.reset_state()
print(a, a.empty)
mean=nan True
mean=0.75 False
mean=0.875 False
mean=nan True
Back to top