Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce SurvivalTree.predict's memory use #369

Merged
merged 24 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4c9505f
reduce SurvivalTree.predict's memory use
cpoerschke Jun 8, 2023
a227853
add low_memory=False option to SurvivalTree constructor
cpoerschke Jun 8, 2023
2fa63cb
add test_predict_low_memory in test_tree.py
cpoerschke Jun 8, 2023
a1a1213
development increment: j_delta=2 (tests continue to pass)
cpoerschke Jun 9, 2023
8c62e09
development increment: j_delta=3 (tests fail for some reason)
cpoerschke Jun 9, 2023
67969a7
low memory mode changes
cpoerschke Jun 12, 2023
f1639ff
Merge remote-tracking branch 'origin/master' into pr-6
cpoerschke Jun 12, 2023
b85f924
annotate TODO w.r.t. summing only for event times
cpoerschke Jun 12, 2023
b5d1ac0
address CI feedback
cpoerschke Jun 12, 2023
179d3ef
action review feedback (part 1 of 2)
cpoerschke Jun 13, 2023
9065fae
action review feedback (part 2 of 2)
cpoerschke Jun 13, 2023
3af23c5
int[::1] --> const bint[::1] for LogrankCriterion's is_event_time
cpoerschke Jun 13, 2023
fe5aed5
lint: line-too-long
cpoerschke Jun 13, 2023
4c21964
address CI feedback (part 1 of 2)
cpoerschke Jun 13, 2023
6ef7353
address CI feedback (part 2 of 2)
cpoerschke Jun 13, 2023
bdf7162
action CI feedback
cpoerschke Jun 13, 2023
706352b
Assign self.is_event_time to local variable
sebp Jun 17, 2023
2b81bb3
Add low_memory option to forest classes
sebp Jun 17, 2023
9acd236
Add test case for low-memory mode for forests
sebp Jun 17, 2023
a41f022
Remove test_predict_low_memory
sebp Jun 17, 2023
4440a76
Use type cnp.npy_bool instead of bint
sebp Jun 17, 2023
35b0811
Remove type conversion
sebp Jun 17, 2023
6cbb8fa
Fix code format
sebp Jun 17, 2023
9e1b344
Fix API doc of RandomSurvivalForest
sebp Jun 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions sksurv/tree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,29 @@ cdef class LogrankCriterion(Criterion):
DOUBLE_t ratio
DOUBLE_t n_events
DOUBLE_t n_at_risk
DOUBLE_t dest_j0
DOUBLE_t dest_j1

self.riskset_total.at(0, &n_at_risk, &n_events)
ratio = n_events / n_at_risk
dest[0] = ratio # Nelson-Aalen estimator
dest[1] = 1.0 - ratio # Kaplan-Meier estimator

# low memory mode
if self.n_outputs == 1:
dest_j0 = dest[0]
dest_j1 = dest[1]
for i in range(1, self.n_unique_times):
self.riskset_total.at(i, &n_at_risk, &n_events)
if n_at_risk != 0:
ratio = n_events / n_at_risk
dest_j0 += ratio
dest_j1 *= 1.0 - ratio
if True: # TODO: only sum for event times
dest[0] += dest_j0
dest[1] += dest_j1
return

j = 2
for i in range(1, self.n_unique_times):
self.riskset_total.at(i, &n_at_risk, &n_events)
Expand Down
27 changes: 27 additions & 0 deletions sksurv/tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.

low_memory : boolean, default: False
If set, ``predict`` computations use heavy memory but ``predict_cumulative_hazard_function``
and ``predict_survival_function`` are not implemented.

Attributes
----------
unique_times_ : array of shape = (n_unique_times,)
Expand Down Expand Up @@ -162,6 +166,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
],
"random_state": ["random_state"],
"max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None],
"low_memory": ["boolean"],
}

def __init__(
Expand All @@ -175,6 +180,7 @@ def __init__(
max_features=None,
random_state=None,
max_leaf_nodes=None,
low_memory=False,
):
self.splitter = splitter
self.max_depth = max_depth
Expand All @@ -184,6 +190,7 @@ def __init__(
self.max_features = max_features
self.random_state = random_state
self.max_leaf_nodes = max_leaf_nodes
self.low_memory = low_memory

def fit(self, X, y, sample_weight=None, check_input=True):
"""Build a survival tree from the training set (X, y).
Expand Down Expand Up @@ -229,6 +236,11 @@ def fit(self, X, y, sample_weight=None, check_input=True):
# one "class" for CHF, one for survival function
self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2

if self.low_memory:
self.n_outputs_ = 1
# one "class" for the sum over the CHF, one for the sum over the survival function
self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2

# Build tree
self.criterion = "logrank"
criterion = LogrankCriterion(self.n_outputs_, n_samples, self.unique_times_)
Expand Down Expand Up @@ -364,6 +376,13 @@ def predict(self, X, check_input=True):
risk_scores : ndarray, shape = (n_samples,)
Predicted risk scores.
"""

if self.low_memory:
check_is_fitted(self, "tree_")
X = self._validate_X_predict(X, check_input, accept_sparse="csr")
pred = self.tree_.predict(X)
return pred[..., 0]

chf = self.predict_cumulative_hazard_function(X, check_input, return_array=True)
return chf[:, self.is_event_time_].sum(1)

Expand Down Expand Up @@ -424,6 +443,10 @@ def predict_cumulative_hazard_function(self, X, check_input=True, return_array=F
>>> plt.ylim(0, 1)
>>> plt.show()
"""

if self.low_memory:
raise NotImplementedError("predict_cumulative_hazard_function is not implemented in low memory mode.")

check_is_fitted(self, "tree_")
X = self._validate_X_predict(X, check_input, accept_sparse="csr")

Expand Down Expand Up @@ -491,6 +514,10 @@ def predict_survival_function(self, X, check_input=True, return_array=False):
>>> plt.ylim(0, 1)
>>> plt.show()
"""

if self.low_memory:
raise NotImplementedError("predict_survival_function is not implemented in low memory mode.")

check_is_fitted(self, "tree_")
X = self._validate_X_predict(X, check_input, accept_sparse="csr")

Expand Down
33 changes: 33 additions & 0 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,36 @@ def test_predict_sparse(make_whas500):
assert_array_equal(y_pred, y_pred_csr)
assert_array_equal(y_cum_h, y_cum_h_csr)
assert_array_equal(y_surv, y_surv_csr)


def test_predict_low_memory(make_whas500):
seed = 42
whas500 = make_whas500(to_numeric=True)
X, y = whas500.x, whas500.y
# Duplicates values in whas500 leads to assert errors because of
# tie resolution during tree fitting.
# Using a synthetic dataset resolves this issue.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, ties should not cause any problems.

X = np.random.RandomState(seed).binomial(n=5, p=0.1, size=X.shape)

X_train, X_test, y_train, _ = train_test_split(X, y, random_state=seed)

tree0 = SurvivalTree(min_samples_leaf=10, random_state=seed, low_memory=False)
tree0.fit(X_train, y_train)
y_pred_0 = tree0.predict(X_test)

tree1 = SurvivalTree(min_samples_leaf=10, random_state=seed, low_memory=True)
tree1.fit(X_train, y_train)
y_pred_1 = tree1.predict(X_test)

assert y_pred_0.shape[0] == X_test.shape[0]
assert y_pred_1.shape[0] == X_test.shape[0]

assert_array_almost_equal(y_pred_0, y_pred_1)

msg = r"predict_cumulative_hazard_function is not implemented in low memory mode."
with pytest.raises(NotImplementedError, match=msg):
tree1.predict_cumulative_hazard_function(X_test)

msg = r"predict_survival_function is not implemented in low memory mode."
with pytest.raises(NotImplementedError, match=msg):
tree1.predict_survival_function(X_test)